All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chen0040.rl.actionselection.EpsilonGreedyActionSelectionStrategy Maven / Gradle / Ivy

package com.github.chen0040.rl.actionselection;

import com.github.chen0040.rl.utils.IndexValue;
import com.github.chen0040.rl.models.QModel;

import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Set;


/**
 * Created by xschen on 9/27/2015 0027.
 */
public class EpsilonGreedyActionSelectionStrategy extends AbstractActionSelectionStrategy {
    public static final String EPSILON = "epsilon";
    private Random random = new Random();

    @Override
    public Object clone(){
        EpsilonGreedyActionSelectionStrategy clone = new EpsilonGreedyActionSelectionStrategy();
        clone.copy(this);
        return clone;
    }

    public void copy(EpsilonGreedyActionSelectionStrategy rhs){
        random = rhs.random;
        for(Map.Entry entry : rhs.attributes.entrySet()){
            attributes.put(entry.getKey(), entry.getValue());
        }
    }

    @Override
    public boolean equals(Object obj){
        if(obj != null && obj instanceof EpsilonGreedyActionSelectionStrategy){
            EpsilonGreedyActionSelectionStrategy rhs = (EpsilonGreedyActionSelectionStrategy)obj;
            if(epsilon() != rhs.epsilon()) return false;
           // if(!random.equals(rhs.random)) return false;
            return true;
        }
        return false;
    }

    private double epsilon(){
        return Double.parseDouble(attributes.get(EPSILON));
    }

    public EpsilonGreedyActionSelectionStrategy(){
        epsilon(0.1);
    }

    public EpsilonGreedyActionSelectionStrategy(HashMap attributes){
        super(attributes);
    }

    private void epsilon(double value){
        attributes.put(EPSILON, "" + value);
    }

    public EpsilonGreedyActionSelectionStrategy(Random random){
        this.random = random;
        epsilon(0.1);
    }

    @Override
    public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
        if(random.nextDouble() < 1- epsilon()){
            return model.actionWithMaxQAtState(stateId, actionsAtState);
        }else{
            int actionId = random.nextInt(model.getActionCount());
            double Q = model.getQ(stateId, actionId);
            return new IndexValue(actionId, Q);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy