All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.SharedTreeMojoWriter Maven / Gradle / Ivy

package hex.tree;

import hex.ModelMojoWriter;
import hex.glm.GLMModel;
import water.DKV;
import water.Key;
import water.Value;
import water.exceptions.H2OKeyNotFoundArgumentException;

import java.io.IOException;

/**
 * Shared Mojo definition file for DRF and GBM models.
 */
public abstract class SharedTreeMojoWriter<
      M extends SharedTreeModel,
      P extends SharedTreeModel.SharedTreeParameters,
      O extends SharedTreeModel.SharedTreeOutput
    > extends ModelMojoWriter {

  public SharedTreeMojoWriter() {}

  public SharedTreeMojoWriter(M model) {
    super(model);
  }

  @Override
  protected void writeModelData() throws IOException {
    assert model._output._treeKeys.length == model._output._ntrees;
    int nclasses = model._output.nclasses();
    int ntreesPerClass = model.binomialOpt() && nclasses == 2 ? 1 : nclasses;
    writekv("n_trees", model._output._ntrees);
    writekv("n_trees_per_class", ntreesPerClass);
    if (model._output._calib_model != null) {
      GLMModel calibModel = model._output._calib_model;
      double[] beta = calibModel.beta();
      assert beta.length == nclasses; // n-1 coefficients + 1 intercept
      writekv("calib_method", "platt");
      writekv("calib_glm_beta", beta);
    }
    for (int i = 0; i < model._output._ntrees; i++) {
      for (int j = 0; j < ntreesPerClass; j++) {
        Key key = model._output._treeKeys[i][j];
        Value ctVal = key != null ? DKV.get(key) : null;
        if (ctVal == null)
          continue; //throw new H2OKeyNotFoundArgumentException("CompressedTree " + key + " not found");
        CompressedTree ct = ctVal.get();
        assert ct._nclass == nclasses;
        // assume ct._seed is useless and need not be persisted
        writeblob(String.format("trees/t%02d_%03d.bin", j, i), ct._bits);

        if (model._output._treeKeysAux!=null) {
          key = model._output._treeKeysAux[i][j];
          ctVal = key != null ? DKV.get(key) : null;
          if (ctVal != null) {
            ct = ctVal.get();
            assert ct._nclass == -1;
            writeblob(String.format("trees/t%02d_%03d_aux.bin", j, i), ct._bits);
          }
        }
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy