All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.emory.mathcs.nlp.component.dep.DEPLabelCandidate Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2015, Emory University
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.emory.mathcs.nlp.component.dep;

import edu.emory.mathcs.nlp.component.template.node.AbstractNLPNode;
import edu.emory.mathcs.nlp.learning.util.MLUtils;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;

import java.io.Serializable;

/**
 * @author Jinho D. Choi ({@code [email protected]})
 */
public class DEPLabelCandidate> implements Serializable
{
	private static final long serialVersionUID = 5579863219050051216L;
	
	private IntSet X_SHIFT;
	private IntSet NO_X;
	private IntSet NO_SHIFT_OR_PASS;
	private IntSet NOT_NO_REDUCE;
	private IntSet LEFT_ARC;
	private IntSet RIGHT_ARC;
	
	public DEPLabelCandidate()
	{
		X_SHIFT          = new IntOpenHashSet();
		NO_X             = new IntOpenHashSet();
		NO_SHIFT_OR_PASS = new IntOpenHashSet();
		NOT_NO_REDUCE    = new IntOpenHashSet();
		LEFT_ARC         = new IntOpenHashSet();
		RIGHT_ARC        = new IntOpenHashSet();
	}

	public void add(String label, int index)
	{
		DEPLabel lb = new DEPLabel(label);
		
		if (lb.isList(DEPState.LIST_SHIFT))
			X_SHIFT.add(index);
		
		if (lb.isArc(DEPState.ARC_LEFT))
			LEFT_ARC.add(index);
		else if (lb.isArc(DEPState.ARC_RIGHT))
			RIGHT_ARC.add(index);
		else
		{
			NO_X.add(index);
			if (!lb.isList(DEPState.LIST_REDUCE)) NO_SHIFT_OR_PASS.add(index);
		}
		
		if (!(lb.isArc(DEPState.ARC_NO) && lb.isList(DEPState.LIST_REDUCE)))
			NOT_NO_REDUCE.add(index);
	}
	
	public IntSet get(N stack, N input)
	{
		if (stack.getID() == 0)
			return X_SHIFT;
		
		if (stack.isDescendantOf(input))
			return NO_X;
		
		if (input.isDescendantOf(stack))
			return stack.hasDependencyHead() ? NO_X : NO_SHIFT_OR_PASS;
		
		if (!stack.hasDependencyHead())
			return NOT_NO_REDUCE;

		return null;
	}
	
	public int[] getLabelIndices(N stack, N input, float[] scores)
	{
		IntSet candidates = get(stack, input);
		return MLUtils.argmax2(scores, candidates);
	}
	
	public IntSet getLeftArcs()
	{
		return LEFT_ARC;
	}
	
	public IntSet getRightArcs()
	{
		return RIGHT_ARC;
	}
	
	public void remap(Int2IntMap map)
	{
		X_SHIFT          = remap(map, X_SHIFT);
		NO_X             = remap(map, NO_X);
		NO_SHIFT_OR_PASS = remap(map, NO_SHIFT_OR_PASS);
		NOT_NO_REDUCE    = remap(map, NOT_NO_REDUCE);
		LEFT_ARC         = remap(map, LEFT_ARC);
		RIGHT_ARC        = remap(map, RIGHT_ARC);
	}
	
	private IntSet remap(Int2IntMap map, IntSet set)
	{
		IntSet nset = new IntOpenHashSet();
		
		for (int i : set)
		{
			if (map.containsKey(i))
				nset.add(map.get(i));
		}
		
		return nset;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy