All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.libs.jaicore.search.gui.plugins.mcts.dng.DNGMCTSPluginModel Maven / Gradle / Ivy

package ai.libs.jaicore.search.gui.plugins.mcts.dng;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;

import ai.libs.jaicore.graphvisualizer.plugin.ASimpleMVCPluginModel;

/**
 *
 * @author fmohr
 *
 * @param 
 *            The node type class.
 */
public class DNGMCTSPluginModel extends ASimpleMVCPluginModel {

	private String currentlySelectedNode = "0";
	private final Map parents = new HashMap<>();
	private final Map> listsOfKnownSuccessors = new HashMap<>();
	private final Map> listOfObersvationsPerNode = new HashMap<>();
	private final Map>> observedQValues = new HashMap<>();
	private final Map> observedUpdates = new HashMap<>();

	@Override
	public void clear() {
		this.getView().clear();
	}

	public void setCurrentlySelectedNode(final String currentlySelectedNode) {
		this.currentlySelectedNode = currentlySelectedNode;
		this.getView().clear();
		this.getView().update();
	}

	public String getCurrentlySelectedNode() {
		return this.currentlySelectedNode;
	}

	public void addObservation(final String node, final double score) {
		this.listOfObersvationsPerNode.computeIfAbsent(node, n -> new ArrayList<>()).add(score);
	}

	public void setNodeStats(final DNGQSample update) {
		if (update == null) {
			throw new IllegalArgumentException("Cannot process NULL update");
		}
		String node = update.getNode();
		if (!this.listsOfKnownSuccessors.containsKey(node)) {
			throw new IllegalArgumentException("Cannot receive update for an unknown node. Make sure that Rollout events are processed!");
		}
		this.observedQValues.computeIfAbsent(node, n -> new HashMap<>()).computeIfAbsent(update.getSuccessor(), n2 -> new ArrayList<>()).add(update.getScore());
		if (node.equals(this.getCurrentlySelectedNode())) {
			this.getView().update();
		}
	}

	public void setNodeStats(final DNGBeliefUpdate update) {
		if (update == null) {
			throw new IllegalArgumentException("Cannot process NULL update");
		}
		String node = update.getNode();
		this.observedUpdates.computeIfAbsent(node, n -> new ArrayList<>()).add(update);
		if (node.equals(this.getCurrentlySelectedNode())) {
			this.getView().update();
		}
	}

	public Map> getQValuesOfNode(final String node) {
		return this.observedQValues.get(node);
	}

	public Map> getQValuesOfSelectedNode() {
		return this.observedQValues.get(this.getCurrentlySelectedNode());
	}

	public Map> getListsOfKnownSuccessors() {
		return this.listsOfKnownSuccessors;
	}

	public List getListOfKnownSuccessorsOfCurrentlySelectedNode() {
		return this.listsOfKnownSuccessors.get(this.getCurrentlySelectedNode());
	}

	public Map getParents() {
		return this.parents;
	}

	public String getParentOfCurrentNode() {
		return this.parents.get(this.getCurrentlySelectedNode());
	}

	public Map> getObservedMuValues() {
		return this.observedUpdates;
	}

	public List getObservedMuValuesOfCurrentlySelectedNode() {
		return this.observedUpdates.get(this.getCurrentlySelectedNode());
	}

	public Map> getListOfObersvationsPerNode() {
		return this.listOfObersvationsPerNode;
	}

	public DescriptiveStatistics getObservationStatisticsOfNode(final String node) {
		DescriptiveStatistics stats = new DescriptiveStatistics();
		for (double val : this.listOfObersvationsPerNode.get(node)) {
			stats.addValue(val);
		}
		return stats;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy