All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.fau.cs.osr.utils.visitor.VisitorStackController Maven / Gradle / Ivy

/**
 * Copyright 2011 The Open Source Research Group,
 *                University of Erlangen-Nürnberg
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package de.fau.cs.osr.utils.visitor;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.commons.lang.StringUtils;

public abstract class VisitorStackController
{
	public static boolean DEBUG = false;
	
	private static final String VISIT_METHOD_NAME = "visit";
	
	private static final Class BATON_CLASS = Baton.class;
	
	private static final Map CACHES = new HashMap();
	
	// =========================================================================
	
	public static  Cache getOrRegisterCache(
			String name,
			List> visitorStack) throws IncompatibleVisitorStackDefinition
	{
		return getOrRegisterCache(name, visitorStack, .6f, 256, 384);
	}
	
	public static synchronized  Cache getOrRegisterCache(
			String name,
			List> visitorStack,
			float loadFactor,
			int lowerCapacity,
			int upperCapacity) throws IncompatibleVisitorStackDefinition
	{
		Cache cache = CACHES.get(name);
		if (cache == null)
		{
			cache = new Cache(visitorStack, loadFactor, lowerCapacity, upperCapacity);
			CACHES.put(name, cache);
		}
		else
		{
			cache.verifyDefinition(visitorStack);
		}
		return cache;
	}
	
	public static synchronized boolean dropCache(String name)
	{
		return (CACHES.remove(name) != null);
	}
	
	// =========================================================================
	
	private final Cache cache;
	
	private StackedVisitorInterface[] visitorStack;
	
	private StackedVisitorInterface[] enabledVisitors;
	
	private Baton baton;
	
	// =========================================================================
	
	protected VisitorStackController(
			String cacheName,
			List> visitorStack) throws IncompatibleVisitorStackDefinition
	{
		this(getOrRegisterCache(cacheName, visitorStack), visitorStack);
	}
	
	protected VisitorStackController(
			Cache cache,
			List> visitorStack) throws IncompatibleVisitorStackDefinition
	{
		for (StackedVisitorInterface visitor : visitorStack)
		{
			if (visitor == null)
				throw new NullPointerException("Visitor stack contains s");
		}
		
		@SuppressWarnings({ "unchecked", "rawtypes" })
		int tmp = new HashSet(visitorStack).size();
		if (tmp != visitorStack.size())
			throw new IllegalArgumentException("Visitor stack contains duplicates!");
		
		cache.verifyDefinition(visitorStack);
		
		@SuppressWarnings({ "unchecked" })
		StackedVisitorInterface[] stackArray = new StackedVisitorInterface[visitorStack.size()];
		
		this.cache = cache;
		this.visitorStack = visitorStack.toArray(stackArray);
		this.enabledVisitors = Arrays.copyOf(this.visitorStack, this.visitorStack.length);
	}
	
	// =========================================================================
	
	public int indexOfVisitor(StackedVisitorInterface visitor)
	{
		for (int i = 0; i < visitorStack.length; ++i)
		{
			if (visitorStack[i] == visitor)
				return i;
		}
		return -1;
	}
	
	public void setVisitor(int i, StackedVisitorInterface visitor)
	{
		if (visitor == null)
			throw new NullPointerException();
		if (cache.cacheDef.visitorStackDef[i] != visitor.getClass())
			throw new IllegalArgumentException("Replacement visitor's class does not matched the replaced visitor's class");
		visitorStack[i] = visitor;
		if (isVisitorEnabled(i))
			enabledVisitors[i] = visitor;
	}
	
	public StackedVisitorInterface getVisitor(int i)
	{
		return visitorStack[i];
	}
	
	public StackedVisitorInterface getEnabledVisitor(int i)
	{
		return enabledVisitors[i];
	}
	
	public boolean isVisitorEnabled(int i)
	{
		return (getEnabledVisitor(i) != null);
	}
	
	public void disableVisitor(int i)
	{
		enabledVisitors[i] = null;
	}
	
	public void enableVisitor(int i)
	{
		enabledVisitors[i] = visitorStack[i];
	}
	
	public void setVisitorEnabled(int i, boolean enable)
	{
		if (enable)
			enableVisitor(i);
		else
			disableVisitor(i);
	}
	
	// =========================================================================
	
	/**
	 * Start visitation at the given node.
	 * 
	 * @param node
	 *            The node at which the visitation will start.
	 * @return The result of the visitation. If the visit() method for the given
	 *         node doesn't return a value, null is returned.
	 */
	public Object go(T node)
	{
		T startNode = (T) before(node);
		
		this.baton = new Baton();
		Object result = resolveAndVisit(startNode);
		
		return after(node, result);
	}
	
	protected T before(T node)
	{
		T transformed = node;
		for (int i = 0; i < visitorStack.length; ++i)
		{
			if (isVisitorEnabled(i))
			{
				T result = getEnabledVisitor(i).before(transformed);
				if (result == null)
				{
					disableVisitor(i);
				}
				else
				{
					transformed = result;
				}
			}
		}
		return transformed;
	}
	
	protected Object after(T node, Object result)
	{
		for (int i = 0; i < visitorStack.length; ++i)
		{
			if (isVisitorEnabled(i))
				result = getEnabledVisitor(i).after(node, result);
		}
		return result;
	}
	
	// =========================================================================
	
	protected abstract Object visitNotFound(T node);
	
	protected Object handleVisitingException(T node, Throwable cause)
	{
		throw new VisitingException(node, cause);
	}
	
	// =========================================================================
	
	protected Object resolveAndVisit(T node)
	{
		Class nClass = node.getClass();
		
		VisitChain key = new VisitChain(nClass);
		VisitChain visiChain = cache.get(key);
		try
		{
			if (visiChain == null)
			{
				visiChain = buildVisitChain(key);
				cache.put(visiChain);
			}
			
			if (visiChain.isEmpty())
			{
				return visitNotFound(node);
			}
			else
			{
				return visiChain.invokeChain(baton, this, node);
			}
		}
		catch (InvocationTargetException e)
		{
			Throwable cause = e.getCause();
			if (cause instanceof VisitingException)
				throw (VisitingException) cause;
			
			return handleVisitingException(node, cause);
		}
		catch (VisitingException e)
		{
			throw e;
		}
		catch (VisitNotFoundException e)
		{
			throw e;
		}
		catch (Exception e)
		{
			throw new VisitorException(node, e);
		}
	}
	
	private VisitChain buildVisitChain(VisitChain key) throws SecurityException, NoSuchMethodException
	{
		Class nClass = key.getNodeClass();
		
		List chain = new ArrayList();
		
		for (int i = 0; i < visitorStack.length; ++i)
		{
			Class vClass = visitorStack[i].getClass();
			
			Method method = findVisit(vClass, nClass);
			if (method != null)
				chain.add(new Link(i, method));
		}
		
		return new VisitChain(key, chain);
	}
	
	private static Method findVisit(final Class vClass, final Class nClass) throws NoSuchMethodException, SecurityException
	{
		Method method = null;
		
		List> candidates = new ArrayList>();
		
		// Do a breadth first search in the hierarchy
		Queue> work = new ArrayDeque>();
		work.add(nClass);
		while (!work.isEmpty())
		{
			Class workItem = work.remove();
			try
			{
				method = vClass.getMethod(VISIT_METHOD_NAME, BATON_CLASS, workItem);
				candidates.add(workItem);
			}
			catch (NoSuchMethodException e)
			{
				// Consider non-interface classes first
				Class superclass = workItem.getSuperclass();
				if (superclass != null)
					work.add(superclass);
				for (Class i : workItem.getInterfaces())
					work.add(i);
			}
		}
		
		if (!candidates.isEmpty())
		{
			Collections.sort(candidates, new Comparator>()
			{
				@Override
				public int compare(Class arg0, Class arg1)
				{
					if (arg0 == arg1)
					{
						return 0;
					}
					else if (arg0.isAssignableFrom(arg1))
					{
						return +1;
					}
					else if (arg1.isAssignableFrom(arg0))
					{
						return -1;
					}
					else
					{
						throw new MultipleVisitMethodsMatchException(vClass, nClass, arg0, arg1);
					}
				}
			});
			
			method = vClass.getMethod(VISIT_METHOD_NAME, BATON_CLASS, candidates.get(0));
		}
		
		return method;
	}
	
	// =========================================================================
	
	protected static final class VisitChain
			implements
				Comparable
	{
		private static long useCounter = 0;
		
		private long lastUse = -1;
		
		private final Class nodeClass;
		
		private final Link[] chain;
		
		public VisitChain(Class nClass)
		{
			this.nodeClass = nClass;
			this.chain = null;
		}
		
		public VisitChain(VisitChain chain, List links)
		{
			this.nodeClass = chain.nodeClass;
			this.chain = links.toArray(new Link[links.size()]);
		}
		
		public Class getNodeClass()
		{
			return nodeClass;
		}
		
		public boolean isEmpty()
		{
			return (chain.length == 0);
		}
		
		@SuppressWarnings({ "rawtypes" })
		public Object invokeChain(
				Baton baton,
				VisitorStackController controller,
				Object node) throws IllegalArgumentException, IllegalAccessException, InvocationTargetException
		{
			touch();
			
			// this method must only be called on non-empty chains
			if (isEmpty())
				throw new AssertionError();
			
			Object visitNext = node;
			// If there are no enabled visitors just return the node itself
			Object result = node;
			
			StackedVisitorInterface[] enabledVisitors = controller.enabledVisitors;
			
			int i = 0;
			chainIter: while (true)
			{
				StackedVisitorInterface visitor = enabledVisitors[chain[i].visitorIndex];
				if (visitor != null)
				{
					if (DEBUG)
						System.err.println(chain[i].method + ": " + StringUtils.abbreviate(visitNext.toString(), 32));
					result = chain[i].method.invoke(visitor, baton, visitNext);
					
					// We must always query the code to reset it, even if result == null
					int batonCode = baton.queryAndResetCode();
					
					if (result == null)
						break chainIter;
					
					switch (batonCode)
					{
						case Baton.REDISPATCH:
							// Re-dispatch if instance of nodeClass
							if (nodeClass.isInstance(result))
								result = redispatch(controller, result);
							
							// Leave chain
							break chainIter;
						
						case Baton.CONTINUE_SAME_TYPE_OR_REDISPATCH:
							if (node.getClass() != result.getClass())
							{
								// Re-dispatch if instance of nodeClass
								if (nodeClass.isInstance(result))
									result = redispatch(controller, result);
								
								// Leave chain
								break chainIter;
							}
							else
							{
								// Continue to next visitor
								break;
							}
							
						case Baton.CONTINUE_ASSIGNABLE_TYPE_OR_REDISPATCH:
							if (!node.getClass().isInstance(result))
							{
								// Re-dispatch if instance of nodeClass
								if (nodeClass.isInstance(result))
									result = redispatch(controller, result);
								
								// Leave chain
								break chainIter;
							}
							else
							{
								// Continue to next visitor
								break;
							}
							
						case Baton.CONTINUE_SAME_REF:
							if (visitNext != result)
							{
								// Leave chain
								break chainIter;
							}
							else
							{
								// Continue to next visitor
								break;
							}
							
						case Baton.CONTINUE_SAME_TYPE:
							if (node.getClass() != result.getClass())
							{
								// Leave chain
								break chainIter;
							}
							else
							{
								// Continue to next visitor
								break;
							}
							
						case Baton.CONTINUE_ASSIGNABLE_TYPE:
							if (!node.getClass().isInstance(result))
							{
								// Leave chain
								break chainIter;
							}
							else
							{
								// Continue to next visitor
								break;
							}
							
						case Baton.SKIP:
							// Leave chain
							break chainIter;
						
						default:
							throw new AssertionError(batonCode);
					}
					
					++i;
					if (i >= chain.length)
						break chainIter;
					
					visitNext = result;
				}
				else
				{
					++i;
					if (i >= chain.length)
						break;
				}
			}
			
			return result;
		}
		
		@SuppressWarnings({ "rawtypes", "unchecked" })
		private Object redispatch(
				VisitorStackController controller,
				Object visitNext)
		{
			if (DEBUG)
				System.err.println(StringUtils.abbreviate(visitNext.toString(), 32));
			return controller.resolveAndVisit(visitNext);
		}
		
		public void touch()
		{
			lastUse = ++useCounter;
		}
		
		@Override
		public int hashCode()
		{
			final int prime = 31;
			int result = 1;
			result = prime * result + nodeClass.hashCode();
			return result;
		}
		
		@Override
		public boolean equals(Object obj)
		{
			VisitChain other = (VisitChain) obj;
			if (nodeClass != other.nodeClass)
				return false;
			return true;
		}
		
		@Override
		public int compareTo(VisitChain o)
		{
			// Equality is not possible!
			return (lastUse < o.lastUse) ? -1 : +1;
		}
	}
	
	// =========================================================================
	
	private static final class Link
	{
		private final int visitorIndex;
		
		private final Method method;
		
		public Link(int visitorIndex, Method method)
		{
			this.visitorIndex = visitorIndex;
			this.method = method;
		}
	}
	
	// =========================================================================
	
	public static final class Cache
	{
		private int lowerCapacity;
		
		private int upperCapacity;
		
		private final CacheDefinition cacheDef;
		
		private final ConcurrentHashMap cache;
		
		private Cache(List> visitorStack,
				float loadFactor,
				int lowerCapacity,
				int upperCapacity)
		{
			this.lowerCapacity = lowerCapacity;
			this.upperCapacity = upperCapacity;
			this.cacheDef = new CacheDefinition(visitorStack);
			this.cache = new ConcurrentHashMap(lowerCapacity, loadFactor);
		}
		
		private void verifyDefinition(
				List> visitorStack) throws IncompatibleVisitorStackDefinition
		{
			if (!new CacheDefinition(visitorStack).equals(cacheDef))
				throw new IncompatibleVisitorStackDefinition("Incompatible visitor stack");
		}
		
		private VisitChain get(VisitChain key)
		{
			return cache.get(key);
		}
		
		private synchronized VisitChain put(VisitChain chain)
		{
			VisitChain cached = cache.putIfAbsent(chain, chain);
			if (cached != null)
			{
				return cached;
			}
			else
			{
				// Make sure the target is not swept from the cache ...
				chain.touch();
				if (cache.size() > upperCapacity)
					sweepCache();
				
				return chain;
			}
		}
		
		private synchronized void sweepCache()
		{
			if (cache.size() <= upperCapacity)
				return;
			
			VisitChain keys[] = new VisitChain[cache.size()];
			
			Enumeration keysEnum = cache.keys();
			
			int i = 0;
			while (i < keys.length && keysEnum.hasMoreElements())
				keys[i++] = keysEnum.nextElement();
			
			int length = i;
			Arrays.sort(keys, 0, length);
			
			int to = length - lowerCapacity;
			for (int j = 0; j < to; ++j)
				cache.remove(keys[j]);
		}
	}
	
	// =========================================================================
	
	private static final class CacheDefinition
	{
		private final int hash;
		
		private final Class[] visitorStackDef;
		
		public CacheDefinition(
				List> visitorStack)
		{
			@SuppressWarnings("rawtypes")
			Class[] visitorStackDef = new Class[visitorStack.size()];
			int hash = 0;
			
			int i = 0;
			for (StackedVisitorInterface visitor : visitorStack)
			{
				visitorStackDef[i] = visitor.getClass();
				hash = hash * 13 + visitorStackDef[i].hashCode() * 17;
				++i;
			}
			
			this.hash = hash;
			this.visitorStackDef = visitorStackDef;
		}
		
		@Override
		public int hashCode()
		{
			return hash;
		}
		
		@Override
		public boolean equals(Object obj)
		{
			if (this == obj)
				return true;
			if (obj == null)
				return false;
			if (getClass() != obj.getClass())
				return false;
			CacheDefinition other = (CacheDefinition) obj;
			if (!Arrays.equals(visitorStackDef, other.visitorStackDef))
				return false;
			return true;
		}
	}
}