All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.XGBoostModel Maven / Gradle / Ivy

package hex.tree.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import biz.k11i.xgboost.tree.RegTreeNodeStat;
import hex.*;
import hex.genmodel.algos.tree.*;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.FriedmanPopescusH;
import hex.tree.CalibrationHelper;
import hex.tree.xgboost.predict.*;
import hex.tree.xgboost.util.PredictConfiguration;
import hex.util.EffectiveParametersUtils;
import org.apache.log4j.Logger;
import water.*;
import water.codegen.CodeGeneratorPipeline;
import water.fvec.Frame;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.util.*;

import java.util.*;
import java.util.function.Consumer;
import java.util.stream.Stream;

import static hex.genmodel.algos.xgboost.XGBoostMojoModel.ObjectiveType;
import static hex.tree.xgboost.XGBoost.makeDataInfo;
import static hex.tree.xgboost.util.GpuUtils.hasGPU;
import static water.H2O.OptArgs.SYSTEM_PROP_PREFIX;

public class XGBoostModel extends Model 
        implements SharedTreeGraphConverter, Model.LeafNodeAssignment, Model.Contributions, FeatureInteractionsCollector, Model.UpdateAuxTreeWeights, FriedmanPopescusHCollector {

  private static final Logger LOG = Logger.getLogger(XGBoostModel.class);

  private static final String PROP_VERBOSITY = H2O.OptArgs.SYSTEM_PROP_PREFIX + "xgboost.verbosity";
  private static final String PROP_NTHREAD = SYSTEM_PROP_PREFIX + "xgboost.nthreadMax";

  private XGBoostModelInfo model_info;

  public XGBoostModelInfo model_info() { return model_info; }

  public static class XGBoostParameters extends Model.Parameters implements Model.GetNTrees, CalibrationHelper.ParamsWithCalibration {
    public enum TreeMethod {
      auto, exact, approx, hist
    }
    public enum GrowPolicy {
      depthwise, lossguide
    }
    public enum Booster {
      gbtree, gblinear, dart
    }
    public enum DartSampleType {
      uniform, weighted
    }
    public enum DartNormalizeType {
      tree, forest
    }
    public enum DMatrixType {
      auto, dense, sparse
    }
    public enum Backend {
      auto, gpu, cpu
    }
    public enum FeatureSelector {
      cyclic, shuffle, random, greedy, thrifty
    }
    public enum Updater {
      gpu_hist, shotgun, coord_descent, gpu_coord_descent,
    }

    // H2O GBM options
    public boolean _quiet_mode = true;

    public int _ntrees = 50; // Number of trees in the final model. Grid Search, comma sep values:50,100,150,200
    /**
     * @deprecated will be removed in 3.30.0.1, use _ntrees
     */
    public int _n_estimators; // This doesn't seem to be used anywhere... (not in clients)

    public int _max_depth = 6; // Maximum tree depth. Grid Search, comma sep values:5,7

    public double _min_rows = 1;
    public double _min_child_weight = 1;

    public double _learn_rate = 0.3;
    public double _eta = 0.3;

    public double _learn_rate_annealing = 1;

    public double _sample_rate = 1.0;
    public double _subsample = 1.0;

    public double _col_sample_rate = 1.0;
    public double _colsample_bylevel = 1.0;
    public double _colsample_bynode = 1.0;

    public double _col_sample_rate_per_tree = 1.0; //fraction of columns to sample for each tree
    public double _colsample_bytree = 1.0;

    public KeyValue[] _monotone_constraints;
    public String[][] _interaction_constraints;

    public float _max_abs_leafnode_pred = 0;
    public float _max_delta_step = 0;

    public int _score_tree_interval = 0; // score every so many trees (no matter what)
    public int _initial_score_interval = 4000; //Adding this parameter to take away the hard coded value of 4000 for scoring the first  4 secs
    public int _score_interval = 4000; //Adding this parameter to take away the hard coded value of 4000 for scoring each iteration every 4 secs
    public float _min_split_improvement = 0;
    public float _gamma;

    // Runtime options
    public int _nthread = -1;
    public String _save_matrix_directory; // dump the xgboost matrix to this directory
    public boolean _build_tree_one_node = false; // force to run on single node

    // LightGBM specific (only for grow_policy == lossguide)
    public int _max_bins = 256;
    public int _max_leaves = 0;

    // XGBoost specific options
    public TreeMethod _tree_method = TreeMethod.auto;
    public GrowPolicy _grow_policy = GrowPolicy.depthwise;
    public Booster _booster = Booster.gbtree;
    public DMatrixType _dmatrix_type = DMatrixType.auto;
    public float _reg_lambda = 1;
    public float _reg_alpha = 0;
    public float _scale_pos_weight = 1;

    // Platt scaling (by default)
    public boolean _calibrate_model;
    public Key _calibration_frame;
    public CalibrationHelper.CalibrationMethod _calibration_method = CalibrationHelper.CalibrationMethod.AUTO;

    // Dart specific (booster == dart)
    public DartSampleType _sample_type = DartSampleType.uniform;
    public DartNormalizeType _normalize_type = DartNormalizeType.tree;
    public float _rate_drop = 0;
    public boolean _one_drop = false;
    public float _skip_drop = 0;
    public int[] _gpu_id; // which GPU to use
    public Backend _backend = Backend.auto;

    // GBLiner specific (booster == gblinear)
    // lambda, alpha support also for gbtree
    public FeatureSelector _feature_selector = FeatureSelector.cyclic;
    public int _top_k;
    public Updater _updater;

    public String _eval_metric;
    public boolean _score_eval_metric_only;

    public String algoName() { return "XGBoost"; }
    public String fullName() { return "XGBoost"; }
    public String javaName() { return XGBoostModel.class.getName(); }

    @Override
    public long progressUnits() {
      return _ntrees;
    }

    /**
     * Finds parameter settings that are not available on GPU backend.
     * In this case the CPU backend should be used instead of GPU.
     * @return map of parameter name -> parameter value
     */
    Map gpuIncompatibleParams() {
      Map incompat = new HashMap<>();
      if (!(TreeMethod.auto == _tree_method || TreeMethod.hist == _tree_method) && Booster.gblinear != _booster) {
        incompat.put("tree_method", "Only auto and hist are supported tree_method on GPU backend.");
      } 
      if (_max_depth > 15 || _max_depth < 1) {
        incompat.put("max_depth",  _max_depth + " . Max depth must be greater than 0 and lower than 16 for GPU backend.");
      }
      if (_grow_policy == GrowPolicy.lossguide)
        incompat.put("grow_policy", GrowPolicy.lossguide); // See PUBDEV-5302 (param.grow_policy != TrainParam::kLossGuide Loss guided growth policy not supported. Use CPU algorithm.)
      return incompat;
    }

    Map monotoneConstraints() {
      if (_monotone_constraints == null || _monotone_constraints.length == 0) {
        return Collections.emptyMap();
      }
      Map constraints = new HashMap<>(_monotone_constraints.length);
      for (KeyValue constraint : _monotone_constraints) {
        final double val = constraint.getValue();
        if (val == 0) {
          continue;
        }
        if (constraints.containsKey(constraint.getKey())) {
          throw new IllegalStateException("Duplicate definition of constraint for feature '" + constraint.getKey() + "'.");
        }
        final int direction = val < 0 ? -1 : 1;
        constraints.put(constraint.getKey(), direction);
      }
      return constraints;
    }

    @Override
    public int getNTrees() {
      return _ntrees;
    }

    @Override
    public Frame getCalibrationFrame() {
      return _calibration_frame != null ? _calibration_frame.get() : null;
    }

    @Override
    public boolean calibrateModel() {
      return _calibrate_model;
    }

    @Override
    public CalibrationHelper.CalibrationMethod getCalibrationMethod() {
      return _calibration_method;
    }

    @Override
    public void setCalibrationMethod(CalibrationHelper.CalibrationMethod calibrationMethod) {
      _calibration_method = calibrationMethod;
    }

    @Override
    public Parameters getParams() {
      return this;
    }

    static String[] CHECKPOINT_NON_MODIFIABLE_FIELDS = { 
        "_tree_method", "_grow_policy", "_booster", "_sample_rate", "_max_depth", "_min_rows" 
    };

  }

  @Override
  public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
    switch(_output.getModelCategory()) {
      case Binomial:    return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
      case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(), domain, _parms._auc_type);
      case Regression:  return new ModelMetricsRegression.MetricBuilderRegression();
      default: throw H2O.unimpl();
    }
  }

  public XGBoostModel(Key selfKey, XGBoostParameters parms, XGBoostOutput output, Frame train, Frame valid) {
    super(selfKey,parms,output);
    final DataInfo dinfo = makeDataInfo(train, valid, _parms);
    DKV.put(dinfo);
    setDataInfoToOutput(dinfo);
    model_info = new XGBoostModelInfo(parms, dinfo);
  }

  @Override
  public void initActualParamValues() {
    super.initActualParamValues();
    EffectiveParametersUtils.initFoldAssignment(_parms);
    _parms._backend = getActualBackend(_parms, true);
    _parms._tree_method = getActualTreeMethod(_parms);
    EffectiveParametersUtils.initCalibrationMethod(_parms);
  }

  public static XGBoostParameters.TreeMethod getActualTreeMethod(XGBoostParameters p) {
    // tree_method parameter is evaluated according to:
    // https://github.com/h2oai/xgboost/blob/96f61fb3be8c4fa0e160dd6e82677dfd96a5a9a1/src/gbm/gbtree.cc#L127 
    // + we don't use external-memory data matrix feature in h2o 
    // + https://github.com/h2oai/h2o-3/blob/b68e544d8dac3c5c0ed16759e6bf7e8288573ab5/h2o-extensions/xgboost/src/main/java/hex/tree/xgboost/XGBoostModel.java#L348
    if ( p._tree_method == XGBoostModel.XGBoostParameters.TreeMethod.auto) {
      if (p._backend == XGBoostParameters.Backend.gpu) {
        return XGBoostParameters.TreeMethod.hist;
      } else if (H2O.getCloudSize() > 1) {
        if (p._monotone_constraints != null && p._booster != XGBoostParameters.Booster.gblinear && p._backend != XGBoostParameters.Backend.gpu) {
          return XGBoostParameters.TreeMethod.hist;
        } else {
          return XGBoostModel.XGBoostParameters.TreeMethod.approx;
        }
      } else if (p.train() != null && p.train().numRows() >= (4 << 20)) {
        return XGBoostModel.XGBoostParameters.TreeMethod.approx;
      } else {
        return XGBoostModel.XGBoostParameters.TreeMethod.exact;
      }
    } else {
      return p._tree_method;
    }
  }

  public void initActualParamValuesAfterOutputSetup(boolean isClassifier, int nclasses) {
    EffectiveParametersUtils.initStoppingMetric(_parms, isClassifier);
    EffectiveParametersUtils.initCategoricalEncoding(_parms, Parameters.CategoricalEncodingScheme.OneHotInternal);
    EffectiveParametersUtils.initDistribution(_parms, nclasses);
    _parms._dmatrix_type = _output._sparse ? XGBoostModel.XGBoostParameters.DMatrixType.sparse : XGBoostModel.XGBoostParameters.DMatrixType.dense;
  }
  
  public static XGBoostParameters.Backend getActualBackend(XGBoostParameters p, boolean verbose) {
    Consumer log = verbose ? LOG::info : LOG::debug;
    if ( p._backend == XGBoostParameters.Backend.auto || p._backend == XGBoostParameters.Backend.gpu ) {
      if (H2O.getCloudSize() > 1 && !p._build_tree_one_node && !XGBoost.allowMultiGPU()) {
        log.accept("GPU backend not supported in distributed mode. Using CPU backend.");
        return XGBoostParameters.Backend.cpu;
      } else if (! p.gpuIncompatibleParams().isEmpty()) {
        log.accept("GPU backend not supported for the choice of parameters (" + p.gpuIncompatibleParams() + "). Using CPU backend.");
        return XGBoostParameters.Backend.cpu;
      } else if (hasGPU(H2O.CLOUD.members()[0], p._gpu_id)) {
        log.accept("Using GPU backend (gpu_id: " + Arrays.toString(p._gpu_id) + ").");
        return XGBoostParameters.Backend.gpu;
      } else {
        log.accept("No GPU (gpu_id: " + Arrays.toString(p._gpu_id) + ") found. Using CPU backend.");
        return XGBoostParameters.Backend.cpu;
      }
    } else {
      log.accept("Using CPU backend.");
      return XGBoostParameters.Backend.cpu;
    }
  }
  
  public static Map createParamsMap(XGBoostParameters p, int nClasses, String[] coefNames) {
    Map params = new HashMap<>();

    // Common parameters with H2O GBM
    if (p._n_estimators != 0) {
      LOG.info("Using user-provided parameter n_estimators instead of ntrees.");
      params.put("nround", p._n_estimators);
      p._ntrees = p._n_estimators;
    } else {
      params.put("nround", p._ntrees);
      p._n_estimators = p._ntrees;
    }
    if (p._eta != 0.3) {
      params.put("eta", p._eta);
      p._learn_rate = p._eta;
    } else {
      params.put("eta", p._learn_rate);
      p._eta = p._learn_rate;
    }
    params.put("max_depth", p._max_depth);
    if (System.getProperty(PROP_VERBOSITY) != null) {
      params.put("verbosity", System.getProperty(PROP_VERBOSITY));
    } else {
      params.put("silent", p._quiet_mode);
    }
    if (p._subsample != 1.0) {
      params.put("subsample", p._subsample);
      p._sample_rate = p._subsample;
    } else {
      params.put("subsample", p._sample_rate);
      p._subsample = p._sample_rate;
    }
    if (p._colsample_bytree != 1.0) {
      params.put("colsample_bytree", p._colsample_bytree);
      p._col_sample_rate_per_tree = p._colsample_bytree;
    } else {
      params.put("colsample_bytree", p._col_sample_rate_per_tree);
      p._colsample_bytree = p._col_sample_rate_per_tree;
    }
    if (p._colsample_bylevel != 1.0) {
      params.put("colsample_bylevel", p._colsample_bylevel);
      p._col_sample_rate = p._colsample_bylevel;
    } else {
      params.put("colsample_bylevel", p._col_sample_rate);
      p._colsample_bylevel = p._col_sample_rate;
    }
    if (p._colsample_bynode != 1.0) {
      params.put("colsample_bynode", p._colsample_bynode);
    }    
    if (p._max_delta_step != 0) {
      params.put("max_delta_step", p._max_delta_step);
      p._max_abs_leafnode_pred = p._max_delta_step;
    } else {
      params.put("max_delta_step", p._max_abs_leafnode_pred);
      p._max_delta_step = p._max_abs_leafnode_pred;
    }
    params.put("seed", (int)(p._seed % Integer.MAX_VALUE));

    // XGBoost specific options
    params.put("grow_policy", p._grow_policy.toString());
    if (p._grow_policy == XGBoostParameters.GrowPolicy.lossguide) {
      params.put("max_bin", p._max_bins);
      params.put("max_leaves", p._max_leaves);
    }
    params.put("booster", p._booster.toString());
    if (p._booster == XGBoostParameters.Booster.dart) {
      params.put("sample_type", p._sample_type.toString());
      params.put("normalize_type", p._normalize_type.toString());
      params.put("rate_drop", p._rate_drop);
      params.put("one_drop", p._one_drop ? "1" : "0");
      params.put("skip_drop", p._skip_drop);
    }
    if (p._booster == XGBoostParameters.Booster.gblinear) {
      params.put("feature_selector", p._feature_selector.toString());
      params.put("top_k", p._top_k);
    }
    XGBoostParameters.Backend actualBackend = getActualBackend(p, true);
    XGBoostParameters.TreeMethod actualTreeMethod = getActualTreeMethod(p);
    if (actualBackend == XGBoostParameters.Backend.gpu) {
      if (p._gpu_id != null && p._gpu_id.length > 0) {
        params.put("gpu_id", p._gpu_id[0]);
      } else {
        params.put("gpu_id", 0);
      }
      // we are setting updater rather than tree_method here to keep CPU predictor, which is faster
      if (p._booster == XGBoostParameters.Booster.gblinear && p._updater == null) {
        LOG.info("Using gpu_coord_descent updater."); 
        params.put("updater", XGBoostParameters.Updater.gpu_coord_descent.toString());
      } else {
        LOG.info("Using gpu_hist tree method.");
        params.put("max_bin", p._max_bins);
        params.put("tree_method", XGBoostParameters.Updater.gpu_hist.toString());
      }
    } else if (p._booster == XGBoostParameters.Booster.gblinear && p._updater == null) {
      LOG.info("Using coord_descent updater.");
      params.put("updater", XGBoostParameters.Updater.coord_descent.toString());
    } else if (H2O.CLOUD.size() > 1 && p._tree_method == XGBoostParameters.TreeMethod.auto &&
        p._monotone_constraints != null) {
      LOG.info("Using hist tree method for distributed computation with monotone_constraints.");
      params.put("tree_method", actualTreeMethod.toString());
      params.put("max_bin", p._max_bins);
    } else {
      LOG.info("Using " + p._tree_method.toString() + " tree method.");
      params.put("tree_method", actualTreeMethod.toString());
      if (p._tree_method == XGBoostParameters.TreeMethod.hist) {
        params.put("max_bin", p._max_bins);
      }
    }
    if (p._updater != null) {
      LOG.info("Using user-provided updater.");
      params.put("updater", p._updater.toString());
    }
    if (p._min_child_weight != 1) {
      LOG.info("Using user-provided parameter min_child_weight instead of min_rows.");
      params.put("min_child_weight", p._min_child_weight);
      p._min_rows = p._min_child_weight;
    } else {
      params.put("min_child_weight", p._min_rows);
      p._min_child_weight = p._min_rows;
    }
    if (p._gamma != 0) {
      LOG.info("Using user-provided parameter gamma instead of min_split_improvement.");
      params.put("gamma", p._gamma);
      p._min_split_improvement = p._gamma;
    } else {
      params.put("gamma", p._min_split_improvement);
      p._gamma = p._min_split_improvement;
    }

    params.put("lambda", p._reg_lambda);
    params.put("alpha", p._reg_alpha);
    if (p._scale_pos_weight != 1)
      params.put("scale_pos_weight", p._scale_pos_weight);

    // objective function
    if (nClasses==2) {
      params.put("objective", ObjectiveType.BINARY_LOGISTIC.getId());
    } else if (nClasses==1) {
      if (p._distribution == DistributionFamily.gamma) {
        params.put("objective", ObjectiveType.REG_GAMMA.getId());
      } else if (p._distribution == DistributionFamily.tweedie) {
        params.put("objective", ObjectiveType.REG_TWEEDIE.getId());
        params.put("tweedie_variance_power", p._tweedie_power);
      } else if (p._distribution == DistributionFamily.poisson) {
        params.put("objective", ObjectiveType.COUNT_POISSON.getId());
      } else if (p._distribution == DistributionFamily.gaussian || p._distribution == DistributionFamily.AUTO) {
        params.put("objective", ObjectiveType.REG_SQUAREDERROR.getId());
      } else {
        throw new UnsupportedOperationException("No support for distribution=" + p._distribution.toString());
      }
    } else {
      params.put("objective", ObjectiveType.MULTI_SOFTPROB.getId());
      params.put("num_class", nClasses);
    }
    assert ObjectiveType.fromXGBoost((String) params.get("objective")) != null;

    // evaluation metric
    if (p._eval_metric != null) {
      params.put("eval_metric", p._eval_metric);
    }

    final int nthreadMax = getMaxNThread();
    final int nthread = p._nthread != -1 ? Math.min(p._nthread, nthreadMax) : nthreadMax;
    if (nthread < p._nthread) {
      LOG.warn("Requested nthread=" + p._nthread + " but the cluster has only " + nthreadMax + " available." +
              "Training will use nthread=" + nthread + " instead of the user specified value.");
    }
    params.put("nthread", nthread);

    Map monotoneConstraints = p.monotoneConstraints();
    if (! monotoneConstraints.isEmpty()) {
      int constraintsUsed = 0;
      StringBuilder sb = new StringBuilder();
      sb.append("(");
      for (String coef : coefNames) {
        final String direction;
        if (monotoneConstraints.containsKey(coef)) {
          direction = monotoneConstraints.get(coef).toString();
          constraintsUsed++;
        } else {
          direction = "0";
        }
        sb.append(direction);
        sb.append(",");
      }
      sb.replace(sb.length()-1, sb.length(), ")");
      params.put("monotone_constraints", sb.toString());
      assert constraintsUsed == monotoneConstraints.size();
    }
    
    String[][] interactionConstraints = p._interaction_constraints;
    if(interactionConstraints != null && interactionConstraints.length > 0) {
      if(!p._categorical_encoding.equals(Parameters.CategoricalEncodingScheme.OneHotInternal)){
        throw new IllegalArgumentException("No support interaction constraint for categorical encoding = " + p._categorical_encoding.toString()+". Constraint interactions are available only for ``AUTO`` (``one_hot_internal`` or ``OneHotInternal``) categorical encoding.");
      }
      params.put("interaction_constraints", createInteractions(interactionConstraints, coefNames, p));
    }
    
    LOG.info("XGBoost Parameters:");
    for (Map.Entry s : params.entrySet()) {
      LOG.info(" " + s.getKey() + " = " + s.getValue());
    }
    LOG.info("");
    return Collections.unmodifiableMap(params);
  }
  
  private static String createInteractions(String[][] interaction_constraints, String[] coefNames, XGBoostParameters params){
    StringBuilder sb = new StringBuilder();
    sb.append("[");
    for (String[] list : interaction_constraints) {
      sb.append("[");
      for (String item : list) {
        if(item.equals(params._response_column)){
          throw new IllegalArgumentException("'interaction_constraints': Column with the name '" + item + "'is used as response column and cannot be used in interaction.");
        }
        if(item.equals(params._weights_column)){
          throw new IllegalArgumentException("'interaction_constraints': Column with the name '" + item + "'is used as weights column and cannot be used in interaction.");
        }
        if(item.equals(params._fold_column)){
          throw new IllegalArgumentException("'interaction_constraints': Column with the name '" + item + "'is used as fold column and cannot be used in interaction.");
        }
        if(params._ignored_columns != null && ArrayUtils.find(params._ignored_columns, item) != -1) {
          throw new IllegalArgumentException("'interaction_constraints': Column with the name '" + item + "'is set in ignored columns and cannot be used in interaction.");
        }
        // first find only name
        int start = ArrayUtils.findWithPrefix(coefNames, item);
        // find start index and add indices until end index
        if (start == -1) {
          throw new IllegalArgumentException("'interaction_constraints': Column with name '" + item + "' is not in the frame.");
        } else if(start > -1){               // find exact position - no encoding  
          sb.append(start).append(",");
        } else {              // find first occur of the name with prefix - encoding
          start = -start - 2;
          assert coefNames[start].startsWith(item): "The column name should be find correctly.";
          // iterate until find all encoding indices
          int end = start;
          while (end < coefNames.length && coefNames[end].startsWith(item)) {
            sb.append(end).append(",");
            end++;
          }
        }
      }
      sb.replace(sb.length() - 1, sb.length(), "],");
    }
    sb.replace(sb.length() - 1, sb.length(), "]");
    return sb.toString();
  }

  public static BoosterParms createParams(XGBoostParameters p, int nClasses, String[] coefNames) {
    return BoosterParms.fromMap(createParamsMap(p, nClasses, coefNames));
  }

  /** Performs deep clone of given model.  */
  protected XGBoostModel deepClone(Key result) {
    XGBoostModel newModel = IcedUtils.deepCopy(this);
    newModel._key = result;
    // Do not clone model metrics
    newModel._output.clearModelMetrics(false);
    newModel._output._training_metrics = null;
    newModel._output._validation_metrics = null;
    return newModel;
  }
  
  static int getMaxNThread() {
    if (System.getProperty(PROP_NTHREAD) != null) {
      return Integer.getInteger(PROP_NTHREAD);
    } else {
      int maxNodesPerHost = 1;
      Set checkedNodes = new HashSet<>();
      for (H2ONode node : H2O.CLOUD.members()) {
        String nodeHost = node.getIp();
        if (!checkedNodes.contains(nodeHost)) {
          checkedNodes.add(nodeHost);
          long cnt = Stream.of(H2O.CLOUD.members()).filter(h -> h.getIp().equals(nodeHost)).count();
          if (cnt > maxNodesPerHost) {
            maxNodesPerHost = (int) cnt;
          }
        }
      }
      return Math.max(1, H2O.ARGS.nthreads / maxNodesPerHost);
    }
  }

  @Override protected AutoBuffer writeAll_impl(AutoBuffer ab) {
    ab.putKey(model_info.getDataInfoKey());
    ab.putKey(model_info.getAuxNodeWeightsKey());
    return super.writeAll_impl(ab);
  }

  @Override protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
    ab.getKey(model_info.getDataInfoKey(), fs);
    ab.getKey(model_info.getAuxNodeWeightsKey(), fs);
    return super.readAll_impl(ab, fs);
  }

  @Override
  public XGBoostMojoWriter getMojo() {
    return new XGBoostMojoWriter(this);
  }

  private ModelMetrics makeMetrics(Frame data, Frame originalData, boolean isTrain, String description) {
    LOG.debug("Making metrics: " + description);
    return new XGBoostModelMetrics(_output, data, originalData, isTrain, this, CFuncRef.from(_parms._custom_metric_func)).compute();
  }

  final void doScoring(Frame train, Frame trainOrig, CustomMetric trainCustomMetric,
                       Frame valid, Frame validOrig, CustomMetric validCustomMetric) {
    ModelMetrics mm = makeMetrics(train, trainOrig, true, "Metrics reported on training frame");
    _output._training_metrics = mm;
    if (trainCustomMetric == null) {
      _output._scored_train[_output._ntrees].fillFrom(mm, mm._custom_metric);
    } else {
      _output._scored_train[_output._ntrees].fillFrom(mm, trainCustomMetric);
    }
    addModelMetrics(mm);
    // Optional validation part
    if (valid != null) {
      mm = makeMetrics(valid, validOrig, false, "Metrics reported on validation frame");
      _output._validation_metrics = mm;
      if (validCustomMetric == null) {
        _output._scored_valid[_output._ntrees].fillFrom(mm, mm._custom_metric);
      } else {
        _output._scored_valid[_output._ntrees].fillFrom(mm, validCustomMetric);
      }
      addModelMetrics(mm);
    }
  }

  @Override
  protected Frame postProcessPredictions(Frame adaptedFrame, Frame predictFr, Job j) {
    return CalibrationHelper.postProcessPredictions(predictFr, j, _output);
  }

  @Override
  protected double[] score0(double[] data, double[] preds) {
    return score0(data, preds, 0.0);
  }

  @Override // per row scoring is slow and should be avoided!
  public double[] score0(final double[] data, final double[] preds, final double offset) {
    final DataInfo di = model_info.dataInfo();
    assert di != null;
    MutableOneHotEncoderFVec row = new MutableOneHotEncoderFVec(di, _output._sparse);
    row.setInput(data);
    Predictor predictor = makePredictor(true);
    float[] out;
    if (_output.hasOffset()) {
      out = predictor.predict(row, (float) offset);
    } else if (offset != 0) {
      throw new UnsupportedOperationException("Unsupported: offset != 0");
    } else {
      out = predictor.predict(row);
    }
    return XGBoostMojoModel.toPreds(data, out, preds, _output.nclasses(), _output._priorClassDist, defaultThreshold());
  }

  @Override
  protected XGBoostBigScorePredict setupBigScorePredict(BigScore bs) {
    return setupBigScorePredict(false);
  }

  public XGBoostBigScorePredict setupBigScorePredict(boolean isTrain) {
    DataInfo di = model_info().scoringInfo(isTrain); // always for validation scoring info for scoring (we are not in the training phase)
    return PredictConfiguration.useJavaScoring() ? setupBigScorePredictJava(di) : setupBigScorePredictNative(di);
  }

  private XGBoostBigScorePredict setupBigScorePredictNative(DataInfo di) {
    BoosterParms boosterParms = XGBoostModel.createParams(_parms, _output.nclasses(), di.coefNames());
    return new XGBoostNativeBigScorePredict(model_info, _parms, _output, di, boosterParms, defaultThreshold());
  }

  private XGBoostBigScorePredict setupBigScorePredictJava(DataInfo di) {
    return new XGBoostJavaBigScorePredict(model_info, _output, di, _parms, defaultThreshold());
  }
  
  public XGBoostVariableImportance setupVarImp() {
    if (PredictConfiguration.useJavaScoring()) {
      return new XGBoostJavaVariableImportance(model_info);
    } else {
      return new XGBoostNativeVariableImportance(_key, model_info.getFeatureMap());
    }
  }

  @Override
  public Frame scoreContributions(Frame frame, Key destination_key) {
    return scoreContributions(frame, destination_key, null, new ContributionsOptions());
  }

  @Override
  public Frame scoreContributions(Frame frame, Key destination_key, Job j, ContributionsOptions options) {
    Frame adaptFrm = new Frame(frame);
    adaptTestForTrain(adaptFrm, true, false);

    DataInfo di = model_info().dataInfo();
    assert di != null;
    final String[] featureContribNames = ContributionsOutputFormat.Compact.equals(options._outputFormat) ? 
            _output.features() : di.coefNames();
    final String[] outputNames = ArrayUtils.append(featureContribNames, "BiasTerm");

    if (options.isSortingRequired()) {
      final ContributionComposer contributionComposer = new ContributionComposer();
      int topNAdjusted = contributionComposer.checkAndAdjustInput(options._topN, featureContribNames.length);
      int bottomNAdjusted = contributionComposer.checkAndAdjustInput(options._bottomN, featureContribNames.length);

      int outputSize = Math.min((topNAdjusted+bottomNAdjusted)*2, featureContribNames.length*2);
      String[] names = new String[outputSize+1];
      byte[] types = new byte[outputSize+1];
      String[][] domains = new String[outputSize+1][outputNames.length];

      composeScoreContributionTaskMetadata(names, types, domains, featureContribNames, options);

      return new PredictTreeSHAPSortingTask(di, model_info(), _output, options)
              .withPostMapAction(JobUpdatePostMap.forJob(j))
              .doAll(types, adaptFrm)
              .outputFrame(destination_key, names, domains);
    }

    return new PredictTreeSHAPTask(di, model_info(), _output, options)
            .withPostMapAction(JobUpdatePostMap.forJob(j))
            .doAll(outputNames.length, Vec.T_NUM, adaptFrm)
            .outputFrame(destination_key, outputNames, null);
  }


  @Override
  public Frame scoreContributions(Frame frame, Key destination_key, Job j, ContributionsOptions options, Frame backgroundFrame) {
    Log.info("Starting contributions calculation for " + this._key + "...");
    try (Scope.Safe s = Scope.safe(frame, backgroundFrame)) {
      Frame contributions;
      if (null == backgroundFrame) {
        contributions = scoreContributions(frame, destination_key, j, options);
      } else {
        Frame adaptedFrame = adaptFrameForScore(frame, false);
        DKV.put(adaptedFrame);
        Frame adaptedBgFrame = adaptFrameForScore(backgroundFrame, false);
        DKV.put(adaptedBgFrame);

        DataInfo di = model_info().dataInfo();
        assert di != null;
        final String[] featureContribNames = ContributionsOutputFormat.Compact.equals(options._outputFormat) ?
                _output.features() : di.coefNames();
        final String[] outputNames = ArrayUtils.append(featureContribNames, "BiasTerm");


        contributions = new PredictTreeSHAPWithBackgroundTask(di, model_info(), _output, options,
                adaptedFrame, adaptedBgFrame, options._outputPerReference, options._outputSpace)
                .runAndGetOutput(j, destination_key, outputNames);
      }
      return Scope.untrack(contributions);
    } finally {
      Log.info("Finished contributions calculation for " + this._key + "...");
    }
  }
  
  @Override
  public UpdateAuxTreeWeightsReport updateAuxTreeWeights(Frame frame, String weightsColumn) {
    if (weightsColumn == null) {
      throw new IllegalArgumentException("Weights column name is not defined");
    }
    Frame adaptFrm = new Frame(frame);
    Vec weights = adaptFrm.remove(weightsColumn);
    if (weights == null) {
      throw new IllegalArgumentException("Input frame doesn't contain weights column `" + weightsColumn + "`");
    }
    adaptTestForTrain(adaptFrm, true, false);
    // keep features only and re-introduce weights column at the end of the frame
    Frame featureFrm = new Frame(_output.features(), frame.vecs(_output.features()));
    featureFrm.add(weightsColumn, weights);

    DataInfo di = model_info().dataInfo();
    assert di != null;

    double[][] nodeWeights = new UpdateAuxTreeWeightsTask(_parms._distribution, di, model_info(), _output)
            .doAll(featureFrm)
            .getNodeWeights();
    AuxNodeWeights auxNodeWeights = new AuxNodeWeights(model_info().getAuxNodeWeightsKey(), nodeWeights);
    DKV.put(auxNodeWeights);

    UpdateAuxTreeWeightsReport report = new UpdateAuxTreeWeightsReport();
    report._warn_classes = new int[0];
    report._warn_trees = new int[0];
    for (int treeId = 0; treeId < nodeWeights.length; treeId++) {
      if (nodeWeights[treeId] == null)
        continue;
      for (double w : nodeWeights[treeId]) {
        if (w == 0) {
          report._warn_trees = ArrayUtils.append(report._warn_trees, treeId);
          report._warn_classes = ArrayUtils.append(report._warn_classes, 0);
          break;
        }
      }
    }
    return report;
  }

  @Override
  public Frame scoreLeafNodeAssignment(
      Frame frame, LeafNodeAssignmentType type, Key destination_key
  ) {
    AssignLeafNodeTask task = AssignLeafNodeTask.make(model_info.scoringInfo(false), _output, model_info._boosterBytes, type);
    Frame adaptFrm = new Frame(frame);
    adaptTestForTrain(adaptFrm, true, false);
    return task.execute(adaptFrm, destination_key);
  }

  private void setDataInfoToOutput(DataInfo dinfo) {
    _output.setNames(dinfo._adaptedFrame.names(), dinfo._adaptedFrame.typesStr());
    _output._domains = dinfo._adaptedFrame.domains();
    _output._nums = dinfo._nums;
    _output._cats = dinfo._cats;
    _output._catOffsets = dinfo._catOffsets;
    _output._useAllFactorLevels = dinfo._useAllFactorLevels;
  }

  @Override
  protected Futures remove_impl(Futures fs, boolean cascade) {
    DataInfo di = model_info().dataInfo();
    if (di != null) {
      di.remove(fs);
    }
    AuxNodeWeights anw = model_info().auxNodeWeights();
    if (anw != null) {
      anw.remove(fs);
    }
    if (_output._calib_model != null)
      _output._calib_model.remove(fs);
    return super.remove_impl(fs, cascade);
  }

  @Override
  public SharedTreeGraph convert(final int treeNumber, final String treeClassName) {
    GradBooster booster = XGBoostJavaMojoModel
            .makePredictor(model_info._boosterBytes, model_info.auxNodeWeightBytes())
            .getBooster();
    if (!(booster instanceof GBTree)) {
      throw new IllegalArgumentException("XGBoost model is not backed by a tree-based booster. Booster class is " + 
              booster.getClass().getCanonicalName());
    }

    final RegTree[][] groupedTrees = ((GBTree) booster).getGroupedTrees();
    final int treeClass = getXGBoostClassIndex(treeClassName);
    if (treeClass >= groupedTrees.length) {
      throw new IllegalArgumentException(String.format("Given XGBoost model does not have given class '%s'.", treeClassName));
    }

    final RegTree[] treesInGroup = groupedTrees[treeClass];

    if (treeNumber >= treesInGroup.length || treeNumber < 0) {
      throw new IllegalArgumentException(String.format("There is no such tree number for given class. Total number of trees is %d.", treesInGroup.length));
    }

    final RegTreeNode[] treeNodes = treesInGroup[treeNumber].getNodes();
    final RegTreeNodeStat[] treeNodeStats = treesInGroup[treeNumber].getStats();
    assert treeNodes.length >= 1;

    SharedTreeGraph sharedTreeGraph = new SharedTreeGraph();
    final SharedTreeSubgraph sharedTreeSubgraph = sharedTreeGraph.makeSubgraph(_output._training_metrics._description);

    final XGBoostUtils.FeatureProperties featureProperties = XGBoostUtils.assembleFeatureNames(model_info.dataInfo()); // XGBoost's usage of one-hot encoding assumed
    constructSubgraph(treeNodes, treeNodeStats, sharedTreeSubgraph.makeRootNode(), 0, sharedTreeSubgraph, featureProperties, true); // Root node is at index 0
    return sharedTreeGraph;
  }

  private static void constructSubgraph(final RegTreeNode[] xgBoostNodes, final RegTreeNodeStat[] xgBoostNodeStats, final SharedTreeNode sharedTreeNode,
                                        final int nodeIndex, final SharedTreeSubgraph sharedTreeSubgraph,
                                        final XGBoostUtils.FeatureProperties featureProperties, boolean inclusiveNA) {
    final RegTreeNode xgBoostNode = xgBoostNodes[nodeIndex];
    final RegTreeNodeStat xgBoostNodeStat = xgBoostNodeStats[nodeIndex];
    // Not testing for NaNs, as SharedTreeNode uses NaNs as default values.
    //No domain set, as the structure mimics XGBoost's tree, which is numeric-only
    if (featureProperties._oneHotEncoded[xgBoostNode.getSplitIndex()]) {
      //Shared tree model uses < to the left and >= to the right. Transforiming one-hot encoded categoricals
      // from 0 to 1 makes it fit the current split description logic
      sharedTreeNode.setSplitValue(1.0F);
    } else {
      sharedTreeNode.setSplitValue(xgBoostNode.getSplitCondition());
    }
    sharedTreeNode.setPredValue(xgBoostNode.getLeafValue());
    sharedTreeNode.setInclusiveNa(inclusiveNA);
    sharedTreeNode.setNodeNumber(nodeIndex);
    sharedTreeNode.setGain(xgBoostNodeStat.getGain());
    sharedTreeNode.setWeight(xgBoostNodeStat.getCover());
    
    if (!xgBoostNode.isLeaf()) {
      sharedTreeNode.setCol(xgBoostNode.getSplitIndex(), featureProperties._names[xgBoostNode.getSplitIndex()]);
      constructSubgraph(xgBoostNodes, xgBoostNodeStats, sharedTreeSubgraph.makeLeftChildNode(sharedTreeNode),
              xgBoostNode.getLeftChildIndex(), sharedTreeSubgraph, featureProperties, xgBoostNode.default_left());
      constructSubgraph(xgBoostNodes, xgBoostNodeStats, sharedTreeSubgraph.makeRightChildNode(sharedTreeNode),
          xgBoostNode.getRightChildIndex(), sharedTreeSubgraph, featureProperties, !xgBoostNode.default_left());
    }
  }

  @Override
  public SharedTreeGraph convert(int treeNumber, String treeClass, ConvertTreeOptions options) {
    return convert(treeNumber, treeClass); // options are currently not applicable to in-H2O conversion
  }

  private int getXGBoostClassIndex(final String treeClass) {
    final ModelCategory modelCategory = _output.getModelCategory();
    if(ModelCategory.Regression.equals(modelCategory) && (treeClass != null && !treeClass.isEmpty())){
      throw new IllegalArgumentException("There should be no tree class specified for regression.");
    }
    if ((treeClass == null || treeClass.isEmpty())) {
      // Binomial & regression problems do not require tree class to be specified, as there is only one available.
      // Such class is selected automatically for the user.
      switch (modelCategory) {
        case Binomial:
        case Regression:
          return 0;
        default:
          // If the user does not specify tree class explicitely and there are multiple options to choose from,
          // throw an error.
          throw new IllegalArgumentException(String.format("Model category '%s' requires tree class to be specified.",
                  modelCategory));
      }
    }

    final String[] domain = _output._domains[_output._domains.length - 1];
    final int treeClassIndex = ArrayUtils.find(domain, treeClass);

    if (ModelCategory.Binomial.equals(modelCategory) && treeClassIndex != 0) {
      throw new IllegalArgumentException(String.format("For binomial XGBoost model, only one tree for class %s has been built.", domain[0]));
    } else if (treeClassIndex < 0) {
      throw new IllegalArgumentException(String.format("No such class '%s' in tree.", treeClass));
    }

    return treeClassIndex;
  }

  @Override
  public boolean isFeatureUsedInPredict(String featureName) {
    int featureIdx = ArrayUtils.find(_output._varimp._names, featureName);
    if (featureIdx == -1 && _output._catOffsets.length > 1) { // feature is possibly categorical
      featureIdx = ArrayUtils.find(_output._names, featureName);
      if (featureIdx == -1 || !_output._column_types[featureIdx].equals("Enum")) return false;
      for (int i = 0; i < _output._varimp._names.length; i++) {
        if (_output._varimp._names[i].startsWith(featureName.concat(".")) && _output._varimp._varimp[i] != 0){
          return true;
        }
      }
      return false;
    }
    return featureIdx != -1 && _output._varimp._varimp[featureIdx] != 0d;
  }

  //--------------------------------------------------------------------------------------------------------------------
  // Serialization into a POJO
  //--------------------------------------------------------------------------------------------------------------------

  @Override
  protected boolean toJavaCheckTooBig() {
    return _output == null || _output._ntrees * _parms._max_depth > 1000;
  }

  @Override protected SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
    sb.nl();
    sb.ip("public boolean isSupervised() { return true; }").nl();
    sb.ip("public int nclasses() { return ").p(_output.nclasses()).p("; }").nl();
    return sb;
  }
  
  @Override
  protected void toJavaPredictBody(
      SBPrintStream sb, CodeGeneratorPipeline classCtx, CodeGeneratorPipeline fileCtx, boolean verboseCode
  ) {
    final String namePrefix = JCodeGen.toJavaId(_key.toString());
    Predictor p = makePredictor(false);
    XGBoostPojoWriter.make(p, namePrefix, _output, defaultThreshold()).renderJavaPredictBody(sb, fileCtx);
  }

  public FeatureInteractions getFeatureInteractions(int maxInteractionDepth, int maxTreeDepth, int maxDeepening) {

    FeatureInteractions featureInteractions = new FeatureInteractions();
    
    for (int i = 0; i < this._parms._ntrees; i++) {
      FeatureInteractions currentTreeFeatureInteractions = new FeatureInteractions();
      SharedTreeGraph sharedTreeGraph = convert(i, null);
      assert sharedTreeGraph.subgraphArray.size() == 1;
      SharedTreeSubgraph tree = sharedTreeGraph.subgraphArray.get(0);
      List interactionPath = new ArrayList<>();
      Set memo = new HashSet<>();
      
      FeatureInteractions.collectFeatureInteractions(tree.rootNode, interactionPath, 0, 0, 1, 0, 0,
              currentTreeFeatureInteractions, memo, maxInteractionDepth, maxTreeDepth, maxDeepening, i, false);
      featureInteractions.mergeWith(currentTreeFeatureInteractions);
    }
    
    return featureInteractions;
  }

  @Override
  public TwoDimTable[][] getFeatureInteractionsTable(int maxInteractionDepth, int maxTreeDepth, int maxDeepening) {
    return FeatureInteractions.getFeatureInteractionsTable(this.getFeatureInteractions(maxInteractionDepth,maxTreeDepth,maxDeepening));
  }

  Predictor makePredictor(boolean scoringOnly) {
    return PredictorFactory.makePredictor(model_info._boosterBytes, model_info.auxNodeWeightBytes(), scoringOnly);
  }

  protected Frame removeSpecialNNonNumericColumns(Frame frame) {
    Frame adaptFrm = new Frame(frame);
    adaptTestForTrain(adaptFrm, true, false);
    // remove non-feature columns
    adaptFrm.remove(_parms._response_column);
    adaptFrm.remove(_parms._fold_column);
    adaptFrm.remove(_parms._weights_column);
    adaptFrm.remove(_parms._offset_column);
    // remove non-numeric columns
    int numCols = adaptFrm.numCols()-1;
    for (int index=numCols; index>=0; index--) {
      if (!adaptFrm.vec(index).isNumeric())
        adaptFrm.remove(index);
    }
    return adaptFrm;
  }
  
  @Override
  public double getFriedmanPopescusH(Frame frame, String[] vars) {
    Frame adaptFrm = removeSpecialNNonNumericColumns(frame);

    for(int colId = 0; colId < adaptFrm.numCols(); colId++) {
      Vec col = adaptFrm.vec(colId);
      if (col.isBad()) {
        throw new UnsupportedOperationException(
                "Calculating of H statistics error: column " + adaptFrm.name(colId) + " is missing.");
      }
      if(!col.isNumeric()) {
        throw new UnsupportedOperationException(
                "Calculating of H statistics error: column " + adaptFrm.name(colId) + " is not numeric.");
      }
    }

    int nclasses = this._output.nclasses() > 2 ? this._output.nclasses() : 1;
    SharedTreeSubgraph[][] sharedTreeSubgraphs = new SharedTreeSubgraph[this._parms._ntrees][nclasses];
    for (int i = 0; i < this._parms._ntrees; i++) {
      for (int j = 0; j < nclasses; j++) {
        SharedTreeGraph graph = this.convert(i, this._output.classNames()[j]);
        assert graph.subgraphArray.size() == 1;
        sharedTreeSubgraphs[i][j] = graph.subgraphArray.get(0);
      }
    }

    return FriedmanPopescusH.h(adaptFrm, vars, this._parms._learn_rate, sharedTreeSubgraphs);
  }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy