All Downloads are FREE. Search and download functionalities are using the official Maven repository.

xgboost.sklearn.BoosterUtil Maven / Gradle / Ivy

There is a newer version: 1.8.9
Show newest version
/*
 * Copyright (c) 2016 Villu Ruusmann
 *
 * This file is part of JPMML-SkLearn
 *
 * JPMML-SkLearn is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-SkLearn is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-SkLearn.  If not, see .
 */
package xgboost.sklearn;

import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.dmg.pmml.PMML;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.Schema;
import org.jpmml.python.HasArray;
import org.jpmml.xgboost.ByteOrderUtil;
import org.jpmml.xgboost.FeatureMap;
import org.jpmml.xgboost.GBTree;
import org.jpmml.xgboost.HasXGBoostOptions;
import org.jpmml.xgboost.Learner;
import org.jpmml.xgboost.ObjFunction;
import pandas.core.BlockManager;
import pandas.core.DataFrame;
import pandas.core.Index;
import sklearn.Estimator;

public class BoosterUtil {

	private BoosterUtil(){
	}

	static
	public  int getNumberOfFeatures(E estimator){
		Learner learner = getLearner(estimator);

		return learner.num_feature();
	}

	static
	public  ObjFunction getObjFunction(E estimator){
		Learner learner = getLearner(estimator);

		return learner.obj();
	}

	static
	public  MiningModel encodeModel(E estimator, Schema schema){
		Booster booster = estimator.getBooster();

		Learner learner = getLearner(estimator);

		Map options = getOptions(booster, learner, estimator);

		Integer ntreeLimit = (Integer)options.get(HasXGBoostOptions.OPTION_NTREE_LIMIT);

		MiningModel miningModel = learner.encodeModel(ntreeLimit, schema);

		return miningModel;
	}

	static
	public  Schema configureSchema(E estimator, Schema schema){
		Booster booster = estimator.getBooster();

		Learner learner = getLearner(estimator);

		Map options = getOptions(booster, learner, estimator);

		Schema xgbSchema = learner.toXGBoostSchema(schema);

		xgbSchema = learner.configureSchema(options, xgbSchema);

		return xgbSchema;
	}

	static
	public  MiningModel configureModel(E estimator, MiningModel miningModel){
		Booster booster = estimator.getBooster();

		Learner learner = getLearner(estimator);

		Map options = getOptions(booster, learner, estimator);

		miningModel = learner.configureModel(options, miningModel);

		return miningModel;
	}

	static
	public PMML encodePMML(Booster booster){
		FeatureMap featureMap = null;

		DataFrame fmap = booster.getFMap();
		if(fmap != null){
			featureMap = parseFMap(fmap);
		}

		Learner learner = booster.getLearner(ByteOrder.nativeOrder(), null);

		Map options = getOptions(booster, learner);

		return learner.encodePMML(options, null, null, featureMap);
	}

	static
	public  PMML encodePMML(E estimator){
		Booster booster = estimator.getBooster();

		Learner learner = getLearner(estimator);

		Map options = getOptions(booster, learner, estimator);

		return learner.encodePMML(options, null, null, null);
	}

	static
	private  Learner getLearner(E estimator){
		Booster booster = estimator.getBooster();

		String byteOrder = (String)estimator.getOption(HasXGBoostOptions.OPTION_BYTE_ORDER, (ByteOrder.nativeOrder()).toString());
		String charset = (String)estimator.getOption(HasXGBoostOptions.OPTION_CHARSET, null);

		return booster.getLearner(ByteOrderUtil.forValue(byteOrder), charset);
	}

	static
	private Map getOptions(Booster booster, Learner learner){
		Map result = new LinkedHashMap<>();

		Integer bestNTreeLimit = booster.getBestNTreeLimit();

		if(bestNTreeLimit == null){
			Integer bestIteration = learner.getBestIteration();

			if(bestIteration != null){
				bestNTreeLimit = (bestIteration + 1);
			}
		}

		result.put(HasXGBoostOptions.OPTION_NTREE_LIMIT, bestNTreeLimit);

		return result;
	}

	static
	private  Map getOptions(Booster booster, Learner learner, E estimator){
		GBTree gbtree = learner.gbtree();

		Map result = new LinkedHashMap<>();

		Integer bestNTreeLimit = booster.getBestNTreeLimit();

		// XGBoost 1.7
		if(bestNTreeLimit == null){
			bestNTreeLimit = (Integer)estimator.getOptionalScalar("best_ntree_limit");
		} // End if

		// XGBoost 2.0+
		if(bestNTreeLimit == null){
			Integer bestIteration = learner.getBestIteration();

			if(bestIteration != null){
				bestNTreeLimit = (bestIteration + 1);
			}
		}

		Integer ntreeLimit = (Integer)estimator.getOption(HasXGBoostOptions.OPTION_NTREE_LIMIT, bestNTreeLimit);
		result.put(HasXGBoostOptions.OPTION_NTREE_LIMIT, ntreeLimit);

		Number missing = (Number)estimator.getOptionalScalar("missing");
		result.put(HasXGBoostOptions.OPTION_MISSING, missing);

		Boolean compact = (Boolean)estimator.getOption(HasXGBoostOptions.OPTION_COMPACT, !gbtree.hasCategoricalSplits());
		Boolean inputFloat = (Boolean)estimator.getOption(HasXGBoostOptions.OPTION_INPUT_FLOAT, null);
		Boolean numeric = (Boolean)estimator.getOption(HasXGBoostOptions.OPTION_NUMERIC, Boolean.TRUE);
		Boolean prune = (Boolean)estimator.getOption(HasXGBoostOptions.OPTION_PRUNE, Boolean.TRUE);

		result.put(HasXGBoostOptions.OPTION_COMPACT, compact);
		result.put(HasXGBoostOptions.OPTION_INPUT_FLOAT, inputFloat);
		result.put(HasXGBoostOptions.OPTION_NUMERIC, numeric);
		result.put(HasXGBoostOptions.OPTION_PRUNE, prune);

		return result;
	}

	static
	private FeatureMap parseFMap(DataFrame fmap){
		BlockManager data = fmap.getData();

		Index columnAxis = data.getColumnAxis();
		Index rowAxis = data.getRowAxis();

		if(!(Arrays.asList("id", "name", "type")).equals(columnAxis.getValues())){
			throw new IllegalArgumentException();
		}

		List blockValues = data.getBlockValues();

		HasArray idColumn = blockValues.get(0);
		HasArray nameTypeColumns = blockValues.get(1);

		List nameTypeContent = nameTypeColumns.getArrayContent();
		int[] nameTypeShape = nameTypeColumns.getArrayShape();

		List nameValues = CMatrixUtil.getRow(nameTypeContent, nameTypeShape[0], nameTypeShape[1], 0);
		List typeValues = CMatrixUtil.getRow(nameTypeContent, nameTypeShape[0], nameTypeShape[1], 1);

		FeatureMap result = new FeatureMap();

		for(int i = 0; i < nameTypeShape[1]; i++){
			String name = (String)nameValues.get(i);
			String type = (String)typeValues.get(i);

			result.addEntry(name, type);
		}

		return result;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy