All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.emory.mathcs.nlp.component.doc.DOCState Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2015, Emory University
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.emory.mathcs.nlp.component.doc;

import edu.emory.mathcs.nlp.component.template.eval.AccuracyEval;
import edu.emory.mathcs.nlp.component.template.eval.Eval;
import edu.emory.mathcs.nlp.component.template.feature.FeatureItem;
import edu.emory.mathcs.nlp.component.template.node.AbstractNLPNode;
import edu.emory.mathcs.nlp.component.template.state.NLPState;
import edu.emory.mathcs.nlp.learning.util.LabelMap;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.List;

/**
 * @author Jinho D. Choi ({@code [email protected]})
 */
public class DOCState> extends NLPState
{
	protected N         key_node;
	protected String    feat_key;
	protected String    oracle;
	protected boolean   terminate;
	protected List non_stopwords;
	protected float[]   prediction_scores;
	
	public DOCState(List document, String key)
	{
		super(document);
		feat_key = key;
		key_node = document.get(0)[1];
		non_stopwords = getNonStopWords(document);
		reinit();
	}
	
	@SuppressWarnings("unchecked")
	public List getNonStopWords(List document)
	{
		List nonstop = new ArrayList<>();
		N node;
		
		for (N[] nodes : document)
		{
			List sen = new ArrayList<>();
			
			for (int i=1; i getDocument(boolean excludeStopwords)
	{
		return excludeStopwords ? non_stopwords : getDocument();
	}
	
	public String getLabel()
	{
		return key_node.getFeat(feat_key);
	}
	
	public void setLabel(String label)
	{
		key_node.putFeat(feat_key, label);
	}
	
	public float[] getPredictionScores()
	{
		return prediction_scores;
	}

	public void setPredictionScores(float[] scores)
	{
		this.prediction_scores = scores;
	}
	
	public N getKeyNode()
	{
		return key_node;
	}
	
	@Override
	public N getNode(FeatureItem item)
	{
		return null;
	}

//	============================== TRANSITION ==============================
	
	@Override
	public void next(LabelMap map, int[] top2, float[] scores)
	{
		setLabel(map.getLabel(top2[0]));
		setPredictionScores(scores);
		terminate = true;
	}

	@Override
	public boolean isTerminate()
	{
		return terminate;
	}

	@Override
	public void evaluate(Eval eval)
	{
		int correct = oracle.equals(getLabel()) ? 1 : 0;
		((AccuracyEval)eval).add(correct, 1);
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy