All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.xgboost4j.java.BoosterWrapper Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package ai.h2o.xgboost4j.java;

import hex.tree.xgboost.util.BoosterHelper;

import java.util.Map;

/**
 * Wrapper to expose package private methods
 */
public class BoosterWrapper {

    private final Booster booster;

    public BoosterWrapper(
        byte[] checkpointBoosterBytes,
        Map params,
        DMatrix train,
        DMatrix valid
    ) throws XGBoostError {
        if (checkpointBoosterBytes != null) {
            booster = BoosterHelper.loadModel(checkpointBoosterBytes);
            booster.setParams(params);
        } else {
            DMatrix[] cacheMats = valid == null ? new DMatrix[]{train} : new DMatrix[]{train, valid};
            booster = Booster.newBooster(params, cacheMats);
            booster.loadRabitCheckpoint();
        }
        booster.saveRabitCheckpoint();
    }

    public void update(DMatrix dtrain, int iter) throws XGBoostError {
        booster.update(dtrain, iter);
    }

    public String evalSet(DMatrix dtrain, DMatrix dvalid, int iter) throws XGBoostError {
        if (dvalid == null) {
            return booster.evalSet(new DMatrix[]{dtrain}, new String[]{"train"}, iter);
        } else {
            return booster.evalSet(new DMatrix[]{dtrain, dvalid}, new String[]{"train", "valid"}, iter);
        }
    }

    public void saveRabitCheckpoint() throws XGBoostError {
        booster.saveRabitCheckpoint();
    }

    public byte[] toByteArray() throws XGBoostError {
        return booster.toByteArray();
    }

    public void dispose() {
        booster.dispose();
    }

    public Booster getBooster() {
        return booster;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy