All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.uplift.UpliftDRFModel Maven / Gradle / Ivy

package hex.tree.uplift;

import hex.*;
import hex.tree.SharedTreeModel;
import hex.tree.SharedTreeModelWithContributions;
import hex.util.EffectiveParametersUtils;
import water.H2O;
import water.Key;

public class UpliftDRFModel extends SharedTreeModel {

    public static class UpliftDRFParameters extends SharedTreeModel.SharedTreeParameters {
        public String algoName() { return "UpliftDRF"; }
        public String fullName() { return "Uplift Distributed Random Forest"; }
        public String javaName() { return UpliftDRFModel.class.getName(); }
        public boolean _binomial_double_trees = false;
        

        public enum UpliftMetricType { AUTO, KL, ChiSquared, Euclidean }
        public UpliftMetricType _uplift_metric = UpliftMetricType.AUTO;

        public int _mtries = -2; //number of columns to use per split. default depeonds on the algorithm and problem (classification/regression)

        public UpliftDRFParameters() {
            super();
            // Set Uplift DRF specific defaults (can differ from SharedTreeModel's defaults)
            _max_depth = 20;
            _min_rows = 1;
            _treatment_column = "treatment";
        }

        @Override
        public long progressUnits() {
            return _ntrees*2;
        }
    }

    public static class UpliftDRFOutput extends SharedTreeModelWithContributions.SharedTreeOutput {
        public UpliftDRFOutput( UpliftDRF b) { super(b); }

        @Override
        public ModelCategory getModelCategory() {
            return ModelCategory.BinomialUplift;
        }

        @Override
        public boolean isBinomialClassifier() {
            return false;
        }
    }

    public UpliftDRFModel(Key selfKey, UpliftDRFParameters parms, UpliftDRFOutput output ) {
        super(selfKey, parms, output);
    }

    @Override
    public void initActualParamValues() {
        super.initActualParamValues();
        EffectiveParametersUtils.initHistogramType(_parms);
        EffectiveParametersUtils.initCategoricalEncoding(_parms, Parameters.CategoricalEncodingScheme.Enum);
        EffectiveParametersUtils.initUpliftMetric(_parms);
    }

    @Override public boolean binomialOpt() { return false; }
    
    /** Bulk scoring API for one row.  Chunks are all compatible with the model,
     *  and expect the last Chunks are for the final distribution and prediction.
     *  Default method is to just load the data into the tmp array, then call
     *  subclass scoring logic. */
    @Override protected double[] score0(double[] data, double[] preds, double offset, int ntrees) {
        super.score0(data, preds, offset, ntrees);
        int N = _output._ntrees;
        preds[1] /= N;
        preds[2] /= N;
        preds[0] = preds[1] - preds[2];
        return preds;
    }

    @Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        if (_output.getModelCategory() == ModelCategory.BinomialUplift) {
            return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain);
        }
        throw H2O.unimpl();
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy