All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.converter.ModelUtil Maven / Gradle / Ivy

There is a newer version: 1.5.10
Show newest version
/*
 * Copyright (c) 2016 Villu Ruusmann
 *
 * This file is part of JPMML-Converter
 *
 * JPMML-Converter is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Converter is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Converter.  If not, see .
 */
package org.jpmml.converter;

import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import com.google.common.collect.Iterables;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Entity;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.ModelVerification;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Row;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.dmg.pmml.VerificationField;
import org.dmg.pmml.VerificationFields;

public class ModelUtil {

	private ModelUtil(){
	}

	static
	public MiningSchema createMiningSchema(Label label){
		MiningSchema miningSchema = new MiningSchema();

		if(label != null){
			FieldName name = label.getName();

			if(name != null){
				MiningField miningField = createMiningField(name, MiningField.UsageType.TARGET);

				miningSchema.addMiningFields(miningField);
			}
		}

		return miningSchema;
	}

	static
	public MiningField createMiningField(FieldName name){
		return createMiningField(name, null);
	}

	static
	public MiningField createMiningField(FieldName name, MiningField.UsageType usageType){
		MiningField miningField = new MiningField(name)
			.setUsageType(usageType);

		return miningField;
	}

	static
	public Targets createRescaleTargets(Number slope, Number intercept, ContinuousLabel continuousLabel){
		FieldName name = continuousLabel.getName();

		Target target = new Target()
			.setField(name);

		boolean rescaled = false;

		if(slope != null && !ValueUtil.isOne(slope)){
			target.setRescaleFactor(slope);

			rescaled = true;
		} // End if

		if(intercept != null && !ValueUtil.isZeroLike(intercept)){
			target.setRescaleConstant(intercept);

			rescaled = true;
		} // End if

		if(!rescaled){
			return null;
		}

		Targets targets = new Targets()
			.addTargets(target);

		return targets;
	}

	static
	public Output ensureOutput(Model model){
		Output output = model.getOutput();

		if(output == null){
			output = new Output();

			model.setOutput(output);
		}

		return output;
	}

	static
	public Output createPredictedOutput(FieldName name, OpType opType, DataType dataType, Transformation... transformations){
		Output output = new Output();

		OutputField outputField = new OutputField(name, dataType)
			.setOpType(opType)
			.setResultFeature(ResultFeature.PREDICTED_VALUE)
			.setFinalResult(false);

		output.addOutputFields(outputField);

		for(Transformation transformation : transformations){
			outputField = new OutputField(transformation.getName(outputField.getName()), transformation.getDataType(outputField.getDataType()))
				.setOpType(transformation.getOpType(outputField.getOpType()))
				.setResultFeature(ResultFeature.TRANSFORMED_VALUE)
				.setFinalResult(transformation.isFinalResult())
				.setExpression(transformation.createExpression(new FieldRef(outputField.getName())));

			output.addOutputFields(outputField);
		}

		return output;
	}

	static
	public Output createProbabilityOutput(MathContext mathContext, CategoricalLabel categoricalLabel){
		DataType dataType = DataType.DOUBLE;

		if((MathContext.FLOAT).equals(mathContext)){
			dataType = DataType.FLOAT;
		}

		return createProbabilityOutput(dataType, categoricalLabel);
	}

	static
	public Output createProbabilityOutput(DataType dataType, CategoricalLabel categoricalLabel){
		Output output = new Output();

		List outputFields = output.getOutputFields();
		outputFields.addAll(createProbabilityFields(dataType, categoricalLabel.getValues()));

		return output;
	}

	static
	public OutputField createAffinityField(DataType dataType, Object value){
		return createAffinityField(FieldName.create("affinity(" + value + ")"), dataType, value);
	}

	static
	public OutputField createAffinityField(FieldName name, DataType dataType, Object value){
		OutputField outputField = new OutputField(name, dataType)
			.setOpType(OpType.CONTINUOUS)
			.setResultFeature(ResultFeature.AFFINITY)
			.setValue(value);

		return outputField;
	}

	static
	public List createAffinityFields(DataType dataType, List> entities){
		return entities.stream()
			.map(entity -> createAffinityField(dataType, entity.getId()))
			.collect(Collectors.toList());
	}

	static
	public OutputField createEntityIdField(FieldName name){
		OutputField outputField = new OutputField(name, DataType.STRING)
			.setOpType(OpType.CATEGORICAL)
			.setResultFeature(ResultFeature.ENTITY_ID);

		return outputField;
	}

	static
	public OutputField createPredictedField(FieldName name, DataType dataType, OpType opType){
		OutputField outputField = new OutputField(name, dataType)
			.setOpType(opType)
			.setResultFeature(ResultFeature.PREDICTED_VALUE);

		return outputField;
	}

	static
	public OutputField createProbabilityField(DataType dataType, Object value){
		return createProbabilityField(FieldName.create("probability(" + value + ")"), dataType, value);
	}

	static
	public OutputField createProbabilityField(FieldName name, DataType dataType, Object value){
		OutputField outputField = new OutputField(name, dataType)
			.setOpType(OpType.CONTINUOUS)
			.setResultFeature(ResultFeature.PROBABILITY)
			.setValue(value);

		return outputField;
	}

	static
	public List createProbabilityFields(DataType dataType, List values){
		return values.stream()
			.map(value -> createProbabilityField(dataType, value))
			.collect(Collectors.toList());
	}

	static
	public MathContext simplifyMathContext(MathContext mathContext){
		return (MathContext.DOUBLE).equals(mathContext) ? null : mathContext;
	}

	static
	public VerificationField createVerificationField(FieldName name){
		String tagName = name.getValue();

		// Replace "function(arg)" with "function_arg"
		Matcher matcher = ModelUtil.FUNCTION_INVOCATION.matcher(tagName);
		if(matcher.matches()){
			tagName = (matcher.group(1) + "_" + matcher.group(2));
		}

		VerificationField verificationField = new VerificationField()
			.setField(name)
			.setColumn("data:" + XMLUtil.createTagName(tagName));

		return verificationField;
	}

	static
	public ModelVerification createModelVerification(Map> data){
		VerificationFields verificationFields = new VerificationFields()
			.addVerificationFields(Iterables.toArray(data.keySet(), VerificationField.class));

		InlineTable inlineTable = PMMLUtil.createInlineTable(VerificationField::getColumn, data);

		List rows = inlineTable.getRows();

		ModelVerification modelVerification = new ModelVerification(verificationFields, inlineTable)
			.setRecordCount(rows.size());

		return modelVerification;
	}

	private static final Pattern FUNCTION_INVOCATION = Pattern.compile("^(.+)\\((.+)\\)$");
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy