hex.tree.SharedTreeMojoWriter Maven / Gradle / Ivy
package hex.tree;
import hex.Model;
import hex.ModelMojoWriter;
import hex.glm.GLMModel;
import hex.isotonic.IsotonicRegressionModel;
import water.DKV;
import water.Key;
import water.Value;
import java.io.IOException;
/**
* Shared Mojo definition file for DRF and GBM models.
*/
public abstract class SharedTreeMojoWriter<
M extends SharedTreeModel,
P extends SharedTreeModel.SharedTreeParameters,
O extends SharedTreeModel.SharedTreeOutput
> extends ModelMojoWriter {
public SharedTreeMojoWriter() {}
public SharedTreeMojoWriter(M model) {
super(model);
}
@Override
protected void writeModelData() throws IOException {
assert model._output._treeKeys.length == model._output._ntrees;
int nclasses = model._output.nclasses();
int ntreesPerClass = model.binomialOpt() && nclasses == 2 ? 1 : nclasses;
writekv("n_trees", model._output._ntrees);
writekv("n_trees_per_class", ntreesPerClass);
if (model._output.isCalibrated()) {
final CalibrationHelper.CalibrationMethod calibMethod = model._output.getCalibrationMethod();
final Model, ?, ?> calibModel = model._output.calibrationModel();
writekv("calib_method", calibMethod.getId());
switch (calibMethod) {
case PlattScaling:
double[] beta = ((GLMModel) calibModel).beta();
assert beta.length == nclasses; // n-1 coefficients + 1 intercept
writekv("calib_glm_beta", beta);
break;
case IsotonicRegression:
IsotonicRegressionModel isotonic = (IsotonicRegressionModel) calibModel;
write(isotonic.toIsotonicCalibrator());
break;
default:
throw new UnsupportedOperationException("MOJO is not (yet) support for calibration model " + calibMethod);
}
}
writekv("_genmodel_encoding", model.getGenModelEncoding());
String[] origNames = model._output._origNames;
if (origNames != null) {
int nOrigNames = origNames.length;
writekv("_n_orig_names", nOrigNames);
writeStringArray(origNames, "_orig_names");
}
if (model._output._origDomains != null) {
int nOrigDomainValues = model._output._origDomains.length;
writekv("_n_orig_domain_values", nOrigDomainValues);
for (int i=0; i < nOrigDomainValues; i++) {
String[] currOrigDomain = model._output._origDomains[i];
writekv("_m_orig_domain_values_" + i, currOrigDomain == null ? 0 : currOrigDomain.length);
if (currOrigDomain != null) {
writeStringArray(currOrigDomain, "_orig_domain_values_" + i);
}
}
}
writekv("_orig_projection_array", model._output._orig_projection_array);
for (int i = 0; i < model._output._ntrees; i++) {
for (int j = 0; j < ntreesPerClass; j++) {
Key key = model._output._treeKeys[i][j];
Value ctVal = key != null ? DKV.get(key) : null;
if (ctVal == null)
continue; //throw new H2OKeyNotFoundArgumentException("CompressedTree " + key + " not found");
CompressedTree ct = ctVal.get();
// assume ct._seed is useless and need not be persisted
writeblob(String.format("trees/t%02d_%03d.bin", j, i), ct._bits);
if (model._output._treeKeysAux!=null) {
key = model._output._treeKeysAux[i][j];
ctVal = key != null ? DKV.get(key) : null;
if (ctVal != null) {
ct = ctVal.get();
writeblob(String.format("trees/t%02d_%03d_aux.bin", j, i), ct._bits);
}
}
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy