All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
hex.tree.xgboost.predict.XGBoostNativeBigScoreChunkPredict Maven / Gradle / Ivy
package hex.tree.xgboost.predict;
import hex.DataInfo;
import hex.Model;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import hex.tree.xgboost.*;
import hex.tree.xgboost.util.BoosterHelper;
import ai.h2o.xgboost4j.java.Booster;
import ai.h2o.xgboost4j.java.DMatrix;
import ai.h2o.xgboost4j.java.Rabit;
import ai.h2o.xgboost4j.java.XGBoostError;
import org.apache.log4j.Logger;
import water.fvec.Chunk;
import water.fvec.Frame;
import java.util.HashMap;
public class XGBoostNativeBigScoreChunkPredict implements XGBoostPredictContrib, Model.BigScoreChunkPredict {
private static final Logger LOG = Logger.getLogger(XGBoostNativeBigScoreChunkPredict.class);
private final double _threshold;
private final int _responseIndex;
private final int _offsetIndex;
private final XGBoostModelInfo _modelInfo;
private final XGBoostModel.XGBoostParameters _parms;
private final DataInfo _dataInfo;
private final BoosterParms _boosterParms;
private final XGBoostOutput _output;
private final float[][] _preds;
public XGBoostNativeBigScoreChunkPredict(
XGBoostModelInfo modelInfo,
XGBoostModel.XGBoostParameters parms,
DataInfo di,
BoosterParms boosterParms,
double threshold,
XGBoostOutput output,
Frame fr,
Chunk[] chks
) {
_modelInfo = modelInfo;
_parms = parms;
_dataInfo = di;
_boosterParms = boosterParms;
_threshold = threshold;
_output = output;
if (fr.vec(_parms._response_column).isBad()) {
_responseIndex = -1;
} else {
_responseIndex = fr.find(_parms._response_column);
}
_offsetIndex = fr.find(_parms._offset_column);
_preds = scoreChunk(chks, XGBoostPredict.OutputType.PREDICT);
}
@Override
public double[] score0(Chunk[] chks, double offset, int row_in_chunk, double[] tmp, double[] preds) {
for (int i = 0; i < tmp.length; i++) {
tmp[i] = chks[i].atd(row_in_chunk);
}
return XGBoostMojoModel.toPreds(tmp, _preds[row_in_chunk], preds, _output.nclasses(), null, _threshold);
}
@Override
public float[][] predictContrib(Chunk[] cs) {
return scoreChunk(cs, OutputType.PREDICT_CONTRIB_APPROX);
}
@Override
public float[][] predict(Chunk[] cs) {
return scoreChunk(cs, OutputType.PREDICT);
}
private float[][] scoreChunk(final Chunk[] cs, final XGBoostPredict.OutputType outputType) {
DMatrix data = null;
Booster booster = null;
try {
// Rabit has to be initialized as parts of booster.predict() are using Rabit
// This might be fixed in future versions of XGBoost
Rabit.init(new HashMap<>());
data = XGBoostUtils.convertChunksToDMatrix(
_dataInfo,
cs,
_responseIndex,
_output._sparse,
_offsetIndex
);
// No local chunks for this frame
if (data.rowNum() == 0) {
return new float[0][];
}
// Initialize Booster
booster = BoosterHelper.loadModel(_modelInfo._boosterBytes);
booster.setParams(_boosterParms.get());
int treeLimit = 0;
// Predict
float[][] preds;
switch (outputType) {
case PREDICT:
preds = booster.predict(data, false, treeLimit);
break;
case PREDICT_CONTRIB_APPROX:
preds = booster.predictContrib(data, treeLimit);
break;
default:
throw new UnsupportedOperationException("Unsupported output type: " + outputType);
}
return preds == null ? new float[0][] : preds;
} catch (XGBoostError xgBoostError) {
throw new IllegalStateException("Failed to score with XGBoost.", xgBoostError);
} finally {
BoosterHelper.dispose(booster, data);
try {
Rabit.shutdown();
} catch (XGBoostError xgBoostError) {
LOG.error("Failed Rabit shutdown. A hanging RabitTracker task might be present on the driver node.", xgBoostError);
}
}
}
@Override
public void close() {
}
}