All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.libs.jaicore.search.algorithms.standard.mcts.MCTSPathSearch Maven / Gradle / Ivy

package ai.libs.jaicore.search.algorithms.standard.mcts;

import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.TimeUnit;

import org.api4.java.ai.graphsearch.problem.IPathSearchWithPathEvaluationsInput;
import org.api4.java.algorithm.Timeout;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.events.result.ISolutionCandidateFoundEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.api4.java.common.event.IEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.eventbus.Subscribe;

import ai.libs.jaicore.basic.algorithm.AlgorithmFinishedEvent;
import ai.libs.jaicore.basic.algorithm.AlgorithmInitializedEvent;
import ai.libs.jaicore.basic.algorithm.EAlgorithmState;
import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.search.algorithms.mdp.mcts.GraphBasedMDP;
import ai.libs.jaicore.search.algorithms.mdp.mcts.MCTS;
import ai.libs.jaicore.search.algorithms.mdp.mcts.MCTSFactory;
import ai.libs.jaicore.search.algorithms.mdp.mcts.MCTSIterationCompletedEvent;
import ai.libs.jaicore.search.algorithms.standard.bestfirst.events.EvaluatedSearchSolutionCandidateFoundEvent;
import ai.libs.jaicore.search.core.interfaces.AOptimalPathInORGraphSearch;
import ai.libs.jaicore.search.model.other.EvaluatedSearchGraphPath;
import ai.libs.jaicore.search.probleminputs.IMDP;

/**
 *
 * @author Felix Mohr
 *
 * @param 
 *            Problem type
 * @param 
 *            Type of states (nodes)
 * @param 
 *            Type of actions
 */
public class MCTSPathSearch, N, A> extends AOptimalPathInORGraphSearch {

	private Logger logger = LoggerFactory.getLogger(MCTSPathSearch.class);
	private final GraphBasedMDP mdp;
	private final MCTS mcts;
	private final Set hashCodesOfReturnedPaths = new HashSet<>();

	public MCTSPathSearch(final I problem, final MCTSFactory mctsFactory) {
		super(problem);
		this.mdp = new GraphBasedMDP<>(problem);
		this.mcts = mctsFactory.getAlgorithm(this.mdp);
		this.mcts.registerListener(new Object() {

			@Subscribe
			public void receiveMCTSEvent(final IAlgorithmEvent e) {
				if (!(e instanceof AlgorithmInitializedEvent) && !(e instanceof AlgorithmFinishedEvent)) {
					MCTSPathSearch.this.post(e); // forward everything
				}
			}
		});
	}

	@Override
	public IAlgorithmEvent nextWithException() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
		switch (this.getState()) {
		case CREATED:

			/* initialize MCTS */
			this.mdp.setLoggerName(this.getLoggerName() + ".mdp");
			IEvent mctsInitEvent;
			do {
				mctsInitEvent = this.mcts.next();
			} while (!(mctsInitEvent instanceof AlgorithmInitializedEvent));
			return this.activate();
		case ACTIVE:

			/* if MCTS has finished, terminate */
			if (this.mcts.getState() != EAlgorithmState.ACTIVE) {
				return this.terminate();
			}

			/* keep asking for playouts until a playout is returned that is a solution path */
			IAlgorithmEvent e;
			while (!((e = this.mcts.nextWithException()) instanceof AlgorithmFinishedEvent)) {

				/* if this is not an event that declares the completion of an iteration, ignore it */
				if (!(e instanceof MCTSIterationCompletedEvent)) {
					continue;
				}

				/* form a path object and return a respective event */
				MCTSIterationCompletedEvent ce = (MCTSIterationCompletedEvent) e;
				double overallScore = SetUtil.sum(ce.getScores());
				this.logger.info("Registered rollout with score {}. Updating best seen solution correspondingly.", overallScore);
				EvaluatedSearchGraphPath path = new EvaluatedSearchGraphPath<>(ce.getRollout(), overallScore);

				/* only if the roll-out is a goal path, emit a success event */
				if (this.getGoalTester().isGoal(path)) {
					this.updateBestSeenSolution(path);
					int hashCode = path.hashCode();
					if (this.hashCodesOfReturnedPaths.contains(hashCode)) {
						this.logger.info("Skipping (and supressing) previously found solution with hash code {}", hashCode);
						continue;
					}
					this.hashCodesOfReturnedPaths.add(hashCode);
					ISolutionCandidateFoundEvent> event = new EvaluatedSearchSolutionCandidateFoundEvent<>(this, path);
					this.post(event);
					return event;
				}
			}
			return this.terminate();

		default:
			throw new IllegalStateException();
		}
	}

	@Override
	public void setTimeout(final Timeout to) {
		long toInSeconds = to.seconds();
		if (toInSeconds < 2) {
			throw new IllegalArgumentException("Cannot run MCTS with a timeout of less than 2 seconds.");
		}
		super.setTimeout(to);
		this.mcts.setTimeout(new Timeout(to.seconds() - 1, TimeUnit.SECONDS));
	}

	@Override
	public void cancel() {
		super.cancel();
		this.mcts.cancel(); // forwarding cancel
	}

	@Override
	public void setLoggerName(final String name) {
		super.setLoggerName(name + "._algorithm");
		this.logger = LoggerFactory.getLogger(name);
		this.mcts.setLoggerName(name + ".mcts");
	}

	@Override
	public String getLoggerName() {
		return this.logger.getName();
	}

	public IMDP getMdp() {
		return this.mdp;
	}

	public MCTS getMcts() {
		return this.mcts;
	}

	public int getNumberOfNodesInMemory() {
		return this.mcts.getNumberOfNodesInMemory();
	}
}