All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chen0040.rl.models.QModel Maven / Gradle / Ivy

package com.github.chen0040.rl.models;


import com.github.chen0040.rl.utils.IndexValue;
import com.github.chen0040.rl.utils.Matrix;
import com.github.chen0040.rl.utils.Vec;
import lombok.Getter;
import lombok.Setter;

import java.util.*;


/**
 * @author xschen
 * 9/27/2015 0027.
 * Q is known as the quality of state-action combination, note that it is different from utility of a state
 */
@Getter
@Setter
public class QModel {
    /**
    *  Q value for (state_id, action_id) pair
    *  Q is known as the quality of state-action combination, note that it is different from utility of a state
    */
    private Matrix Q;
    /**
    *  $\alpha[s, a]$ value for learning rate: alpha(state_id, action_id)
    */
    private Matrix alphaMatrix;

    /**
     * discount factor
     */
    private double gamma = 0.7;

    private int stateCount;
    private int actionCount;

    public QModel(int stateCount, int actionCount, double initialQ){
        this.stateCount = stateCount;
        this.actionCount = actionCount;
        Q = new Matrix(stateCount,actionCount);
        alphaMatrix = new Matrix(stateCount, actionCount);
        Q.setAll(initialQ);
        alphaMatrix.setAll(0.1);
    }

    public QModel(int stateCount, int actionCount){
        this(stateCount, actionCount, 0.1);
    }

    public QModel(){

    }

    @Override
    public boolean equals(Object rhs){
        if(rhs != null && rhs instanceof QModel){
            QModel rhs2 = (QModel)rhs;


            if(gamma != rhs2.gamma) return false;


            if(stateCount != rhs2.stateCount || actionCount != rhs2.actionCount) return false;

            if((Q!=null && rhs2.Q==null) || (Q==null && rhs2.Q !=null)) return false;
            if((alphaMatrix !=null && rhs2.alphaMatrix ==null) || (alphaMatrix ==null && rhs2.alphaMatrix !=null)) return false;

            return !((Q != null && !Q.equals(rhs2.Q)) || (alphaMatrix != null && !alphaMatrix.equals(rhs2.alphaMatrix)));

        }
        return false;
    }

    public QModel makeCopy(){
        QModel clone = new QModel();
        clone.copy(this);
        return clone;
    }

    public void copy(QModel rhs){
        gamma = rhs.gamma;
        stateCount = rhs.stateCount;
        actionCount = rhs.actionCount;
        Q = rhs.Q==null ? null : rhs.Q.makeCopy();
        alphaMatrix = rhs.alphaMatrix == null ? null : rhs.alphaMatrix.makeCopy();
    }


    public double getQ(int stateId, int actionId){
        return Q.get(stateId, actionId);
    }


    public void setQ(int stateId, int actionId, double Qij){
        Q.set(stateId, actionId, Qij);
    }


    public double getAlpha(int stateId, int actionId){
        return alphaMatrix.get(stateId, actionId);
    }


    public void setAlpha(double defaultAlpha) {
        this.alphaMatrix.setAll(defaultAlpha);
    }


    public IndexValue actionWithMaxQAtState(int stateId, Set actionsAtState){
        Vec rowVector = Q.rowAt(stateId);
        return rowVector.indexWithMaxValue(actionsAtState);
    }

    private void reset(double initialQ){
        Q.setAll(initialQ);
    }


    public IndexValue actionWithSoftMaxQAtState(int stateId,Set actionsAtState, Random random) {
        Vec rowVector = Q.rowAt(stateId);
        double sum = 0;

        if(actionsAtState==null){
            actionsAtState = new HashSet<>();
            for(int i=0; i < actionCount; ++i){
                actionsAtState.add(i);
            }
        }

        List actions = new ArrayList<>();
        for(Integer actionId : actionsAtState){
            actions.add(actionId);
        }

        double[] acc = new double[actions.size()];
        for(int i=0; i < actions.size(); ++i){
            sum += rowVector.get(actions.get(i));
            acc[i] = sum;
        }


        double r = random.nextDouble() * sum;

        IndexValue result = new IndexValue();
        for(int i=0; i < actions.size(); ++i){
            if(acc[i] >= r){
                int actionId = actions.get(i);
                result.setIndex(actionId);
                result.setValue(rowVector.get(actionId));
                break;
            }
        }

        return result;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy