All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.cschen1205.navigator.minefield.agents.TDFalconNavAgent Maven / Gradle / Ivy

package com.github.cschen1205.navigator.minefield.agents;

import com.github.cschen1205.falcon.*;
import com.github.cschen1205.navigator.minefield.env.MineField;

import java.util.Set;

/**
 * Created by cschen1205 on 9/29/2015 0029.
 */
public class TDFalconNavAgent extends FalconNavAgent {
    private TDFalcon ai;
    public boolean useImmediateRewardAsQ;

    public TDFalconNavAgent(FalconConfig config, int id, int numSonarInput, int numAVSonarInput, int numBearingInput, int numRangeInput) {
        super(id, numSonarInput, numAVSonarInput, numBearingInput, numRangeInput);
        ai = new TDFalcon(config);
    }

    public TDFalconNavAgent(FalconConfig config, int id, TDMethod method, int numSonarInput, int numAVSonarInput, int numBearingInput, int numRangeInput) {
        super(id, numSonarInput, numAVSonarInput, numBearingInput, numRangeInput);
        ai = new TDFalcon(config, method);
    }

    public void decayQEpsilon() {
        ai.decayQEpsilon();
    }

    @Override
    public void learn(final MineField maze) {
        Set feasibleActionAtNewState = getFeasibleActions(maze);
        ai.learnQ(state, actions, newState, feasibleActionAtNewState, reward, createQInject(maze));
    }

    protected QValueProvider createQInject(final MineField maze) {
        QValueProvider Qinject = new QValueProvider() {
            public QValue queryQValue(double[] state, int actionTaken, boolean isNextAction) {
                if (useImmediateRewardAsQ) {
                    return new QValue(reward);
                } else {
                    if (isNextAction) {
                        if (maze.willHitMine(getId(), actionTaken - 2)) {
                            return new QValue(0.0);
                        } else if (maze.willHitTarget(getId(), actionTaken - 2)) {
                            return new QValue(1.0);
                        }
                    } else {
                        if (maze.isHitMine(getId())) {
                            return new QValue(0.0);
                        } else if (maze.isHitTarget(getId())) {
                            return new QValue(1.0); //case reach target
                        }
                    }

                    return QValue.Invalid();
                }
            }
        };
        return Qinject;
    }


    @Override
    public int selectValidAction(final MineField maze) {
        Set feasibleActions = getFeasibleActions(maze);
        int selectedAction = ai.selectActionId(state, feasibleActions, createQInject(maze));
        return selectedAction;
    }

    @Override
    public int getNodeCount(){
        return ai.nodes.size();
    }

    public void setQGamma(double QGamma) {
        this.ai.QGamma = QGamma;
    }

    public void enableEligibilityTrace(){
        this.ai = new TDLambdaFalcon(ai.getConfig(), ai.method);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy