org.jpmml.lightgbm.Tree Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jpmml-lightgbm Show documentation
Show all versions of jpmml-lightgbm Show documentation
Java library and command-line application for converting LightGBM models to PMML
/*
* Copyright (c) 2017 Villu Ruusmann
*
* This file is part of JPMML-LightGBM
*
* JPMML-LightGBM is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-LightGBM is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-LightGBM. If not, see .
*/
package org.jpmml.lightgbm;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.CountingBranchNode;
import org.dmg.pmml.tree.CountingLeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
public class Tree {
private int num_leaves_;
private int num_cat_;
private int[] left_child_;
private int[] right_child_;
private int[] split_feature_real_;
private double[] threshold_;
private int[] decision_type_;
private double[] leaf_value_;
private int[] leaf_count_;
private double[] internal_value_;
private int[] internal_count_;
private int[] cat_boundaries_;
private long[] cat_threshold_;
public void load(Section section){
this.num_leaves_ = section.getInt("num_leaves");
this.num_cat_ = section.getInt("num_cat");
this.left_child_ = section.getIntArray("left_child", this.num_leaves_ - 1);
this.right_child_ = section.getIntArray("right_child", this.num_leaves_ - 1);
this.split_feature_real_ = section.getIntArray("split_feature", this.num_leaves_ - 1);
this.threshold_ = section.getDoubleArray("threshold", this.num_leaves_ - 1);
this.decision_type_ = section.getIntArray("decision_type", this.num_leaves_ - 1);
this.leaf_value_ = section.getDoubleArray("leaf_value", this.num_leaves_);
this.leaf_count_ = section.getIntArray("leaf_count", this.num_leaves_);
this.internal_value_ = section.getDoubleArray("internal_value", this.num_leaves_ - 1);
this.internal_count_ = section.getIntArray("internal_count", this.num_leaves_ - 1);
if(this.num_cat_ > 0){
this.cat_boundaries_ = section.getIntArray("cat_boundaries", this.num_cat_ + 1);
this.cat_threshold_ = section.getUnsignedIntArray("cat_threshold", -1);
}
}
public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schema){
Node root = encodeNode(True.INSTANCE, predicateManager, new CategoryManager(), 0, schema);
TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root)
.setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT)
.setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);
return treeModel;
}
public Node encodeNode(Predicate predicate, PredicateManager predicateManager, CategoryManager categoryManager, int index, Schema schema){
Integer id = Integer.valueOf(~index);
// Non-leaf (aka internal) node
if(index >= 0){
Feature feature = schema.getFeature(this.split_feature_real_[index]);
double threshold_ = this.threshold_[index];
int decision_type_ = this.decision_type_[index];
CategoryManager leftCategoryManager = categoryManager;
CategoryManager rightCategoryManager = categoryManager;
Predicate leftPredicate;
Predicate rightPredicate;
boolean defaultLeft = hasDefaultLeftMask(decision_type_);
if(feature instanceof BinaryFeature){
BinaryFeature binaryFeature = (BinaryFeature)feature;
if(hasCategoricalMask(decision_type_)){
throw new IllegalArgumentException("Expected a false (off) categorical split mask for binary feature " + FeatureUtil.getName(binaryFeature) + ", got true (on)");
} // End if
if(threshold_ != 0.5d){
throw new IllegalArgumentException("Expected 0.5 as a threshold value for binary feature " + FeatureUtil.getName(binaryFeature) + ", got " + threshold_);
}
Object value = binaryFeature.getValue();
leftPredicate = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value);
rightPredicate = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.EQUAL, value);
} else
if(feature instanceof BinaryCategoricalFeature){
BinaryCategoricalFeature binaryCategoricalFeature = (BinaryCategoricalFeature)feature;
if(!hasCategoricalMask(decision_type_)){
throw new IllegalArgumentException("Expected a true (on) categorical split mask for binary categorical feature " + FeatureUtil.getName(binaryCategoricalFeature) + ", got false (off)");
}
FieldName name = binaryCategoricalFeature.getName();
List> values = binaryCategoricalFeature.getValues();
int cat_idx = ValueUtil.asInt(threshold_);
List