org.aika.corpus.SearchNode Maven / Gradle / Ivy
Show all versions of aika Show documentation
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.aika.corpus;
import org.aika.neuron.Activation.Rounds;
import org.aika.neuron.Activation.SynapseActivation;
import org.aika.neuron.Activation;
import org.aika.corpus.Conflicts.Conflict;
import org.aika.neuron.INeuron.NormWeight;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import static org.aika.neuron.Activation.State.UB;
import static org.aika.neuron.Activation.State.VALUE;
/**
* The {@code SearchNode} class represents a node in the binary search tree that is used to find the optimal
* interpretation for a given document. Each search node possess a refinement (simply a set of interpretation nodes).
* The two options that this search node examines are that the refinement will either part of the final interpretation or not.
* During each search step the activation values in all the neuron activations adjusted such that they reflect the interpretation of the current search path.
* When the search reaches the maximum depth of the search tree and no further refinements exists, a weight is computed evaluating the current search path.
* The search path with the highest weight is used to determine the final interpretation.
*
* Before the search is started a set of initial refinements is generated from the conflicts within the document.
* In other words, if there are no conflicts in a given document, then no search is needed. In this case the final interpretation
* will simply be the set of all interpretation nodes. The initial refinements are then expanded, meaning all interpretation nodes that are consistent
* with this refinement are added to the refinement. The initial refinements are then propagated along the search path as refinement candidates.
*
* @author Lukas Molzberger
*/
public class SearchNode implements Comparable {
private static final Logger log = LoggerFactory.getLogger(SearchNode.class);
public static int MAX_SEARCH_STEPS = 100000;
public int id;
public SearchNode excludedParent;
public SearchNode selectedParent;
public int visited;
List refinement;
RefMarker marker;
NormWeight[] weightDelta = new NormWeight[] {NormWeight.ZERO_WEIGHT, NormWeight.ZERO_WEIGHT};
NormWeight[] accumulatedWeight = new NormWeight[2];
public List modifiedActs = new ArrayList<>();
public TreeSet candidates = new TreeSet<>();
public enum Coverage {
SELECTED,
UNKNOWN,
EXCLUDED
}
private SearchNode(Document doc, List changed, SearchNode selParent, SearchNode exclParent, List ref, RefMarker m) {
id = doc.searchNodeIdCounter++;
visited = doc.visitedCounter++;
selectedParent = selParent;
excludedParent = exclParent;
refinement = expandRefinement(ref, doc.visitedCounter++);
markSelected(changed, refinement);
markExcluded(changed, refinement);
marker = m;
weightDelta = doc.vQueue.adjustWeight(this, changed);
if(selectedParent != null) {
for (int i = 0; i < 2; i++) {
accumulatedWeight[i] = weightDelta[i].add(selectedParent.accumulatedWeight[i]);
}
}
if(Document.OPTIMIZE_DEBUG_OUTPUT) {
log.info("Search Step: " + id + " Candidate Weight Delta: " + weightDelta);
log.info(doc.neuronActivationsToString(true, true, false) + "\n");
}
changeState(StateChange.Mode.OLD);
}
private void collectResults(Collection results) {
results.addAll(refinement);
if(selectedParent != null) selectedParent.collectResults(results);
}
public static SearchNode createRootSearchNode(Document doc) {
List changed = new ArrayList<>();
changed.add(doc.bottom);
return new SearchNode(doc, changed, null, null, Arrays.asList(doc.bottom), null);
}
public void computeBestInterpretation(Document doc) {
ArrayList results = new ArrayList<>();
results.add(doc.bottom);
doc.selectedSearchNode = null;
int[] searchSteps = new int[1];
List rootRefs = expandRootRefinement(doc);
refinement = expandRefinement(rootRefs, doc.visitedCounter++);
// Mark all interpretation nodes as candidates
markCandidatesRecursiveStep(doc.bottom, true);
markSelected(null, refinement);
markExcluded(null, refinement);
weightDelta = doc.vQueue.adjustWeight(this, rootRefs);
accumulatedWeight = weightDelta;
if(Document.OPTIMIZE_DEBUG_OUTPUT) {
log.info("Root SearchNode:" + toString());
}
doc.bottom.storeFinalWeight(doc.visitedCounter++);
generateInitialCandidates(doc);
SearchNode child = this.selectCandidate();
if(child != null) {
child.search(doc, this, null, searchSteps);
}
if (doc.selectedSearchNode != null) {
doc.selectedSearchNode.reconstructSelectedResult(doc);
doc.selectedSearchNode.collectResults(results);
}
doc.bestInterpretation = results;
if(doc.interrupted) {
log.warn("The search for the best interpretation has been interrupted. Too many search steps!");
}
}
private void reconstructSelectedResult(Document doc) {
if(selectedParent != null) selectedParent.reconstructSelectedResult(doc);
changeState(StateChange.Mode.NEW);
for(StateChange sc : modifiedActs) {
Activation act = sc.act;
if(act.finalState != null && act.finalState.value > 0.0) {
doc.finallyActivatedNeurons.add(act.key.n.neuron.get());
}
}
}
private double search(Document doc, SearchNode selectedParent, SearchNode excludedParent, int[] searchSteps) {
double selectedWeight = 0.0;
double excludedWeight = 0.0;
if(searchSteps[0] > MAX_SEARCH_STEPS) {
doc.interrupted = true;
}
searchSteps[0]++;
markCandidates(selectedParent.candidates);
markSelected(null, refinement);
markExcluded(null, refinement);
if(Document.OPTIMIZE_DEBUG_OUTPUT) {
log.info("Search Step: " + id);
log.info(toString());
}
changeState(StateChange.Mode.NEW);
if(Document.OPTIMIZE_DEBUG_OUTPUT) {
log.info(doc.neuronActivationsToString(true, true, false) + "\n");
}
generateNextLevelCandidates(doc, selectedParent, excludedParent);
if(candidates.size() == 0) {
SearchNode en = this;
while(en != null) {
if(en.marker != null && !en.marker.complete) {
en.marker.complete = !hasUnsatisfiedPositiveFeedbackLink(en.refinement);
}
en = en.selectedParent;
}
double accNW = accumulatedWeight[0].getNormWeight();
double selectedAccNW = doc.selectedSearchNode != null ? doc.selectedSearchNode.accumulatedWeight[0].getNormWeight() : 0.0;
if (accNW > selectedAccNW) {
doc.selectedSearchNode = this;
doc.bottom.storeFinalWeight(doc.visitedCounter++);
}
} else {
SearchNode child = selectCandidate();
if (child != null && !(marker.excluded && marker.complete)) {
selectedWeight = child.search(doc, this, excludedParent, searchSteps);
}
}
changeState(StateChange.Mode.OLD);
if(doc.interrupted) {
return 0.0;
}
SearchNode child;
do {
child = selectedParent.selectCandidate();
} while(child != null && marker.selected && child.marker.complete);
if(child != null) {
excludedWeight = child.search(doc, selectedParent, this, searchSteps);
}
if(selectedWeight >= excludedWeight) {
marker.selected = true;
return selectedWeight;
} else {
marker.excluded = true;
return excludedWeight;
}
}
private boolean hasUnsatisfiedPositiveFeedbackLink(List n) {
for(InterprNode x: n) {
if(hasUnsatisfiedPositiveFeedbackLink(x)) return true;
}
return false;
}
private boolean hasUnsatisfiedPositiveFeedbackLink(InterprNode n) {
if(n.hasUnsatisfiedPosFeedbackLinksCache != null) return n.hasUnsatisfiedPosFeedbackLinksCache;
for(Activation act: n.getNeuronActivations()) {
for(SynapseActivation sa: act.neuronOutputs) {
if(sa.s.key.isRecurrent && sa.s.w > 0.0 && !isCovered(sa.output.key.o.markedSelected)) {
n.hasUnsatisfiedPosFeedbackLinksCache = true;
return true;
}
}
}
for(InterprNode pn: n.parents) {
if(hasUnsatisfiedPositiveFeedbackLink(pn)) {
n.hasUnsatisfiedPosFeedbackLinksCache = true;
return true;
}
}
n.hasUnsatisfiedPosFeedbackLinksCache = false;
return false;
}
private SearchNode selectCandidate() {
if(candidates.isEmpty()) return null;
return candidates.pollFirst();
}
public void generateInitialCandidates(Document doc) {
candidates = new TreeSet<>();
for(InterprNode cn: collectConflicts(doc)) {
List changed = new ArrayList<>();
candidates.add(new SearchNode(doc, changed, this, null, Arrays.asList(cn), new RefMarker()));
}
}
public void generateNextLevelCandidates(Document doc, SearchNode selectedParent, SearchNode excludedParent) {
candidates = new TreeSet<>();
for(SearchNode pc: selectedParent.candidates) {
if(!checkSelected(pc.refinement) && !checkExcluded(pc.refinement, doc.visitedCounter++)) {
List changed = new ArrayList<>();
SearchNode c = new SearchNode(doc, changed, this, excludedParent, pc.refinement, pc.marker);
if(doc.selectedSearchNode == null || doc.selectedSearchNode.accumulatedWeight[VALUE].getNormWeight() < c.accumulatedWeight[UB].getNormWeight()) {
candidates.add(c);
}
}
}
}
private boolean checkSelected(List n) {
for(InterprNode x: n) {
if(!isCovered(x.markedSelected)) return false;
}
return true;
}
private boolean checkExcluded(List n, int v) {
for(InterprNode x: n) {
if(checkExcluded(x, v)) return true;
}
return false;
}
private boolean checkExcluded(InterprNode ref, int v) {
if(ref.visitedCheckExcluded == v) return false;
ref.visitedCheckExcluded = v;
if(isCovered(ref.markedExcluded)) return true;
for(InterprNode pn: ref.parents) {
if(checkExcluded(pn, v)) return true;
}
return false;
}
public static List collectConflicts(Document doc) {
List results = new ArrayList<>();
int v = doc.visitedCounter++;
for(InterprNode n: doc.bottom.children) {
for(Conflict c: n.conflicts.primary.values()) {
if(c.secondary.visitedCollectConflicts != v) {
c.secondary.visitedCollectConflicts = v;
results.add(c.secondary);
}
}
}
return results;
}
private static List expandRootRefinement(Document doc) {
ArrayList tmp = new ArrayList<>();
tmp.add(doc.bottom);
for(InterprNode pn: doc.bottom.children) {
if((pn.orInterprNodes == null || pn.orInterprNodes.isEmpty()) && pn.conflicts.primary.isEmpty() && pn.conflicts.secondary.isEmpty()) {
tmp.add(pn);
}
}
return tmp;
}
private List expandRefinement(List ref, int v) {
ArrayList tmp = new ArrayList<>();
for(InterprNode n: ref) {
markExpandRefinement(n, v);
tmp.add(n);
}
for(InterprNode n: ref) {
expandRefinementRecursiveStep(tmp, n, v);
}
if(ref.size() == tmp.size()) return tmp;
else return expandRefinement(tmp, v);
}
private void markExpandRefinement(InterprNode n, int v) {
if(n.markedExpandRefinement == v) return;
n.markedExpandRefinement = v;
for(InterprNode pn: n.parents) {
markExpandRefinement(pn, v);
}
}
private boolean hasUncoveredConflicts(InterprNode n) {
if(!n.conflicts.hasConflicts()) return false;
ArrayList conflicts = new ArrayList<>();
Conflicts.collectDirectConflicting(conflicts, n);
for(InterprNode cn: conflicts) {
if(!isCovered(cn.markedExcluded)) return true;
}
return false;
}
private void expandRefinementRecursiveStep(Collection results, InterprNode n, int v) {
if(n.visitedExpandRefinementRecursiveStep == v) return;
n.visitedExpandRefinementRecursiveStep = v;
if (n.refByOrInterprNode != null) {
for (InterprNode on : n.refByOrInterprNode) {
if(on.markedExpandRefinement != v && !hasUncoveredConflicts(on) && !isCovered(on.markedSelected)) {
markExpandRefinement(on, v);
results.add(on);
}
}
}
for(InterprNode pn: n.parents) {
if(!pn.isBottom()) {
expandRefinementRecursiveStep(results, pn, v);
}
}
if(n.isBottom()) return;
// Expand options that are partially covered by this refinement and partially by an earlier expand node.
for(InterprNode cn: n.children) {
if(cn.visitedExpandRefinementRecursiveStep == v) break;
// Check if all parents are either contained in this refinement or an earlier refinement.
boolean covered = true;
for(InterprNode cnp: cn.parents) {
if(cnp.visitedExpandRefinementRecursiveStep != v && !isCovered(cnp.markedSelected)) {
covered = false;
break;
}
}
if(covered) {
expandRefinementRecursiveStep(results, cn, v);
}
}
}
public Coverage getCoverage(InterprNode n) {
if(isCovered(n.markedSelected)) return Coverage.SELECTED;
if(selectedParent != null && n.markedHasCandidate != selectedParent.visited) return Coverage.EXCLUDED;
if(isCovered(n.markedExcluded)) return Coverage.EXCLUDED;
return Coverage.UNKNOWN;
}
public boolean isCovered(int g) {
SearchNode n = this;
do {
if(g == n.visited) return true;
else if(g > n.visited) return false;
n = n.selectedParent;
} while(n != null);
return false;
}
private void markSelected(List changed, List n) {
for(InterprNode x: n) {
markSelected(changed, x);
}
}
private boolean markSelected(List changed, InterprNode n) {
if(isCovered(n.markedSelected)) return false;
n.markedSelected = visited;
if(n.isBottom()) {
return false;
}
for(InterprNode p: n.parents) {
if(markSelected(changed, p)) return true;
}
for(InterprNode c: n.children) {
if(isCovered(n.markedSelected) || !containedInSelectedBranch(c)) continue;
if(c.isConflicting(n.doc.visitedCounter++)) return true;
if(markSelected(changed, c)) return true;
}
if(changed != null) changed.add(n);
return false;
}
private void markExcluded(List changed, List n) {
for(InterprNode x: n) {
markExcluded(changed, x);
}
}
private void markExcluded(List changed, InterprNode n) {
List conflicting = new ArrayList<>();
Conflicts.collectAllConflicting(conflicting, n, n.doc.visitedCounter++);
for(InterprNode cn: conflicting) {
markExcludedRecursiveStep(changed, cn);
}
}
private void markExcludedRecursiveStep(List changed, InterprNode n) {
if(isCovered(n.markedExcluded)) return;
n.markedExcluded = visited;
for(InterprNode c: n.children) {
markExcludedRecursiveStep(changed, c);
}
// If the or option has two input options and one of them is already excluded, then when the other one is excluded we also have to exclude the or option.
if(n.linkedByLCS != null) {
for(InterprNode c: n.linkedByLCS) {
if(checkOrNodeExcluded(c)) {
markExcludedRecursiveStep(changed, c);
}
}
}
if(changed != null) changed.add(n);
return;
}
private boolean checkOrNodeExcluded(InterprNode n) {
for(InterprNode on: n.orInterprNodes.values()) {
if(!isCovered(on.markedExcluded)) {
return false;
}
}
return true;
}
private void markCandidates(Collection candidates) {
for(SearchNode sn: candidates) {
for(InterprNode ref: sn.refinement) {
markCandidatesRecursiveStep(ref, false);
}
}
}
private void markCandidatesRecursiveStep(InterprNode n, boolean dir) {
if(n.markedHasCandidate == visited) return;
n.markedHasCandidate = visited;
for(InterprNode pn: dir ? n.children : n.parents) {
if(!pn.isBottom()) markCandidatesRecursiveStep(pn, dir);
}
}
public boolean containedInSelectedBranch(InterprNode n) {
for(InterprNode p: n.parents) {
if(!isCovered(p.markedSelected)) return false;
}
return true;
}
public String pathToString(Document doc) {
return (selectedParent != null ? selectedParent.pathToString(doc) : "") + " - " + toString(doc);
}
public String toString(Document doc) {
TreeSet tmp = new TreeSet<>();
for(InterprNode n: refinement) {
n.collectPrimitiveNodes(tmp, doc.interprIdCounter++);
}
StringBuilder sb = new StringBuilder();
for(InterprNode n: tmp) {
sb.append(n.primId);
sb.append(", ");
}
return sb.toString();
// return id + " : " + Utils.round(computeAccumulatedWeight()[0].getNormWeight()) + ", " + Utils.round(computeAccumulatedWeight()[1].getNormWeight()) + " - " + Utils.round(computeAccumulatedWeight()[0].w) + ", " + Utils.round(computeAccumulatedWeight()[1].w);
}
public void changeState(StateChange.Mode m) {
for(StateChange sc: modifiedActs) {
sc.restoreState(m);
}
}
@Override
public int compareTo(SearchNode c) {
int r = Double.compare(c.accumulatedWeight[0].getNormWeight(), accumulatedWeight[0].getNormWeight());
if(r != 0) return r;
return Integer.compare(id, c.id);
}
/**
* The {@code StateChange} class is used to store the state change of an activation that occurs in each node of
* the binary search tree. When a candidate refinement is selected during the search, then the activation values of
* all affected activation objects are adjusted. The changes to the activation values are also propagated through
* the network. The old state needs to be stored here in order for the search to be able to restore the old network
* state before following the alternative search branch.
*
*/
public static class StateChange {
public Activation act;
public Rounds oldRounds;
public Rounds newRounds;
public enum Mode { OLD, NEW }
public static void saveOldState(List changes, Activation act, long v) {
StateChange sc = act.currentStateChange;
if(sc == null || act.currentStateV != v) {
sc = new StateChange();
sc.oldRounds = act.rounds.copy();
act.currentStateChange = sc;
act.currentStateV = v;
sc.act = act;
if(changes != null) {
changes.add(sc);
}
}
}
public static void saveNewState(Activation act) {
StateChange sc = act.currentStateChange;
sc.newRounds = act.rounds;
}
public void restoreState(Mode m) {
act.rounds = (m == Mode.OLD ? oldRounds : newRounds).copy();
}
}
private static class RefMarker {
public boolean selected;
public boolean excluded;
public boolean complete;
}
}