All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chen0040.rl.learning.qlearn.QAgent Maven / Gradle / Ivy

package com.github.chen0040.rl.learning.qlearn;

import com.github.chen0040.rl.utils.IndexValue;

import java.io.Serializable;
import java.util.Random;
import java.util.Set;


/**
 * Created by xschen on 9/27/2015 0027.
 */
public class QAgent implements Serializable{
    private QLearner learner;
    private int currentState;
    private int prevState;

    /** action taken at prevState */
    private int prevAction;

    public int getCurrentState(){
        return currentState;
    }

    public int getPrevState(){
        return prevState;
    }

    public int getPrevAction(){
        return prevAction;
    }

    public void start(int currentState){
        this.currentState = currentState;
        this.prevAction = -1;
        this.prevState = -1;
    }

    public IndexValue selectAction(){
        return learner.selectAction(currentState);
    }

    public IndexValue selectAction(Set actionsAtState){
        return learner.selectAction(currentState, actionsAtState);
    }

    public void update(int actionTaken, int newState, double immediateReward){
        update(actionTaken, newState, null, immediateReward);
    }

    public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward){

        learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward);

        prevState = currentState;
        prevAction = actionTaken;

        currentState = newState;
    }

    public void enableEligibilityTrace(double lambda){
        QLambdaLearner acll = new QLambdaLearner(learner);
        acll.setLambda(lambda);
        learner = acll;
    }

    public QLearner getLearner(){
        return learner;
    }

    public void setLearner(QLearner learner){
        this.learner = learner;
    }

    public QAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ){
        learner = new QLearner(stateCount, actionCount, alpha, gamma, initialQ);
    }

    public QAgent(QLearner learner){
        this.learner = learner;
    }

    public QAgent(int stateCount, int actionCount){
        learner = new QLearner(stateCount, actionCount);
    }

    public QAgent(){

    }

    public QAgent makeCopy(){
        QAgent clone = new QAgent();
        clone.copy(this);
        return clone;
    }

    public void copy(QAgent rhs){
        learner.copy(rhs.learner);
        prevAction = rhs.prevAction;
        prevState = rhs.prevState;
        currentState = rhs.currentState;
    }

    @Override
    public boolean equals(Object obj){
        if(obj != null && obj instanceof QAgent){
            QAgent rhs = (QAgent)obj;
            return prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState && learner.equals(rhs.learner);
        }
        return false;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy