All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.modality.rl.env.RlEnv Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.modality.rl.env;

import ai.djl.modality.rl.ActionSpace;
import ai.djl.modality.rl.agent.RlAgent;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/** An environment to use for reinforcement learning. */
public interface RlEnv extends AutoCloseable {

    /** Resets the environment to it's default state. */
    void reset();

    /**
     * Returns the observation detailing the current state of the environment.
     *
     * @return the observation detailing the current state of the environment
     */
    NDList getObservation();

    /**
     * Returns the current actions that can be taken in the environment.
     *
     * @return the current actions that can be taken in the environment
     */
    ActionSpace getActionSpace();

    /**
     * Takes a step by performing an action in this environment.
     *
     * @param action the action to perform
     * @param training true if the step is during training
     * @return the {@link Step} with the result of the action
     */
    Step step(NDList action, boolean training);

    /**
     * Runs the environment from reset until done.
     *
     * @param agent the agent to choose the actions with
     * @param training true to run while training. When training, the steps will be recorded
     * @return the total reward
     */
    default float runEnvironment(RlAgent agent, boolean training) {
        float totalReward = 0;
        reset();

        while (true) {
            NDList action = agent.chooseAction(this, training);
            Step step = step(action, training);
            totalReward += step.getReward().getFloat();
            if (step.isDone()) {
                return totalReward;
            }
        }
    }

    /**
     * Returns a batch of steps from the environment {@link ai.djl.modality.rl.ReplayBuffer}.
     *
     * @return a batch of steps from the environment {@link ai.djl.modality.rl.ReplayBuffer}
     */
    Step[] getBatch();

    /** {@inheritDoc} */
    @Override
    void close();

    /** A record of taking a step in the environment. */
    interface Step extends AutoCloseable {

        /**
         * Returns the observation detailing the state before the action.
         *
         * @return the observation detailing the state before the action
         */
        NDList getPreObservation();

        /**
         * Returns the action taken.
         *
         * @return the action taken
         */
        NDList getAction();

        /**
         * Returns the observation detailing the state after the action.
         *
         * @return the observation detailing the state after the action
         */
        NDList getPostObservation();

        /**
         * Returns the available actions after the step.
         *
         * @return the available actions after the step
         */
        ActionSpace getPostActionSpace();

        /**
         * Returns the reward given for the action.
         *
         * @return the reward given for the action
         */
        NDArray getReward();

        /**
         * Returns whether the environment is finished or can accept further actions.
         *
         * @return true if the environment is finished and can no longer accept further actions.
         */
        boolean isDone();

        /** {@inheritDoc} */
        @Override
        void close();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy