All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.emory.mathcs.nlp.component.template.OnlineComponent Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2015, Emory University
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.emory.mathcs.nlp.component.template;

import edu.emory.mathcs.nlp.component.template.config.NLPConfig;
import edu.emory.mathcs.nlp.component.template.eval.Eval;
import edu.emory.mathcs.nlp.component.template.feature.FeatureTemplate;
import edu.emory.mathcs.nlp.component.template.node.AbstractNLPNode;
import edu.emory.mathcs.nlp.component.template.state.NLPState;
import edu.emory.mathcs.nlp.component.template.train.HyperParameter;
import edu.emory.mathcs.nlp.component.template.util.NLPFlag;
import edu.emory.mathcs.nlp.learning.optimization.OnlineOptimizer;
import edu.emory.mathcs.nlp.learning.util.FeatureVector;
import edu.emory.mathcs.nlp.learning.util.Instance;
import edu.emory.mathcs.nlp.learning.util.MLUtils;

import java.io.InputStream;
import java.io.Serializable;
import java.util.List;

/**
 * @author Jinho D. Choi ({@code [email protected]})
 */
public abstract class OnlineComponent, S extends NLPState> implements NLPComponent, Serializable
{
	private static final long serialVersionUID = 59819173578703335L;
	protected FeatureTemplate feature_template;
	protected boolean              document_based;
	protected OnlineOptimizer      optimizer;
	
	// for training and development
	protected transient HyperParameter hyper_parameter;
	protected transient NLPConfig   config;
	protected transient NLPFlag        flag;
	protected transient Eval           eval;

//	============================== CONSTRUCTORS ==============================
	
	public OnlineComponent(boolean document)
	{
		setDocumentBased(document);
	}
	
	public OnlineComponent(boolean document, InputStream configuration)
	{
		this(document);
		setConfiguration(configuration);
	}

//	============================== GETTERS/SETTERS ==============================
	
	public OnlineOptimizer getOptimizer()
	{
		return optimizer;
	}
	
	public void setOptimizer(OnlineOptimizer optimizer)
	{
		this.optimizer = optimizer;
	}
	
	public HyperParameter getHyperParameter()
	{
		return hyper_parameter;
	}

	public void setHyperParameter(HyperParameter hyperparameter)
	{
		hyper_parameter = hyperparameter;
	}
	
	public FeatureTemplate getFeatureTemplate()
	{
		return feature_template;
	}

	public void setFeatureTemplate(FeatureTemplate template)
	{
		feature_template = template;
	}
	
	/** {@link #config} and {@link #hyper_parameter} must not be null. */
	public void initFeatureTemplate()
	{
		feature_template = new FeatureTemplate<>(config.getFeatureTemplateElement(), getHyperParameter());
	}
	
	public Eval getEval()
	{
		return eval;
	}
	
	public void setEval(Eval eval)
	{
		this.eval = eval;
	}
	
	public NLPFlag getFlag()
	{
		return flag;
	}
	
	public void setFlag(NLPFlag flag)
	{
		this.flag = flag;
		
		if (flag == NLPFlag.EVALUATE && eval == null)
			setEval(createEvaluator());
	}
	
	public NLPConfig getConfiguration()
	{
		return config;
	}
	
	public void setConfiguration(NLPConfig config)
	{
		this.config = config;
	}
	
	public NLPConfig setConfiguration(InputStream in)
	{
		NLPConfig config = new NLPConfig(in);
		setConfiguration(config);
		return config;
	}
	
	public boolean isDocumentBased()
	{
		return document_based;
	}
	
	public void setDocumentBased(boolean document)
	{
		document_based = document;
	}
	
//	============================== FLAGS ==============================
	
	public boolean isTrain()
	{
		return flag == NLPFlag.TRAIN;
	}
	
	public boolean isDecode()
	{
		return flag == NLPFlag.DECODE;
	}
	
	public boolean isEvaluate()
	{
		return flag == NLPFlag.EVALUATE;
	}
	
//	============================== PROCESS ==============================
	
	@Override
	public void process(N[] nodes)
	{
		process(initState(nodes));
	}
	
	@Override
	public void process(List document)
	{
		if (document_based) process(initState(document));
		else for (N[] nodes : document) process(nodes);
	}
	
	/** Process the sequence of the nodes given the state. */
	public S process(S state)
	{
		if (!isDecode() && !state.saveOracle()) return state;
		int[] top2 = {0,-1};
		Instance instance;
		FeatureVector x;
		float[] scores;
		String label;

		while (!state.isTerminate())
		{
			x = feature_template.createFeatureVector(state, isTrain());
			
			if (isTrain())
			{
				label = state.getOracle();
				instance = new Instance(label, x);
				optimizer.train(instance);
				scores = instance.getScores();
				putLabel(instance.getStringLabel(), instance.getGoldLabel());
				top2[0] = hyper_parameter.getLOLS().chooseGold() ? instance.getGoldLabel() : getPrediction(state, scores)[0];
			}
			else
			{
				scores = optimizer.scores(x);
				top2 = getPrediction(state, scores);
			}
			
			state.next(optimizer.getLabelMap(), top2, scores);
		}
		
		if (isDecode() || isEvaluate())
		{
			postProcess(state);
			if (isEvaluate()) state.evaluate(eval);
		}
		
		return state;
	}
	
//	============================== HELPERS ==============================

	protected int[] getPrediction(S state, float[] scores)
	{
		return MLUtils.argmax2(scores);
	}
	
	protected void putLabel(String label, int index) {}
	
	/** @return the processing state for the input nodes. */
	protected abstract S initState(N[] nodes);
	
	/** @return the processing state for the input document. */
	protected abstract S initState(List document);
	
	public abstract Eval createEvaluator();
	
	/** Post-processes if necessary. */
	protected abstract void postProcess(S state);
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy