All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.XGBoostMojoWriter Maven / Gradle / Ivy

package hex.tree.xgboost;

import hex.Model;
import hex.ModelMojoWriter;
import hex.glm.GLMModel;
import hex.isotonic.IsotonicRegressionModel;
import hex.tree.CalibrationHelper;

import java.io.IOException;
import java.nio.charset.Charset;

/**
 * MOJO support for XGBoost model.
 */
public class XGBoostMojoWriter extends ModelMojoWriter {

  @SuppressWarnings("unused")  // Called through reflection in ModelBuildersHandler
  public XGBoostMojoWriter() {}

  public XGBoostMojoWriter(XGBoostModel model) {
    super(model);
  }

  @Override public String mojoVersion() {
    return "1.10";
  }

  @Override
  protected void writeModelData() throws IOException {
    writeblob("boosterBytes", this.model.model_info()._boosterBytes);
    byte[] auxNodeWeightBytes = this.model.model_info().auxNodeWeightBytes();
    if (auxNodeWeightBytes != null) {
      writeblob("auxNodeWeights", auxNodeWeightBytes);
    }
    writekv("nums", model._output._nums);
    writekv("cats", model._output._cats);
    writekv("cat_offsets", model._output._catOffsets);
    writekv("use_all_factor_levels", model._output._useAllFactorLevels);
    writekv("sparse", model._output._sparse);
    writekv("booster", model._parms._booster.toString());
    writekv("ntrees", model._output._ntrees);
    writeblob("feature_map", model.model_info().getFeatureMap().getBytes(Charset.forName("UTF-8")));
    writekv("use_java_scoring_by_default", true);
    if (model._output.isCalibrated()) {
      final CalibrationHelper.CalibrationMethod calibMethod = model._output.getCalibrationMethod();
      final Model calibModel = model._output.calibrationModel();
      writekv("calib_method", calibMethod.getId());
      switch (calibMethod) {
        case PlattScaling:
          double[] beta = ((GLMModel) calibModel).beta();
          assert beta.length == model._output.nclasses(); // n-1 coefficients + 1 intercept
          writekv("calib_glm_beta", beta);
          break;
        case IsotonicRegression:
          IsotonicRegressionModel isotonic = (IsotonicRegressionModel) calibModel;
          write(isotonic.toIsotonicCalibrator());
          break;
        default:
          throw new UnsupportedOperationException("MOJO is not (yet) support for calibration model " + calibMethod);
      }
    }
    writekv("has_offset", model._output.hasOffset());
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy