com.sri.ai.praise.sgsolver.solver.HOGMQueryRunner Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of aic-praise Show documentation
Show all versions of aic-praise Show documentation
SRI International's AIC PRAiSE (Probabilistic Reasoning As Symbolic Evaluation) Library (for Java 1.8+)
/*
* Copyright (c) 2015, SRI International
* All rights reserved.
* Licensed under the The BSD 3-Clause License;
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at:
*
* http://opensource.org/licenses/BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
* Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* Neither the name of the aic-praise nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
* FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
* COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
* INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
* OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.sri.ai.praise.sgsolver.solver;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.antlr.v4.runtime.RecognitionException;
import com.google.common.annotations.Beta;
import com.sri.ai.expresso.api.Expression;
import com.sri.ai.expresso.api.Parser;
import com.sri.ai.expresso.helper.Expressions;
import com.sri.ai.grinder.api.Context;
import com.sri.ai.grinder.helper.GrinderUtil;
import com.sri.ai.praise.model.v1.HOGMSortDeclaration;
import com.sri.ai.praise.model.v1.HOGModelException;
import com.sri.ai.praise.model.v1.hogm.antlr.HOGMParserWrapper;
import com.sri.ai.praise.model.v1.hogm.antlr.ParsedHOGModel;
import com.sri.ai.praise.model.v1.hogm.antlr.UnableToParseAllTheInputError;
import com.sri.ai.praise.sgsolver.solver.ExpressionFactorsAndTypes;
import com.sri.ai.praise.sgsolver.solver.FactorsAndTypes;
import com.sri.ai.praise.sgsolver.solver.InferenceForFactorGraphAndEvidence;
@Beta
public class HOGMQueryRunner {
private String model;
private List queries = new ArrayList<>();
private boolean canceled = false;
//
private InferenceForFactorGraphAndEvidence inferencer = null;
public HOGMQueryRunner(String model, String query) {
this(model, Collections.singletonList(query));
}
public HOGMQueryRunner(String model, List queries) {
this.model = model;
this.queries.addAll(queries);
}
public List query() {
List result = new ArrayList<>();
Expression queryExpr = null;
//
ParsedHOGModel parsedModel = null;
for (String query : queries) {
long startQuery = System.currentTimeMillis();
List errors = new ArrayList<>();
try {
if (model == null || model.trim().equals("")) {
errors.add(new HOGMQueryError(HOGMQueryError.Context.MODEL, "Model not specified", 0, 0, 0));
}
if (query == null || query.trim().equals("")) {
errors.add(new HOGMQueryError(HOGMQueryError.Context.QUERY, "Query not specified", 0, 0, 0));
}
if (errors.size() == 0) {
HOGMParserWrapper parser = new HOGMParserWrapper();
if (parsedModel == null) {
parsedModel = parser.parseModel(model, new QueryErrorListener(HOGMQueryError.Context.MODEL, errors));
}
queryExpr = parser.parseTerm(query, new QueryErrorListener(HOGMQueryError.Context.QUERY, errors));
if (errors.size() == 0) {
FactorsAndTypes factorsAndTypes = new ExpressionFactorsAndTypes(parsedModel);
if (!canceled) {
inferencer = new InferenceForFactorGraphAndEvidence(factorsAndTypes, false, null, true, null);
startQuery = System.currentTimeMillis();
Expression marginal = inferencer.solve(queryExpr);
result.add(new HOGMQueryResult(query, queryExpr, parsedModel, marginal, System.currentTimeMillis() - startQuery));
}
}
}
}
catch (RecognitionException re) {
errors.add(new HOGMQueryError(HOGMQueryError.Context.MODEL, re.getMessage(), re.getOffendingToken().getLine(), re.getOffendingToken().getStartIndex(), re.getOffendingToken().getStopIndex()));
}
catch (UnableToParseAllTheInputError utpai) {
errors.add(new HOGMQueryError(utpai));
}
catch (HOGModelException me) {
me.getErrors().forEach(modelError -> {
String inStatement = modelError.getInStatementInfo().statement.toString();
String inSource = modelError.getInStatementInfo().sourceText;
String inSubStatement = modelError.getMessage();
String inInfo = "";
if (inSubStatement.equals("") || inSubStatement.equals(inSource)) {
inInfo = " in '"+inStatement+"'";
}
else {
inInfo = " ('"+inSubStatement+"') in '"+inStatement+"'";
}
if (!inSource.replaceAll(" ", "").replaceAll(";", "").equals(inStatement.replaceAll(" ", ""))) {
inInfo = inInfo + " derived from '"+inSource+"'";
}
errors.add(new HOGMQueryError(HOGMQueryError.Context.MODEL,
modelError.getErrorType().formattedMessage()+inInfo,
modelError.getInStatementInfo().line,
modelError.getInStatementInfo().startIndex,
modelError.getInStatementInfo().endIndex));
});
}
catch (Throwable t) {
// Unexpected
errors.add(new HOGMQueryError(t));
}
if (errors.size() > 0) {
result.add(new HOGMQueryResult(query, queryExpr, parsedModel, errors, System.currentTimeMillis() - startQuery));
}
}
return result;
}
public Expression simplifyAnswer(Expression answer, Expression forQuery) {
Expression result = answer;
Context context = getQueryContext();
if (HOGMSortDeclaration.IN_BUILT_BOOLEAN.getName().equals(GrinderUtil.getType(forQuery, context))) {
result = result.replaceAllOccurrences(forQuery, Expressions.TRUE, context);
result = simplifyWithinQueryContext(result);
answer = Expressions.parse(result.toString()); // This ensures numeric values have the correct precision
}
return result;
}
public Context getQueryContext() {
return inferencer.makeContextWithTypeInformation();
}
public Expression simplifyWithinQueryContext(Expression expr) {
return inferencer.simplify(expr);
}
public void cancelQuery() {
canceled = true;
if (inferencer != null) {
inferencer.interrupt();
}
}
protected class QueryErrorListener implements Parser.ErrorListener {
HOGMQueryError.Context context;
List errors;
QueryErrorListener(HOGMQueryError.Context context, List errors) {
this.context = context;
this.errors = errors;
}
@Override
public void parseError(Object offendingSymbol, int line, int charPositionInLine, String msg, Exception e) {
int start = 0;
int end = 0;
if (e != null && e instanceof RecognitionException) {
RecognitionException re = (RecognitionException) e;
if (re.getOffendingToken() != null) {
start = re.getOffendingToken().getStartIndex();
end = re.getOffendingToken().getStopIndex();
}
}
if (start > end) {
start = end;
}
errors.add(new HOGMQueryError(context,
"Error at line " + line + " column "+ charPositionInLine + " - " + msg,
line,
start,
end
));
}
}
}