All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
hex.tree.xgboost.task.XGBoostUpdater Maven / Gradle / Ivy
package hex.tree.xgboost.task;
import hex.tree.xgboost.BoosterParms;
import ai.h2o.xgboost4j.java.*;
import hex.tree.xgboost.EvalMetric;
import org.apache.log4j.Logger;
import water.H2O;
import water.Key;
import water.nbhm.NonBlockingHashMap;
import water.util.Log;
import java.util.Map;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.TimeUnit;
public class XGBoostUpdater extends Thread {
private static final Logger LOG = Logger.getLogger(XGBoostUpdater.class);
private static final long WORK_START_TIMEOUT_SECS = 5 * 60; // Each Booster task should start before this timer expires
private static final long INACTIVE_CHECK_INTERVAL_SECS = 60;
private static final NonBlockingHashMap updaters = new NonBlockingHashMap<>();
private final Key _modelKey;
private final DMatrix _trainMat;
private final DMatrix _validMat;
private final BoosterParms _boosterParms;
private final String _evalMetricSpec;
private final byte[] _checkpointBoosterBytes;
private final Map _rabitEnv;
private volatile SynchronousQueue> _in;
private volatile SynchronousQueue _out;
private BoosterWrapper _booster;
private volatile EvalMetric _evalMetric;
private XGBoostUpdater(
Key modelKey, DMatrix trainMat, DMatrix validMat, BoosterParms boosterParms,
byte[] checkpointBoosterBytes, Map rabitEnv
) {
super("XGBoostUpdater-" + modelKey);
_modelKey = modelKey;
_trainMat = trainMat;
_validMat = validMat;
_boosterParms = boosterParms;
_checkpointBoosterBytes = checkpointBoosterBytes;
_rabitEnv = rabitEnv;
_evalMetricSpec = (String) _boosterParms.get().get("eval_metric");
_in = new SynchronousQueue<>();
_out = new SynchronousQueue<>();
}
@Override
public void run() {
try {
Rabit.init(_rabitEnv);
while (! interrupted()) {
BoosterCallable> task = _in.take();
Object result = task.call();
_out.put(result);
}
} catch (InterruptedException e) {
XGBoostUpdater self = updaters.get(_modelKey);
if (self != null) {
LOG.error("Updater thread was interrupted while it was still registered, name=" + self.getName(), e);
} else {
LOG.debug("Updater thread interrupted.", e);
}
Thread.currentThread().interrupt();
} catch (XGBoostError e) {
LOG.error("XGBoost training iteration failed", e);
} finally {
_in = null; // Will throw NPE if used wrong
_out = null;
updaters.remove(_modelKey);
try {
_trainMat.dispose();
if (_validMat != null) {
_validMat.dispose();
}
if (_booster != null)
_booster.dispose();
} catch (Exception e) {
LOG.warn("Failed to dispose of training matrix/booster", e);
}
try {
Rabit.shutdown();
} catch (Exception xgBoostError) {
LOG.warn("Rabit shutdown during update failed", xgBoostError);
}
}
}
@SuppressWarnings("unchecked")
private T invoke(BoosterCallable callable) throws InterruptedException {
final SynchronousQueue> inQ = _in;
if (inQ == null)
throw new IllegalStateException("Updater is inactive on node " + H2O.SELF);
if (! inQ.offer(callable, WORK_START_TIMEOUT_SECS, TimeUnit.SECONDS))
throw new IllegalStateException("XGBoostUpdater couldn't start work on task " + callable + " in " + WORK_START_TIMEOUT_SECS + "s.");
SynchronousQueue> outQ;
int i = 0;
while ((outQ = _out) != null) {
i++;
T result = (T) outQ.poll(INACTIVE_CHECK_INTERVAL_SECS, TimeUnit.SECONDS);
if (result != null) {
return result;
} else if (i > 5) {
LOG.warn(String.format("XGBoost task of type '%s' is taking unexpectedly long, it didn't finish in %d seconds.",
callable, INACTIVE_CHECK_INTERVAL_SECS * i));
}
}
throw new IllegalStateException("Cannot perform booster operation: updater is inactive on node " + H2O.SELF);
}
private class UpdateBooster implements BoosterCallable {
private final int _tid;
private UpdateBooster(int tid) { _tid = tid; }
@Override
public Booster call() throws XGBoostError {
if ((_booster == null) && _tid == 0) {
_booster = new BoosterWrapper(_checkpointBoosterBytes, _boosterParms.get(), _trainMat, _validMat);
_evalMetric = computeEvalMetric();
// Force Booster initialization; we can call any method that does "lazy init"
byte[] boosterBytes = _booster.toByteArray();
LOG.info("Initial Booster created, size=" + boosterBytes.length);
} else {
// Do one iteration
assert _booster != null;
_booster.update(_trainMat, _tid);
_evalMetric = computeEvalMetric();
_booster.saveRabitCheckpoint();
}
return _booster.getBooster();
}
private EvalMetric computeEvalMetric() throws XGBoostError {
if (_evalMetricSpec == null) {
return null;
}
final String evalMetricVal = _booster.evalSet(_trainMat, _validMat, _tid);
return parseEvalMetric(evalMetricVal);
}
@Override
public String toString() {
return "Boosting Iteration (tid=" + _tid + ")";
}
}
private EvalMetric parseEvalMetric(String evalMetricVal) {
return parseEvalMetric(_evalMetricSpec, _validMat != null, evalMetricVal);
}
static EvalMetric parseEvalMetric(String evalMetricSpec, boolean hasValid, String evalMetricVal) {
final String[] parts = evalMetricVal.split("\t");
final int expectedParts = hasValid ? 3 : 2;
if (parts.length != expectedParts) {
Log.err("Evaluation metric cannot be parsed, unexpected number of elements. Value: '" + evalMetricSpec + "'.");
return EvalMetric.empty(evalMetricSpec);
}
double trainVal, validVal = Double.NaN;
trainVal = parseEvalMetricPart(parts[1]);
if (hasValid) {
validVal = parseEvalMetricPart(parts[2]);
}
return new EvalMetric(evalMetricSpec, trainVal, validVal);
}
static double parseEvalMetricPart(String evalMetricVal) {
final int sepPos = evalMetricVal.lastIndexOf(":");
if (sepPos >= 0) {
String valStr = evalMetricVal.substring(sepPos + 1).trim();
try {
return Double.parseDouble(valStr);
} catch (Exception e) {
Log.err("Failed to parse value of evaluation metric: '" + evalMetricVal + "'.", e);
}
}
return Double.NaN;
}
private class SerializeBooster implements BoosterCallable {
@Override
public byte[] call() throws XGBoostError {
return _booster.toByteArray();
}
@Override
public String toString() {
return "SerializeBooster";
}
}
byte[] getBoosterBytes() {
try {
return invoke(new SerializeBooster());
} catch (InterruptedException e) {
throw new IllegalStateException("Failed to serialize Booster - operation was interrupted", e);
}
}
EvalMetric getEvalMetric() {
return _evalMetric;
}
Booster doUpdate(int tid) {
try {
return invoke(new UpdateBooster(tid));
} catch (InterruptedException e) {
throw new IllegalStateException("Boosting iteration failed - operation was interrupted", e);
}
}
static XGBoostUpdater make(Key modelKey, DMatrix trainMat, DMatrix validMat, BoosterParms boosterParms,
byte[] checkpoint, Map rabitEnv) {
XGBoostUpdater updater = new XGBoostUpdater(modelKey, trainMat, validMat, boosterParms, checkpoint, rabitEnv);
updater.setUncaughtExceptionHandler(LoggingExceptionHandler.INSTANCE);
if (updaters.putIfAbsent(modelKey, updater) != null)
throw new IllegalStateException("XGBoostUpdater for modelKey=" + modelKey + " already exists!");
return updater;
}
static void terminate(Key modelKey) {
XGBoostUpdater updater = updaters.remove(modelKey);
if (updater == null)
LOG.debug("XGBoostUpdater for modelKey=" + modelKey + " was already clean-up on node " + H2O.SELF);
else
updater.interrupt();
}
static XGBoostUpdater getUpdater(Key modelKey) {
XGBoostUpdater updater = updaters.get(modelKey);
if (updater == null) {
throw new IllegalStateException("XGBoostUpdater for modelKey=" + modelKey + " was not found!");
}
return updater;
}
private interface BoosterCallable {
E call() throws XGBoostError;
}
private static class LoggingExceptionHandler implements UncaughtExceptionHandler {
private static LoggingExceptionHandler INSTANCE = new LoggingExceptionHandler();
@Override
public void uncaughtException(Thread t, Throwable e) {
LOG.error("Uncaught exception in " + t.getName(), e);
}
}
}