All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.rl4j.mdp.gym.GymEnv Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta
Show newest version
package org.deeplearning4j.rl4j.mdp.gym;


import org.deeplearning4j.gym.Client;
import org.deeplearning4j.gym.ClientFactory;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.HighLowDiscrete;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.deeplearning4j.rl4j.mdp.MDP;

/**
 * @author rubenfiszel ([email protected]) 7/12/16.
 *
 * Wrapper over the client of gym-java-client
 *
 */
public class GymEnv> implements MDP {

    final public static String GYM_MONITOR_DIR = "/tmp/gym-dqn";

    final private Client client;
    final private String envId;
    final private boolean render;
    final private boolean monitor;
    private ActionTransformer actionTransformer = null;
    private boolean done = false;

    public GymEnv(String envId, boolean render, boolean monitor) {
        this.client = ClientFactory.build(envId, render);
        this.envId = envId;
        this.render = render;
        this.monitor = monitor;
        if (monitor)
            client.monitorStart(GYM_MONITOR_DIR, true, false);
    }

    public GymEnv(String envId, boolean render, boolean monitor, int[] actions) {
        this(envId, render, monitor);
        actionTransformer = new ActionTransformer((HighLowDiscrete) getActionSpace(), actions);
    }


    public ObservationSpace getObservationSpace() {
        return client.getObservationSpace();
    }

    public AS getActionSpace() {
        if (actionTransformer == null)
            return client.getActionSpace();
        else
            return (AS) actionTransformer;
    }

    public StepReply step(A action) {
        StepReply stepRep = client.step(action);
        done = stepRep.isDone();
        return stepRep;
    }

    public boolean isDone() {
        return done;
    }

    public O reset() {
        done = false;
        return client.reset();
    }


    public void upload(String apiKey) {
        client.upload(GYM_MONITOR_DIR, apiKey);
    }

    public void close() {
        if (monitor)
            client.monitorClose();
    }

    public GymEnv newInstance() {
        return new GymEnv(envId, render, monitor);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy