All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.translator.mining.ModelChainTranslator Maven / Gradle / Ivy

There is a newer version: 1.3.8
Show newest version
/*
 * Copyright (c) 2019 Villu Ruusmann
 *
 * This file is part of JPMML-Transpiler
 *
 * JPMML-Transpiler is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Transpiler is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Transpiler.  If not, see .
 */
package org.jpmml.translator.mining;

import java.util.List;
import java.util.Objects;

import com.google.common.collect.Iterables;
import com.sun.codemodel.JExpression;
import com.sun.codemodel.JInvocation;
import com.sun.codemodel.JMethod;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Target;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.NumericPredictor;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.TargetField;
import org.jpmml.model.InvalidElementException;
import org.jpmml.model.MissingElementException;
import org.jpmml.model.UnsupportedAttributeException;
import org.jpmml.model.UnsupportedElementException;
import org.jpmml.model.XPathUtil;
import org.jpmml.translator.IdentifierUtil;
import org.jpmml.translator.MethodScope;
import org.jpmml.translator.ModelTranslator;
import org.jpmml.translator.PMMLObjectUtil;
import org.jpmml.translator.TranslationContext;
import org.jpmml.translator.ValueBuilder;
import org.jpmml.translator.ValueFactoryRef;
import org.jpmml.translator.ValueMapBuilder;
import org.jpmml.translator.regression.RegressionModelTranslator;

public class ModelChainTranslator extends MiningModelTranslator {

	public ModelChainTranslator(PMML pmml, MiningModel miningModel){
		super(pmml, miningModel);

		MiningFunction miningFunction = miningModel.requireMiningFunction();
		switch(miningFunction){
			case CLASSIFICATION:
				break;
			default:
				throw new UnsupportedAttributeException(miningModel, miningFunction);
		}

		MathContext mathContext = miningModel.getMathContext();

		Segmentation segmentation = miningModel.requireSegmentation();

		Segmentation.MultipleModelMethod multipleModelMethod = segmentation.requireMultipleModelMethod();
		switch(multipleModelMethod){
			case MODEL_CHAIN:
				break;
			default:
				throw new UnsupportedAttributeException(segmentation, multipleModelMethod);
		}

		List segments = segmentation.requireSegments();

		List regressorSegments = segments.subList(0, segments.size() - 1);
		for(Segment regressorSegment : regressorSegments){
			@SuppressWarnings("unused")
			True _true = regressorSegment.requirePredicate(True.class);
			Model model = regressorSegment.requireModel();

			MiningFunction modelMiningFunction = model.requireMiningFunction();
			switch(modelMiningFunction){
				case REGRESSION:
					break;
				default:
					throw new UnsupportedAttributeException(model, modelMiningFunction);
			}

			MathContext modelMathContext = model.getMathContext();
			if(!Objects.equals(mathContext, modelMathContext)){
				throw new UnsupportedAttributeException(model, modelMathContext);
			}

			checkMiningSchema(model);

			Output modelOutput = model.getOutput();
			if(modelOutput == null){
				throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(model.getClass()) + "/" + XPathUtil.formatElement(Output.class)), model);
			} // End if

			if(modelOutput.hasOutputFields()){
				List outputFields = modelOutput.getOutputFields();

				if(outputFields.size() != 1){
					throw new UnsupportedElementException(modelOutput);
				}

				OutputField outputField = Iterables.getOnlyElement(outputFields);

				ResultFeature resultFeature = outputField.getResultFeature();
				switch(resultFeature){
					case PREDICTED_VALUE:
						break;
					default:
						throw new UnsupportedAttributeException(outputField, resultFeature);
				}
			} else

			{
				throw new MissingElementException(modelOutput, org.dmg.pmml.PMMLElements.OUTPUT_OUTPUTFIELDS);
			}

			@SuppressWarnings("unused")
			ModelTranslator modelTranslator = newModelTranslator(model);
		}

		{
			Segment classifierSegment = segments.get(segments.size() - 1);

			@SuppressWarnings("unused")
			True _true = classifierSegment.requirePredicate(True.class);
			RegressionModel regressionModel = classifierSegment.requireModel(RegressionModel.class);

			MiningFunction modelMiningFunction = regressionModel.requireMiningFunction();
			switch(modelMiningFunction){
				case CLASSIFICATION:
					break;
				default:
					throw new UnsupportedAttributeException(regressionModel, modelMiningFunction);
			}

			MathContext modelMathContext = regressionModel.getMathContext();
			if(!Objects.equals(mathContext, modelMathContext)){
				throw new UnsupportedAttributeException(regressionModel, modelMathContext);
			}

			checkMiningSchema(regressionModel);
			checkLocalTransformations(regressionModel);
			checkTargets(regressionModel);

			List regressionTables = regressionModel.getRegressionTables();
			for(RegressionTable regressionTable : regressionTables){

				if(regressionTable.hasNumericPredictors()){
					List numericPredictors = regressionTable.getNumericPredictors();

					if(numericPredictors.size() > 1){
						throw new InvalidElementException(regressionTable);
					}
				} // End if

				if(regressionTable.hasCategoricalPredictors() || regressionTable.hasPredictorTerms()){
					throw new UnsupportedElementException(regressionTable);
				}
			}
		}
	}

	@Override
	public JMethod translateClassifier(TranslationContext context){
		MiningModel miningModel = getModel();

		Segmentation segmentation = miningModel.requireSegmentation();

		JMethod evaluateMethod = createEvaluatorMethod(Classification.class, segmentation, true, context);

		try {
			context.pushScope(new MethodScope(evaluateMethod));

			translateSegmentation(segmentation, context);
		} finally {
			context.popScope();
		}

		return evaluateMethod;
	}

	private void translateSegmentation(Segmentation segmentation, TranslationContext context){
		MiningModel miningModel = getModel();

		List segments = segmentation.requireSegments();

		List regressorSegments = segments.subList(0, segments.size() - 1);
		for(Segment regressorSegment : regressorSegments){
			Model model = regressorSegment.requireModel();

			Output modelOutput = model.getOutput();

			OutputField outputField = Iterables.getOnlyElement(modelOutput.getOutputFields());

			ModelTranslator modelTranslator = newModelTranslator(model);

			TargetField targetField = modelTranslator.getTargetField();

			JMethod evaluateMethod = modelTranslator.translateRegressor(context);

			JInvocation methodInvocation = createEvaluatorMethodInvocation(evaluateMethod, context);

			ValueBuilder valueBuilder = new ValueBuilder(context)
				.declare(IdentifierUtil.create("value", outputField.requireName()), methodInvocation);

			Target target = targetField.getTarget();
			if(target != null){
				translateRegressorTarget(model, target, valueBuilder);
			}

			pullUpDerivedFields(miningModel, model);
		}

		ValueMapBuilder valueMapBuilder = new ValueMapBuilder(context)
			.construct("values");

		{
			Segment classifierSegment = segments.get(segments.size() - 1);

			RegressionModel regressionModel = classifierSegment.requireModel(RegressionModel.class);

			List regressionTables = regressionModel.getRegressionTables();
			for(RegressionTable regressionTable : regressionTables){
				List numericPredictors = regressionTable.getNumericPredictors();

				Number intercept = regressionTable.requireIntercept();

				JExpression valueExpr;

				NumericPredictor numericPredictor = Iterables.getFirst(numericPredictors, null);
				if(numericPredictor != null){
					valueExpr = context.getVariable(IdentifierUtil.create("value", numericPredictor.requireField()));

					Number coefficient = numericPredictor.requireCoefficient();
					if(coefficient.doubleValue() != 1d){
						valueExpr = context.invoke(valueExpr, "multiply", coefficient);
					} // End if

					if(intercept.doubleValue() != 0d){
						valueExpr = context.invoke(valueExpr, "add", intercept);
					}
				} else

				{
					ValueFactoryRef valueFactoryRef = context.getValueFactoryVariable();

					if(intercept.doubleValue() != 0d){
						valueExpr = valueFactoryRef.newValue(PMMLObjectUtil.createExpression(intercept, context));
					} else

					{
						valueExpr = valueFactoryRef.newValue();
					}
				}

				valueMapBuilder.update("put", regressionTable.getTargetCategory(), valueExpr);
			}

			RegressionModelTranslator.computeClassification(valueMapBuilder, regressionModel, context);

			pullUpOutputFields(miningModel, regressionModel);
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy