
org.linqs.psl.runtime.GroundingAPI Maven / Gradle / Ivy
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2023 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.runtime;
import org.linqs.psl.config.Config;
import org.linqs.psl.config.Options;
import org.linqs.psl.config.RuntimeOptions;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.DataStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.Partition;
import org.linqs.psl.grounding.Grounding;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.model.rule.arithmetic.AbstractGroundArithmeticRule;
import org.linqs.psl.model.rule.logical.AbstractGroundLogicalRule;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.model.term.Term;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.DummyTermStore;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.StringUtils;
import org.linqs.psl.util.Version;
import com.fasterxml.jackson.core.util.DefaultIndenter;
import com.fasterxml.jackson.core.util.DefaultPrettyPrinter;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* A interface to PSL's grounding functionality.
*/
public final class GroundingAPI extends Runtime {
private static final Logger log = Logger.getLogger(GroundingAPI.class);
public static GroundProgram groundStatic(String configPath) {
GroundingAPI api = new GroundingAPI();
return api.ground(configPath);
}
public static GroundProgram groundStatic(RuntimeConfig config) {
GroundingAPI api = new GroundingAPI();
return api.ground(config);
}
/**
* A static interface specifically meant for methods that provide serialized input and want serialized output
* (both in the form of JSON).
*/
public static String serializedGround(String jsonConfig, String basePath) {
RuntimeConfig config = RuntimeConfig.fromJSON(jsonConfig, basePath);
GroundProgram program = groundStatic(config);
return program.toJSON();
}
public GroundProgram ground(String configPath) {
RuntimeConfig config = RuntimeConfig.fromFile(configPath);
return ground(config);
}
public GroundProgram ground(RuntimeConfig config) {
Config.pushLayer();
try {
return groundInternal(config);
} finally {
Config.popLayer();
cleanup();
}
}
private GroundProgram groundInternal(RuntimeConfig config) {
// Apply any top-level options found in the config.
for (Map.Entry entry : config.options.entrySet()) {
Config.setProperty(entry.getKey(), entry.getValue(), false);
}
// Specially check if we need to re-init the logger.
initLogger();
log.info("PSL Grounding API Version {}", Version.getFull());
config.validate();
// Ensure that all atoms are stored (unless overwritten).
Options.ATOM_STORE_STORE_ALL_ATOMS.set(true);
// Apply top-level options again after validation (since options may have been changed or added).
for (Map.Entry entry : config.options.entrySet()) {
Config.setProperty(entry.getKey(), entry.getValue(), false);
}
List rules = new ArrayList();
for (Rule rule : config.rules.getRules()) {
rules.add(rule);
}
DataStore dataStore = initDataStore(config);
loadData(dataStore, config, RuntimeConfig.KEY_INFER);
Set closedPredicates = config.getClosedPredicates(RuntimeConfig.KEY_INFER);
Partition targetPartition = dataStore.getPartition(Runtime.PARTITION_NAME_TARGET);
Partition observationsPartition = dataStore.getPartition(Runtime.PARTITION_NAME_OBSERVATIONS);
Database database = dataStore.getDatabase(targetPartition, closedPredicates, observationsPartition);
AtomStore atomStore = database.getAtomStore();
TermStore store = new DummyTermStore(database);
final List groundRules = new ArrayList();
Map groundAtoms = null;
if (!RuntimeOptions.OUTPUT_ALL_ATOMS.getBoolean()) {
groundAtoms = new HashMap();
}
final Map finalGroundAtoms = groundAtoms;
Grounding.setGroundRuleCallback(new Grounding.GroundRuleCallback() {
public synchronized void call(GroundRule groundRule) {
groundRules.add(mapGroundRule(rules.indexOf(groundRule.getRule()), atomStore, groundRule, finalGroundAtoms));
}
});
Grounding.groundAll(rules, store);
Grounding.setGroundRuleCallback(null);
if (groundAtoms == null) {
groundAtoms = new HashMap(atomStore.size());
for (GroundAtom groundAtom : atomStore) {
groundAtoms.put(Integer.valueOf(groundAtom.getIndex()), new AtomInfo(groundAtom));
}
}
store.close();
database.close();
dataStore.close();
return new GroundProgram(groundAtoms, groundRules);
}
private GroundRuleInfo mapGroundRule(int ruleIndex, AtomStore store, GroundRule groundRule, Map usedAtoms) {
float weight = -1.0f;
if (groundRule.getRule().isWeighted()) {
weight = ((WeightedRule)groundRule.getRule()).getWeight();
}
if (groundRule instanceof AbstractGroundLogicalRule) {
return mapGroundRule(ruleIndex, store, (AbstractGroundLogicalRule)groundRule, weight, usedAtoms);
} else if (groundRule instanceof AbstractGroundArithmeticRule) {
return mapGroundRule(ruleIndex, store, (AbstractGroundArithmeticRule)groundRule, weight, usedAtoms);
}
throw new IllegalStateException("Unknown rule type: " + groundRule.getClass());
}
private GroundRuleInfo mapGroundRule(int ruleIndex, AtomStore store, AbstractGroundLogicalRule groundRule, float weight,
Map usedAtoms) {
int currentAtom = 0;
float[] coefficients = new float[groundRule.size()];
int[] atoms = new int[groundRule.size()];
// Remember: the negated DNF is tracked, so invert all coefficients.
for (GroundAtom atom : groundRule.getPositiveAtoms()) {
coefficients[currentAtom] = -1.0f;
int atomIndex = store.getAtomIndex(atom);
atoms[currentAtom] = atomIndex;
currentAtom++;
if (usedAtoms != null) {
Integer key = Integer.valueOf(atomIndex);
if (!usedAtoms.containsKey(key)) {
usedAtoms.put(key, new AtomInfo(atom));
}
}
}
for (GroundAtom atom : groundRule.getNegativeAtoms()) {
coefficients[currentAtom] = 1.0f;
int atomIndex = store.getAtomIndex(atom);
atoms[currentAtom] = atomIndex;
currentAtom++;
if (usedAtoms != null) {
Integer key = Integer.valueOf(atomIndex);
if (!usedAtoms.containsKey(key)) {
usedAtoms.put(key, new AtomInfo(atom));
}
}
}
return new GroundRuleInfo(ruleIndex, "|", weight, 0.0f, coefficients, atoms);
}
private GroundRuleInfo mapGroundRule(int ruleIndex, AtomStore store, AbstractGroundArithmeticRule groundRule, float weight,
Map usedAtoms) {
GroundAtom[] rawAtoms = groundRule.getOrderedAtoms();
int[] atoms = new int[rawAtoms.length];
for (int i = 0; i < rawAtoms.length; i++) {
int atomIndex = store.getAtomIndex(rawAtoms[i]);
atoms[i] = atomIndex;
if (usedAtoms != null) {
Integer key = Integer.valueOf(atomIndex);
if (!usedAtoms.containsKey(key)) {
usedAtoms.put(key, new AtomInfo(rawAtoms[i]));
}
}
}
return new GroundRuleInfo(ruleIndex, groundRule.getComparator().toString(), weight, groundRule.getConstant(),
groundRule.getCoefficients(), atoms);
}
public static final class GroundProgram {
public Map atoms;
public List groundRules;
public GroundProgram(Map atoms, List groundRules) {
this.atoms = atoms;
this.groundRules = groundRules;
}
@Override
public String toString() {
return toJSON();
}
public String toJSON() {
ObjectMapper mapper = new ObjectMapper();
mapper.enable(SerializationFeature.INDENT_OUTPUT);
mapper.enable(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS);
DefaultPrettyPrinter printer = new DefaultPrettyPrinter().withObjectIndenter(new DefaultIndenter(" ", "\n"));
try {
return mapper.writer(printer).writeValueAsString(this);
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
}
public static final class AtomInfo {
public String predicate;
public String[] arguments;
public float value;
public boolean observed;
public AtomInfo(GroundAtom atom) {
predicate = atom.getPredicate().getName();
value = atom.getValue();
observed = (atom instanceof ObservedAtom);
arguments = new String[atom.getArity()];
Term[] terms = atom.getArguments();
for (int i = 0; i < terms.length; i++) {
arguments[i] = ((Constant)terms[i]).rawToString();
}
}
}
public static final class GroundRuleInfo {
public int ruleIndex;
public String operator;
public float weight;
public float constant;
public float[] coefficients;
public int[] atoms;
public GroundRuleInfo(int ruleIndex, String operator, float weight, float constant, float[] coefficients, int[] atoms) {
this.ruleIndex = ruleIndex;
this.operator = operator;
this.weight = weight;
this.constant = constant;
this.coefficients = coefficients;
this.atoms = atoms;
}
public String toString() {
return String.format(
"Rule Type: %s, Weight: %f, Constant: %f, coefficients: [%s], atoms: [%s].",
operator, weight, constant,
StringUtils.join(", ", coefficients), StringUtils.join(", ", atoms));
}
}
public static void main(String[] args) {
if (args == null || args.length != 1) {
System.out.println("USAGE: " + GroundingAPI.class + " ");
return;
}
GroundProgram program = GroundingAPI.groundStatic(args[0]);
System.out.println(program);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy