All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.xgboost.ObjFunction Maven / Gradle / Ivy

Go to download

Java library and command-line application for converting XGBoost models to PMML

There is a newer version: 1.8.7
Show newest version
/*
 * Copyright (c) 2016 Villu Ruusmann
 *
 * This file is part of JPMML-XGBoost
 *
 * JPMML-XGBoost is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-XGBoost is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-XGBoost.  If not, see .
 */
package org.jpmml.xgboost;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.dmg.pmml.FieldName;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;

abstract
public class ObjFunction {

	abstract
	public Label encodeLabel(FieldName targetField, List targetCategories, PMMLEncoder encoder);

	abstract
	public MiningModel encodeMiningModel(List trees, List weights, float base_score, Integer ntreeLimit, Schema schema);

	public float probToMargin(float value){
		return value;
	}

	static
	protected MiningModel createMiningModel(List trees, List weights, float base_score, Integer ntreeLimit, Schema schema){

		if(weights != null){

			if(trees.size() != weights.size()){
				throw new IllegalArgumentException();
			}
		} // End if

		if(ntreeLimit != null){

			if(ntreeLimit > trees.size()){
				throw new IllegalArgumentException("Tree limit " + ntreeLimit + " is greater than the number of trees");
			}

			trees = trees.subList(0, ntreeLimit);

			if(weights != null){
				weights = weights.subList(0, ntreeLimit);
			}
		} // End if

		if(weights != null){
			weights = new ArrayList<>(weights);
		}

		ContinuousLabel continuousLabel = (ContinuousLabel)schema.getLabel();

		Schema segmentSchema = schema.toAnonymousSchema();

		PredicateManager predicateManager = new PredicateManager();

		List treeModels = new ArrayList<>();

		boolean equalWeights = true;

		Iterator treeIt = trees.iterator();
		Iterator weightIt = (weights != null ? weights.iterator() : null);

		while(treeIt.hasNext()){
			RegTree tree = treeIt.next();
			Float weight = (weightIt != null ? weightIt.next() : null);

			if(tree.isEmpty()){

				if(weightIt != null){
					weightIt.remove();
				}

				continue;
			} // End if

			if(weight != null){
				equalWeights &= ValueUtil.isOne(weight);
			}

			TreeModel treeModel = tree.encodeTreeModel(predicateManager, segmentSchema);

			treeModels.add(treeModel);
		}

		MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel))
			.setMathContext(MathContext.FLOAT)
			.setSegmentation(MiningModelUtil.createSegmentation(equalWeights ? Segmentation.MultipleModelMethod.SUM : Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, weights))
			.setTargets(ModelUtil.createRescaleTargets(null, base_score, continuousLabel));

		return miningModel;
	}

	static
	protected float inverseLogit(float value){
		return (float)-Math.log((1f / value) - 1f);
	}

	static
	protected float inverseExp(float value){
		return (float)Math.log(value);
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy