org.maltparserx.parser.guide.decision.SeqDecisionModel Maven / Gradle / Ivy
package org.maltparserx.parser.guide.decision;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import org.maltparserx.core.exception.MaltChainedException;
import org.maltparserx.core.feature.FeatureModel;
import org.maltparserx.core.feature.FeatureVector;
import org.maltparserx.core.syntaxgraph.DependencyStructure;
import org.maltparserx.parser.DependencyParserConfig;
import org.maltparserx.parser.guide.ClassifierGuide;
import org.maltparserx.parser.guide.GuideException;
import org.maltparserx.parser.guide.instance.AtomicModel;
import org.maltparserx.parser.guide.instance.FeatureDivideModel;
import org.maltparserx.parser.guide.instance.InstanceModel;
import org.maltparserx.parser.history.action.GuideDecision;
import org.maltparserx.parser.history.action.MultipleDecision;
import org.maltparserx.parser.history.action.SingleDecision;
import org.maltparserx.parser.history.container.TableContainer.RelationToNextDecision;
/**
*
* @author Johan Hall
* @since 1.1
**/
public class SeqDecisionModel implements DecisionModel {
private ClassifierGuide guide;
private String modelName;
private FeatureModel featureModel;
private InstanceModel instanceModel;
private int decisionIndex;
private DecisionModel prevDecisionModel;
private DecisionModel nextDecisionModel;
private String branchedDecisionSymbols;
public SeqDecisionModel(ClassifierGuide guide, FeatureModel featureModel) throws MaltChainedException {
this.branchedDecisionSymbols = "";
setGuide(guide);
setFeatureModel(featureModel);
setDecisionIndex(0);
setModelName("sdm"+decisionIndex);
setPrevDecisionModel(null);
}
public SeqDecisionModel(ClassifierGuide guide, DecisionModel prevDecisionModel, String branchedDecisionSymbol) throws MaltChainedException {
if (branchedDecisionSymbol != null && branchedDecisionSymbol.length() > 0) {
this.branchedDecisionSymbols = branchedDecisionSymbol;
} else {
this.branchedDecisionSymbols = "";
}
setGuide(guide);
setFeatureModel(prevDecisionModel.getFeatureModel());
setDecisionIndex(prevDecisionModel.getDecisionIndex() + 1);
setPrevDecisionModel(prevDecisionModel);
if (branchedDecisionSymbols != null && branchedDecisionSymbols.length() > 0) {
setModelName("sdm"+decisionIndex+branchedDecisionSymbols);
} else {
setModelName("sdm"+decisionIndex);
}
}
public void updateFeatureModel() throws MaltChainedException {
featureModel.update();
}
public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
if (instanceModel != null) {
instanceModel.finalizeSentence(dependencyGraph);
}
if (nextDecisionModel != null) {
nextDecisionModel.finalizeSentence(dependencyGraph);
}
}
public void noMoreInstances() throws MaltChainedException {
if (guide.getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
throw new GuideException("The decision model could not create it's model. ");
}
if (instanceModel != null) {
instanceModel.noMoreInstances();
instanceModel.train();
}
if (nextDecisionModel != null) {
nextDecisionModel.noMoreInstances();
}
}
public void terminate() throws MaltChainedException {
if (instanceModel != null) {
instanceModel.terminate();
instanceModel = null;
}
if (nextDecisionModel != null) {
nextDecisionModel.terminate();
nextDecisionModel = null;
}
}
public void addInstance(GuideDecision decision) throws MaltChainedException {
if (decision instanceof SingleDecision) {
throw new GuideException("A sequantial decision model expect a sequence of decisions, not a single decision. ");
}
featureModel.update();
final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
if (instanceModel == null) {
initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
}
instanceModel.addInstance(singleDecision);
if (singleDecision.continueWithNextDecision() && decisionIndex+1 < decision.numberOfDecisions()) {
if (nextDecisionModel == null) {
initNextDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), branchedDecisionSymbols);
}
nextDecisionModel.addInstance(decision);
}
}
public boolean predict(GuideDecision decision) throws MaltChainedException {
if (decision instanceof SingleDecision) {
throw new GuideException("A sequantial decision model expect a sequence of decisions, not a single decision. ");
}
featureModel.update();
final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
if (instanceModel == null) {
initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
}
boolean success = instanceModel.predict(singleDecision);
if (singleDecision.continueWithNextDecision() && decisionIndex+1 < decision.numberOfDecisions()) {
if (nextDecisionModel == null) {
initNextDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), branchedDecisionSymbols);
}
success = nextDecisionModel.predict(decision) && success;
}
return success;
}
public FeatureVector predictExtract(GuideDecision decision) throws MaltChainedException {
if (decision instanceof SingleDecision) {
throw new GuideException("A sequantial decision model expect a sequence of decisions, not a single decision. ");
}
featureModel.update();
final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
if (instanceModel == null) {
initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
}
FeatureVector fv = instanceModel.predictExtract(singleDecision);
if (singleDecision.continueWithNextDecision() && decisionIndex+1 < decision.numberOfDecisions()) {
if (nextDecisionModel == null) {
initNextDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), branchedDecisionSymbols);
}
nextDecisionModel.predictExtract(decision);
}
return fv;
}
public FeatureVector extract() throws MaltChainedException {
featureModel.update();
return instanceModel.extract(); // TODO handle many feature vectors
}
public boolean predictFromKBestList(GuideDecision decision) throws MaltChainedException {
if (decision instanceof SingleDecision) {
throw new GuideException("A sequantial decision model expect a sequence of decisions, not a single decision. ");
}
boolean success = false;
final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
// TODO develop different strategies for resolving which kBestlist that should be used
if (nextDecisionModel != null && singleDecision.continueWithNextDecision()) {
success = nextDecisionModel.predictFromKBestList(decision);
}
if (!success) {
success = singleDecision.updateFromKBestList();
if (success && singleDecision.continueWithNextDecision() && decisionIndex+1 < decision.numberOfDecisions()) {
if (nextDecisionModel == null) {
initNextDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), branchedDecisionSymbols);
}
nextDecisionModel.predict(decision);
}
}
return success;
}
public ClassifierGuide getGuide() {
return guide;
}
public String getModelName() {
return modelName;
}
public FeatureModel getFeatureModel() {
return featureModel;
}
public int getDecisionIndex() {
return decisionIndex;
}
public DecisionModel getPrevDecisionModel() {
return prevDecisionModel;
}
public DecisionModel getNextDecisionModel() {
return nextDecisionModel;
}
private void setPrevDecisionModel(DecisionModel prevDecisionModel) {
this.prevDecisionModel = prevDecisionModel;
}
private void setNextDecisionModel(DecisionModel nextDecisionModel) {
this.nextDecisionModel = nextDecisionModel;
}
private void setFeatureModel(FeatureModel featureModel) {
this.featureModel = featureModel;
}
private void setDecisionIndex(int decisionIndex) {
this.decisionIndex = decisionIndex;
}
private void setModelName(String modelName) {
this.modelName = modelName;
}
private void setGuide(ClassifierGuide guide) {
this.guide = guide;
}
private void initInstanceModel(String subModelName) throws MaltChainedException {
FeatureVector fv = featureModel.getFeatureVector(branchedDecisionSymbols+"."+subModelName);
if (fv == null) {
fv = featureModel.getFeatureVector(subModelName);
}
if (fv == null) {
fv = featureModel.getMainFeatureVector();
}
DependencyParserConfig c = guide.getConfiguration();
if (c.getOptionValue("guide", "data_split_column").toString().length() == 0) {
instanceModel = new AtomicModel(-1, fv, this);
} else {
instanceModel = new FeatureDivideModel(fv, this);
}
}
private void initNextDecisionModel(SingleDecision decision, String branchedDecisionSymbol) throws MaltChainedException {
Class> decisionModelClass = null;
if (decision.getRelationToNextDecision() == RelationToNextDecision.SEQUANTIAL) {
decisionModelClass = org.maltparserx.parser.guide.decision.SeqDecisionModel.class;
} else if (decision.getRelationToNextDecision() == RelationToNextDecision.BRANCHED) {
decisionModelClass = org.maltparserx.parser.guide.decision.BranchedDecisionModel.class;
} else if (decision.getRelationToNextDecision() == RelationToNextDecision.NONE) {
decisionModelClass = org.maltparserx.parser.guide.decision.OneDecisionModel.class;
}
if (decisionModelClass == null) {
throw new GuideException("Could not find an appropriate decision model for the relation to the next decision");
}
try {
Class>[] argTypes = { org.maltparserx.parser.guide.ClassifierGuide.class, org.maltparserx.parser.guide.decision.DecisionModel.class,
java.lang.String.class };
Object[] arguments = new Object[3];
arguments[0] = getGuide();
arguments[1] = this;
arguments[2] = branchedDecisionSymbol;
Constructor> constructor = decisionModelClass.getConstructor(argTypes);
setNextDecisionModel((DecisionModel)constructor.newInstance(arguments));
} catch (NoSuchMethodException e) {
throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
} catch (InstantiationException e) {
throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
} catch (IllegalAccessException e) {
throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
} catch (InvocationTargetException e) {
throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
}
}
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append(modelName + ", ");
sb.append(nextDecisionModel.toString());
return sb.toString();
}
}