com.sri.ai.praise.evaluate.run.Evaluation Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of aic-praise Show documentation
Show all versions of aic-praise Show documentation
SRI International's AIC PRAiSE (Probabilistic Reasoning As Symbolic Evaluation) Library (for Java 1.8+)
/*
* Copyright (c) 2015, SRI International
* All rights reserved.
* Licensed under the The BSD 3-Clause License;
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at:
*
* http://opensource.org/licenses/BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
* Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* Neither the name of the aic-praise nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
* FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
* COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
* INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
* OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.sri.ai.praise.evaluate.run;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.StringJoiner;
import com.sri.ai.praise.evaluate.solver.SolverEvaluator;
import com.sri.ai.praise.evaluate.solver.SolverEvaluatorConfiguration;
import com.sri.ai.praise.evaluate.solver.SolverEvaluatorProbabilityEvidenceResult;
import com.sri.ai.praise.model.common.io.ModelPage;
import com.sri.ai.praise.model.common.io.PagedModelContainer;
import com.sri.ai.praise.sgsolver.solver.ExpressionFactorsAndTypes;
/**
* Class responsible for running an evaluation of 1 or more solvers on a given problem set.
*
* @author oreilly
*
*/
public class Evaluation {
// NOTE: Based on http://www.hlt.utdallas.edu/~vgogate/uai14-competition/information.html
public enum Type {
PR, // Computing the the partition function and probability of evidence
MAR, // Computing the marginal probability distribution over variable(s) given evidence
MAP, // Computing the most likely assignment to all variables given evidence (also known as MPE, Most Probable Explanation)
MMAP, // Computing the most likely assignment to a subset of variables given evidence (Marginal MAP)
};
public static class Configuration {
private Evaluation.Type type;
private File workingDirectory;
private int numberRunsToAverageOver;
public Configuration(Evaluation.Type type, File workingDirectory, int numberRunsToAverageOver) {
this.type = type;
this.workingDirectory = workingDirectory;
this.numberRunsToAverageOver = numberRunsToAverageOver;
}
public Evaluation.Type getType() {
return type;
}
public File getWorkingDirectory() {
return workingDirectory;
}
public int getNumberRunsToAverageOver() {
return numberRunsToAverageOver;
}
}
public interface Listener {
void notification(String notification);
void notificationException(Exception ex);
void csvResultOutput(String csvLine);
}
public void evaluate(Evaluation.Configuration configuration, PagedModelContainer modelsToEvaluateContainer, List solverConfigurations, Evaluation.Listener evaluationListener) {
// Note, varying domain sizes etc... is achieved by creating variants of a base model in the provided paged model container
long evaluationStart = System.currentTimeMillis();
try {
List solvers = instantiateSolvers(solverConfigurations, configuration.getWorkingDirectory());
// Do an initial burn in to ensure any OS caching etc... occurs so as to even out times across runs
ModelPage burnInModel = modelsToEvaluateContainer.getPages().get(0);
String burnInQuery = burnInModel.getDefaultQueriesToRun().get(0);
evaluationListener.notification("Starting solver burn in based on '"+modelsToEvaluateContainer.getName()+" - "+burnInModel.getName() + " : " + burnInQuery+"'");
for (SolverEvaluator solver : solvers) {
SolverCallResult solverResult = callSolver(configuration, solver, burnInModel, burnInQuery);
evaluationListener.notification("Burn in for "+solver.getConfiguration().getName()+" complete. Average inference time = "+toDurationString(solverResult.averageInferenceTimeInMilliseconds));
}
// Output the report header line
StringJoiner csvLine = new StringJoiner(",");
csvLine.add("Problem");
csvLine.add("Inference Type");
csvLine.add("Domain Size(s)");
csvLine.add("# runs values averaged over");
for (SolverEvaluator solver : solvers) {
csvLine.add("Solver");
csvLine.add("Result for "+solver.getConfiguration().getName());
csvLine.add("Inference ms. for "+solver.getConfiguration().getName());
csvLine.add("HH:MM:SS.");
csvLine.add("Translation ms. for "+solver.getConfiguration().getName());
csvLine.add("HH:MM:SS.");
}
evaluationListener.notification("Starting to generate Evaluation Report");
evaluationListener.csvResultOutput(csvLine.toString());
// Now evaluate each of the model-query-solver combinations.
for (ModelPage model : modelsToEvaluateContainer.getPages()) {
String domainSizes = getDomainSizes(model.getModel());
for (String query: model.getDefaultQueriesToRun()) {
csvLine = new StringJoiner(",");
String problemName = modelsToEvaluateContainer.getName()+" - "+model.getName() + " : " + query;
evaluationListener.notification("Starting to evaluate "+problemName);
csvLine.add(problemName);
csvLine.add(configuration.type.name());
csvLine.add(domainSizes);
csvLine.add(""+configuration.getNumberRunsToAverageOver());
for (SolverEvaluator solver : solvers) {
SolverCallResult solverResult = callSolver(configuration, solver, model, query);
csvLine.add(solver.getConfiguration().getName());
csvLine.add(solverResult.failed ? "FAILED" : ""+solverResult.answer);
csvLine.add(""+solverResult.averageInferenceTimeInMilliseconds);
csvLine.add(toDurationString(solverResult.averageInferenceTimeInMilliseconds));
csvLine.add(""+solverResult.averagelTranslationTimeInMilliseconds);
csvLine.add(toDurationString(solverResult.averagelTranslationTimeInMilliseconds));
evaluationListener.notification("Solver "+solver.getConfiguration().getName()+" took an average inference time of "+toDurationString(solverResult.averageInferenceTimeInMilliseconds)+" to solve "+problemName);
}
evaluationListener.csvResultOutput(csvLine.toString());
}
}
}
catch (Exception ex) {
evaluationListener.notificationException(ex);
}
long evaluationEnd = System.currentTimeMillis();
evaluationListener.notification("Evaluation took "+toDurationString(evaluationEnd - evaluationStart)+ " to run to completion.");
}
@SuppressWarnings("unchecked")
private List instantiateSolvers(List solverConfigurations, File workingDirectory) {
List result = new ArrayList<>(solverConfigurations.size());
for (SolverEvaluatorConfiguration configuration : solverConfigurations) {
Class extends SolverEvaluator> clazz;
try {
clazz = (Class extends SolverEvaluator>) Class.forName(configuration.getImplementationClassName());
}
catch (ClassNotFoundException cnfe) {
throw new IllegalArgumentException("Unable to find "+SolverEvaluator.class.getName()+" implementation class: "+configuration.getImplementationClassName(), cnfe);
}
try {
SolverEvaluator solver = clazz.newInstance();
configuration.setWorkingDirectory(workingDirectory);
solver.setConfiguration(configuration);
result.add(solver);
}
catch (Throwable t) {
throw new IllegalStateException("Unable to instantiate instance of "+clazz.getName(), t);
}
}
return result;
}
private SolverCallResult callSolver(Evaluation.Configuration configuration, SolverEvaluator solver, ModelPage model, String query)
throws Exception {
SolverCallResult result = new SolverCallResult();
long sumOfTotalInferenceTimeInMilliseconds = 0L;
long sumOfTotalTranslationTimeInMilliseconds = 0L;
for (int i = 0; i < configuration.getNumberRunsToAverageOver(); i++) {
if (configuration.type == Type.PR) {
SolverEvaluatorProbabilityEvidenceResult prResult = solver.solveProbabilityEvidence(model.getName()+" - "+query,
model.getLanguage(), model.getModel(), query);
sumOfTotalInferenceTimeInMilliseconds += prResult.getTotalInferenceTimeInMilliseconds();
sumOfTotalTranslationTimeInMilliseconds += prResult.getTotalTranslationTimeInMilliseconds();
if (prResult.getProbabilityOfEvidence() == null) {
result.failed = true;
}
else {
result.answer = prResult.getProbabilityOfEvidence().doubleValue();
}
}
else {
throw new UnsupportedOperationException(configuration.type.name()+" type evaluations are currently not supported");
}
}
result.averageInferenceTimeInMilliseconds = sumOfTotalInferenceTimeInMilliseconds / configuration.getNumberRunsToAverageOver();
result.averagelTranslationTimeInMilliseconds = sumOfTotalTranslationTimeInMilliseconds / configuration.getNumberRunsToAverageOver();
return result;
}
private String getDomainSizes(String model) {
StringJoiner result = new StringJoiner(",");
ExpressionFactorsAndTypes factorsAndTypes = new ExpressionFactorsAndTypes(model);
// TODO - this only works with HOGM models and inequality integer types currently
for (com.sri.ai.expresso.api.Type type : factorsAndTypes.getAdditionalTypes()) {
result.add(""+type.cardinality().intValueExact());
}
return result.toString();
}
private String toDurationString(long duration) {
long hours = 0L, minutes = 0L, seconds = 0L, milliseconds = 0L;
if (duration != 0) {
hours = duration / 3600000;
duration = duration % 3600000;
}
if (duration != 0) {
minutes = duration / 60000;
duration = duration % 60000;
}
if (duration != 0) {
seconds = duration / 1000;
duration = duration % 1000;
}
milliseconds = duration;
return "" + hours + "h" + minutes + "m" + seconds + "." + milliseconds+"s";
}
class SolverCallResult {
double answer = -1;
boolean failed = false;
long averageInferenceTimeInMilliseconds;
long averagelTranslationTimeInMilliseconds;
}
}