
org.deeplearning4j.gym.Client Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of gym-java-client Show documentation
Show all versions of gym-java-client Show documentation
A Java client for Open AI's Reinforcement Learning Gym
package org.deeplearning4j.gym;
import com.mashape.unirest.http.JsonNode;
import lombok.Value;
import org.deeplearning4j.rl4j.space.GymObservationSpace;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.json.JSONObject;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* @author rubenfiszel ([email protected]) on 7/6/16.
*
* A client represent an active connection to a specific instance of an environment on a rl4j-http-api server.
* for API specification
*
* @param Observation type
* @param Action type
* @param Action Space type
* @see https://github.com/openai/gym-http-api#api-specification
*/
@Value
public class Client> {
public static String V1_ROOT = "/v1";
public static String ENVS_ROOT = V1_ROOT + "/envs/";
public static String MONITOR_START = "/monitor/start/";
public static String MONITOR_CLOSE = "/monitor/close/";
public static String CLOSE = "/close/";
public static String RESET = "/reset/";
public static String SHUTDOWN = "/shutdown/";
public static String UPLOAD = "/upload/";
public static String STEP = "/step/";
public static String OBSERVATION_SPACE = "/observation_space/";
public static String ACTION_SPACE = "/action_space/";
String url;
String envId;
String instanceId;
GymObservationSpace observationSpace;
AS actionSpace;
boolean render;
/**
* @param url url of the server
* @return set of all environments running on the server at the url
*/
public static Set listAll(String url) {
JSONObject reply = ClientUtils.get(url + ENVS_ROOT);
return reply.getJSONObject("envs").keySet();
}
/**
* Shutdown the server at the url
*
* @param url url of the server
*/
public static void serverShutdown(String url) {
ClientUtils.post(url + ENVS_ROOT + SHUTDOWN, new JSONObject());
}
/**
* @return set of all environments running on the same server than this client
*/
public Set listAll() {
return listAll(url);
}
/**
* Step the environment by one action
*
* @param action action to step the environment with
* @return the StepReply containing the next observation, the reward, if it is a terminal state and optional information.
*/
public StepReply step(A action) {
JSONObject body = new JSONObject()
.put("action", getActionSpace().encode(action))
.put("render", render);
JSONObject reply = ClientUtils.post(url + ENVS_ROOT + instanceId + STEP, body).getObject();
O observation = observationSpace.getValue(reply, "observation");
double reward = reply.getDouble("reward");
boolean done = reply.getBoolean("done");
JSONObject info = reply.getJSONObject("info");
return new StepReply(observation, reward, done, info);
}
/**
* Reset the state of the environment and return an initial observation.
*
* @return initial observation
*/
public O reset() {
JsonNode resetRep = ClientUtils.post(url + ENVS_ROOT + instanceId + RESET, new JSONObject());
return observationSpace.getValue(resetRep.getObject(), "observation");
}
/*
Present in the doc but not working currently server-side
public void monitorStart(String directory) {
JSONObject json = new JSONObject()
.put("directory", directory);
monitorStartPost(json);
}
*/
/**
* Start monitoring.
*
* @param directory path to directory in which store the monitoring file
* @param force clear out existing training data from this directory (by deleting every file prefixed with "openaigym.")
* @param resume retain the training data already in this directory, which will be merged with our new data
*/
public void monitorStart(String directory, boolean force, boolean resume) {
JSONObject json = new JSONObject()
.put("directory", directory)
.put("force", force)
.put("resume", resume);
monitorStartPost(json);
}
private void monitorStartPost(JSONObject json) {
ClientUtils.post(url + ENVS_ROOT + instanceId + MONITOR_START, json);
}
/**
* Flush all monitor data to disk
*/
public void monitorClose() {
ClientUtils.post(url + ENVS_ROOT + instanceId + MONITOR_CLOSE, new JSONObject());
}
/**
* Upload monitoring data to OpenAI servers.
*
* @param trainingDir directory that contains the monitoring data
* @param apiKey personal OpenAI API key
* @param algorithmId an arbitrary string indicating the paricular version of the algorithm (including choices of parameters) you are running.
**/
public void upload(String trainingDir, String apiKey, String algorithmId) {
JSONObject json = new JSONObject()
.put("training_dir", trainingDir)
.put("api_key", apiKey)
.put("algorithm_id", algorithmId);
uploadPost(json);
}
/**
* Upload monitoring data to OpenAI servers.
*
* @param trainingDir directory that contains the monitoring data
* @param apiKey personal OpenAI API key
*/
public void upload(String trainingDir, String apiKey) {
JSONObject json = new JSONObject()
.put("training_dir", trainingDir)
.put("api_key", apiKey);
uploadPost(json);
}
private void uploadPost(JSONObject json) {
try {
ClientUtils.post(url + V1_ROOT + instanceId + UPLOAD, json);
} catch (RuntimeException e) {
Logger logger = Logger.getLogger("Client Upload");
logger.log(Level.SEVERE, "Impossible to upload: Wrong API key?");
}
}
/**
* Shutdown the server at the same url than this client
*/
public void serverShutdown() {
serverShutdown(url);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy