All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.berkeley.nlp.syntax.UnaryClosureComputer Maven / Gradle / Ivy

Go to download

The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).

The newest version!
package edu.berkeley.nlp.syntax;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import edu.berkeley.nlp.util.CollectionUtils;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Factory;

/**
 * Assumes the type V is hashable
 * 
 * @author adampauls
 * 
 * @param 
 */
public class UnaryClosureComputer
{

	

	public static class Edge
	{
		
		@Override
		public int hashCode()
		{
			final int prime = 31;
			int result = 1;
			
			result = prime * result + ((child == null) ? 0 : child.hashCode());
			result = prime * result + ((parent == null) ? 0 : parent.hashCode());
			return result;
		}

		@Override
		public boolean equals(Object obj)
		{
			if (this == obj) return true;
			if (obj == null) return false;
			if (getClass() != obj.getClass()) return false;
			Edge other = (Edge) obj;
			
			if (child == null)
			{
				if (other.child != null) return false;
			}
			else if (!child.equals(other.child)) return false;
			if (parent == null)
			{
				if (other.parent != null) return false;
			}
			else if (!parent.equals(other.parent)) return false;
			return true;
		}

		public void setParent(V parent)
		{
			this.parent = parent;
		}

		public void setChild(V child)
		{
			this.child = child;
		}


		private V parent;

		private V child;

		private double score;

		private Edge(V parent, V child)
		{
			this.parent = parent;
			this.child = child;
		}

		public V getParent()
		{
			return parent;
		}

		public V getChild()
		{
			return child;
		}

		public double getScore()
		{
			return score;
		}
		
		
		public void setScore(double d)
		{
			score = d;
		}

		
		
		
	}
	
	private Factory unaryRuleFactory = new Factory()
	{

		public Edge newInstance(Object... args)
		{
			return new Edge(args[0], args[1]);
		}
	};
	

	
	
	Map>> closedUnaryRulesByChild = new HashMap>>();

	Map>> closedUnaryRulesByParent = new HashMap>>();

	Map, List> pathMap = new HashMap, List>();
	
	Set> unaryRules = new HashSet>();




	private boolean sumInsteadOfMultipy;
	
	/**
	 * First is parent, second is child;
	 * 
	 * @return
	 */
	public Map>> getAllClosedRulesByChildren()
	{
		return closedUnaryRulesByChild;
	}

	public List> getClosedUnaryRulesByChild(V child)
	{
		return CollectionUtils.getValueList(closedUnaryRulesByChild, child);
	}

	public List> getClosedUnaryRulesByParent(V parent)
	{
		return CollectionUtils.getValueList(closedUnaryRulesByParent, parent);
	}

	public List getPath(Edge unaryRule)
	{
		return pathMap.get(unaryRule);
	}

	@Override
	public String toString()
	{
		StringBuilder sb = new StringBuilder();
		for (V parent : closedUnaryRulesByParent.keySet())
		{
			for (Edge unaryRule : getClosedUnaryRulesByParent(parent))
			{
				List path = getPath(unaryRule);
				// if (path.size() == 2) continue;
				sb.append(unaryRule);
				sb.append("  ");
				sb.append(path);
				sb.append("\n");
			}
		}
		return sb.toString();
	}

	public UnaryClosureComputer(boolean sumInsteadOfMultiply)
	{
		this.sumInsteadOfMultipy = sumInsteadOfMultiply;
	}

	public void add(V parent, V child, double score)
	{
		final Edge edge = new Edge(parent, child);
		edge.setScore(score);
		unaryRules.add(edge);
	}

	public void solve()
	{
		Map, List> closureMap = computeUnaryClosure(unaryRules);
		for (Edge unaryRule : closureMap.keySet())
		{
			addUnary(unaryRule, closureMap.get(unaryRule));
		}
	}

	

	private void addUnary(Edge unaryRule, List path)
	{
		CollectionUtils.addToValueList(closedUnaryRulesByChild, unaryRule.getChild(), unaryRule);
		CollectionUtils.addToValueList(closedUnaryRulesByParent, unaryRule.getParent(), unaryRule);
		pathMap.put(unaryRule, path);
	}

	private Map, List> computeUnaryClosure(Collection> unaryRules)
	{

		Map, V> intermediateStates = new HashMap, V>();
		Counter> pathCosts = new Counter>();
		Map>> closedUnaryRulesByChild = new HashMap>>();
		Map>> closedUnaryRulesByParent = new HashMap>>();

		Set states = new HashSet();

		for (Edge unaryRule : unaryRules)
		{
			relax(pathCosts, intermediateStates, closedUnaryRulesByChild, closedUnaryRulesByParent, unaryRule, null, unaryRule.getScore());
			states.add(unaryRule.getParent());
			states.add(unaryRule.getChild());
		}

		for (V intermediateState : states)
		{
			List> incomingRules = closedUnaryRulesByChild.get(intermediateState);
			List> outgoingRules = closedUnaryRulesByParent.get(intermediateState);
			if (incomingRules == null || outgoingRules == null) continue;
			for (Edge incomingRule : incomingRules)
			{
				for (Edge outgoingRule : outgoingRules)
				{
					Edge rule = unaryRuleFactory.newInstance(incomingRule.getParent(), outgoingRule.getChild());
					double newScore = combinePathCosts(pathCosts, incomingRule, outgoingRule);
					relax(pathCosts, intermediateStates, closedUnaryRulesByChild, closedUnaryRulesByParent, rule, intermediateState, newScore);
				}
			}
		}

		for (V state : states)
			
		{
			Edge selfLoopRule = unaryRuleFactory.newInstance(state, state);
			relax(pathCosts, intermediateStates, closedUnaryRulesByChild, closedUnaryRulesByParent, selfLoopRule, null, 0.0);
		}

		Map, List> closureMap = new HashMap, List>();

		for (Edge unaryRule : pathCosts.keySet())
		{
			unaryRule.setScore(pathCosts.getCount(unaryRule));
			List path = extractPath(unaryRule, intermediateStates);
			closureMap.put(unaryRule, path);
		}

		
		return closureMap;

	}

	/**
	 * @param pathCosts
	 * @param incomingRule
	 * @param outgoingRule
	 * @return
	 */
	private double combinePathCosts(Counter> pathCosts, Edge incomingRule, Edge outgoingRule)
	{
		return this.sumInsteadOfMultipy ? (pathCosts.getCount(incomingRule) + pathCosts.getCount(outgoingRule)) : (pathCosts.getCount(incomingRule) * pathCosts
			.getCount(outgoingRule));
	}

	private List extractPath(Edge unaryRule, Map, V> intermediateStates)
	{
		List path = new ArrayList();
		path.add(unaryRule.getParent());
		V intermediateState = intermediateStates.get(unaryRule);
		if (intermediateState != null)
		{
			List parentPath = extractPath(unaryRuleFactory.newInstance(unaryRule.getParent(), intermediateState), intermediateStates);
			for (int i = 1; i < parentPath.size() - 1; i++)
			{
				V state = parentPath.get(i);
				path.add(state);
			}
			path.add(intermediateState);
			List childPath = extractPath(unaryRuleFactory.newInstance(intermediateState, unaryRule.getChild()), intermediateStates);
			for (int i = 1; i < childPath.size() - 1; i++)
			{
				V state = childPath.get(i);
				path.add(state);
			}
		}
		if (path.size() == 1 && unaryRule.getParent() == unaryRule.getChild()) return path;
		path.add(unaryRule.getChild());
		return path;
	}

	private void relax(Counter> pathCosts, Map, V> intermediateStates, Map>> closedUnaryRulesByChild,
		Map>> closedUnaryRulesByParent, Edge unaryRule, V intermediateState, double newScore)
	{
		if (intermediateState != null && (intermediateState.equals(unaryRule.getParent()) || intermediateState.equals(unaryRule.getChild()))) return;
		boolean isNewRule = !pathCosts.containsKey(unaryRule);
		double oldScore = (isNewRule ? Double.NEGATIVE_INFINITY : pathCosts.getCount(unaryRule));
		if (oldScore > newScore) return;
		if (isNewRule)
		{
			CollectionUtils.addToValueList(closedUnaryRulesByChild, unaryRule.getChild(), unaryRule);
			CollectionUtils.addToValueList(closedUnaryRulesByParent, unaryRule.getParent(), unaryRule);
		}
		pathCosts.setCount(unaryRule, newScore);
		intermediateStates.put(unaryRule, intermediateState);
	}

	public double getProb(V parent, V child)
	{
		if (parent == child) return 0.0;
		final List> byParent = closedUnaryRulesByParent.get(parent);
		if (byParent == null) return Double.POSITIVE_INFINITY;
		int childIndex = byParent.indexOf(unaryRuleFactory.newInstance(parent, child));
		if (childIndex < 0) return Double.POSITIVE_INFINITY;
		final Edge unaryRule = byParent.get(childIndex);

		return unaryRule.getScore();
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy