All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.maltparserx.parser.guide.instance.AtomicModel Maven / Gradle / Ivy

package org.maltparserx.parser.guide.instance;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Formatter;

import org.maltparserx.core.exception.MaltChainedException;
import org.maltparserx.core.feature.FeatureVector;
import org.maltparserx.core.feature.function.FeatureFunction;
import org.maltparserx.core.feature.function.Modifiable;
import org.maltparserx.core.syntaxgraph.DependencyStructure;
import org.maltparserx.ml.LearningMethod;
import org.maltparserx.parser.guide.ClassifierGuide;
import org.maltparserx.parser.guide.GuideException;
import org.maltparserx.parser.guide.Model;
import org.maltparserx.parser.history.action.SingleDecision;


/**

@author Johan Hall
@since 1.0
*/
public class AtomicModel implements InstanceModel {
	private Model parent;
	private String modelName;
	private FeatureVector featureVector;
	private int index;
	private int frequency = 0;
	private LearningMethod method;

	
	/**
	 * Constructs an atomic model.
	 * 
	 * @param index the index of the atomic model (-1..n), where -1 is special value (used by a single model 
	 * or the master divide model) and n is number of divide models.
	 * @param features the feature vector used by the atomic model.
	 * @param parent the parent guide model.
	 * @throws MaltChainedException
	 */
	public AtomicModel(int index, FeatureVector features, Model parent) throws MaltChainedException {
		setParent(parent);
		setIndex(index);
		if (index == -1) {
			setModelName(parent.getModelName()+".");
		} else {
			setModelName(parent.getModelName()+"."+new Formatter().format("%03d", index)+".");
		}
		setFeatures(features);
		setFrequency(0);
		initMethod();
		if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH && index == -1 && getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter() != null) {
			try {
				getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().write(method.toString());
				getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().flush();
			} catch (IOException e) {
				throw new GuideException("Could not write learner settings to the information file. ", e);
			}
		}
	}
	
	public void addInstance(SingleDecision decision) throws MaltChainedException {
		try {
			method.addInstance(decision, featureVector);
		} catch (NullPointerException e) {
			throw new GuideException("The learner cannot be found. ", e);
		}
	}

	
	public void noMoreInstances() throws MaltChainedException {
		try {
			method.noMoreInstances();
		} catch (NullPointerException e) {
			throw new GuideException("The learner cannot be found. ", e);
		}
	}
	
	public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
		try {
			method.finalizeSentence(dependencyGraph);
		} catch (NullPointerException e) {
			throw new GuideException("The learner cannot be found. ", e);
		}
	}

	public boolean predict(SingleDecision decision) throws MaltChainedException {
		try {
//			if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
//				throw new GuideException("Cannot predict during batch training. ");
//			}
			return method.predict(featureVector, decision);
		} catch (NullPointerException e) {
			throw new GuideException("The learner cannot be found. ", e);
		}
	}

	public FeatureVector predictExtract(SingleDecision decision) throws MaltChainedException {
		try {
//			if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
//				throw new GuideException("Cannot predict during batch training. ");
//			}
			if (method.predict(featureVector, decision)) {
				return featureVector;
			}
			return null;
		} catch (NullPointerException e) {
			throw new GuideException("The learner cannot be found. ", e);
		}
	}
	
	public FeatureVector extract() throws MaltChainedException {
		return featureVector;
	}
	
	public void terminate() throws MaltChainedException {
		if (method != null) {
			method.terminate();
			method = null;
		}
		featureVector = null;
		parent = null;
	}
	
	/**
	 * Moves all instance from this atomic model into the destination atomic model and add the divide feature.
	 * This method is used by the feature divide model to sum up all model below a certain threshold.
	 * 
	 * @param model the destination atomic model 
	 * @param divideFeature the divide feature
	 * @param divideFeatureIndexVector the divide feature index vector
	 * @throws MaltChainedException
	 */
	public void moveAllInstances(AtomicModel model, FeatureFunction divideFeature, ArrayList divideFeatureIndexVector) throws MaltChainedException {
		if (method == null) {
			throw new GuideException("The learner cannot be found. ");
		} else if (model == null) {
			throw new GuideException("The guide model cannot be found. ");
		} else if (divideFeature == null) {
			throw new GuideException("The divide feature cannot be found. ");
		} else if (divideFeatureIndexVector == null) {
			throw new GuideException("The divide feature index vector cannot be found. ");
		}
		((Modifiable)divideFeature).setFeatureValue(index);
		method.moveAllInstances(model.getMethod(), divideFeature, divideFeatureIndexVector);
		method.terminate();
		method = null;
	}
	
	/**
	 * Invokes the train() of the learning method 
	 * 
	 * @throws MaltChainedException
	 */
	public void train() throws MaltChainedException {
		try {
			method.train(featureVector);
			method.terminate();
			method = null;
			
		} catch (NullPointerException e) {	
			throw new GuideException("The learner cannot be found. ", e);
		}
		

	}
	
	/**
	 * Initialize the learning method according to the option --learner-method.
	 * 
	 * @throws MaltChainedException
	 */
	public void initMethod() throws MaltChainedException {
		Class clazz = (Class)getGuide().getConfiguration().getOptionValue("guide", "learner");
		Class[] argTypes = { org.maltparserx.parser.guide.instance.InstanceModel.class, java.lang.Integer.class };
		Object[] arguments = new Object[2];
		arguments[0] = this;
		if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
			arguments[1] = LearningMethod.CLASSIFY;
		} else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
			arguments[1] = LearningMethod.BATCH;
		} 

		try {	
			Constructor constructor = clazz.getConstructor(argTypes);
			this.method = (LearningMethod)constructor.newInstance(arguments);
		} catch (NoSuchMethodException e) {
			throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
		} catch (InstantiationException e) {
			throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
		} catch (IllegalAccessException e) {
			throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
		} catch (InvocationTargetException e) {
			throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
		}
	}
	
	
	
	/**
	 * Returns the parent guide model
	 * 
	 * @return the parent guide model
	 */
	public Model getParent() throws MaltChainedException {
		if (parent == null) {
			throw new GuideException("The atomic model can only be used by a parent model. ");
		}
		return parent;
	}

	/**
	 * Sets the parent guide model
	 * 
	 * @param parent the parent guide model
	 */
	protected void setParent(Model parent) {
		this.parent = parent;
	}

	public String getModelName() {
		return modelName;
	}

	/**
	 * Sets the name of the atomic model
	 * 
	 * @param modelName the name of the atomic model
	 */
	protected void setModelName(String modelName) {
		this.modelName = modelName;
	}

	/**
	 * Returns the feature vector used by this atomic model
	 * 
	 * @return a feature vector object
	 */
	public FeatureVector getFeatures() {
		return featureVector;
	}

	/**
	 * Sets the feature vector used by the atomic model.
	 * 
	 * @param features a feature vector object
	 */
	protected void setFeatures(FeatureVector features) {
		this.featureVector = features;
	}

	public ClassifierGuide getGuide() {
		return parent.getGuide();
	}
	
	/**
	 * Returns the index of the atomic model
	 * 
	 * @return the index of the atomic model
	 */
	public int getIndex() {
		return index;
	}

	/**
	 * Sets the index of the model (-1..n), where -1 is a special value.
	 * 
	 * @param index index value (-1..n) of the atomic model
	 */
	protected void setIndex(int index) {
		this.index = index;
	}

	/**
	 * Returns the frequency (number of instances)
	 * 
	 * @return the frequency (number of instances)
	 */
	public int getFrequency() {
		return frequency;
	}
	
	/**
	 * Increase the frequency by 1
	 */
	public void increaseFrequency() {
		if (parent instanceof InstanceModel) {
			((InstanceModel)parent).increaseFrequency();
		}
		frequency++;
	}
	
	public void decreaseFrequency() {
		if (parent instanceof InstanceModel) {
			((InstanceModel)parent).decreaseFrequency();
		}
		frequency--;
	}
	/**
	 * Sets the frequency (number of instances)
	 * 
	 * @param frequency (number of instances)
	 */
	protected void setFrequency(int frequency) {
		this.frequency = frequency;
	} 
	
	/**
	 * Returns a learner object
	 * 
	 * @return a learner object
	 */
	public LearningMethod getMethod() {
		return method;
	}
	
	
	/* (non-Javadoc)
	 * @see java.lang.Object#toString()
	 */
	public String toString() {
		final StringBuilder sb = new StringBuilder();
		sb.append(method.toString());
		return sb.toString();
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy