All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.dmlc.xgboost4j.java.XGBoostUpdateTask Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package ml.dmlc.xgboost4j.java;

import hex.tree.xgboost.XGBoost;
import hex.tree.xgboost.XGBoostModel;
import water.*;
import water.util.Log;

public class XGBoostUpdateTask extends AbstractXGBoostTask {

    private final int _tid;

    public XGBoostUpdateTask(XGBoostSetupTask setupTask, int tid) {
        super(setupTask);
        _tid = tid;
    }

    @Override
    protected void execute() {
        Booster booster = XGBoostUpdater.getUpdater(_modelKey).doUpdate(_tid);
        if (booster == null)
            throw new IllegalStateException("Boosting iteration didn't produce a valid Booster.");
    }

    // This is called from driver
    public byte[] getBoosterBytes() {
        final H2ONode boosterNode = getBoosterNode();
        final byte[] boosterBytes;
        if (H2O.SELF.equals(boosterNode)) {
            boosterBytes = XGBoostUpdater.getUpdater(_modelKey).getBoosterBytes();
        } else {
            Log.debug("Booster will be retrieved from a remote node, node=" + boosterNode);
            FetchBoosterTask t = new FetchBoosterTask(_modelKey);
            boosterBytes = new RPC<>(boosterNode, t).call().get()._boosterBytes;
        }
        return boosterBytes;
    }

    private static class FetchBoosterTask extends DTask {
        private final Key _modelKey;

        // OUT
        private byte[] _boosterBytes;

        private FetchBoosterTask(Key modelKey) {
            _modelKey = modelKey;
        }

        @Override
        public void compute2() {
            _boosterBytes = XGBoostUpdater.getUpdater(_modelKey).getBoosterBytes();
            tryComplete();
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy