All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.sri.ai.praise.lang.grounded.model.HOGModelGrounding Maven / Gradle / Ivy

Go to download

SRI International's AIC PRAiSE (Probabilistic Reasoning As Symbolic Evaluation) Library (for Java 1.8+)

There is a newer version: 1.3.2
Show newest version
/*
 * Copyright (c) 2015, SRI International
 * All rights reserved.
 * Licensed under the The BSD 3-Clause License;
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at:
 * 
 * http://opensource.org/licenses/BSD-3-Clause
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 
 * Redistributions of source code must retain the above copyright
 * notice, this list of conditions and the following disclaimer.
 * 
 * Redistributions in binary form must reproduce the above copyright
 * notice, this list of conditions and the following disclaimer in the
 * documentation and/or other materials provided with the distribution.
 * 
 * Neither the name of the aic-praise nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
 * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 
 * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 
 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
 * OF THE POSSIBILITY OF SUCH DAMAGE.
 */
package com.sri.ai.praise.lang.grounded.model;

import static com.sri.ai.expresso.helper.Expressions.isNumber;
import static com.sri.ai.expresso.helper.Expressions.makeSymbol;
import static com.sri.ai.util.Util.list;
import static com.sri.ai.util.Util.myAssert;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;

import com.google.common.annotations.Beta;
import com.google.common.base.Function;
import com.sri.ai.expresso.api.Expression;
import com.sri.ai.expresso.helper.Expressions;
import com.sri.ai.grinder.api.Context;
import com.sri.ai.grinder.library.FunctorConstants;
import com.sri.ai.grinder.sgdpll.api.ConstraintTheory;
import com.sri.ai.praise.model.v1.HOGMSortDeclaration;
import com.sri.ai.praise.sgsolver.solver.ExpressionFactorsAndTypes;
import com.sri.ai.praise.sgsolver.solver.FactorsAndTypes;
import com.sri.ai.praise.sgsolver.solver.InferenceForFactorGraphAndEvidence;
import com.sri.ai.util.base.BinaryFunction;
import com.sri.ai.util.base.TernaryProcedure;
import com.sri.ai.util.base.Triple;
import com.sri.ai.util.math.MixedRadixNumber;
import com.sri.ai.util.math.Rational;

@Beta
public class HOGModelGrounding {

	// TODO: this class needs to be cleaned up to use the unified functionalities of expresso Types.
	
	public static boolean useContextSensitiveGrounding = true;
	
	public interface Listener {
		//
		// Preamble information
		void numberGroundVariables(int number);
		void groundVariableCardinality(int variableIndex, int cardinality);
		void numberFactors(int number);
		void factorParticipants(int factorIndex, int[] variableIndexes);
		//
		// Function tables
		void factorValue(int numberFactorValues, boolean isFirstValue, boolean isLastValue, Rational value);
		//
		// Evidence
		void evidence(int variableIndex, int valueIndex);
		//
		// Indicates grounding complete
		void groundingComplete();
	}
	
	public static void ground(FactorsAndTypes factorsAndTypes, List evidence, Listener listener) {
		if (factorsAndTypes.getMapFromNonUniquelyNamedConstantNameToTypeName().size() > 0) {
			throw new IllegalArgumentException("Constants cannot be grounded");
		}
		Map>> randomVariableNameToTypeSizeAndUniqueConstants = createRandomVariableNameToTypeSizeAndUniqueConstantsMap(factorsAndTypes);
		Map randomVariableIndexes = new LinkedHashMap<>();
		AtomicInteger atomicVariableIndex = new AtomicInteger(-1);
		listener.numberGroundVariables(randomVariableNameToTypeSizeAndUniqueConstants.size());
		randomVariableNameToTypeSizeAndUniqueConstants.entrySet().forEach(entry -> {
			randomVariableIndexes.put(entry.getKey(), atomicVariableIndex.addAndGet(1));
			listener.groundVariableCardinality(atomicVariableIndex.get(), entry.getValue().second);
		});
		
		Map> typeToValues = createTypeToValuesMap(factorsAndTypes, randomVariableNameToTypeSizeAndUniqueConstants);
		Map newUniqueConstantToTypeMap = createGroundedUniqueConstantToTypeMap(typeToValues);
	    
		InferenceForFactorGraphAndEvidence inferencer = makeInferencer(factorsAndTypes, newUniqueConstantToTypeMap);
		
		Context context = inferencer.makeContextWithTypeInformation();
		
		listener.numberFactors(factorsAndTypes.getFactors().size());
		int factorIndex = 0;
		for (Expression factor : factorsAndTypes.getFactors()) {
	    	ArrayList randomVariablesInFactor = new ArrayList<>(Expressions.getSubExpressionsSatisfying(factor, randomVariableNameToTypeSizeAndUniqueConstants::containsKey));
	    	if (randomVariablesInFactor.size() == 0) {
	    		throw new IllegalArgumentException("Factor contains no random variables: "+factor);
	    	}
	    	
	    	int[] participantVariableIndexes = new int[randomVariablesInFactor.size()];
	    	for (int i = 0; i < randomVariablesInFactor.size(); i++) {
	    		Expression randomVariable = randomVariablesInFactor.get(i);
	    		participantVariableIndexes[i] = randomVariableIndexes.get(randomVariable);
	    	}
	    	listener.factorParticipants(factorIndex, participantVariableIndexes);

	    	if (!useContextSensitiveGrounding) {
	    		fullGrounding(
	    				factor,
	    				randomVariablesInFactor,
	    				listener,
	    				randomVariableNameToTypeSizeAndUniqueConstants,
	    				typeToValues,
	    				inferencer,
	    				context);
	    	}
	    	else {
	    		contextSensitiveGrounding(
	    				factor,
	    				randomVariablesInFactor,
	    				listener,
	    				randomVariableNameToTypeSizeAndUniqueConstants,
	    				typeToValues,
	    				inferencer,
	    				context);
	    	}

	    	factorIndex++;
		}
		
		// Handle the evidence
		for (Expression evidenceAssignment : evidence) {
			if (Expressions.isFunctionApplicationWithArguments(evidenceAssignment)) {
				// TODO - add support for 'not ' and 'variable = value' and 'value = variable'
				throw new UnsupportedOperationException("Function application of evidence currently not supported: "+evidenceAssignment);
			}
			else if (Expressions.isSymbol(evidenceAssignment)) {
				int evidenceVariableIndex = randomVariableIndexes.get(evidenceAssignment);
				int evidenceValueIndex    = typeToValues.get(randomVariableNameToTypeSizeAndUniqueConstants.get(evidenceAssignment).first).indexOf(Expressions.TRUE);
				listener.evidence(evidenceVariableIndex, evidenceValueIndex);
			}
		}
		
		listener.groundingComplete();
	}

	/**
	 * Provides an appropriate {@link InferenceForFactorGraphAndEvidence} object.
	 * @param factorsAndTypes
	 * @param newUniqueConstantToTypeMap
	 * @return
	 */
	private static InferenceForFactorGraphAndEvidence makeInferencer(FactorsAndTypes factorsAndTypes, Map newUniqueConstantToTypeMap) {
		ExpressionFactorsAndTypes groundedFactorsAndTypesInformation = 
				new ExpressionFactorsAndTypes(
						Collections.emptyList(), // factors
						factorsAndTypes.getMapFromRandomVariableNameToTypeName(),
						factorsAndTypes.getMapFromNonUniquelyNamedConstantNameToTypeName(),
						newUniqueConstantToTypeMap,
						factorsAndTypes.getMapFromCategoricalTypeNameToSizeString(),
						list()); // additional types
		InferenceForFactorGraphAndEvidence inferencer = new InferenceForFactorGraphAndEvidence(groundedFactorsAndTypesInformation, false, null, true, null);
		return inferencer;
	}

	/**
	 * @param randomVariableNameToTypeSizeAndUniqueConstants
	 * @param randomVariablesInFactor
	 * @return
	 */
	private static BinaryFunction makeFunctionFromVariableIndexValueIndexToValue(
			Map>> randomVariableNameToTypeSizeAndUniqueConstants,
			ArrayList randomVariablesInFactor,
			Map> typeToValues) {
		
		return (variableIndex, valueIndex)
		-> {
			Expression variable = randomVariablesInFactor.get(variableIndex.intValue());
			Expression type = randomVariableNameToTypeSizeAndUniqueConstants.get(variable).first;
			return typeToValues.get(type).get(valueIndex);
		};
	}

	/**
	 * @param randomVariableNameToTypeSizeAndUniqueConstants
	 * @param randomVariablesInFactor
	 * @return
	 */
	private static Function makeFunctionFromVariableIndexToDomainSize(
			Map>> randomVariableNameToTypeSizeAndUniqueConstants,
			ArrayList randomVariablesInFactor) {
		
		return (variableIndex)
		-> {
			Expression variable = randomVariablesInFactor.get(variableIndex.intValue());
			Triple>
			typeCardinalityAndConstants
			= randomVariableNameToTypeSizeAndUniqueConstants.get(variable);
			return typeCardinalityAndConstants.second;
		};
	}

	/**
	 * @param factor
	 * @param randomVariablesInFactor
	 * @param listener
	 * @param randomVariableNameToTypeSizeAndUniqueConstants
	 * @param typeToValues
	 * @param inferencer
	 * @param context
	 */
	private static void fullGrounding(Expression factor, List randomVariablesInFactor, Listener listener, Map>> randomVariableNameToTypeSizeAndUniqueConstants, Map> typeToValues, InferenceForFactorGraphAndEvidence inferencer, Context context) {
		int[] radices                    = new int[randomVariablesInFactor.size()];
		List> factorRandomVariableTypeValues = new ArrayList<>();
		for (int i = 0; i < randomVariablesInFactor.size(); i++) {
			Expression randomVariable = randomVariablesInFactor.get(i);
			Expression type               = randomVariableNameToTypeSizeAndUniqueConstants.get(randomVariable).first;
			radices[i]                    = randomVariableNameToTypeSizeAndUniqueConstants.get(randomVariable).second;
			factorRandomVariableTypeValues.add(typeToValues.get(type));
		}
		
		boolean didIncrement     = true;
		MixedRadixNumber mrn     = new MixedRadixNumber(BigInteger.ZERO, radices);
		int numberFactorValues   = mrn.getMaxAllowedValue().intValue()+1;
		do {
			Expression groundedFactor = factor;
			for (int i = 0; i < randomVariablesInFactor.size(); i++) {
				int valueIndex = mrn.getCurrentNumeralValue(i);
				groundedFactor = groundedFactor.replaceAllOccurrences(randomVariablesInFactor.get(i), factorRandomVariableTypeValues.get(i).get(valueIndex), context);
			}  		
			Expression value = inferencer.simplify(groundedFactor, context);
			//				Expression value = inferencer.evaluate(groundedFactor);
			if (!Expressions.isNumber(value)) {
				throw new IllegalStateException("Unable to compute a number for the grounded factor ["+groundedFactor+"], instead got:"+value);
			}

			boolean isFirstValue = mrn.getValue().intValue() == 0;
			boolean isLastValue = mrn.getValue().intValue() == numberFactorValues - 1;
			listener.factorValue(numberFactorValues, isFirstValue, isLastValue, value.rationalValue());

			if (didIncrement = mrn.canIncrement()) {
				mrn.increment();
			}				
		} while (didIncrement);
	}

	//
	// PRIVATE
	//
	private static Map>> createRandomVariableNameToTypeSizeAndUniqueConstantsMap(FactorsAndTypes factorsAndTypes) {
		Map>> result = new LinkedHashMap<>();
		factorsAndTypes.getMapFromRandomVariableNameToTypeName().entrySet().forEach(entry -> {
			Expression       randomVariableName = Expressions.parse(entry.getKey());
			Expression       type               = Expressions.parse(entry.getValue());
			int              size               = 0;			
			List uniqueConstants    = new ArrayList<>();
			if (Expressions.hasFunctor(type, FunctorConstants.FUNCTION_TYPE)) {
				throw new UnsupportedOperationException("Relational random variables, "+randomVariableName+", are currently not supported.");
			}
			else if (Expressions.hasFunctor(type, HOGMSortDeclaration.IN_BUILT_INTEGER.getName()) && type.numberOfArguments() == 2) {
				size = (type.get(1).intValueExact() - type.get(0).intValueExact()) + 1;
			}
			else if (type.hasFunctor(FunctorConstants.INTEGER_INTERVAL) && type.numberOfArguments() == 2) {
				size = (type.get(1).intValueExact() - type.get(0).intValueExact()) + 1;
			}
			else {
				String sizeString = factorsAndTypes.getMapFromCategoricalTypeNameToSizeString().get(type);
				if (sizeString == null) {
					throw new IllegalArgumentException("Size of sort " + type + " is unknown");
				}
				size = Integer.parseInt(sizeString);
				factorsAndTypes.getMapFromUniquelyNamedConstantNameToTypeName()
					.entrySet().stream()
					.filter(uniqueConstantAndTypeEntry -> uniqueConstantAndTypeEntry.getValue().equals(entry.getValue()))
					.forEach(uniqueConstantAndTypeEntry -> uniqueConstants.add(Expressions.parse(uniqueConstantAndTypeEntry.getKey())));
			}
			result.put(randomVariableName, new Triple<>(type, size, uniqueConstants));
		});
		return result;
	}
	
	private static Map> createTypeToValuesMap(FactorsAndTypes factorsAndTypes, Map>> randomVariableNameToTypeExpressionTypeSizeAndUniqueConstants) {
		Map> typeToValuesMap = new LinkedHashMap<>();
		
		randomVariableNameToTypeExpressionTypeSizeAndUniqueConstants.values().forEach(typeSizeAndUniqueConstants -> {
			Expression       type            = typeSizeAndUniqueConstants.first;
		    Integer          size            = typeSizeAndUniqueConstants.second;
		    List uniqueConstants = typeSizeAndUniqueConstants.third;
			// random variables can share type information
			if (!typeToValuesMap.containsKey(type)) {
				List values = new ArrayList<>();
				// Is a numeric range
				if (Expressions.hasFunctor(type, HOGMSortDeclaration.IN_BUILT_INTEGER.getName()) && type.numberOfArguments() == 2) {
					int startInclusive = type.get(0).intValueExact();
					int endInclusive   = type.get(1).intValueExact();
					IntStream.rangeClosed(startInclusive, endInclusive).sequential().forEach(value -> values.add(Expressions.makeSymbol(value)));	
				}
				else {
					if (HOGMSortDeclaration.IN_BUILT_BOOLEAN.getName().equals(type)) {
						values.addAll(HOGMSortDeclaration.IN_BUILT_BOOLEAN.getAssignedConstants());
					}
					else if (type.hasFunctor(FunctorConstants.INTEGER_INTERVAL) && type.numberOfArguments() == 2) {
						int firstValue;
						int lastValue;
						try {
							firstValue = type.get(0).intValue();
							lastValue = type.get(1).intValue();
							for (int i = firstValue; i != lastValue + 1; i++) {
								values.add(makeSymbol(i));
							}
						}
						catch (Error e) {
							throw new Error("Integer interval can only be grounded if it has fixed bounds, but got " + type);
						}
					}
					else {
						// Is a sort name
						values.addAll(uniqueConstants);
						for (int i = uniqueConstants.size() + 1; i <= size; i++) {
							values.add(Expressions.makeSymbol(type.toString().toLowerCase() + "_" + i));
						}
					}
				}
				typeToValuesMap.put(type, values);
			}
		});
		
		return typeToValuesMap;
	}
	
	private static Map createGroundedUniqueConstantToTypeMap(Map> typeToValues) {
		Map result = new LinkedHashMap<>();
		
		typeToValues.entrySet().stream()
			.filter(entry -> Expressions.isSymbol(entry.getKey()))
			.forEach(sortEntry -> {
				sortEntry.getValue().forEach(constant -> result.put(constant.toString(), sortEntry.getKey().toString()));
			});
		
		return result;
	}
	
	/**
	 * @param factor
	 * @param randomVariablesInFactor
	 * @param listener
	 * @param randomVariableNameToTypeSizeAndUniqueConstants
	 * @param typeToValues TODO
	 * @param inferencer
	 * @param context
	 */
	private static void contextSensitiveGrounding(Expression factor, ArrayList randomVariablesInFactor, Listener listener, Map>> randomVariableNameToTypeSizeAndUniqueConstants, Map> typeToValues, InferenceForFactorGraphAndEvidence inferencer, Context context) {
		Function fromVariableIndexToDomainSize = 
				makeFunctionFromVariableIndexToDomainSize(randomVariableNameToTypeSizeAndUniqueConstants, randomVariablesInFactor);
		int numberFactorValues = 
				numberOfAssignmentsForVariablesStartingAt(0, randomVariablesInFactor.size(), fromVariableIndexToDomainSize);
		
		contextSensitiveGroundingFrom(
				0,
				factor, // starting from first variable
				randomVariablesInFactor, // variables to be used
				makeFunctionFromVariableIndexValueIndexToValue(randomVariableNameToTypeSizeAndUniqueConstants, randomVariablesInFactor, typeToValues),
				fromVariableIndexToDomainSize,
				inferencer.getConstraintTheory(),
				true, // first time this variable is being iterated (it happens only once)
				true, // last time this variable is being iterated (it happens only once)
				(isFirstValue, isLastValue, value)
				-> listener.factorValue(numberFactorValues, isFirstValue, isLastValue, value.rationalValue()),
				context);
	}

	private static void contextSensitiveGroundingFrom(
			int variableIndex,
			Expression expression,
			ArrayList variables,
			BinaryFunction fromVariableIndexAndValueIndexToValue,
			Function fromVariableIndexToDomainSize,
			ConstraintTheory constraintTheory,
			boolean firstIterationForVariable,
			boolean lastIterationForVariable,
			TernaryProcedure recordValue,
			Context context) {
		
		Expression variable = variables.get(variableIndex);
		boolean isLastVariable = variableIndex == variables.size() - 1;
		int numberOfVariableValues = fromVariableIndexToDomainSize.apply(variableIndex);
		
		for (int variableValueIndex = 0; variableValueIndex != numberOfVariableValues; variableValueIndex++) {
			boolean thisVariableIsAtItsFirstValue = variableValueIndex == 0;
			boolean thisVariableIsAtItsLastValue = variableValueIndex == numberOfVariableValues - 1;
			Expression value = fromVariableIndexAndValueIndexToValue.apply(variableIndex, variableValueIndex);
			Expression expressionWithReplacedValue = expression.replaceAllOccurrences(variable, value, context);
			Expression simplifiedExpression = constraintTheory.simplify(expressionWithReplacedValue, context);
			
			boolean expressionIsSimplifiedToConstant =
					isLastVariable || simplifiedExpression.getSyntacticFormType().equals("Symbol");

			if (expressionIsSimplifiedToConstant) {
				myAssert( () -> isNumber(simplifiedExpression) , () -> "Expression being grounded has been simplified to a symbol that is not a numerical constant: " + simplifiedExpression);

				int numberOfTimesThisValueMustBeWritten
				= numberOfAssignmentsForVariablesStartingAt(variableIndex + 1, variables.size(), fromVariableIndexToDomainSize);
				
				for (int i = 0; i != numberOfTimesThisValueMustBeWritten; i++) {
					boolean isFirstOverallValue = 
							firstIterationForVariable && 
							thisVariableIsAtItsFirstValue &&
							i == 0;
					
					boolean isLastOverallValue =
							lastIterationForVariable &&
							thisVariableIsAtItsLastValue &&
							i == numberOfTimesThisValueMustBeWritten - 1;
					
					recordValue.apply(isFirstOverallValue, isLastOverallValue, simplifiedExpression);
				}
			}
			else {
				boolean firstIterationForNextVariable
				= firstIterationForVariable && thisVariableIsAtItsFirstValue;
				
				boolean lastIterationForNextVariable
				= lastIterationForVariable && thisVariableIsAtItsLastValue;
				
				contextSensitiveGroundingFrom(
						variableIndex + 1,
						simplifiedExpression,
						variables,
						fromVariableIndexAndValueIndexToValue,
						fromVariableIndexToDomainSize,
						constraintTheory,
						firstIterationForNextVariable,
						lastIterationForNextVariable,
						recordValue,
						context);
			}
		}
	}

	private static int numberOfAssignmentsForVariablesStartingAt(int variableIndex, int numberOfVariables, Function domainSize) {
		if (variableIndex == numberOfVariables) {
			return 1;
		}
		else {
			int result = 1;
			for (int i = variableIndex; i != numberOfVariables; i++) {
				Integer variableDomainSize = domainSize.apply(i);
				myAssert( () -> variableDomainSize != 0, () -> "Variable domain size cannot be zero");
				result *= variableDomainSize;
			}
			return result;
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy