All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.gym.Client Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta2
Show newest version
package org.deeplearning4j.gym;


import com.mashape.unirest.http.JsonNode;
import lombok.Value;
import org.deeplearning4j.rl4j.space.GymObservationSpace;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.json.JSONObject;

import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * @author rubenfiszel ([email protected]) on 7/6/16.
 *
 * A client represent an active connection to a specific instance of an environment on a rl4j-http-api server.
 * for API specification
 *
 * @param   Observation type
 * @param   Action type
 * @param  Action Space type
 * @see https://github.com/openai/gym-http-api#api-specification
 */
@Value
public class Client> {


    public static String V1_ROOT = "/v1";
    public static String ENVS_ROOT = V1_ROOT + "/envs/";

    public static String MONITOR_START = "/monitor/start/";
    public static String MONITOR_CLOSE = "/monitor/close/";
    public static String CLOSE = "/close/";
    public static String RESET = "/reset/";
    public static String SHUTDOWN = "/shutdown/";
    public static String UPLOAD = "/upload/";
    public static String STEP = "/step/";
    public static String OBSERVATION_SPACE = "/observation_space/";
    public static String ACTION_SPACE = "/action_space/";


    String url;
    String envId;
    String instanceId;
    GymObservationSpace observationSpace;
    AS actionSpace;
    boolean render;


    /**
     * @param url url of the server
     * @return set of all environments running on the server at the url
     */
    public static Set listAll(String url) {
        JSONObject reply = ClientUtils.get(url + ENVS_ROOT);
        return reply.getJSONObject("envs").keySet();
    }

    /**
     * Shutdown the server at the url
     *
     * @param url url of the server
     */
    public static void serverShutdown(String url) {
        ClientUtils.post(url + ENVS_ROOT + SHUTDOWN, new JSONObject());
    }

    /**
     * @return set of all environments running on the same server than this client
     */
    public Set listAll() {
        return listAll(url);
    }

    /**
     * Step the environment by one action
     *
     * @param action action to step the environment with
     * @return the StepReply containing the next observation, the reward, if it is a terminal state and optional information.
     */
    public StepReply step(A action) {
        JSONObject body = new JSONObject()
                .put("action", getActionSpace().encode(action))
                .put("render", render);

        JSONObject reply = ClientUtils.post(url + ENVS_ROOT + instanceId + STEP, body).getObject();

        O observation = observationSpace.getValue(reply, "observation");
        double reward = reply.getDouble("reward");
        boolean done = reply.getBoolean("done");
        JSONObject info = reply.getJSONObject("info");

        return new StepReply(observation, reward, done, info);
    }

    /**
     * Reset the state of the environment and return an initial observation.
     *
     * @return initial observation
     */
    public O reset() {
        JsonNode resetRep = ClientUtils.post(url + ENVS_ROOT + instanceId + RESET, new JSONObject());
        return observationSpace.getValue(resetRep.getObject(), "observation");
    }

    /*
    Present in the doc but not working currently server-side
    public void monitorStart(String directory) {

        JSONObject json = new JSONObject()
                .put("directory", directory);

        monitorStartPost(json);
    }
    */

    /**
     * Start monitoring.
     *
     * @param directory path to directory in which store the monitoring file
     * @param force     clear out existing training data from this directory (by deleting every file prefixed with "openaigym.")
     * @param resume    retain the training data already in this directory, which will be merged with our new data
     */
    public void monitorStart(String directory, boolean force, boolean resume) {
        JSONObject json = new JSONObject()
                .put("directory", directory)
                .put("force", force)
                .put("resume", resume);

        monitorStartPost(json);
    }

    private void monitorStartPost(JSONObject json) {
        ClientUtils.post(url + ENVS_ROOT + instanceId + MONITOR_START, json);
    }

    /**
     * Flush all monitor data to disk
     */
    public void monitorClose() {
        ClientUtils.post(url + ENVS_ROOT + instanceId + MONITOR_CLOSE, new JSONObject());
    }

    /**
     * Upload monitoring data to OpenAI servers.
     *
     * @param trainingDir directory that contains the monitoring data
     * @param apiKey      personal OpenAI API key
     * @param algorithmId an arbitrary string indicating the paricular version of the algorithm (including choices of parameters) you are running.
     **/
    public void upload(String trainingDir, String apiKey, String algorithmId) {
        JSONObject json = new JSONObject()
                .put("training_dir", trainingDir)
                .put("api_key", apiKey)
                .put("algorithm_id", algorithmId);

        uploadPost(json);
    }

    /**
     * Upload monitoring data to OpenAI servers.
     *
     * @param trainingDir directory that contains the monitoring data
     * @param apiKey      personal OpenAI API key
     */
    public void upload(String trainingDir, String apiKey) {
        JSONObject json = new JSONObject()
                .put("training_dir", trainingDir)
                .put("api_key", apiKey);

        uploadPost(json);
    }

    private void uploadPost(JSONObject json) {
        try {
            ClientUtils.post(url + V1_ROOT + instanceId + UPLOAD, json);
        } catch (RuntimeException e) {
            Logger logger = Logger.getLogger("Client Upload");
            logger.log(Level.SEVERE, "Impossible to upload: Wrong API key?");
        }
    }

    /**
     * Shutdown the server at the same url than this client
     */
    public void serverShutdown() {
        serverShutdown(url);
    }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy