
org.jpmml.xgboost.Learner Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-xgboost Show documentation
Show all versions of pmml-xgboost Show documentation
JPMML XGBoost to PMML converter
The newest version!
/*
* Copyright (c) 2016 Villu Ruusmann
*
* This file is part of JPMML-XGBoost
*
* JPMML-XGBoost is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-XGBoost is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-XGBoost. If not, see .
*/
package org.jpmml.xgboost;
import java.io.DataInput;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import com.devsmart.ubjson.GsonUtil;
import com.devsmart.ubjson.UBObject;
import com.devsmart.ubjson.UBReader;
import com.devsmart.ubjson.UBValue;
import com.google.common.collect.Iterables;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.HasContinuousDomain;
import org.dmg.pmml.Interval;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLFunctions;
import org.dmg.pmml.Value;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.FieldUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.MissingValueFeature;
import org.jpmml.converter.MultiLabel;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ThresholdFeature;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.visitors.TreeModelPruner;
import org.jpmml.model.visitors.VisitorBattery;
import org.jpmml.xgboost.visitors.TreeModelCompactor;
public class Learner implements BinaryLoadable, JSONLoadable, UBJSONLoadable {
private float base_score;
private int num_feature;
private int num_class;
private int contain_extra_attrs;
private int contain_eval_metrics;
private int major_version;
private int minor_version;
private int num_target;
private int base_score_estimated;
private ObjFunction obj;
private GBTree gbtree;
private Map attributes = null;
private String[] feature_names = null;
private String[] feature_types = null;
private String[] metrics = null;
public Learner(){
}
@Override
public void loadBinary(XGBoostDataInput input) throws IOException {
this.base_score = input.readFloat();
this.num_feature = input.readInt();
this.num_class = input.readInt();
this.contain_extra_attrs = input.readInt();
this.contain_eval_metrics = input.readInt();
this.major_version = input.readInt();
this.minor_version = input.readInt();
if(this.major_version < 0 || this.major_version > 2){
throw new IllegalArgumentException(this.major_version + "." + this.minor_version);
}
this.num_target = Math.max(input.readInt(), 1);
this.base_score_estimated = input.readInt();
input.readReserved(25);
String name_obj = input.readString();
this.obj = parseObjective(name_obj);
// Starting from 1.0.0, the base score is saved as an untransformed value
if(this.major_version >= 1){
this.base_score = this.obj.probToMargin(this.base_score) + 0f;
} else
{
this.base_score = this.base_score;
}
String name_gbm = input.readString();
this.gbtree = parseGradientBooster(name_gbm);
this.gbtree.loadBinary(input);
if(this.contain_extra_attrs != 0){
this.attributes = input.readStringMap();
} // End if
if(this.major_version >= 1){
return;
} // End if
if(this.obj instanceof PoissonRegression){
String max_delta_step;
try {
max_delta_step = input.readString();
} catch(EOFException eofe){
// Ignored
}
} // End if
if(this.contain_eval_metrics != 0){
this.metrics = input.readStringVector();
}
}
@Override
public void loadJSON(JsonObject root){
UBValue value = GsonUtil.toUBValue(root);
loadUBJSON(value.asObject());
}
@Override
public void loadUBJSON(UBObject root){
if(!root.containsKey("version")){
throw new IllegalArgumentException("Property \"version\" not found among " + root.keySet());
}
int[] version = UBJSONUtil.toIntArray(root.get("version"));
this.major_version = version[0];
this.minor_version = version[1];
if(this.major_version < 1 || this.major_version > 2){
throw new IllegalArgumentException(this.major_version + "." + this.minor_version);
}
UBObject learner = root.get("learner").asObject();
UBObject learnerModelParam = learner.get("learner_model_param").asObject();
this.base_score = learnerModelParam.get("base_score").asFloat32();
this.num_feature = learnerModelParam.get("num_feature").asInt();
this.num_class = learnerModelParam.get("num_class").asInt();
if(learnerModelParam.containsKey("num_target")){
this.num_target = learnerModelParam.get("num_target").asInt();
} else
{
this.num_target = 1;
}
UBObject objective = learner.get("objective").asObject();
String name_obj = objective.get("name").asString();
this.obj = parseObjective(name_obj);
// Starting from 1.0.0, the base score is saved as an untransformed value
this.base_score = this.obj.probToMargin(this.base_score) + 0f;
UBObject gradientBooster = learner.get("gradient_booster").asObject();
String name_gbm = gradientBooster.get("name").asString();
this.gbtree = parseGradientBooster(name_gbm);
this.gbtree.loadUBJSON(gradientBooster);
if(learner.containsKey("attributes")){
UBObject attributes = learner.get("attributes").asObject();
this.attributes = new HashMap<>();
String[] keys = {"best_iteration", "best_score"};
for(String key : keys){
if(attributes.containsKey(key)){
this.attributes.put(key, attributes.get(key).asString());
}
}
} // End if
if(learner.containsKey("feature_names")){
this.feature_names = UBJSONUtil.toStringArray(learner.get("feature_names"));
} // End if
if(learner.containsKey("feature_types")){
this.feature_types = UBJSONUtil.toStringArray(learner.get("feature_types"));
}
}
public void loadBinary(DIS is, String charset) throws IOException {
boolean hasSerializationHeader = consumeHeader(is, XGBoostUtil.SERIALIZATION_HEADER);
if(hasSerializationHeader){
long offset = is.readLong();
if(offset < 0L){
throw new IOException();
}
} else
{
// Ignored
}
boolean hasBInfHeader = consumeHeader(is, XGBoostUtil.BINF_HEADER);
if(hasBInfHeader){
// Ignored
}
try(XGBoostDataInput input = new XGBoostDataInput(is, charset)){
loadBinary(input);
if(hasSerializationHeader){
// Ignored
} else
{
int eof = is.read();
if(eof != -1){
throw new IOException();
}
}
}
}
@SuppressWarnings("deprecation")
public void loadJSON(InputStream is, String charset, String jsonPath) throws IOException {
JsonParser parser = new JsonParser();
if(charset == null){
charset = "UTF-8";
}
try(Reader reader = new InputStreamReader(is, charset)){
JsonElement element = parser.parse(reader);
JsonObject object = element.getAsJsonObject();
String[] names = jsonPath.split("\\.");
for(int i = 0; i < names.length; i++){
String name = names[i];
if(i == 0 && ("$").equals(name)){
continue;
}
JsonElement childElement = object.get(name);
if(childElement == null){
throw new IllegalArgumentException("Property \"" + name + "\" not among " + object.keySet());
}
object = childElement.getAsJsonObject();
}
loadJSON(object);
int eof = is.read();
if(eof != -1){
throw new IOException();
}
}
}
public void loadUBJSON(InputStream is, String jsonPath) throws IOException {
try(UBReader reader = new UBReader(is)){
UBObject object = reader.read().asObject();
String[] names = jsonPath.split("\\.");
for(int i = 0; i < names.length; i++){
String name = names[i];
if(i == 0 && ("$").equals(name)){
continue;
}
UBValue childValue = object.get(name);
if(childValue == null){
throw new IllegalArgumentException("Property \"" + name + "\" not among " + object.keySet());
}
object = childValue.asObject();
}
loadUBJSON(object);
int eof = is.read();
if(eof != -1){
throw new IOException();
}
}
}
public FeatureMap encodeFeatureMap(){
if(this.feature_names == null || this.feature_types == null){
return null;
}
FeatureMap result = new FeatureMap();
for(int i = 0; i < this.feature_names.length; i++){
result.addEntry(this.feature_names[i], this.feature_types[i]);
}
return result;
}
public Schema encodeSchema(String targetName, List targetCategories, FeatureMap featureMap, XGBoostEncoder encoder){
if(targetName == null){
targetName = "_target";
}
Label label = encodeLabel(targetName, targetCategories, encoder);
List features = featureMap.encodeFeatures(this, encoder);
return new Schema(encoder, label, features);
}
public Label encodeLabel(String targetName, List targetCategories, XGBoostEncoder encoder){
if(this.num_target == 1){
return this.obj.encodeLabel(targetName, targetCategories, encoder);
} else
if(this.num_target >= 2){
List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy