org.jpmml.sparkml.model.TreeModelUtil Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-sparkml Show documentation
Show all versions of pmml-sparkml Show documentation
JPMML Apache Spark ML to PMML converter
The newest version!
/*
* Copyright (c) 2016 Villu Ruusmann
*
* This file is part of JPMML-SparkML
*
* JPMML-SparkML is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SparkML is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SparkML. If not, see .
*/
package org.jpmml.sparkml.model;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.tree.CategoricalSplit;
import org.apache.spark.ml.tree.ContinuousSplit;
import org.apache.spark.ml.tree.DecisionTreeModel;
import org.apache.spark.ml.tree.Split;
import org.apache.spark.ml.tree.TreeEnsembleModel;
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ScoreDistributionManager;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.visitors.TreeModelCompactor;
public class TreeModelUtil {
private TreeModelUtil(){
}
static
public & HasTreeOptions, M extends Model & DecisionTreeModel> TreeModel encodeDecisionTree(C converter, Schema schema){
PredicateManager predicateManager = new PredicateManager();
ScoreDistributionManager scoreDistributionManager = new ScoreDistributionManager();
return encodeDecisionTree(converter, predicateManager, scoreDistributionManager, schema);
}
static
public & HasTreeOptions, M extends Model & DecisionTreeModel> TreeModel encodeDecisionTree(C converter, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema){
return encodeDecisionTree(converter, converter.getModel(), predicateManager, scoreDistributionManager, schema);
}
static
public & HasTreeOptions, M extends Model & TreeEnsembleModel, T extends Model & DecisionTreeModel> List encodeDecisionTreeEnsemble(C converter, Schema schema){
PredicateManager predicateManager = new PredicateManager();
ScoreDistributionManager scoreDistributionManager = new ScoreDistributionManager();
return encodeDecisionTreeEnsemble(converter, predicateManager, scoreDistributionManager, schema);
}
static
public & HasTreeOptions, M extends Model & TreeEnsembleModel, T extends Model & DecisionTreeModel> List encodeDecisionTreeEnsemble(C converter, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema){
M model = converter.getModel();
Schema segmentSchema = schema.toAnonymousSchema();
List treeModels = new ArrayList<>();
T[] trees = model.trees();
for(T tree : trees){
TreeModel treeModel = encodeDecisionTree(converter, tree, predicateManager, scoreDistributionManager, segmentSchema);
treeModels.add(treeModel);
}
return treeModels;
}
static
private & DecisionTreeModel> TreeModel encodeDecisionTree(ModelConverter> converter, M model, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema){
TreeModel treeModel;
if(model instanceof DecisionTreeRegressionModel){
ScoreEncoder scoreEncoder = new ScoreEncoder(){
@Override
public Node encode(Node node, org.apache.spark.ml.tree.LeafNode leafNode){
node.setScore(leafNode.prediction());
return node;
}
};
treeModel = encodeTreeModel(MiningFunction.REGRESSION, scoreEncoder, model, predicateManager, schema);
} else
if(model instanceof DecisionTreeClassificationModel){
ScoreEncoder scoreEncoder = new ScoreEncoder(){
private CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
@Override
public Node encode(Node node, org.apache.spark.ml.tree.LeafNode leafNode){
ImpurityCalculator impurityCalculator = leafNode.impurityStats();
double[] stats = impurityCalculator.stats();
List recordCounts = new AbstractList(){
@Override
public int size(){
return stats.length;
}
@Override
public Number get(int index){
double stat = stats[index];
return ValueUtil.narrow(stat);
}
};
double totalRecordCount = impurityCalculator.count();
int maxIndex = ValueUtil.asInt(leafNode.prediction());
Object score = this.categoricalLabel.getValue(maxIndex);
node = new ClassifierNode(score, node.requirePredicate())
.setRecordCount(ValueUtil.narrow(totalRecordCount));
scoreDistributionManager.addScoreDistributions(node, this.categoricalLabel.getValues(), recordCounts, null);
return node;
}
};
treeModel = encodeTreeModel(MiningFunction.CLASSIFICATION, scoreEncoder, model, predicateManager, schema);
} else
{
throw new IllegalArgumentException();
}
Boolean compact = (Boolean)converter.getOption(HasTreeOptions.OPTION_COMPACT, Boolean.TRUE);
if(compact != null && compact){
Visitor visitor = new TreeModelCompactor();
visitor.applyTo(treeModel);
}
return treeModel;
}
static
private & DecisionTreeModel> TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder scoreEncoder, M model, PredicateManager predicateManager, Schema schema){
Node root = encodeNode(True.INSTANCE, scoreEncoder, model.rootNode(), predicateManager, new CategoryManager(), schema);
TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root)
.setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
return treeModel;
}
static
private Node encodeNode(Predicate predicate, ScoreEncoder scoreEncoder, org.apache.spark.ml.tree.Node sparkNode, PredicateManager predicateManager, CategoryManager categoryManager, Schema schema){
if(sparkNode instanceof org.apache.spark.ml.tree.LeafNode){
org.apache.spark.ml.tree.LeafNode leafNode = (org.apache.spark.ml.tree.LeafNode)sparkNode;
Node result = new LeafNode(null, predicate);
return scoreEncoder.encode(result, leafNode);
} else
if(sparkNode instanceof org.apache.spark.ml.tree.InternalNode){
org.apache.spark.ml.tree.InternalNode internalNode = (org.apache.spark.ml.tree.InternalNode)sparkNode;
CategoryManager leftCategoryManager = categoryManager;
CategoryManager rightCategoryManager = categoryManager;
Predicate leftPredicate;
Predicate rightPredicate;
Split split = internalNode.split();
Feature feature = schema.getFeature(split.featureIndex());
if(split instanceof ContinuousSplit){
ContinuousSplit continuousSplit = (ContinuousSplit)split;
Double threshold = continuousSplit.threshold();
if(feature instanceof BooleanFeature){
BooleanFeature booleanFeature = (BooleanFeature)feature;
if(threshold != 0.5d){
throw new IllegalArgumentException("Invalid split threshold value " + threshold + " for a boolean feature");
}
leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
rightPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
} else
{
ContinuousFeature continuousFeature = feature.toContinuousFeature();
DataType dataType = continuousFeature.getDataType();
switch(dataType){
case INTEGER:
threshold = Math.floor(threshold);
break;
default:
break;
}
leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, threshold);
rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, threshold);
}
} else
if(split instanceof CategoricalSplit){
CategoricalSplit categoricalSplit = (CategoricalSplit)split;
double[] leftCategories = categoricalSplit.leftCategories();
double[] rightCategories = categoricalSplit.rightCategories();
if(feature instanceof BinaryFeature){
BinaryFeature binaryFeature = (BinaryFeature)feature;
SimplePredicate.Operator leftOperator;
SimplePredicate.Operator rightOperator;
if(Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)){
leftOperator = SimplePredicate.Operator.EQUAL;
rightOperator = SimplePredicate.Operator.NOT_EQUAL;
} else
if(Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)){
leftOperator = SimplePredicate.Operator.NOT_EQUAL;
rightOperator = SimplePredicate.Operator.EQUAL;
} else
{
throw new IllegalArgumentException();
}
Object value = binaryFeature.getValue();
leftPredicate = predicateManager.createSimplePredicate(binaryFeature, leftOperator, value);
rightPredicate = predicateManager.createSimplePredicate(binaryFeature, rightOperator, value);
} else
if(feature instanceof CategoricalFeature){
CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
String name = categoricalFeature.getName();
List> values = categoricalFeature.getValues();
if(values.size() != (leftCategories.length + rightCategories.length)){
throw new IllegalArgumentException();
}
java.util.function.Predicate
© 2015 - 2025 Weber Informatics LLC | Privacy Policy