biz.k11i.xgboost.Predictor Maven / Gradle / Ivy
The newest version!
package biz.k11i.xgboost;
import biz.k11i.xgboost.config.PredictorConfiguration;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.learner.ObjFunction;
import biz.k11i.xgboost.spark.SparkModelParam;
import biz.k11i.xgboost.util.FVec;
import biz.k11i.xgboost.util.ModelReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
/**
* Predicts using the Xgboost model.
*/
public class Predictor implements Serializable {
private ModelParam mparam;
private SparkModelParam sparkModelParam;
private String name_obj;
private String name_gbm;
private ObjFunction obj;
private GradBooster gbm;
public Predictor(InputStream in) throws IOException {
this(in, null);
}
/**
* Instantiates with the Xgboost model
*
* @param in input stream
* @param configuration configuration
* @throws IOException If an I/O error occurs
*/
public Predictor(InputStream in, PredictorConfiguration configuration) throws IOException {
if (configuration == null) {
configuration = PredictorConfiguration.DEFAULT;
}
ModelReader reader = new ModelReader(in);
readParam(reader);
initObjFunction(configuration);
initObjGbm();
gbm.loadModel(reader, mparam.saved_with_pbuffer != 0);
}
void readParam(ModelReader reader) throws IOException {
byte[] first4Bytes = reader.readByteArray(4);
byte[] next4Bytes = reader.readByteArray(4);
float base_score;
int num_feature;
if (first4Bytes[0] == 0x62 &&
first4Bytes[1] == 0x69 &&
first4Bytes[2] == 0x6e &&
first4Bytes[3] == 0x66) {
// Old model file format has a signature "binf" (62 69 6e 66)
base_score = reader.asFloat(next4Bytes);
num_feature = reader.readUnsignedInt();
} else if (first4Bytes[0] == 0x00 &&
first4Bytes[1] == 0x05 &&
first4Bytes[2] == 0x5f) {
// Model generated by xgboost4j-spark?
String modelType = null;
if (first4Bytes[3] == 0x63 &&
next4Bytes[0] == 0x6c &&
next4Bytes[1] == 0x73 &&
next4Bytes[2] == 0x5f) {
// classification model
modelType = SparkModelParam.MODEL_TYPE_CLS;
} else if (first4Bytes[3] == 0x72 &&
next4Bytes[0] == 0x65 &&
next4Bytes[1] == 0x67 &&
next4Bytes[2] == 0x5f) {
// regression model
modelType = SparkModelParam.MODEL_TYPE_REG;
}
if (modelType != null) {
int len = (next4Bytes[3] << 8) + (reader.readByteAsInt());
String featuresCol = reader.readUTF(len);
this.sparkModelParam = new SparkModelParam(modelType, featuresCol, reader);
base_score = reader.readFloat();
num_feature = reader.readUnsignedInt();
} else {
base_score = reader.asFloat(first4Bytes);
num_feature = reader.asUnsignedInt(next4Bytes);
}
} else {
base_score = reader.asFloat(first4Bytes);
num_feature = reader.asUnsignedInt(next4Bytes);
}
mparam = new ModelParam(base_score, num_feature, reader);
name_obj = reader.readString();
name_gbm = reader.readString();
}
void initObjFunction(PredictorConfiguration configuration) {
obj = configuration.getObjFunction();
if (obj == null) {
obj = ObjFunction.fromName(name_obj);
}
}
void initObjGbm() {
obj = ObjFunction.fromName(name_obj);
gbm = GradBooster.Factory.createGradBooster(name_gbm);
gbm.setNumClass(mparam.num_class);
}
/**
* Generates predictions for given feature vector.
*
* @param feat feature vector
* @return prediction values
*/
public double[] predict(FVec feat) {
return predict(feat, false);
}
/**
* Generates predictions for given feature vector.
*
* @param feat feature vector
* @param output_margin whether to only predict margin value instead of transformed prediction
* @return prediction values
*/
public double[] predict(FVec feat, boolean output_margin) {
return predict(feat, output_margin, 0);
}
/**
* Generates predictions for given feature vector.
*
* @param feat feature vector
* @param output_margin whether to only predict margin value instead of transformed prediction
* @param ntree_limit limit the number of trees used in prediction
* @return prediction values
*/
public double[] predict(FVec feat, boolean output_margin, int ntree_limit) {
double[] preds = predictRaw(feat, ntree_limit);
if (!output_margin) {
return obj.predTransform(preds);
}
return preds;
}
double[] predictRaw(FVec feat, int ntree_limit) {
double[] preds = gbm.predict(feat, ntree_limit);
for (int i = 0; i < preds.length; i++) {
preds[i] += mparam.base_score;
}
return preds;
}
/**
* Generates a prediction for given feature vector.
*
* This method only works when the model outputs single value.
*
*
* @param feat feature vector
* @return prediction value
*/
public double predictSingle(FVec feat) {
return predictSingle(feat, false);
}
/**
* Generates a prediction for given feature vector.
*
* This method only works when the model outputs single value.
*
*
* @param feat feature vector
* @param output_margin whether to only predict margin value instead of transformed prediction
* @return prediction value
*/
public double predictSingle(FVec feat, boolean output_margin) {
return predictSingle(feat, output_margin, 0);
}
/**
* Generates a prediction for given feature vector.
*
* This method only works when the model outputs single value.
*
*
* @param feat feature vector
* @param output_margin whether to only predict margin value instead of transformed prediction
* @param ntree_limit limit the number of trees used in prediction
* @return prediction value
*/
public double predictSingle(FVec feat, boolean output_margin, int ntree_limit) {
double pred = predictSingleRaw(feat, ntree_limit);
if (!output_margin) {
return obj.predTransform(pred);
}
return pred;
}
double predictSingleRaw(FVec feat, int ntree_limit) {
return gbm.predictSingle(feat, ntree_limit) + mparam.base_score;
}
/**
* Predicts leaf index of each tree.
*
* @param feat feature vector
* @return leaf indexes
*/
public int[] predictLeaf(FVec feat) {
return predictLeaf(feat, 0);
}
/**
* Predicts leaf index of each tree.
*
* @param feat feature vector
* @param ntree_limit limit
* @return leaf indexes
*/
public int[] predictLeaf(FVec feat,
int ntree_limit) {
return gbm.predictLeaf(feat, ntree_limit);
}
public SparkModelParam getSparkModelParam() {
return sparkModelParam;
}
/**
* Returns number of class.
*
* @return number of class
*/
public int getNumClass() {
return mparam.num_class;
}
/**
* Parameters.
*/
static class ModelParam implements Serializable {
/* \brief global bias */
final float base_score;
/* \brief number of features */
final /* unsigned */ int num_feature;
/* \brief number of class, if it is multi-class classification */
final int num_class;
/*! \brief whether the model itself is saved with pbuffer */
final int saved_with_pbuffer;
/*! \brief reserved field */
final int[] reserved;
ModelParam(float base_score, int num_feature, ModelReader reader) throws IOException {
this.base_score = base_score;
this.num_feature = num_feature;
this.num_class = reader.readInt();
this.saved_with_pbuffer = reader.readInt();
this.reserved = reader.readIntArray(30);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy