All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.emory.mathcs.nlp.component.dep.DEPParser Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2015, Emory University
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.emory.mathcs.nlp.component.dep;

import edu.emory.mathcs.nlp.common.treebank.DEPTagEn;
import edu.emory.mathcs.nlp.component.template.OnlineComponent;
import edu.emory.mathcs.nlp.component.template.eval.Eval;
import edu.emory.mathcs.nlp.component.template.node.AbstractNLPNode;
import edu.emory.mathcs.nlp.learning.util.FeatureVector;
import edu.emory.mathcs.nlp.learning.util.MLUtils;
import it.unimi.dsi.fastutil.ints.IntSet;

import java.io.InputStream;
import java.util.List;

/**
 * @author Jinho D. Choi ({@code [email protected]})
 */
public class DEPParser> extends OnlineComponent>
{
	private static final long serialVersionUID = 7031031976396726276L;
	private DEPLabelCandidate label_candidates;

	public DEPParser()
	{
		super(false);
		label_candidates = new DEPLabelCandidate<>(); 
	}
	
	public DEPParser(InputStream configuration)
	{
		super(false, configuration);
		label_candidates = new DEPLabelCandidate<>();
	}
	
//	============================== ABSTRACT ==============================
	
	@Override
	public Eval createEvaluator()
	{
		return new DEPEval();
	}
	
	@Override
	protected DEPState initState(N[] nodes)
	{
		return new DEPState<>(nodes);
	}
	
	@Override
	protected DEPState initState(List document) {return null;}
	
//	============================== LABELS ==============================
	
	@Override
	protected void putLabel(String label, int index)
	{
		label_candidates.add(label, index);
	}
	
	@Override
	protected int[] getPrediction(DEPState state, float[] scores)
	{
		return label_candidates.getLabelIndices(state.getStack(), state.getInput(), scores);
	}
	
	public DEPLabelCandidate getLabelCandidates()
	{
		return label_candidates;
	}
	
//	============================== POST-PROCESS ==============================

	@Override
	protected void postProcess(DEPState state)
	{
		N[] nodes = state.getNodes();
		DEPTriple max;
		N node;
		
		for (int i=1; i state, DEPTriple max, N[] nodes, int currID, int dir)
	{
		IntSet labels = (dir > 0) ? label_candidates.getLeftArcs() : label_candidates.getRightArcs();
		N head, node = nodes[currID];
		int yhat, window = 0;
		float[] scores;
		FeatureVector x;
		
		for (int headID=currID+dir; 0 <= headID&&headID < nodes.length; headID+=dir)
		{
			if (++window > 5) break;
			head = nodes[headID];

			if (!head.isDescendantOf(node))
			{
				if (dir > 0)	state.reset(currID, headID);
				else			state.reset(headID, currID);
				
				x = feature_template.createFeatureVector(state, isTrain());
				scores = optimizer.scores(x);
				yhat = MLUtils.argmax(scores, labels);
				if (max.score < scores[yhat]) max.set(headID, yhat, scores[yhat]);	
			}
		}
	}
	
	class DEPTriple
	{
		int    headId;
		int    yhat;
		double score;
		
		public DEPTriple()
		{
			set(-1, -1, -Double.MAX_VALUE);
		}
		
		public void set(int headID, int yhat, double score)
		{
			this.headId = headID;
			this.yhat   = yhat;
			this.score  = score;
		}
		
		public boolean isNull()
		{
			return headId < 0;
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy