hex.tree.xgboost.predict.XGBoostNativeVariableImportance Maven / Gradle / Ivy
package hex.tree.xgboost.predict;
import hex.tree.xgboost.util.BoosterHelper;
import hex.tree.xgboost.util.FeatureScore;
import ai.h2o.xgboost4j.java.Booster;
import ai.h2o.xgboost4j.java.XGBoostError;
import org.apache.log4j.Logger;
import water.Key;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Files;
import java.util.HashMap;
import java.util.Map;
public class XGBoostNativeVariableImportance implements XGBoostVariableImportance {
private static final Logger LOG = Logger.getLogger(XGBoostNativeVariableImportance.class);
private final File featureMapFile;
public XGBoostNativeVariableImportance(Key modelKey, String featureMap) {
featureMapFile = createFeatureMapFile(modelKey, featureMap);
}
private File createFeatureMapFile(Key modelKey, String featureMap) {
try {
File fmFile = Files.createTempFile("h2o_xgb_" + modelKey.toString(), ".txt").toFile();
fmFile.deleteOnExit();
try (OutputStream os = new FileOutputStream(fmFile)) {
os.write(featureMap.getBytes());
}
return fmFile;
} catch (IOException e) {
throw new RuntimeException("Cannot generate feature map file", e);
}
}
@Override
public void cleanup() {
if (featureMapFile != null) {
if (!featureMapFile.delete()) {
LOG.warn("Unable to delete file " + featureMapFile + ". Please do a manual clean-up.");
}
}
}
public Map getFeatureScores(byte[] boosterBytes) {
Booster booster = null;
try {
booster = BoosterHelper.loadModel(boosterBytes);
return BoosterHelper.doWithLocalRabit(new BoosterHelper.BoosterOp