All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscreteDense Maven / Gradle / Ivy

There is a newer version: 1.0.0-M1.1
Show newest version
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;

import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.*;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager;

/**
 * @author rubenfiszel ([email protected]) on 8/8/16.
 *
 * Training for A3C in the Discrete Domain
 *
 * We use specifically the Separate version because
 * the model is too small to have enough benefit by sharing layers
 *
 */
public class A3CDiscreteDense extends A3CDiscrete {

    public A3CDiscreteDense(MDP mdp, IActorCritic IActorCritic, A3CConfiguration conf,
                    DataManager dataManager) {
        super(mdp, IActorCritic, conf, dataManager);
    }

    public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory,
                    A3CConfiguration conf, DataManager dataManager) {
        this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
                        dataManager);
    }

    public A3CDiscreteDense(MDP mdp,
                    ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
                    DataManager dataManager) {
        this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager);
    }

    public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory,
                    A3CConfiguration conf, DataManager dataManager) {
        this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
                        dataManager);
    }

    public A3CDiscreteDense(MDP mdp,
                    ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf,
                    DataManager dataManager) {
        this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy