All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.translator.ModelTranslator Maven / Gradle / Ivy

There is a newer version: 1.3.8
Show newest version
/*
 * Copyright (c) 2019 Villu Ruusmann
 *
 * This file is part of JPMML-Transpiler
 *
 * JPMML-Transpiler is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Transpiler is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Transpiler.  If not, see .
 */
package org.jpmml.translator;

import java.lang.reflect.TypeVariable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.sun.codemodel.JBlock;
import com.sun.codemodel.JClass;
import com.sun.codemodel.JClassAlreadyExistsException;
import com.sun.codemodel.JDefinedClass;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JExpression;
import com.sun.codemodel.JFieldVar;
import com.sun.codemodel.JInvocation;
import com.sun.codemodel.JMethod;
import com.sun.codemodel.JType;
import com.sun.codemodel.JTypeVar;
import com.sun.codemodel.JVar;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Target;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.VisitorAction;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.IndexableUtil;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelManager;
import org.jpmml.evaluator.PMMLManager;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueFactoryFactory;
import org.jpmml.evaluator.java.JavaModel;
import org.jpmml.model.PMMLException;
import org.jpmml.model.UnsupportedAttributeException;
import org.jpmml.model.visitors.ActiveFieldFinder;
import org.jpmml.model.visitors.FieldResolver;

abstract
public class ModelTranslator extends ModelManager {

	public ModelTranslator(PMML pmml, M model){
		super(pmml, model);

		MathContext mathContext = model.getMathContext();
		switch(mathContext){
			case FLOAT:
			case DOUBLE:
				break;
			default:
				throw new UnsupportedAttributeException(model, mathContext);
		}
	}

	public JExpression translate(TranslationContext context){
		M model = getModel();

		JDefinedClass javaModelClazz = PMMLObjectUtil.createMemberClass(Modifiers.PUBLIC_STATIC_FINAL, IdentifierUtil.create(JavaModel.class.getSimpleName(), model), context);

		javaModelClazz._extends(JavaModel.class);

		Set activeFieldNames = context.getActiveFieldNames();
		activeFieldNames.clear();

		try {
			context.pushOwner(javaModelClazz);

			createEvaluateMethod(context);
		} catch(PMMLException pe){
			throw pe.ensureContext(model);
		} finally {
			context.popOwner();
		}

		JWrappedExpression expression = new JWrappedExpression(context._new(javaModelClazz));

		TranslatedModel translatedModel = new TranslatedModel(model)
			.setExpression(expression)
			.setActiveFields(new LinkedHashSet<>(activeFieldNames));

		context.addTranslation(model, translatedModel);

		return expression;
	}

	public void createEvaluateMethod(TranslationContext context){
		M model = getModel();

		MiningFunction miningFunction = model.requireMiningFunction();
		switch(miningFunction){
			case REGRESSION:
				{
					JMethod regressorMethod = translateRegressor(context);

					createEvaluateRegressionMethod(regressorMethod, context);
				}
				break;
			case CLASSIFICATION:
				{
					JMethod classifierMethod = translateClassifier(context);

					createEvaluateClassificationMethod(classifierMethod, context);
				}
				break;
			default:
				throw new UnsupportedAttributeException(model, miningFunction);
		}
	}

	public JMethod translateRegressor(TranslationContext context){
		throw new UnsupportedOperationException();
	}

	public JMethod translateClassifier(TranslationContext context){
		throw new UnsupportedOperationException();
	}

	public JMethod createEvaluateRegressionMethod(JMethod evaluateMethod, TranslationContext context){
		M model = getModel();

		TargetField targetField = getTargetField();

		JMethod evaluateRegressionMethod = createEvaluatorMethod("evaluateRegression", context);

		try {
			context.pushScope(new MethodScope(evaluateRegressionMethod));

			JInvocation methodInvocation = createEvaluatorMethodInvocation(evaluateMethod, context);

			JType valueClazz = context.ref(Value.class);

			if(!((evaluateMethod.type()).erasure()).equals(valueClazz)){
				methodInvocation = context.getValueFactoryVariable().newValue(methodInvocation);
			}

			ValueBuilder valueBuilder = new ValueBuilder(context)
				.declare("value", methodInvocation);

			Target target = targetField.getTarget();
			if(target != null){
				translateRegressorTarget(model, target, valueBuilder);

				// XXX
				model.setTargets(null);
			}

			JVar valueVar = valueBuilder.getVariable();

			context._return(context.staticInvoke(Collections.class, "singletonMap", targetField.getName(), valueVar.invoke("getValue")));
		} finally {
			context.popScope();
		}

		return evaluateRegressionMethod;
	}

	public JMethod createEvaluateClassificationMethod(JMethod evaluateMethod, TranslationContext context){
		M model = getModel();

		TargetField targetField = getTargetField();

		JMethod evaluateClassificationMethod = createEvaluatorMethod("evaluateClassification", context);

		try {
			context.pushScope(new MethodScope(evaluateClassificationMethod));

			ClassificationBuilder classificationBuilder = new ClassificationBuilder(context)
				.declare("classification", createEvaluatorMethodInvocation(evaluateMethod, context))
				.computeResult(targetField.getDataType());

			context._return(context.staticInvoke(Collections.class, "singletonMap", context.constantFieldName(targetField.getName()), classificationBuilder));
		} finally {
			context.popScope();
		}

		return evaluateClassificationMethod;
	}

	public FieldInfoMap getFieldInfos(Set bodyObjects){
		PMML pmml = getPMML();
		M model = getModel();

		MiningSchema miningSchema = model.requireMiningSchema();

		Map> bodyFields = new HashMap<>();

		Visitor fieldResolver = new FieldResolver(){

			@Override
			public VisitorAction visit(PMMLObject object){

				if(bodyObjects.contains(object)){
					Model parent = (Model)getParent();

					Collection> fields = getFields();

					for(Field field : fields){
						String name = field.requireName();

						Field previousField = bodyFields.put(name, field);
						if((previousField != null) && (previousField != field)){
							throw new IllegalArgumentException(name);
						}
					}

					// XXX
					return VisitorAction.SKIP;
				}

				return super.visit(object);
			}
		};
		fieldResolver.applyTo(pmml);

		FieldInfoMap result = new FieldInfoMap();

		Set names = ActiveFieldFinder.getFieldNames(bodyObjects.toArray(new PMMLObject[bodyObjects.size()]));
		for(String name : names){
			Field field = bodyFields.get(name);

			result.create(field);
		}

		FunctionInvocationContext context = new FunctionInvocationContext(){

			@Override
			public DefineFunction getDefineFunction(String name){
				return ModelTranslator.this.getDefineFunction(name);
			}
		};

		Collection fieldInfos = new ArrayList<>(result.values());
		for(FieldInfo fieldInfo : fieldInfos){
			enhanceFieldInfo(fieldInfo, miningSchema, bodyFields, result, context);
		}

		return result;
	}

	public ArrayInfoMap getArrayInfos(){
		ArrayInfoMap result = new ArrayInfoMap();

		// XXX
		Pattern pattern = Pattern.compile("^(.+)\\_(\\d+)");

		Matcher matcher = null;

		List inputFields = getInputFields();
		for(InputField inputField : inputFields){
			String name = inputField.getFieldName();

			if(matcher == null){
				matcher = pattern.matcher(name);
			} else

			{
				matcher.reset(name);
			} // End if

			if(matcher.matches()){
				String arrayName = matcher.group(1);
				Integer arrayIndex = Integer.parseInt(matcher.group(2));

				DataField dataField = getDataField(name);

				ArrayInfo arrayInfo = result.get(arrayName);
				if(arrayInfo == null){
					arrayInfo = result.create(arrayName);
				}

				arrayInfo.setElement(arrayIndex, dataField);
			}
		}

		return result;
	}

	public Object[] getTargetCategories(){
		TargetField targetField = getTargetField();

		List categories = targetField.getCategories();

		return categories.toArray(new Object[categories.size()]);
	}

	protected void declareArrayFields(Collection arrayInfos){
		PMML pmml = getPMML();

		DataDictionary dataDictionary = pmml.requireDataDictionary();

		for(ArrayInfo arrayInfo : arrayInfos){
			DataField dataField = new DataField(arrayInfo.getName(), arrayInfo.getOpType(), arrayInfo.getDataType());

			dataDictionary.addDataFields(dataField);
		}

		try {
			java.lang.reflect.Field dataFieldsField = PMMLManager.class.getDeclaredField("dataFields");
			if(!dataFieldsField.isAccessible()){
				dataFieldsField.setAccessible(true);
			}

			dataFieldsField.set(this, ImmutableMap.copyOf(IndexableUtil.buildMap(dataDictionary.getDataFields())));
		} catch(ReflectiveOperationException roe){
			throw new RuntimeException(roe);
		}
	}

	static
	public  ValueFactory getValueFactory(Model model){
		MathContext mathContext = model.getMathContext();

		switch(mathContext){
			case FLOAT:
			case DOUBLE:
				ValueFactoryFactory valueFactoryFactory = ValueFactoryFactory.newInstance();

				// XXX
				return (ValueFactory)valueFactoryFactory.newValueFactory(mathContext);
			default:
				throw new UnsupportedAttributeException(model, mathContext);
		}
	}

	static
	public void translateRegressorTarget(Model model, Target target, ValueBuilder valueBuilder){
		MathContext mathContext = model.getMathContext();

		Number rescaleFactor = target.getRescaleFactor();
		if(rescaleFactor != null && rescaleFactor.doubleValue() != 1d){
			valueBuilder.update("multiply", rescaleFactor);
		}

		Number rescaleConstant = target.getRescaleConstant();
		if(rescaleConstant != null && rescaleConstant.doubleValue() != 0d){
			valueBuilder.update("add", rescaleConstant);
		}

		Target.CastInteger castInteger = target.getCastInteger();
		if(castInteger != null){
			throw new UnsupportedAttributeException(target, castInteger);
		}
	}

	static
	public JMethod createEvaluatorMethod(String name, TranslationContext context){
		JDefinedClass owner = context.getOwner();

		JMethod method = owner.method(Modifiers.PUBLIC_FINAL, context.genericRef(Map.class, String.class, context.wildcard()), name);
		method.annotate(Override.class);

		JTypeVar numberTypeVar = method.generify(MethodScope.TYPEVAR_NUMBER, Number.class);

		method.param(context.genericRef(ValueFactory.class, numberTypeVar), Scope.VAR_VALUEFACTORY);
		method.param(EvaluationContext.class, Scope.VAR_CONTEXT);

		return method;
	}

	static
	public JMethod createEvaluatorMethod(Class type, PMMLObject object, boolean withValueFactory, TranslationContext context){
		return createEvaluatorMethod(type, IdentifierUtil.create("evaluate" + (object.getClass()).getSimpleName(), object), withValueFactory, context);
	}

	static
	public JMethod createEvaluatorMethod(Class type, List objects, boolean withValueFactory, TranslationContext context){
		PMMLObject object = Iterables.getFirst(objects, null);

		return createEvaluatorMethod(type, IdentifierUtil.create("evaluate" + (object.getClass()).getSimpleName() + "List", object), withValueFactory, context);
	}

	static
	private JMethod createEvaluatorMethod(Class type, String name, boolean withValueFactory, TranslationContext context){
		JDefinedClass owner = context.getOwner();

		JMethod method = owner.method(Modifiers.PRIVATE_STATIC_FINAL, type, name);

		if(withValueFactory){
			JTypeVar numberTypeVar = method.generify(MethodScope.TYPEVAR_NUMBER, Number.class);

			TypeVariable[] typeVariables = type.getTypeParameters();
			if(typeVariables.length == 1){
				method.type(context.genericRef(type, numberTypeVar));
			} else

			if(typeVariables.length == 2){
				method.type(context.genericRef(type, Object.class, numberTypeVar));
			}

			method.param(context.genericRef(ValueFactory.class, numberTypeVar), Scope.VAR_VALUEFACTORY);
		}

		method.param(ensureArgumentsType(context), Scope.VAR_ARGUMENTS);

		return method;
	}

	static
	public JInvocation createEvaluatorMethodInvocation(JMethod method, TranslationContext context){
		JInvocation invocation = JExpr.invoke(method);

		List params = method.params();
		for(JVar param : params){
			String name = param.name();

			JExpression arg;

			switch(name){
				case Scope.VAR_ARGUMENTS:
					try {
						arg = (context.getArgumentsVariable()).getExpression();
					} catch(IllegalArgumentException iae){
						arg = context._new(ensureArgumentsType(context), (context.getContextVariable()).getExpression());
					}
					break;
				case Scope.VAR_CONTEXT:
					arg = (context.getContextVariable()).getExpression();
					break;
				case Scope.VAR_VALUEFACTORY:
					arg = (context.getValueFactoryVariable()).getExpression();
					break;
				default:
					throw new IllegalArgumentException(name);
			}

			invocation = invocation.arg(arg);
		}

		return invocation;
	}

	static
	public JDefinedClass ensureArgumentsType(TranslationContext context){
		JDefinedClass owner = context.getOwner(JavaModel.class);

		JDefinedClass argumentsClazz = JCodeModelUtil.getNestedClass(owner, "Arguments");
		if(argumentsClazz != null){
			return argumentsClazz;
		}

		argumentsClazz = PMMLObjectUtil.createMemberClass(Modifiers.PUBLIC_STATIC_FINAL, "Arguments", context);

		JFieldVar contextVar = argumentsClazz.field(Modifiers.PRIVATE, EvaluationContext.class, "context");

		JMethod constructor = argumentsClazz.constructor(Modifiers.PUBLIC);

		JVar contextParam = constructor.param(EvaluationContext.class, "context");

		JBlock block = constructor.body();

		block.assign(JExpr.refthis(contextVar.name()), contextParam);

		return argumentsClazz;
	}

	static
	public JDefinedClass ensureFunctionalInterface(Class returnType, TranslationContext context){
		JDefinedClass owner = context.getOwner(JavaModel.class);

		JDefinedClass funcInterface = JCodeModelUtil.getNestedClass(owner, "JavaModelFunction");
		if(funcInterface != null){
			return funcInterface;
		}

		try {
			funcInterface = owner._interface("JavaModelFunction");
		} catch(JClassAlreadyExistsException jcaee){
			throw new IllegalArgumentException(jcaee);
		}

		funcInterface.annotate(FunctionalInterface.class);

		JClass argumentsClazz = ensureArgumentsType(context);

		JMethod method = funcInterface.method(Modifiers.PUBLIC_ABSTRACT, returnType, "apply");

		method.param(argumentsClazz, "arguments");

		return funcInterface;
	}

	static
	private void enhanceFieldInfo(FieldInfo fieldInfo, MiningSchema miningSchema, Map> bodyFields, FieldInfoMap fieldInfos, FunctionInvocationContext context){
		Field field = fieldInfo.getField();

		if(field instanceof DerivedField){
			DerivedField derivedField = (DerivedField)field;

			Expression expression = derivedField.requireExpression();

			FunctionInvocation functionInvocation = FunctionInvocationUtil.match(expression, context);

			if(functionInvocation instanceof FunctionInvocation.Ref){
				FunctionInvocation.Ref ref = (FunctionInvocation.Ref)functionInvocation;

				String fieldName = ref.getField();

				FieldInfo refFieldInfo = fieldInfos.get(fieldName);
				if(refFieldInfo == null){
					Field refField = bodyFields.get(fieldName);

					refFieldInfo = fieldInfos.create(refField);

					enhanceFieldInfo(refFieldInfo, miningSchema, bodyFields, fieldInfos, context);
				}

				fieldInfo.setRef(refFieldInfo);

				functionInvocation = null;
			}

			fieldInfo.setFunctionInvocation(functionInvocation);
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy