All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.libs.mlplan.core.MLPlan Maven / Gradle / Ivy

The newest version!
package ai.libs.mlplan.core;

import java.io.IOException;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.TimeUnit;

import org.api4.java.ai.graphsearch.problem.IPathSearchInput;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.algorithm.IAlgorithm;
import org.api4.java.algorithm.Timeout;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.event.IEvent;
import org.api4.java.common.reconstruction.IReconstructible;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.eventbus.Subscribe;

import ai.libs.hasco.builder.TwoPhaseHASCOBuilder;
import ai.libs.hasco.builder.forwarddecomposition.HASCOViaFDBuilder;
import ai.libs.hasco.core.HASCO;
import ai.libs.hasco.core.HASCOSolutionCandidate;
import ai.libs.hasco.core.events.HASCOSolutionEvent;
import ai.libs.hasco.core.events.TwoPhaseHASCOPhaseSwitchEvent;
import ai.libs.hasco.twophase.TwoPhaseHASCO;
import ai.libs.hasco.twophase.TwoPhaseHASCOConfig;
import ai.libs.hasco.twophase.TwoPhaseSoftwareConfigurationProblem;
import ai.libs.jaicore.basic.MathExt;
import ai.libs.jaicore.basic.algorithm.AAlgorithm;
import ai.libs.jaicore.basic.algorithm.AlgorithmFinishedEvent;
import ai.libs.jaicore.basic.algorithm.AlgorithmInitializedEvent;
import ai.libs.jaicore.basic.algorithm.EAlgorithmState;
import ai.libs.jaicore.basic.reconstruction.ReconstructionUtil;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.components.api.IComponentInstance;
import ai.libs.jaicore.components.exceptions.ComponentInstantiationFailedException;
import ai.libs.jaicore.components.optimizingfactory.OptimizingFactory;
import ai.libs.jaicore.components.optimizingfactory.OptimizingFactoryProblem;
import ai.libs.jaicore.components.serialization.ComponentSerialization;
import ai.libs.jaicore.ml.core.dataset.DatasetUtil;
import ai.libs.jaicore.ml.core.evaluation.evaluator.factory.LearnerEvaluatorConstructionFailedException;
import ai.libs.jaicore.planning.hierarchical.algorithms.forwarddecomposition.graphgenerators.tfd.TFDNode;
import ai.libs.mlplan.core.events.ClassifierFoundEvent;
import ai.libs.mlplan.core.events.MLPlanPhaseSwitchedEvent;
import ai.libs.mlplan.multiclass.IMLPlanClassifierConfig;

public class MLPlan>> extends AAlgorithm, L> implements ILoggingCustomizable {

	/** Logger for controlled output. */
	private Logger logger = LoggerFactory.getLogger(MLPlan.class);
	private String loggerName;

	private final ComponentSerialization serializer = new ComponentSerialization();
	private L selectedClassifier;
	private double internalValidationErrorOfSelectedClassifier;
	private IComponentInstance componentInstanceOfSelectedClassifier;

	private final IMLPlanBuilder builder;
	private TwoPhaseHASCOBuilder twoPhaseHASCOFactory;
	private OptimizingFactory, Double> optimizingFactory;

	private PipelineEvaluator classifierEvaluatorForSearch;
	private PipelineEvaluator classifierEvaluatorForSelection;

	private boolean buildSelectedClasifierOnGivenData = true;
	private final long seed;

	private long timestampAlgorithmStart;
	private boolean maintainReconstructibility = true;

	protected MLPlan(final IMLPlanBuilder builder, final ILabeledDataset data) { // ML-Plan has a package visible constructor, because it should only be constructed using a builder
		super(builder.getAlgorithmConfig(), data);

		/* sanity checks */
		if (builder.getSearchSpaceConfigFile() == null || !builder.getSearchSpaceConfigFile().exists()) {
			throw new IllegalArgumentException("The search space configuration file must be set in MLPlanBuilder, and it must be set to a file that exists!");
		}
		Objects.requireNonNull(builder.getLearnerFactory(), "The learner factory must be set in MLPlanBuilder!");
		if (builder.getRequestedInterface() == null || builder.getRequestedInterface().isEmpty()) {
			throw new IllegalArgumentException("No requested HASCO interface defined!");
		}

		if (this.getConfig().getTimeout().seconds() <= this.getConfig().precautionOffset()) {
			throw new IllegalArgumentException("Illegal timeout configuration. The precaution offset must be strictly smaller than the specified timeout.");
		}

		/* store builder and data for main algorithm */
		this.builder = builder;
		Objects.requireNonNull(this.getInput());
		if (this.getInput().isEmpty()) {
			throw new IllegalArgumentException("Cannot run ML-Plan on empty dataset.");
		}
		this.seed = this.builder.getAlgorithmConfig().seed();
		if (this.getInput() instanceof IReconstructible) {
			this.maintainReconstructibility = ReconstructionUtil.areInstructionsNonEmptyIfReconstructibilityClaimed(this.getInput());
			if (!this.maintainReconstructibility) {
				this.logger.warn("The dataset claims to be reconstructible, but it does not carry any instructions. ML-Plan will not add reconstruction instructions.");
			}
		} else {
			this.maintainReconstructibility = false;
		}
	}

	@Override
	public IAlgorithmEvent nextWithException() throws AlgorithmException, InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException {
		switch (this.getState()) {
		case CREATED:
			this.setTimeoutPrecautionOffset(1000); // for this routine, only consider a precaution of 1s
			this.logger.info("Starting an ML-Plan instance. Timeout precaution is {}ms", this.getTimeoutPrecautionOffset());
			this.timestampAlgorithmStart = System.currentTimeMillis();
			this.setDeadline(); // algorithm execution starts NOW, set deadline

			/* check number of CPUs assigned */
			if (this.getConfig().cpus() < 1) {
				throw new IllegalStateException("Cannot generate search where number of CPUs is " + this.getConfig().cpus());
			}

			/* set up exact splits */
			double portionForSelection = this.getConfig().dataPortionForSelection();
			Pair, ILabeledDataset> split = MLPlanUtil.getDataForSearchAndSelection(this.getInput(), portionForSelection, new Random(this.getConfig().seed()),
					this.builder.getSearchSelectionDatasetSplitter(), this.logger);
			ILabeledDataset dataShownToSearch = split.getX();
			ILabeledDataset dataShownToSelection = split.getY();

			/* check that class proportions are maintained */
			if (this.logger.isDebugEnabled()) {
				this.logger.debug("Class distribution is {}. Original class distribution was {}", DatasetUtil.getLabelCounts(dataShownToSearch), DatasetUtil.getLabelCounts(this.getInput()));
			}

			/* check that reconstructibility is preserved */
			if (this.maintainReconstructibility && ((IReconstructible) dataShownToSearch).getConstructionPlan().getInstructions().isEmpty()) {
				throw new IllegalStateException("Reconstructibility instructions have been lost in search/selection-split!");
			}

			/* dynamically compute blow-ups */
			if (Double.isNaN(this.getConfig().expectedBlowupInSelection())) {
				double blowUpInSelectionPhase = (1 + portionForSelection) * 1.5; // assume super-linear runime increase
				this.getConfig().setProperty(TwoPhaseHASCOConfig.K_BLOWUP_SELECTION, String.valueOf(blowUpInSelectionPhase));
				this.logger.info("No expected blow-up for selection phase has been defined. Automatically configuring {}", blowUpInSelectionPhase);
			}
			if (!this.buildSelectedClasifierOnGivenData) {
				this.getConfig().setProperty(TwoPhaseHASCOConfig.K_BLOWUP_POSTPROCESS, String.valueOf(0));
				this.logger.info("Selected classifier won't be built, so now blow-up is calculated.");
			} else if (Double.isNaN(this.getConfig().expectedBlowupInPostprocessing())) {
				double blowUpInPostprocessing = ((1.0 +  portionForSelection) / 0.8)  * 1.5; // the 1.5 are for a supposed super-linear runtime increase
				this.getConfig().setProperty(TwoPhaseHASCOConfig.K_BLOWUP_POSTPROCESS, String.valueOf(blowUpInPostprocessing));
				this.logger.info("No expected blow-up for postprocessing phase has been defined. Automatically configuring {}", blowUpInPostprocessing);
			}

			/* setup the pipeline evaluators */
			this.logger.debug("Setting up the pipeline evaluators.");
			Pair evaluators;
			try {
				evaluators = MLPlanUtil.getPipelineEvaluators(this.builder.getLearnerEvaluationFactoryForSearchPhase(), this.builder.getMetricForSearchPhase(), this.builder.getLearnerEvaluationFactoryForSelectionPhase(),
						this.builder.getMetricForSelectionPhase(), new Random(this.seed), dataShownToSearch, dataShownToSelection, this.builder.getSafeGuardFactory(), this.builder.getLearnerFactory(),
						this.getConfig().getTimeoutForCandidateEvaluation());
			} catch (LearnerEvaluatorConstructionFailedException e2) {
				throw new AlgorithmException("Could not create the evaluators.", e2);
			}
			this.classifierEvaluatorForSearch = evaluators.getX();
			this.classifierEvaluatorForSelection = evaluators.getY();
			this.classifierEvaluatorForSearch.registerListener(this); // events will be forwarded
			if (this.classifierEvaluatorForSearch.getSafeGuard() != null) {
				this.classifierEvaluatorForSearch.getSafeGuard().registerListener(this);
			}
			if (this.classifierEvaluatorForSelection != null) {
				this.classifierEvaluatorForSelection.registerListener(this); // events will be forwarded
			}

			/* communicate the parameters with which ML-Plan will run */
			if (this.logger.isInfoEnabled()) {
				this.logger.info(
						"Starting ML-Plan with the following setup:\n\tDataset: {}\n\tCPUs: {}\n\tTimeout: {}s\n\tRemaining Time after initialization: {}s\n\tTimeout Precaution Offset: {}s\n\tTimeout for single candidate evaluation: {}s\n\tTimeout for node evaluation: {}s\n\tRandom Completions per node evaluation: {}\n\tPortion of data for selection phase: {}%\n\tData points used during search: {}\n\tData points used during selection: {}\n\tPipeline evaluation during search: {}\n\tPipeline evaluation during selection: {}\n\tBlow-ups are {} for selection phase and {} for post-processing phase.",
						this.getInput().getRelationName(), this.getConfig().cpus(), this.getTimeout().seconds(), this.getRemainingTimeToDeadline().seconds(), this.getConfig().precautionOffset(),
						this.getConfig().timeoutForCandidateEvaluation() / 1000, this.getConfig().timeoutForNodeEvaluation() / 1000, this.getConfig().numberOfRandomCompletions(),
						MathExt.round(this.getConfig().dataPortionForSelection() * 100, 2), dataShownToSearch.size(), dataShownToSelection != null ? dataShownToSelection.size() : 0, this.classifierEvaluatorForSearch.getBenchmark(),
								this.classifierEvaluatorForSelection != null ? this.classifierEvaluatorForSelection.getBenchmark() : null, this.getConfig().expectedBlowupInSelection(), this.getConfig().expectedBlowupInPostprocessing());
			}

			/* create 2-phase software configuration problem */
			this.logger.debug("Creating 2-phase software configuration problem.");
			TwoPhaseSoftwareConfigurationProblem problem = null;
			try {
				problem = new TwoPhaseSoftwareConfigurationProblem(this.builder.getSearchSpaceConfigFile(), this.builder.getRequestedInterface(), this.classifierEvaluatorForSearch, this.classifierEvaluatorForSelection);
			} catch (IOException e1) {
				throw new AlgorithmException("Could not create the 2-phase configuration problem with search space file \"" + this.builder.getSearchSpaceConfigFile() + "\" and required interface " + this.builder.getRequestedInterface(),
						e1);
			}

			/* create 2-phase HASCO */
			this.logger.info("Creating the twoPhaseHASCOFactory.");
			OptimizingFactoryProblem optimizingFactoryProblem = new OptimizingFactoryProblem<>(this.builder.getLearnerFactory(), problem);
			HASCOViaFDBuilder hascoFactory = this.builder.getHASCOFactory();
			this.twoPhaseHASCOFactory = new TwoPhaseHASCOBuilder<>(hascoFactory);
			this.twoPhaseHASCOFactory.setConfig(this.getConfig().copy(TwoPhaseHASCOConfig.class)); // instantiate 2-Phase-HASCO with a config COPY to not have config changes in 2-Phase-HASCO impacts on the MLPlan configuration
			this.optimizingFactory = new OptimizingFactory<>(optimizingFactoryProblem, this.twoPhaseHASCOFactory);
			if (this.loggerName != null) {
				this.logger.info("Setting logger of {} to {}.optimizingfactory", this.optimizingFactory.getClass().getName(), this.loggerName);
				this.optimizingFactory.setLoggerName(this.loggerName + ".optimizingfactory");
			}
			final double dataPortionUsedForSelection = this.getConfig().dataPortionForSelection();
			this.optimizingFactory.registerListener(new Object() {
				@Subscribe
				public void receiveEventFromFactory(final IEvent event) throws InterruptedException {
					if (event instanceof AlgorithmInitializedEvent || event instanceof AlgorithmFinishedEvent) {
						return;
					}
					if (event instanceof TwoPhaseHASCOPhaseSwitchEvent) {
						MLPlan.this.post(new MLPlanPhaseSwitchedEvent(MLPlan.this));
					} else if (event instanceof HASCOSolutionEvent) {
						HASCOSolutionCandidate solution = ((HASCOSolutionEvent) event).getSolutionCandidate();
						try {
							MLPlan.this.logger.info("Received new solution {} with score {} and evaluation time {}ms", MLPlan.this.serializer.serialize(solution.getComponentInstance()), solution.getScore(),
									solution.getTimeToEvaluateCandidate());
						} catch (Exception e) {
							MLPlan.this.logger.warn("Could not print log due to exception while preparing the log message.", e);
						}

						if (dataPortionUsedForSelection == 0.0 && solution.getScore() < MLPlan.this.internalValidationErrorOfSelectedClassifier) {
							try {
								MLPlan.this.selectedClassifier = MLPlan.this.builder.getLearnerFactory().getComponentInstantiation(solution.getComponentInstance());
								MLPlan.this.internalValidationErrorOfSelectedClassifier = solution.getScore();
								MLPlan.this.componentInstanceOfSelectedClassifier = solution.getComponentInstance();
							} catch (ComponentInstantiationFailedException e) {
								MLPlan.this.logger.error("Could not update selectedClassifier with newly best seen solution due to issues building the classifier from its ComponentInstance description.", e);
							}
						}

						try {
							MLPlan.this.post(new ClassifierFoundEvent(MLPlan.this, solution.getComponentInstance(), MLPlan.this.builder.getLearnerFactory().getComponentInstantiation(solution.getComponentInstance()), solution.getScore(),
									solution.getTimeToEvaluateCandidate()));
						} catch (ComponentInstantiationFailedException e) {
							MLPlan.this.logger.error("An issue occurred while preparing the description for the post of a ClassifierFoundEvent", e);
						}
					} else {
						MLPlan.this.post(event);
					}
				}
			});
			Timeout remainingTimeConsideringPrecaution = new Timeout(this.getRemainingTimeToDeadline().seconds() - this.getConfig().precautionOffset(), TimeUnit.SECONDS);
			this.logger.info("Initializing the optimization factory with timeout {}.", remainingTimeConsideringPrecaution);
			this.optimizingFactory.setTimeout(remainingTimeConsideringPrecaution);
			this.optimizingFactory.init();
			AlgorithmInitializedEvent event = this.activate();
			this.logger.info("Started and activated ML-Plan.");
			return event;

		case ACTIVE:

			/* train the classifier returned by the optimizing factory */
			long startOptimizationTime = System.currentTimeMillis();
			try {
				this.selectedClassifier = this.optimizingFactory.call();
				this.logger.info("2-Phase-HASCO has chosen classifier {}, which will now be built on the entire data given, i.e. {} data points.", this.selectedClassifier, this.getInput().size());
			} catch (AlgorithmException | InterruptedException | AlgorithmExecutionCanceledException | AlgorithmTimeoutedException e) {
				this.terminate(); // send the termination event
				throw e;
			}
			this.internalValidationErrorOfSelectedClassifier = this.optimizingFactory.getPerformanceOfObject();
			this.componentInstanceOfSelectedClassifier = this.optimizingFactory.getComponentInstanceOfObject();
			if (this.buildSelectedClasifierOnGivenData) {
				long startBuildTime = System.currentTimeMillis();
				try {
					this.selectedClassifier.fit(this.getInput());
				} catch (Exception e) {
					throw new AlgorithmException("Training the classifier failed!", e);
				}
				long endBuildTime = System.currentTimeMillis();
				this.logger.info(
						"Selected model has been built on entire dataset. Build time of chosen model was {}ms. Total construction time was {}ms ({}ms of that on preparation and {}ms on essential optimization). The chosen classifier is: {}",
						endBuildTime - startBuildTime, endBuildTime - this.timestampAlgorithmStart, startOptimizationTime - this.timestampAlgorithmStart, endBuildTime - startOptimizationTime, this.selectedClassifier);
			} else {
				this.logger.info("Selected model has not been built, since model building has been disabled. Total construction time was {}ms.", System.currentTimeMillis() - startOptimizationTime);
			}
			return this.terminate();

		default:
			throw new IllegalStateException("Cannot do anything in state " + this.getState());
		}

	}

	@Override
	public L call() throws AlgorithmException, InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException {
		while (this.hasNext()) {
			this.nextWithException();
		}
		return this.selectedClassifier;
	}

	@Override
	public void setLoggerName(final String name) {
		this.loggerName = name;
		super.setLoggerName(name + "._algorithm");
		this.logger.info("Switching logger name to {}", name);
		this.logger = LoggerFactory.getLogger(name);
		this.logger.info("Activated ML-Plan logger {}. Now setting logger of twoPhaseHASCO to {}.2phasehasco", name, name);
		if (this.optimizingFactory != null) {
			this.logger.info("Setting logger of {} to {}.optimizingfactory", this.optimizingFactory.getClass().getName(), this.loggerName);
			this.optimizingFactory.setLoggerName(this.loggerName + ".optimizingfactory");
		} else {
			this.logger.debug("Optimizingfactory has not been set yet, so not customizing its logger.");
		}

		this.serializer.setLoggerName(name + ".ser");
		this.logger.info("Switched ML-Plan logger to {}", name);
	}

	public void setPortionOfDataForPhase2(final double portion) {
		this.getConfig().setProperty(IMLPlanClassifierConfig.SELECTION_PORTION, String.valueOf(portion));
	}

	@Override
	public String getLoggerName() {
		return this.loggerName;
	}

	@Override
	public IMLPlanClassifierConfig getConfig() {
		return (IMLPlanClassifierConfig) super.getConfig();
	}

	public void setRandomSeed(final int seed) {
		this.getConfig().setProperty(TwoPhaseHASCOConfig.K_RANDOM_SEED, String.valueOf(seed));
	}

	public L getSelectedClassifier() {
		return this.selectedClassifier;
	}

	public IComponentInstance getComponentInstanceOfSelectedClassifier() {
		return this.componentInstanceOfSelectedClassifier;
	}

	public IPathSearchInput getSearchProblemInputGenerator() {
		this.initializeIfNotDone();
		return ((TwoPhaseHASCO) this.optimizingFactory.getOptimizer()).getGraphSearchInput();
	}

	public double getInternalValidationErrorOfSelectedClassifier() {
		return this.internalValidationErrorOfSelectedClassifier;
	}

	@Override
	public synchronized void cancel() {
		this.logger.info("Received cancel. First canceling optimizer, then invoking general shutdown.");
		this.optimizingFactory.cancel();
		this.logger.debug("Now canceling main ML-Plan routine");
		super.cancel();
		assert this.isCanceled() : "Canceled-flag is not positive at the end of the cancel routine!";
		this.logger.info("Completed cancellation of ML-Plan. Cancel status is {}", this.isCanceled());
	}

	public OptimizingFactory, Double> getOptimizingFactory() {
		return this.optimizingFactory;
	}

	public HASCO getHASCO() {
		this.initializeIfNotDone();
		return ((TwoPhaseHASCO) this.optimizingFactory.getOptimizer()).getHasco();
	}

	public IAlgorithm getSearch() {
		this.initializeIfNotDone();
		return this.getHASCO().getSearch();
	}

	private void initializeIfNotDone() {
		if (this.getState() == EAlgorithmState.CREATED) {
			this.next(); // initialize
		}
	}

	public PipelineEvaluator getClassifierEvaluatorForSearch() {
		return this.classifierEvaluatorForSearch;
	}

	public PipelineEvaluator getClassifierEvaluatorForSelection() {
		return this.classifierEvaluatorForSelection;
	}

	@Subscribe
	public void receiveEvent(final IEvent e) {
		this.post(e);
	}

	public TwoPhaseHASCOBuilder getTwoPhaseHASCOFactory() {
		return this.twoPhaseHASCOFactory;
	}

	public boolean isBuildSelectedClasifierOnGivenData() {
		return this.buildSelectedClasifierOnGivenData;
	}

	public void setBuildSelectedClasifierOnGivenData(final boolean buildSelectedClasifierOnGivenData) {
		this.buildSelectedClasifierOnGivenData = buildSelectedClasifierOnGivenData;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy