All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.uplift.UpliftDRF Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree.uplift;

import hex.*;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.*;
import org.apache.log4j.Logger;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import water.H2O;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.C0DChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.PrettyPrint;
import water.util.TwoDimTable;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

public class UpliftDRF extends SharedTree {

    private static final Logger LOG = Logger.getLogger(UpliftDRF.class);

    public enum UpliftMetricType { AUTO, KL, ChiSquared, Euclidean }

    // Called from an http request
    public UpliftDRF(hex.tree.uplift.UpliftDRFModel.UpliftDRFParameters parms) {
        super(parms);
        init(false);
    }

    public UpliftDRF(hex.tree.uplift.UpliftDRFModel.UpliftDRFParameters parms, Key key) {
        super(parms, key);
        init(false);
    }

    public UpliftDRF(hex.tree.uplift.UpliftDRFModel.UpliftDRFParameters parms, Job job) {
        super(parms, job);
        init(false);
    }

    public UpliftDRF(boolean startup_once) {
        super(new hex.tree.uplift.UpliftDRFModel.UpliftDRFParameters(), startup_once);
    }

    @Override
    public boolean haveMojo() {
        return false;
    }

    @Override
    public boolean havePojo() {
        return false;
    }

    @Override
    public ModelCategory[] can_build() {
        return new ModelCategory[]{
                ModelCategory.BinomialUplift
        };
    }
    
    /** Start the DRF training Job on an F/J thread. */
    @Override protected Driver trainModelImpl() { return new UpliftDRFDriver(); }


    @Override public boolean scoreZeroTrees() { return false; }

    @Override
    public boolean providesVarImp() {
        return false; // Currently disabled as it is not straightforward how to attribute contribution of a split 
    }

    /** Initialize the ModelBuilder, validating all arguments and preparing the
     *  training frame.  This call is expected to be overridden in the subclasses
     *  and each subclass will start with "super.init();".  This call is made
     *  by the front-end whenever the GUI is clicked, and needs to be fast;
     *  heavy-weight prep needs to wait for the trainModel() call.
     */
    @Override public void init(boolean expensive) {
        super.init(expensive);
        // Initialize local variables
        if( _parms._mtries < 1 && _parms._mtries != -1 && _parms._mtries != -2 )
            error("_mtries", "mtries must be -1 (converted to sqrt(features)) or -2 (All features) or >= 1 but it is " + _parms._mtries);
        if( _train != null ) {
            int ncols = _train.numCols();
            if( _parms._mtries != -1 && _parms._mtries != -2 && !(1 <= _parms._mtries && _parms._mtries < ncols /*ncols includes the response*/))
                error("_mtries","Computed mtries should be -1 or -2 or in interval [1,"+ncols+"[ but it is " + _parms._mtries);
        }
        if (_parms._sample_rate == 1f && _valid == null)
            warn("_sample_rate", "Sample rate is 100% and no validation dataset. There are no out-of-bag data to compute error estimates on the training data!");
        if (hasOffsetCol())
            error("_offset_column", "Offsets are not yet supported for Uplift DRF.");
        if (hasWeightCol())
            error("_weight_column", "Weights are not yet supported for Uplift DRF.");
        if (hasFoldCol())
            error("_fold_column", "Cross-validation is not yet supported for Uplift DRF.");
        if (_parms._nfolds > 0)
            error("_nfolds", "Cross-validation is not yet supported for Uplift DRF.");
        if (_nclass == 1)
            error("_distribution", "UpliftDRF currently support binomial classification problems only.");
        if (_nclass > 2 || _parms._distribution.equals(DistributionFamily.multinomial)) 
            error("_distribution", "UpliftDRF currently does not support multinomial distribution.");
        if (_parms._treatment_column == null) 
            error("_treatment_column", "The treatment column has to be defined.");
        if (_parms._custom_distribution_func != null)
            error("_custom_distribution_func", "The custom distribution is not yet supported for Uplift DRF.");
        if (_parms._custom_metric_func != null)
            error("_custom_metric_func", "The custom metric is not yet supported for Uplift DRF.");
        if (_parms._stopping_metric != ScoreKeeper.StoppingMetric.AUTO)
            error("_stopping_metric", "The early stopping is not yet supported for Uplift DRF.");
        if (_parms._stopping_rounds != 0)
            error("_stopping_rounds", "The early stopping is not yet supported for Uplift DRF.");
    }
    
    

    // ----------------------
    private class UpliftDRFDriver extends Driver {

        @Override
        protected boolean doOOBScoring() {
            return true;
        }

        @Override protected void initializeModelSpecifics() {
            _mtry_per_tree = Math.max(1, (int) (_parms._col_sample_rate_per_tree * _ncols));
            if (!(1 <= _mtry_per_tree && _mtry_per_tree <= _ncols))
                throw new IllegalArgumentException("Computed mtry_per_tree should be in interval <1," + _ncols + "> but it is " + _mtry_per_tree);
            if (_parms._mtries == -2) { //mtries set to -2 would use all columns in each split regardless of what column has been dropped during train
                _mtry = _ncols;
            } else if (_parms._mtries == -1) {
                _mtry = (isClassifier() ? Math.max((int) Math.sqrt(_ncols), 1) : Math.max(_ncols / 3, 1)); // classification: mtry=sqrt(_ncols), regression: mtry=_ncols/3
            } else {
                _mtry = _parms._mtries;
            }
            if (!(1 <= _mtry && _mtry <= _ncols)) {
                throw new IllegalArgumentException("Computed mtry should be in interval <1," + _ncols + "> but it is " + _mtry);
            }
            new MRTask() {
                @Override public void map(Chunk chks[]) {
                    Chunk cy = chk_resp(chks);
                    for (int i = 0; i < cy._len; i++) {
                        if (cy.isNA(i)) continue;
                        int cls = (int) cy.at8(i);
                        chk_work(chks, cls).set(i, 1L);
                    }
                }
            }.doAll(_train);
        }

        // --------------------------------------------------------------------------
        // Build the next random k-trees representing tid-th tree
        @Override protected boolean buildNextKTrees() {
            // We're going to build K (nclass) trees - each focused on correcting
            // errors for a single class.
            final DTree[] ktrees = new DTree[_nclass];

            // Define a "working set" of leaf splits, from leafs[i] to tree._len for each tree i
            int[] leafs = new int[_nclass];

            // Assign rows to nodes - fill the "NIDs" column(s)
            growTrees(ktrees, leafs, _rand);

            // Move rows into the final leaf rows - fill "Tree" and OUT_BAG_TREES columns and zap the NIDs column
            UpliftCollectPreds cp = new UpliftCollectPreds(ktrees,leafs).doAll(_train,_parms._build_tree_one_node);
            
            // Grow the model by K-trees
            _model._output.addKTrees(ktrees);

            return false; //never stop early
        }

        // Assumes that the "Work" column are filled with horizontalized (0/1) class memberships per row (or copy of regression response)
        private void growTrees(DTree[] ktrees, int[] leafs, Random rand) {
            // Initial set of histograms.  All trees; one leaf per tree (the root
            // leaf); all columns
            DHistogram hcs[][][] = new DHistogram[_nclass][1/*just root leaf*/][_ncols];
            
            // Adjust real bins for the top-levels
            int adj_nbins = Math.max(_parms._nbins_top_level,_parms._nbins);

            // Use for all k-trees the same seed. NOTE: this is only to make a fair
            // view for all k-trees
            long rseed = rand.nextLong();
            // Initially setup as-if an empty-split had just happened
            for (int k = 0; k < _nclass; k++) {
                if (_model._output._distribution[k] != 0) { // Ignore missing classes
                    ktrees[k] = new DTree(_train, _ncols, _mtry, _mtry_per_tree, rseed, _parms);
                    new DTree.UndecidedNode(ktrees[k], -1, DHistogram.initialHist(_train, _ncols, adj_nbins, hcs[k][0], rseed, _parms, getGlobalSplitPointsKeys(), null, false, null), null, null); // The "root" node
                }
            }

            // Sample - mark the lines by putting 'OUT_OF_BAG' into nid() vector
            Sample s = new Sample(ktrees[0], _parms._sample_rate, _parms._sample_rate_per_class).dfork(null,new Frame(vec_nids(_train,0),vec_resp(_train)), _parms._build_tree_one_node).getResult();

            // ----
            // One Big Loop till the ktrees are of proper depth.
            // Adds a layer to the trees each pass.
            int depth=0;
            for( ; depth<_parms._max_depth; depth++ ) {
                hcs = buildLayer(_train, _parms._nbins, ktrees[0], leafs, hcs, _parms._build_tree_one_node);
                
                // If we did not make any new splits, then the tree is split-to-death
                if( hcs == null ) break;
            }
            // Each tree bottomed-out in a DecidedNode; go 1 more level and insert
            // LeafNodes to hold predictions.
            DTree treeTr = ktrees[0];
            ktrees[1] = new DTree(ktrees[0]); // make a deep copy of the tree to assign control prediction to the leaves
            DTree treeCt = ktrees[1]; 
            int leaf = leafs[0] = treeTr.len();
            for (int nid = 0; nid < leaf; nid++) {
                if (treeTr.node(nid) instanceof DTree.DecidedNode) { // Should be the same for treatment and control tree
                    DTree.DecidedNode dnTr = treeTr.decided(nid); // Treatment tree node
                    DTree.DecidedNode dnCt = treeCt.decided(nid); // Control tree node
                    if (dnTr._split == null) { // No decision here, no row should have this NID now
                        if (nid == 0) { // Handle the trivial non-splitting tree
                            DTree.LeafNode lnTr = new DTree.LeafNode(treeTr, -1, 0);
                            lnTr._pred = (float) (_model._output._priorClassDist[1]);
                            DTree.LeafNode lnCt = new DTree.LeafNode(treeCt, -1, 0);
                            lnCt._pred = (float) (_model._output._priorClassDist[0]);
                        }
                        continue;
                    }
                    for (int i = 0; i < dnTr._nids.length; i++) {
                        int cnid = dnTr._nids[i];
                        if (cnid == -1 || // Bottomed out (predictors or responses known constant)
                                treeTr.node(cnid) instanceof DTree.UndecidedNode || // Or chopped off for depth
                                (treeTr.node(cnid) instanceof DTree.DecidedNode &&  // Or not possible to split
                                        ((DTree.DecidedNode) treeTr.node(cnid))._split == null)) {
                            DTree.LeafNode lnTr = new DTree.LeafNode(treeTr, nid);
                            lnTr._pred = (float) dnTr.predTreatment(i);  // Set prediction into the treatment leaf
                            dnTr._nids[i] = lnTr.nid(); // Mark a leaf here for treatment 
                            DTree.LeafNode lnCt = new DTree.LeafNode(treeCt, nid);
                            lnCt._pred = (float) dnCt.predControl(i);  // Set prediction into the control leaf
                            dnCt._nids[i] = lnCt.nid(); // Mark a leaf here for control
                        }
                    }
                } 
            }
        }

        // Collect and write predictions into leaves.
        private class UpliftCollectPreds extends MRTask {
            /* @IN  */ final DTree _trees[]; // Read-only, shared (except at the histograms in the Nodes)
            /* @OUT */ double allRows;    // number of all OOB rows (sampled by this tree)
            
            UpliftCollectPreds(DTree trees[], int leafs[]) { _trees=trees;}
            @Override public void map( Chunk[] chks ) {
                final Chunk    y       = chk_resp(chks); // Response
                final Chunk   oobt     = chk_oobt(chks); // Out-of-bag rows counter over all trees
                final Chunk   weights  = hasWeightCol() ? chk_weight(chks) : new C0DChunk(1, chks[0]._len); // Out-of-bag rows counter over all trees
                // Iterate over all rows
                for( int row=0; row colHeaders = new ArrayList<>();
        List colTypes = new ArrayList<>();
        List colFormat = new ArrayList<>();
        colHeaders.add("Timestamp"); colTypes.add("string"); colFormat.add("%s");
        colHeaders.add("Duration"); colTypes.add("string"); colFormat.add("%s");
        colHeaders.add("Number of Trees"); colTypes.add("long"); colFormat.add("%d");
        colHeaders.add("Training AUUC nbins"); colTypes.add("int"); colFormat.add("%d");
        colHeaders.add("Training AUUC"); colTypes.add("double"); colFormat.add("%.5f");
        colHeaders.add("Training AUUC normalized"); colTypes.add("double"); colFormat.add("%.5f");
        colHeaders.add("Training Qini value"); colTypes.add("double"); colFormat.add("%.5f");
        if (hasCustomMetric) {
            colHeaders.add("Training Custom"); colTypes.add("double"); colFormat.add("%.5f");
        }

        if (_output._validation_metrics != null) {
            colHeaders.add("Validation AUUC nbins"); colTypes.add("int"); colFormat.add("%d");
            colHeaders.add("Validation AUUC"); colTypes.add("double"); colFormat.add("%.5f");
            colHeaders.add("Validation AUUC normalized"); colTypes.add("double"); colFormat.add("%.5f");
            colHeaders.add("Validation Qini value"); colTypes.add("double"); colFormat.add("%.5f");
            if (hasCustomMetric) {
                colHeaders.add("Validation Custom"); colTypes.add("double"); colFormat.add("%.5f");
            }
        }

        int rows = 0;
        for( int i = 0; i<_scored_train.length; i++ ) {
            if (i != 0 && Double.isNaN(_scored_train[i]._AUUC) && (_scored_valid == null || Double.isNaN(_scored_valid[i]._AUUC))) continue;
            rows++;
        }
        TwoDimTable table = new TwoDimTable(
                "Scoring History", null,
                new String[rows],
                colHeaders.toArray(new String[0]),
                colTypes.toArray(new String[0]),
                colFormat.toArray(new String[0]),
                "");
        int row = 0;
        for( int i = 0; i<_scored_train.length; i++ ) {
            if (i != 0 && Double.isNaN(_scored_train[i]._AUUC) && (_scored_valid == null || Double.isNaN(_scored_valid[i]._AUUC))) continue;
            int col = 0;
            DateTimeFormatter fmt = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss");
            table.set(row, col++, fmt.print(_training_time_ms[i]));
            table.set(row, col++, PrettyPrint.msecs(_training_time_ms[i] - job.start_time(), true));
            table.set(row, col++, i);
            ScoreKeeper st = _scored_train[i];
            table.set(row, col++, st._auuc_nbins);
            table.set(row, col++, st._AUUC);
            table.set(row, col++, st._auuc_normalized);
            table.set(row, col++, st._qini);
            if (hasCustomMetric) table.set(row, col++, st._custom_metric);

            if (_output._validation_metrics != null) {
                st = _scored_valid[i];
                table.set(row, col++, st._auuc_nbins);
                table.set(row, col++, st._AUUC);
                table.set(row, col++, st._auuc_normalized);
                table.set(row, col++, st._qini);
                if (hasCustomMetric) table.set(row, col++, st._custom_metric);
            }
            row++;
        }
        return table;
    }

    @Override
    protected UpliftScoreExtension makeScoreExtension() {
        return new UpliftScoreExtension();
    }

    private static class UpliftScoreExtension extends Score.ScoreExtension {
        public UpliftScoreExtension() {
        }

        @Override
        protected double getPrediction(double[] cdist) {
            return cdist[1] - cdist[2];
        }

        @Override
        protected int[] getResponseComplements(SharedTreeModel m) {
            return new int[]{m._output.treatmentIdx()};
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy