All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.schemas.XGBoostV3 Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.schemas;

import hex.tree.xgboost.XGBoost;
import hex.tree.xgboost.XGBoostModel.XGBoostParameters;
import water.api.API;
import water.api.schemas3.ModelParametersSchemaV3;


public class XGBoostV3 extends ModelBuilderSchema {

  public static final class XGBoostParametersV3 extends ModelParametersSchemaV3 {
    static public String[] fields = new String[] {
        "model_id",
        "training_frame",
        "validation_frame",
        "nfolds",
        "keep_cross_validation_predictions",
        "keep_cross_validation_fold_assignment",
        "score_each_iteration",
        "fold_assignment",
        "fold_column",
        "response_column",
        "ignored_columns",
        "ignore_const_cols",
        "offset_column",
        "weights_column",
        "stopping_rounds",
        "stopping_metric",
        "stopping_tolerance",
        "max_runtime_secs",
        "seed",
        "distribution",
        "tweedie_power",
        "categorical_encoding",
        "quiet_mode",

        // model specific
        "ntrees",
        "max_depth",
        "min_rows", "min_child_weight",
        "learn_rate", "eta",
//        "learn_rate_annealing", //disabled for now

        "sample_rate", "subsample",
        "col_sample_rate", "colsample_bylevel",
        "col_sample_rate_per_tree", "colsample_bytree",
        "max_abs_leafnode_pred", "max_delta_step",

        "score_tree_interval",
        "min_split_improvement", "gamma",

        //lightgbm only
        "max_bins",
        "max_leaves",
        "min_sum_hessian_in_leaf",
        "min_data_in_leaf",

        //dart
        "sample_type",
        "normalize_type",
        "rate_drop",
        "one_drop",
        "skip_drop",

        //xgboost only
        "tree_method",
        "grow_policy",
        "booster",
        "reg_lambda",
        "reg_alpha",
        "dmatrix_type",
        "backend",
        "gpu_id"
    };

    @API(help="(same as n_estimators) Number of trees.", gridable = true)
    public int ntrees;
    @API(help="(same as ntrees) Number of trees.", gridable = true)
    public int n_estimators;

    @API(help="Maximum tree depth.", gridable = true)
    public int max_depth;

    @API(help="(same as min_child_weight) Fewest allowed (weighted) observations in a leaf.", gridable = true)
    public double min_rows;
    @API(help="(same as min_rows) Fewest allowed (weighted) observations in a leaf.", gridable = true)
    public double min_child_weight;

    @API(help="(same as eta) Learning rate (from 0.0 to 1.0)", gridable = true)
    public double learn_rate;
    @API(help="(same as learn_rate) Learning rate (from 0.0 to 1.0)", gridable = true)
    public double eta;

//    @API(help="Scale the learning rate by this factor after each tree (e.g., 0.99 or 0.999) ", level = API.Level.secondary, gridable = true)
//    public double learn_rate_annealing;

    @API(help = "(same as subsample) Row sample rate per tree (from 0.0 to 1.0)", gridable = true)
    public double sample_rate;
    @API(help = "(same as sample_rate) Row sample rate per tree (from 0.0 to 1.0)", gridable = true)
    public double subsample;

    @API(help="(same as colsample_bylevel) Column sample rate (from 0.0 to 1.0)", gridable = true)
    public double col_sample_rate;
    @API(help="(same as col_sample_rate) Column sample rate (from 0.0 to 1.0)", gridable = true)
    public double colsample_bylevel;

    @API(help = "(same as colsample_bytree) Column sample rate per tree (from 0.0 to 1.0)", level = API.Level.secondary, gridable = true)
    public double col_sample_rate_per_tree;
    @API(help = "(same as col_sample_rate_per_tree) Column sample rate per tree (from 0.0 to 1.0)", level = API.Level.secondary, gridable = true)
    public double colsample_bytree;

    @API(help="(same as max_delta_step) Maximum absolute value of a leaf node prediction", level = API.Level.expert, gridable = true)
    public float max_abs_leafnode_pred;
    @API(help="(same as max_abs_leafnode_pred) Maximum absolute value of a leaf node prediction", level = API.Level.expert, gridable = true)
    public float max_delta_step;

    @API(help="Score the model after every so many trees. Disabled if set to 0.", level = API.Level.secondary, gridable = false)
    public int score_tree_interval;

    @API(help = "Seed for pseudo random number generator (if applicable)", gridable = true)
    public long seed;

    @API(help="(same as gamma) Minimum relative improvement in squared error reduction for a split to happen", level = API.Level.secondary, gridable = true)
    public float min_split_improvement;
    @API(help="(same as min_split_improvement) Minimum relative improvement in squared error reduction for a split to happen", level = API.Level.secondary, gridable = true)
    public float gamma;

    @API(help = "For tree_method=hist only: maximum number of bins", level = API.Level.expert, gridable = true)
    public int max_bins;

    @API(help = "For tree_method=hist only: maximum number of leaves", level = API.Level.secondary, gridable = true)
    public int max_leaves;

    @API(help = "For tree_method=hist only: the mininum sum of hessian in a leaf to keep splitting", level = API.Level.expert, gridable = true)
    public float min_sum_hessian_in_leaf;

    @API(help = "For tree_method=hist only: the mininum data in a leaf to keep splitting", level = API.Level.expert, gridable = true)
    public float min_data_in_leaf;

    @API(help="Tree method", values = { "auto", "exact", "approx", "hist"}, level = API.Level.secondary, gridable = true)
    public XGBoostParameters.TreeMethod tree_method;

    @API(help="Grow policy - depthwise is standard GBM, lossguide is LightGBM", values = { "depthwise", "lossguide"}, level = API.Level.secondary, gridable = true)
    public XGBoostParameters.GrowPolicy grow_policy;

    @API(help="Booster type", values = { "gbtree", "gblinear", "dart"}, level = API.Level.expert, gridable = true)
    public XGBoostParameters.Booster booster;

    @API(help = "L2 regularization", level = API.Level.expert, gridable = true)
    public float reg_lambda;

    @API(help = "L1 regularization", level = API.Level.expert, gridable = true)
    public float reg_alpha;
    
    // no special support for missing value right now - missing value are handled by XGBoost internally
    //@API(help="Missing Value Handling", values = { "mean_imputation", "skip"}, level = API.Level.expert, gridable = true)
    //public XGBoostParameters.MissingValuesHandling missing_values_handling;

    @API(help="Enable quiet mode", level = API.Level.expert, gridable = false)
    public boolean quiet_mode;

    @API(help="For booster=dart only: sample_type", values = { "uniform", "weighted"}, level = API.Level.expert, gridable = true)
    public XGBoostParameters.DartSampleType sample_type;

    @API(help="For booster=dart only: normalize_type", values = { "tree", "forest"}, level = API.Level.expert, gridable = true)
    public XGBoostParameters.DartNormalizeType normalize_type;

    @API(help="For booster=dart only: rate_drop (0..1)", level = API.Level.expert, gridable = true)
    public float rate_drop;

    @API(help="For booster=dart only: one_drop", level = API.Level.expert, gridable = true)
    public boolean one_drop;

    @API(help="For booster=dart only: skip_drop (0..1)", level = API.Level.expert, gridable = true)
    public float skip_drop;

    @API(help="Type of DMatrix. For sparse, NAs and 0 are treated equally.", values = { "auto", "dense", "sparse" }, level = API.Level.secondary, gridable = true)
    public XGBoostParameters.DMatrixType dmatrix_type;

    @API(help="Backend. By default (auto), a GPU is used if available.", values = { "auto", "gpu", "cpu" }, level = API.Level.expert, gridable = true)
    public XGBoostParameters.Backend backend;

    @API(help="Which GPU to use. ", level = API.Level.expert, gridable = false)
    public int gpu_id;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy