All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.rexp.BinaryTreeConverter Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2015 Villu Ruusmann
 *
 * This file is part of JPMML-R
 *
 * JPMML-R is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-R is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-R.  If not, see .
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.ScoreProbability;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNames;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;

public class BinaryTreeConverter extends TreeModelConverter {

	private MiningFunction miningFunction = null;

	private Map featureIndexes = new LinkedHashMap<>();


	public BinaryTreeConverter(S4Object binaryTree){
		super(binaryTree);
	}

	@Override
	public void encodeSchema(RExpEncoder encoder){
		S4Object binaryTree = getObject();

		S4Object responses = (S4Object)binaryTree.getAttribute("responses");
		RGenericVector tree = binaryTree.getGenericAttribute("tree");

		encodeResponse(responses, encoder);
		encodeVariableList(tree, encoder);
	}

	@Override
	public TreeModel encodeModel(Schema schema){
		S4Object binaryTree = getObject();

		RGenericVector tree = binaryTree.getGenericAttribute("tree");

		Output output;

		switch(this.miningFunction){
			case REGRESSION:
				output = new Output();
				break;
			case CLASSIFICATION:
				CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();

				output = ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel);
				break;
			default:
				throw new IllegalArgumentException();
		}

		output.addOutputFields(ModelUtil.createEntityIdField(FieldNames.NODE_ID, DataType.STRING));

		TreeModel treeModel = encodeTreeModel(tree, schema)
			.setOutput(output);

		return treeModel;
	}

	private void encodeResponse(S4Object responses, RExpEncoder encoder){
		RGenericVector variables = responses.getGenericAttribute("variables");
		RBooleanVector is_nominal = responses.getBooleanAttribute("is_nominal");
		RGenericVector levels = responses.getGenericAttribute("levels");

		RStringVector variableNames = variables.names();

		String variableName = variableNames.asScalar();

		DataField dataField;

		Boolean categorical = is_nominal.getElement(variableName);
		if((Boolean.TRUE).equals(categorical)){
			this.miningFunction = MiningFunction.CLASSIFICATION;

			RExp targetVariable = variables.getElement(variableName);

			RStringVector targetVariableClass = RExpUtil.getClassNames(targetVariable);

			RStringVector targetCategories = levels.getStringElement(variableName);

			dataField = encoder.createDataField(variableName, OpType.CATEGORICAL, RExpUtil.getDataType(targetVariableClass.asScalar()), targetCategories.getValues());
		} else

		if((Boolean.FALSE).equals(categorical)){
			this.miningFunction = MiningFunction.REGRESSION;

			dataField = encoder.createDataField(variableName, OpType.CONTINUOUS, DataType.DOUBLE);
		} else

		{
			throw new IllegalArgumentException();
		}

		encoder.setLabel(dataField);
	}

	private void encodeVariableList(RGenericVector tree, RExpEncoder encoder){
		RBooleanVector terminal = tree.getBooleanElement("terminal");
		RGenericVector psplit = tree.getGenericElement("psplit");
		RGenericVector left = tree.getGenericElement("left");
		RGenericVector right = tree.getGenericElement("right");

		if((Boolean.TRUE).equals(terminal.asScalar())){
			return;
		}

		RNumberVector splitpoint = psplit.getNumericElement("splitpoint");
		RStringVector variableName = psplit.getStringElement("variableName");

		String name = variableName.asScalar();

		DataField dataField = encoder.getDataField(name);
		if(dataField == null){

			if(splitpoint instanceof RIntegerVector){
				RStringVector levels = splitpoint.getStringAttribute("levels");

				dataField = encoder.createDataField(name, OpType.CATEGORICAL, null, levels.getValues());
			} else

			if(splitpoint instanceof RDoubleVector){
				dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
			} else

			{
				throw new IllegalArgumentException();
			}

			encoder.addFeature(dataField);

			this.featureIndexes.put(name, this.featureIndexes.size());
		}

		encodeVariableList(left, encoder);
		encodeVariableList(right, encoder);
	}

	private TreeModel encodeTreeModel(RGenericVector tree, Schema schema){
		Node root = encodeNode(tree, True.INSTANCE, schema);

		TreeModel treeModel = new TreeModel(this.miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root)
			.setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);

		return treeModel;
	}

	private Node encodeNode(RGenericVector tree, Predicate predicate, Schema schema){
		RIntegerVector nodeId = tree.getIntegerElement("nodeID");
		RBooleanVector terminal = tree.getBooleanElement("terminal");
		RGenericVector psplit = tree.getGenericElement("psplit");
		RGenericVector ssplits = tree.getGenericElement("ssplits");
		RDoubleVector prediction = tree.getDoubleElement("prediction");
		RGenericVector left = tree.getGenericElement("left");
		RGenericVector right = tree.getGenericElement("right");

		Integer id = nodeId.asScalar();

		if((Boolean.TRUE).equals(terminal.asScalar())){
			Node result = new LeafNode(null, predicate)
				.setId(id);

			return encodeScore(result, prediction, schema);
		}

		RNumberVector splitpoint = psplit.getNumericElement("splitpoint");
		RStringVector variableName = psplit.getStringElement("variableName");

		if(!ssplits.isEmpty()){
			throw new IllegalArgumentException();
		}

		Predicate leftPredicate;
		Predicate rightPredicate;

		String name = variableName.asScalar();

		Integer index = this.featureIndexes.get(name);
		if(index == null){
			throw new IllegalArgumentException();
		}

		Feature feature = schema.getFeature(index);

		if(feature instanceof CategoricalFeature){
			CategoricalFeature categoricalFeature = (CategoricalFeature)feature;

			List values = categoricalFeature.getValues();
			List splitValues = (List)splitpoint.getValues();

			leftPredicate = createPredicate(categoricalFeature, selectValues(values, splitValues, true));
			rightPredicate = createPredicate(categoricalFeature, selectValues(values, splitValues, false));
		} else

		{
			ContinuousFeature continuousFeature = feature.toContinuousFeature();

			Number value = splitpoint.asScalar();

			leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
			rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
		}

		Node leftChild = encodeNode(left, leftPredicate, schema);
		Node rightChild = encodeNode(right, rightPredicate, schema);

		Node result = new BranchNode(null, predicate)
			.setId(id)
			.addNodes(leftChild, rightChild);

		return result;
	}

	private Node encodeScore(Node node, RDoubleVector probabilities, Schema schema){

		switch(this.miningFunction){
			case REGRESSION:
				return encodeRegressionScore(node, probabilities);
			case CLASSIFICATION:
				return encodeClassificationScore(node, probabilities, schema);
			default:
				throw new IllegalArgumentException();
		}
	}

	static
	private  List selectValues(List values, List splits, boolean left){

		if(values.size() != splits.size()){
			throw new IllegalArgumentException();
		}

		List result = new ArrayList<>();

		for(int i = 0; i < values.size(); i++){
			E value = values.get(i);
			Integer split = splits.get(i);

			boolean append;

			if(left){
				append = (split == 1);
			} else

			{
				append = (split == 0);
			} // End if

			if(append){
				result.add(value);
			}
		}

		return result;
	}

	static
	private Node encodeRegressionScore(Node node, RDoubleVector probabilities){
		Double probability = probabilities.asScalar();

		node.setScore(probability);

		return node;
	}

	static
	private Node encodeClassificationScore(Node node, RDoubleVector probabilities, Schema schema){
		CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();

		SchemaUtil.checkSize(probabilities.size(), categoricalLabel);

		node = new ClassifierNode(node);

		List scoreDistributions = node.getScoreDistributions();

		Double maxProbability = null;

		for(int i = 0; i < categoricalLabel.size(); i++){
			Object value = categoricalLabel.getValue(i);
			Double probability = probabilities.getValue(i);

			if(maxProbability == null || (maxProbability).compareTo(probability) < 0){
				node.setScore(value);

				maxProbability = probability;
			}

			ScoreDistribution scoreDistribution = new ScoreProbability(value, null, probability);

			scoreDistributions.add(scoreDistribution);
		}

		return node;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy