All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.daisy.common.xpath.saxon.ReflexiveExtensionFunctionProvider Maven / Gradle / Ivy

package org.daisy.common.xpath.saxon;

import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import javax.xml.parsers.DocumentBuilder;
import javax.xml.xpath.XPath;

import org.daisy.common.saxon.SaxonHelper;
import org.daisy.common.saxon.SaxonInputValue;

import org.w3c.dom.Element;
import org.w3c.dom.Node;

import net.sf.saxon.dom.DocumentBuilderImpl;
import net.sf.saxon.dom.ElementOverNodeInfo;
import net.sf.saxon.dom.NodeOverNodeInfo;
import net.sf.saxon.expr.XPathContext;
import net.sf.saxon.lib.ExtensionFunctionCall;
import net.sf.saxon.lib.ExtensionFunctionDefinition;
import net.sf.saxon.ma.arrays.ArrayItem;
import net.sf.saxon.ma.map.HashTrieMap;
import net.sf.saxon.ma.map.KeyValuePair;
import net.sf.saxon.ma.map.MapItem;
import net.sf.saxon.om.Item;
import net.sf.saxon.om.NodeInfo;
import net.sf.saxon.om.Sequence;
import net.sf.saxon.om.SequenceIterator;
import net.sf.saxon.om.StructuredQName;
import net.sf.saxon.s9api.XdmNode;
import net.sf.saxon.trans.XPathException;
import net.sf.saxon.type.ValidationException;
import net.sf.saxon.value.AnyURIValue;
import net.sf.saxon.value.BooleanValue;
import net.sf.saxon.value.DecimalValue;
import net.sf.saxon.value.FloatValue;
import net.sf.saxon.value.IntegerValue;
import net.sf.saxon.value.ObjectValue;
import net.sf.saxon.value.SequenceType;
import net.sf.saxon.value.StringValue;
import net.sf.saxon.xpath.XPathFactoryImpl;

/**
 * Poor man's implementation of Saxon's reflexive
 * extension function mechanism, which is only available in the PE and EE versions.
 */
public abstract class ReflexiveExtensionFunctionProvider implements ExtensionFunctionProvider {

	private final List definitions;

	public Collection getDefinitions() {
		return definitions;
	}

	protected ReflexiveExtensionFunctionProvider(Class definition) {
		definitions = new ArrayList<>();
		Collection names = new ArrayList<>();
		if (!Modifier.isAbstract(definition.getModifiers())) {
			for (Constructor constructor : definition.getConstructors()) {
				if (Modifier.isPublic(constructor.getModifiers())) {
					if (names.contains("new"))
						throw new IllegalArgumentException("function overloading not supported");
					else {
						names.add("new");
						definitions.add(extensionFunctionDefinitionFromMethod(constructor));
					}
				}
			}
		}
		for (Method method : definition.getDeclaredMethods()) {
			if (Modifier.isPublic(method.getModifiers())) {
				if (names.contains(method.getName()))
					throw new IllegalArgumentException("function overloading not supported");
				else {
					names.add(method.getName());
					definitions.add(extensionFunctionDefinitionFromMethod(method));
				}
			}
		}
	}

	private ExtensionFunctionDefinition extensionFunctionDefinitionFromMethod(Executable method)
			throws IllegalArgumentException {
		assert method instanceof Constructor || method instanceof Method;
		if (method.isVarArgs())
			throw new IllegalArgumentException(); // vararg functions not supported
		else {
			Class declaringClass = method.getDeclaringClass();
			boolean isInnerClass = Arrays.stream(getClass().getClasses()).anyMatch(declaringClass::equals);
			boolean isConstructor = method instanceof Constructor;
			boolean isStatic = isConstructor || Modifier.isStatic(method.getModifiers());
			boolean requiresXPath = false;
			boolean requiresDocumentBuilder = false;
			for (Class t : method.getParameterTypes()) {
				if (t.equals(XPath.class)) {
					if (requiresXPath)
						throw new IllegalArgumentException(); // only one XPath argument allowed
					requiresXPath = true;
				} else if (t.equals(DocumentBuilder.class)) {
					if (requiresDocumentBuilder)
						throw new IllegalArgumentException(); // only one DocumentBuilder argument allowed
					requiresDocumentBuilder = true;
				}
			}
			Type[] javaArgumentTypes; { // arguments of Java method/constructor
				Type[] types = method.getGenericParameterTypes();
				javaArgumentTypes = isConstructor && isInnerClass
					// in case of constructor of inner class, first element is type of enclosing instance
					? Arrays.copyOfRange(types, 1, types.length)
					: types;
			}
			SequenceType[] argumentTypes = new SequenceType[ // arguments of XPath function
				javaArgumentTypes.length
				+ (isStatic ? 0 : 1)
				- (requiresXPath ? 1 : 0)
				- (requiresDocumentBuilder ? 1 : 0)
			];
			int i = 0;
			if (!isStatic)
				argumentTypes[i++] = SequenceType.SINGLE_ITEM; // must be special wrapper item
			for (Type t : javaArgumentTypes)
				if (!t.equals(XPath.class) && !t.equals(DocumentBuilder.class))
					argumentTypes[i++] = sequenceTypeFromType(t);
			SequenceType resultType = isConstructor
				? SequenceType.SINGLE_ITEM // special wrapper item
				: sequenceTypeFromType(((Method)method).getGenericReturnType(), Collections.singleton(declaringClass));
			return new ExtensionFunctionDefinition() {
				@Override
				public SequenceType[] getArgumentTypes() {
					return argumentTypes;
				}
				@Override
				public StructuredQName getFunctionQName() {
					return new StructuredQName(declaringClass.getSimpleName(),
					                           declaringClass.getName(),
					                           isConstructor ? "new" : method.getName());
				}
				@Override
				public SequenceType getResultType(SequenceType[] arg0) {
					return resultType;
				}
				@Override
				public ExtensionFunctionCall makeCallExpression() {
					return new ExtensionFunctionCall() {
						@Override
						public Sequence call(XPathContext ctxt, Sequence[] args) throws XPathException {
							try {
								if (args.length != argumentTypes.length)
									throw new IllegalArgumentException(); // should not happen
								int i = 0;
								Object instance = null;
								if (!isStatic) {
									Item item = getSingleItem(args[i++]);
									if (!(item instanceof ObjectValue))
										throw new IllegalArgumentException(
											"Expected ObjectValue<" + declaringClass.getSimpleName() + ">" + ", but got: " + item);
									instance = ((ObjectValue)item).getObject();
								}
								Object[] javaArgs = new Object[ // arguments passed to Method.invoke() in addition to instance,
								                                // or to Constructor.newInstance()
									method.getParameterCount()
								];
								int j = 0;
								if (isConstructor && isInnerClass)
									// in case of constructor of inner class, first argument is enclosing instance
									javaArgs[j++] = ReflexiveExtensionFunctionProvider.this;
								for (Type type : javaArgumentTypes) {
									if (type.equals(XPath.class))
										javaArgs[j++] = new XPathFactoryImpl(ctxt.getConfiguration()).newXPath();
									else if (type.equals(DocumentBuilder.class)) {
										DocumentBuilderImpl b = new DocumentBuilderImpl();
										b.setConfiguration(ctxt.getConfiguration());
										javaArgs[j++] = b;
									} else if (type instanceof ParameterizedType
									           && ((ParameterizedType)type).getRawType().equals(Iterator.class))
										javaArgs[j++] = iteratorFromSequence(
											args[i++],
											((ParameterizedType)type).getActualTypeArguments()[0]);
									else if (type instanceof ParameterizedType
									           && ((ParameterizedType)type).getRawType().equals(Optional.class)) {
										Optional item = getOptionalItem(args[i++]);
										javaArgs[j++] = item.isPresent()
											? Optional.of(objectFromItem(item.get(), ((ParameterizedType)type).getActualTypeArguments()[0]))
											: Optional.empty();
									} else if (type.equals(URI.class)) {
										// could be empty because sequenceTypeFromType() returned OPTIONAL_ANY_URI
										Optional item = getOptionalItem(args[i++]);
										if (!item.isPresent())
											throw new IllegalArgumentException("Expected xs:anyURI but got an empty sequence");
										javaArgs[j++] = objectFromItem(item.get(), type);
									} else
										javaArgs[j++] = objectFromItem(getSingleItem(args[i++]), type);
								}
								Object result; {
									try {
										if (isConstructor)
											result = ((Constructor)method).newInstance(javaArgs);
										else
											result = ((Method)method).invoke(instance, javaArgs);
									} catch (InstantiationException|InvocationTargetException e) {
										throw new XPathException(e.getCause());
									} catch (IllegalAccessException e) {
										throw new RuntimeException(); // should not happen
									}
								}
								if (result instanceof Optional)
									return SaxonHelper.sequenceFromObject(((Optional)result).orElse(null));
								else
									return SaxonHelper.sequenceFromObject(result, Collections.singleton(declaringClass));
							} catch (RuntimeException e) {
								throw new XPathException("Unexpected error in " + getFunctionQName().getClarkName(), e);
							}
						}
					};
				}
			};
		}
	}

	private static SequenceType sequenceTypeFromType(Type type) throws IllegalArgumentException {
		return sequenceTypeFromType(type, null);
	}

	/**
	 * @param externalObjectClasses Classes of objects that are wrapped in a XPath value.
	 */
	private static SequenceType sequenceTypeFromType(Type type, Set> externalObjectClasses) throws IllegalArgumentException {
		if (externalObjectClasses != null && externalObjectClasses.contains(type))
			return SequenceType.SINGLE_ITEM; // special wrapper item
		else if (type.equals(Void.TYPE))
			return SequenceType.EMPTY_SEQUENCE;
		else if (type.equals(String.class))
			return SequenceType.SINGLE_STRING;
		else if (type.equals(Integer.class)
		         || type.equals(int.class)
		         || type.equals(Long.class)
		         || type.equals(long.class))
			return SequenceType.SINGLE_INTEGER;
		else if (type.equals(Float.class)
		         || type.equals(float.class))
			return SequenceType.SINGLE_FLOAT;
		else if (type.equals(BigDecimal.class))
			return SequenceType.SINGLE_DECIMAL;
		else if (type.equals(Boolean.class)
		         || type.equals(boolean.class))
			return SequenceType.SINGLE_BOOLEAN;
		else if (type.equals(URI.class))
			return SequenceType.OPTIONAL_ANY_URI; // SINGLE_ANY_URI
		else if (type.equals(Element.class) || type.equals(Node.class))
			return SequenceType.SINGLE_NODE;
		else if (type.equals(Object.class))
			return SequenceType.SINGLE_ITEM;
		else if (type instanceof ParameterizedType) {
			Type rawType = ((ParameterizedType)type).getRawType();
			if (rawType.equals(Optional.class)) {
				Type itemType = ((ParameterizedType)type).getActualTypeArguments()[0];
				if (itemType.equals(Node.class) || itemType.equals(Element.class))
					return SequenceType.OPTIONAL_NODE;
				else if (itemType.equals(String.class))
					return SequenceType.OPTIONAL_STRING;
				else if (itemType.equals(URI.class))
					return SequenceType.OPTIONAL_ANY_URI;
			} else if (rawType.equals(Iterator.class)) {
				Type itemType = ((ParameterizedType)type).getActualTypeArguments()[0];
				if (itemType.equals(Node.class) || itemType.equals(Element.class))
					return SequenceType.NODE_SEQUENCE;
				else if (itemType.equals(String.class))
					return SequenceType.STRING_SEQUENCE;
			} else if (rawType.equals(List.class)) {
				Type itemType = ((ParameterizedType)type).getActualTypeArguments()[0];
				sequenceTypeFromType(itemType, externalObjectClasses);
				return ArrayItem.SINGLE_ARRAY_TYPE;
			} else if (rawType.equals(Map.class)) {
				Type keyType = ((ParameterizedType)type).getActualTypeArguments()[0];
				Type valueType = ((ParameterizedType)type).getActualTypeArguments()[1];
				if (keyType.equals(String.class)) {
					sequenceTypeFromType(valueType, externalObjectClasses);
					return HashTrieMap.SINGLE_MAP_TYPE;
				}
			}
		}
		throw new IllegalArgumentException("Unsupported type: " + type);
	}

	private static Iterator iteratorFromSequence(Sequence sequence, Type itemType) throws XPathException {
		if (itemType instanceof Class)
			return iteratorFromSequence(sequence, (Class)itemType);
		else
			throw new IllegalArgumentException();
	}

	@SuppressWarnings("unchecked") // safe casts
	private static  Iterator iteratorFromSequence(Sequence sequence, Class itemType) throws XPathException {
		if (itemType.equals(Node.class)) {
			List list = new ArrayList<>();
			SequenceIterator iterator = sequence.iterate();
			Item next;
			while ((next = iterator.next()) != null)
				list.add(objectFromItem(next, XdmNode.class));
			return list.isEmpty()
				? (Iterator)Collections.EMPTY_LIST.iterator()
				: (Iterator)new SaxonInputValue(list.iterator()).asNodeIterator();
		} else {
			List list = new ArrayList<>();
			SequenceIterator iterator = sequence.iterate();
			Item next;
			while ((next = iterator.next()) != null)
				list.add(objectFromItem(next, itemType));
			return list.iterator();
		}
	}

	private static List listFromArrayItem(ArrayItem array, Type itemType) throws XPathException {
		List list = new ArrayList<>();
		for (Sequence s : array)
			list.add(objectFromItem(getSingleItem(s), itemType));
		return list;
	}

	private static Map mapFromMapItem(MapItem item, Type itemType) throws XPathException {
		Map map = new HashMap<>();
		for (KeyValuePair kv : item)
			map.put(kv.key.getStringValue(), objectFromItem(getSingleItem(kv.value), itemType));
		return map;
	}

	private static Object objectFromItem(Item item, Type type) throws XPathException {
		if (type instanceof Class)
			return objectFromItem(item, (Class)type);
		else if (type instanceof ParameterizedType) {
			Type rawType = ((ParameterizedType)type).getRawType();
			if (rawType.equals(List.class)) {
				if (item instanceof ArrayItem)
					return listFromArrayItem(
						(ArrayItem)item,
						((ParameterizedType)type).getActualTypeArguments()[0]);
			} else if (rawType.equals(Map.class)) {
				if (item instanceof MapItem)
					return mapFromMapItem(
						(MapItem)item,
						((ParameterizedType)type).getActualTypeArguments()[1]);
			}
		}
		throw new IllegalArgumentException();
	}

	@SuppressWarnings("unchecked") // safe casts
	private static  T objectFromItem(Item item, Class type) {
		if (type.equals(XdmNode.class))
			if (item instanceof NodeInfo)
				return (T)new XdmNode((NodeInfo)item);
			else
				throw new IllegalArgumentException();
		else if (type.equals(Element.class))
			if (item instanceof NodeInfo)
				return (T)ElementOverNodeInfo.wrap((NodeInfo)item);
			else
				throw new IllegalArgumentException();
		else if (type.equals(Node.class))
			if (item instanceof NodeInfo)
				return (T)NodeOverNodeInfo.wrap((NodeInfo)item);
			else
				throw new IllegalArgumentException();
		else if (type.equals(String.class))
			if (item instanceof StringValue)
				return (T)(String)((StringValue)item).getStringValue();
			else
				throw new IllegalArgumentException();
		else if (type.equals(Integer.class) || type.equals(int.class))
			if (item instanceof IntegerValue)
				return (T)(Integer)((IntegerValue)item).asBigInteger().intValue();
			else
				throw new IllegalArgumentException();
		else if (type.equals(Long.class) || type.equals(long.class))
			if (item instanceof IntegerValue)
				return (T)(Long)((IntegerValue)item).asBigInteger().longValue();
			else
				throw new IllegalArgumentException();
		else if (type.equals(Float.class) || type.equals(float.class))
			if (item instanceof FloatValue)
				return (T)(Float)((FloatValue)item).getFloatValue();
			else
				throw new IllegalArgumentException();
		else if (type.equals(BigDecimal.class))
			if (item instanceof DecimalValue)
				try {
					return (T)(BigDecimal)((DecimalValue)item).getDecimalValue();
				} catch (ValidationException e) {
					throw new RuntimeException(e); // should not happen
				}
			else
				throw new IllegalArgumentException();
		else if (type.equals(Boolean.class))
			if (item instanceof BooleanValue)
				return (T)(Boolean)((BooleanValue)item).getBooleanValue();
			else
				throw new IllegalArgumentException();
		else if (type.equals(URI.class))
			if (item instanceof AnyURIValue)
				try {
					return (T)(new URI((String)((StringValue)item).getStringValue()));
				} catch (URISyntaxException e) {
					throw new IllegalArgumentException(e); // should not happen
				}
			else
				throw new IllegalArgumentException();
		else if (type.equals(Object.class))
			if (item instanceof IntegerValue)
				return (T)(Object)((IntegerValue)item).asBigInteger().longValue();
			else if (item instanceof StringValue)
				return (T)(Object)((StringValue)item).getStringValue();
			else
				throw new IllegalArgumentException();
		else
			throw new IllegalArgumentException();
	}

	private static Item getSingleItem(Sequence sequence) throws XPathException {
		SequenceIterator iterator = sequence.iterate();
		Item item = iterator.next();
		if (item == null)
			throw new IllegalArgumentException();
		if (iterator.next() != null)
			throw new IllegalArgumentException();
		return item;
	}

	private static Optional getOptionalItem(Sequence sequence) throws XPathException {
		SequenceIterator iterator = sequence.iterate();
		Item item = iterator.next();
		if (iterator.next() != null)
			throw new IllegalArgumentException();
		return Optional.ofNullable(item);
	}
}