All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.task.XGBoostUpdateTask Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree.xgboost.task;

import ai.h2o.xgboost4j.java.Booster;
import hex.tree.xgboost.EvalMetric;
import org.apache.log4j.Logger;
import water.*;

public class XGBoostUpdateTask extends AbstractXGBoostTask {

    private static final Logger LOG = Logger.getLogger(XGBoostUpdateTask.class);

    private final int _tid;

    public XGBoostUpdateTask(XGBoostSetupTask setupTask, int tid) {
        super(setupTask);
        _tid = tid;
    }

    @Override
    protected void execute() {
        Booster booster = XGBoostUpdater.getUpdater(_modelKey).doUpdate(_tid);
        if (booster == null)
            throw new IllegalStateException("Boosting iteration didn't produce a valid Booster.");
    }

    public byte[] getBoosterBytes() {
        final H2ONode boosterNode = getBoosterNode();
        final byte[] boosterBytes;
        if (H2O.SELF.equals(boosterNode)) {
            boosterBytes = XGBoostUpdater.getUpdater(_modelKey).getBoosterBytes();
        } else {
            LOG.debug("Booster will be retrieved from a remote node, node=" + boosterNode);
            FetchBoosterTask t = new FetchBoosterTask(_modelKey);
            boosterBytes = new RPC<>(boosterNode, t).call().get()._boosterBytes;
        }
        return boosterBytes;
    }

    public EvalMetric getEvalMetric() {
        final H2ONode boosterNode = getBoosterNode();
        final EvalMetric evalMetric;
        if (H2O.SELF.equals(boosterNode)) {
            evalMetric = XGBoostUpdater.getUpdater(_modelKey).getEvalMetric();
        } else {
            LOG.debug("CustomMetric will be retrieved from a remote node, node=" + boosterNode);
            FetchEvalMetricTask t = new FetchEvalMetricTask(_modelKey);
            evalMetric = new RPC<>(boosterNode, t).call().get()._evalMetric;
        }
        return evalMetric;
    }

    private static class FetchBoosterTask extends DTask {
        private final Key _modelKey;

        // OUT
        private byte[] _boosterBytes;

        private FetchBoosterTask(Key modelKey) {
            _modelKey = modelKey;
        }

        @Override
        public void compute2() {
            _boosterBytes = XGBoostUpdater.getUpdater(_modelKey).getBoosterBytes();
            tryComplete();
        }
    }

    private static class FetchEvalMetricTask extends DTask {
        private final Key _modelKey;

        // OUT
        private EvalMetric _evalMetric;

        private FetchEvalMetricTask(Key modelKey) {
            _modelKey = modelKey;
        }

        @Override
        public void compute2() {
            _evalMetric = XGBoostUpdater.getUpdater(_modelKey).getEvalMetric();
            tryComplete();
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy