All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.libs.jaicore.search.algorithms.mdp.mcts.uct.UCBPolicy Maven / Gradle / Ivy

package ai.libs.jaicore.search.algorithms.mdp.mcts.uct;

import java.util.Map;
import java.util.Map.Entry;

import org.api4.java.common.control.ILoggingCustomizable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ai.libs.jaicore.search.algorithms.mdp.mcts.NodeLabel;

public class UCBPolicy extends AUpdatingPolicy implements ILoggingCustomizable {

	private String loggerName;
	private Logger logger = LoggerFactory.getLogger(UCBPolicy.class);
	private double explorationConstant;

	public UCBPolicy(final double gamma, final double explorationConstant, final boolean maximize) {
		super(gamma, maximize);
		this.explorationConstant = explorationConstant;
	}

	public UCBPolicy(final double gamma, final boolean maximize) {
		this(gamma, Math.sqrt(2), maximize);
	}

	@Override
	public String getLoggerName() {
		return this.loggerName;
	}

	@Override
	public void setLoggerName(final String name) {
		this.loggerName = name;
		super.setLoggerName(name + "._updating");
		this.logger = LoggerFactory.getLogger(name);
	}

	public double getEmpiricalMean(final T node, final A action) {
		NodeLabel nodeLabel = this.getLabelOfNode(node);
		if (nodeLabel == null || nodeLabel.getNumPulls(action) == 0) {
			return (this.isMaximize() ? -1 : 1) * Double.MAX_VALUE;
		}
		int timesThisActionHasBeenChosen = nodeLabel.getNumPulls(action);
		return nodeLabel.getAccumulatedRewardsOfAction(action) / timesThisActionHasBeenChosen;
	}

	public double getExplorationTerm(final T node, final A action) {
		NodeLabel nodeLabel = this.getLabelOfNode(node);
		if (nodeLabel == null || nodeLabel.getNumPulls(action) == 0) {
			return (this.isMaximize() ? -1 : 1) * Double.MAX_VALUE;
		}
		int timesThisActionHasBeenChosen = nodeLabel.getNumPulls(action);
		return (this.isMaximize() ? 1 : -1) * this.explorationConstant * Math.sqrt(Math.log(nodeLabel.getVisits()) / timesThisActionHasBeenChosen);
	}

	@Override
	public double getScore(final T node, final A action) {
		NodeLabel nodeLabel = this.getLabelOfNode(node);
		if (nodeLabel == null || nodeLabel.isVirgin(action)) {
			return (this.isMaximize() ? -1 : 1) * Double.MAX_VALUE;
		}
		int timesThisActionHasBeenChosen = nodeLabel.getNumPulls(action);
		double averageScoreForThisAction = nodeLabel.getAccumulatedRewardsOfAction(action) / timesThisActionHasBeenChosen;
		double explorationTerm = (this.isMaximize() ? 1 : -1) * this.explorationConstant * Math.sqrt(Math.log(nodeLabel.getVisits()) / timesThisActionHasBeenChosen);
		double score = averageScoreForThisAction + explorationTerm;
		this.logger.trace("Computed UCB score {} = {} + {} * {} * sqrt(log({})/{}). That is, exploration term is {}", score, averageScoreForThisAction, this.isMaximize() ? 1 : -1, this.explorationConstant, nodeLabel.getVisits(),
				timesThisActionHasBeenChosen, explorationTerm);
		return score;
	}

	public double getExplorationConstant() {
		return this.explorationConstant;
	}

	public void setExplorationConstant(final double explorationConstant) {
		this.explorationConstant = explorationConstant;
	}

	@Override
	public A getActionBasedOnScores(final Map scores) {
		A choice = null;
		if (scores.isEmpty()) {
			throw new IllegalArgumentException("An empty set of scored actions has been given to UCB to decide!");
		}
		this.logger.debug("Getting action for scores {}", scores);
		double best = (this.isMaximize() ? (-1) : 1) * Double.MAX_VALUE;
		for (Entry entry : scores.entrySet()) {
			A action = entry.getKey();
			double score = entry.getValue();
			if (choice == null || (this.isMaximize() && (score > best) || !this.isMaximize() && (score < best))) {
				this.logger.trace("Updating best choice {} with {} since it is better than the current solution with performance {}", choice, action, best);
				best = score;
				choice = action;
			} else {
				this.logger.trace("Skipping current solution {} since its score {} is not better than the currently best {}.", action, score, best);
			}
		}
		if (choice == null) {
			throw new IllegalStateException("UCB would return NULL action, which must not be the case!");
		}
		return choice;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy