ai.libs.jaicore.search.algorithms.mdp.mcts.GraphBasedMDP Maven / Gradle / Ivy
package ai.libs.jaicore.search.algorithms.mdp.mcts;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import org.api4.java.ai.graphsearch.problem.IPathSearchWithPathEvaluationsInput;
import org.api4.java.ai.graphsearch.problem.implicit.graphgenerator.IPathGoalTester;
import org.api4.java.ai.graphsearch.problem.pathsearch.pathevaluation.PathEvaluationException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.api4.java.datastructure.graph.implicit.ILazySuccessorGenerator;
import org.api4.java.datastructure.graph.implicit.INewNodeDescription;
import org.api4.java.datastructure.graph.implicit.ISingleRootGenerator;
import org.api4.java.datastructure.graph.implicit.ISuccessorGenerator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.search.model.ILazyRandomizableSuccessorGenerator;
import ai.libs.jaicore.search.model.other.SearchGraphPath;
import ai.libs.jaicore.search.probleminputs.IMDP;
public class GraphBasedMDP implements IMDP, ILoggingCustomizable {
private static final int MAX_SUCCESSOR_CACHE_SIZE = 100;
private final IPathSearchWithPathEvaluationsInput graph;
private final N root;
private final ISuccessorGenerator succGen;
private final IPathGoalTester goalTester;
private final Map> backPointers = new HashMap<>();
private Logger logger = LoggerFactory.getLogger(GraphBasedMDP.class);
private final Map> successorCache = new HashMap<>();
private final boolean lazy;
private final ILazySuccessorGenerator lazySuccGen;
public GraphBasedMDP(final IPathSearchWithPathEvaluationsInput graph) {
super();
this.graph = graph;
this.root = ((ISingleRootGenerator) this.graph.getGraphGenerator().getRootGenerator()).getRoot();
this.succGen = graph.getGraphGenerator().getSuccessorGenerator();
this.goalTester = graph.getGoalTester();
this.lazySuccGen = this.succGen instanceof ILazySuccessorGenerator ? (ILazySuccessorGenerator) this.succGen : null;
this.lazy = this.lazySuccGen != null;
}
@Override
public N getInitState() {
return this.root;
}
@Override
public boolean isMaximizing() { // path searches are, by definition, always minimization problems
return false;
}
@Override
public Collection getApplicableActions(final N state) throws InterruptedException {
this.logger.debug("Computing applicable actions.");
Collection> successors = this.succGen.generateSuccessors(state);
Collection actions = new ArrayList<>();
Map cache = new HashMap<>();
if (Thread.interrupted()) {
throw new InterruptedException("The computation of applicable actions has been interrupted.");
}
long lastInterruptCheck = System.currentTimeMillis();
for (INewNodeDescription succ : successors) {
A action = succ.getArcLabel();
actions.add(action);
cache.put(action, succ.getTo());
long now = System.currentTimeMillis();
if (lastInterruptCheck < now - 10) {
lastInterruptCheck = now;
if (Thread.interrupted()) {
throw new InterruptedException("The computation of applicable actions has been interrupted.");
}
}
if (this.backPointers.containsKey(succ.getTo())) {
Pair backpointer = this.backPointers.get(succ.getTo());
boolean sameParent = backpointer.getX().equals(state);
boolean sameAction = backpointer.getY().equals(action);
if (!sameParent || !sameAction) {
N otherNode = null;
for (N key : this.backPointers.keySet()) {
now = System.currentTimeMillis();
if (lastInterruptCheck < now - 10) {
lastInterruptCheck = now;
if (Thread.interrupted()) {
throw new InterruptedException("The computation of applicable actions has been interrupted.");
}
}
if (key.equals(succ.getTo())) {
otherNode = key;
break;
}
}
throw new IllegalStateException("Reaching state " + succ.getTo() + " on a second way, which must not be the case in trees!\n\t1st way: " + backpointer.getX() + "; " + backpointer.getY() + "\n\t2nd way: " + state + "; "
+ action + "\n\ttoString of existing node: " + otherNode + "\n\tSame parent: " + sameParent + "\n\tSame Action: " + sameAction);
}
}
this.logger.debug("Setting backpointer from {} to {}", succ.getTo(), state);
this.backPointers.put(succ.getTo(), new Pair<>(state, action));
}
/* clear the cache if we have too many entries */
if (this.successorCache.size() > MAX_SUCCESSOR_CACHE_SIZE) {
this.successorCache.clear();
}
this.successorCache.put(state, cache);
return actions;
}
@Override
public Map getProb(final N state, final A action) throws InterruptedException {
/* first determine the successor node (either by cache or by constructing the successors again) */
N successor = null;
if (this.successorCache.containsKey(state) && this.successorCache.get(state).containsKey(action)) {
successor = this.successorCache.get(state).get(action);
} else {
Optional> succOpt = this.succGen.generateSuccessors(state).stream().filter(nd -> nd.getArcLabel().equals(action)).findAny();
if (!succOpt.isPresent()) {
this.logger.error("THERE IS NO SUCCESSOR REACHABLE WITH ACTION {} IN THE MDP!", action);
return null;
}
successor = succOpt.get().getTo();
}
/* now equip this successor with probability 1 */
Map out = new HashMap<>();
out.put(successor, 1.0);
return out;
}
@Override
public double getProb(final N state, final A action, final N successor) throws InterruptedException {
return this.getProb(state, action).containsKey(successor) ? 1 : 0.0;
}
@Override
public Double getScore(final N state, final A action, final N successor) throws PathEvaluationException, InterruptedException {
/* now build the whole path using the back-pointer map */
this.logger.info("Getting score for SAS-triple ({}, {}, {})", state, action, successor);
N cur = successor;
List nodes = new ArrayList<>();
List arcs = new ArrayList<>();
nodes.add(cur);
while (cur != this.root) {
Pair parentEdge = this.backPointers.get(cur);
if (parentEdge == null) {
throw new NullPointerException("No back pointer defined for non-root node " + cur); // this is INTENTIONALLY not done with Object.requireNonNull, because the string shall not be evaluated otherwise!
}
cur = parentEdge.getX();
nodes.add(0, cur);
arcs.add(0, parentEdge.getY());
}
ILabeledPath path = new SearchGraphPath<>(nodes, arcs);
/* check whether path is a goal path */
if (!this.goalTester.isGoal(path)) { // in the MDP-view of a node, partial paths do not yield a reward but only full paths.
boolean isTerminal = this.isTerminalState(path.getHead());
if (isTerminal) {
this.logger.debug("Found dead end! Returning null.");
return null;
}
this.logger.info("Path {} is not a goal path, returning 0.0", path);
return 0.0;
}
this.logger.info("Path is a goal path, invoking path evaluator.");
double score = this.graph.getPathEvaluator().evaluate(path).doubleValue();
this.logger.info("Obtained score {} for path", score);
return score;
}
@Override
public boolean isTerminalState(final N state) throws InterruptedException {
if (this.lazy) {
this.logger.debug("Determining terminal state condition for lazy graph generator.");
return !this.lazySuccGen.getIterativeGenerator(state).hasNext();
} else {
return this.succGen.generateSuccessors(state).isEmpty();
}
}
@Override
public String getLoggerName() {
return this.logger.getName();
}
@Override
public void setLoggerName(final String name) {
this.logger = LoggerFactory.getLogger(name);
if (this.succGen instanceof ILoggingCustomizable) {
this.logger.info("Setting logger of successor generator to {}.gg", name);
((ILoggingCustomizable) this.succGen).setLoggerName(name + ".gg");
}
if (this.goalTester instanceof ILoggingCustomizable) {
((ILoggingCustomizable) this.goalTester).setLoggerName(name + ".gt");
}
if (this.graph.getPathEvaluator() instanceof ILoggingCustomizable) {
((ILoggingCustomizable) this.graph.getPathEvaluator()).setLoggerName(name + ".pe");
}
}
@Override
public A getUniformlyRandomApplicableAction(final N state, final Random random) throws InterruptedException {
if (this.succGen instanceof ILazyRandomizableSuccessorGenerator) {
INewNodeDescription ne = ((ILazyRandomizableSuccessorGenerator) this.succGen).getIterativeGenerator(state, random).next();
if (this.successorCache.size() > MAX_SUCCESSOR_CACHE_SIZE) {
this.successorCache.clear();
}
this.successorCache.computeIfAbsent(state, n -> new HashMap<>()).put(ne.getArcLabel(), ne.getTo());
this.backPointers.put(ne.getTo(), new Pair<>(state, ne.getArcLabel()));
return ne.getArcLabel();
}
this.logger.debug("The successor generator {} does not support lazy AND randomized successor generation. Now computing all successors and drawing one at random.", this.succGen.getClass());
Collection actions = this.getApplicableActions(state);
if (actions.isEmpty()) {
throw new IllegalArgumentException("The given node has no successors: " + state);
}
return SetUtil.getRandomElement(actions, random);
}
@Override
public boolean isActionApplicableInState(final N state, final A action) throws InterruptedException {
if (this.successorCache.containsKey(state) && this.successorCache.get(state).containsKey(action)) {
return true;
}
return this.getApplicableActions(state).contains(action);
}
}