All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.libs.jaicore.search.algorithms.mdp.mcts.spuct.SPUCBPolicy Maven / Gradle / Ivy

package ai.libs.jaicore.search.algorithms.mdp.mcts.spuct;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ai.libs.jaicore.search.algorithms.mdp.mcts.NodeLabel;
import ai.libs.jaicore.search.algorithms.mdp.mcts.uct.UCBPolicy;

public class SPUCBPolicy extends UCBPolicy implements ILoggingCustomizable {
	private String loggerName;
	private Logger logger = LoggerFactory.getLogger(SPUCBPolicy.class);
	private final double bigD;
	private Map, Double> squaredObservations = new HashMap<>();

	public SPUCBPolicy(final double gamma, final double bigD) {
		this(gamma, true, bigD);
	}

	public SPUCBPolicy(final double gamma, final boolean maximize, final double bigD) {
		super(gamma, maximize);
		this.bigD = bigD;
	}

	@Override
	public String getLoggerName() {
		return this.loggerName;
	}

	@Override
	public void setLoggerName(final String name) {
		this.loggerName = name;
		super.setLoggerName(name + "._updating");
		this.logger = LoggerFactory.getLogger(name);
	}

	@Override
	public void updatePath(final ILabeledPath path, final List scores) {
		super.updatePath(path, scores); // careful! the visits stats has already been updated here!
		List nodes = path.getNodes();
		int l = nodes.size();
		double accumulatedScores = 0;
		for (int i = l - 2; i >= 0; i--) {
			NodeLabel nl = this.getLabelOfNode(nodes.get(i));
			if (!Double.isNaN(accumulatedScores) && scores.get(i) != null) {
				accumulatedScores = scores.get(i) + this.getGamma() * accumulatedScores;
			}
			else if (!Double.isNaN(accumulatedScores)) {
				accumulatedScores = Double.NaN;
			}
			this.squaredObservations.put(nl, this.squaredObservations.computeIfAbsent(nl, label -> 0.0) + Math.pow(accumulatedScores, 2));
		}
	}

	@Override
	public double getScore(final N node, final A action) {

		/* get ucb term */
		double ucbMean = super.getEmpiricalMean(node, action);
		double ucbExploration = super.getEmpiricalMean(node, action);
		double ucb = ucbMean + ucbExploration;

		/* get single player term added */
		NodeLabel labelOfNode = this.getLabelOfNode(node);
		int visitsOfChild = labelOfNode.getNumPulls(action); // the t-parameter in the paper
		double squaredResults = this.squaredObservations.containsKey(labelOfNode) ? this.squaredObservations.get(labelOfNode) : 0.0;
		double expectedResults = visitsOfChild * Math.pow(ucbMean, 2);
		double spTerm = (this.isMaximize() ? 1 : -1) * Math.sqrt((squaredResults - expectedResults + this.bigD) / visitsOfChild);
		double score = ucb + spTerm;
		this.logger.debug("Computed score for action {}: {} = {} + {}", action, score, ucb, spTerm);
		return score;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy