All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.drf.TreeMeasuresCollector Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree.drf;


import java.util.Arrays;
import java.util.Random;

import static hex.genmodel.GenModel.getPrediction;

import hex.tree.CompressedForest;
import hex.tree.CompressedTree;

import hex.tree.SharedTree;
import water.Iced;
import water.MRTask;
import water.fvec.C0DChunk;
import water.fvec.Chunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.ModelUtils;
import static water.util.RandomUtils.getRNG;

/** Score given tree model and preserve errors per tree in form of votes (for classification)
 * or MSE (for regression).
 *
 * This is different from Model.score() function since the MR task
 * uses inverse loop: first over all trees and over all rows in chunk.
 */
public class TreeMeasuresCollector extends MRTask {
  /* @IN */ final private CompressedForest _cforest;
  /* @IN */ final private float     _rate;
  /* @IN */ final private int       _var;
  /* @IN */ final private boolean   _oob;
  /* @IN */ final private int       _ncols;
  /* @IN */ final private int       _nclasses;
  /* @IN */ final private boolean   _classification;
  /* @IN */ final private double   _threshold;
  final private SharedTree _st;

  /* @INOUT */ private final int _ntrees;
  /* @OUT */ private double [/*ntrees*/] _votes; // Number of correct votes per tree (for classification only)
  /* @OUT */ private double [/*ntrees*/] _nrows; // Number of scored row per tree (for classification/regression)
  /* @OUT */ private float[/*ntrees*/] _sse;   // Sum of squared errors per tree (for regression only)

  /* Intermediate */
  private transient CompressedForest.LocalCompressedForest _forest;

  private TreeMeasuresCollector(CompressedForest cforest, int nclasses, int ncols, float rate, int variable, double threshold, SharedTree st) {
    assert cforest._treeKeys.length > 0;
    assert nclasses == cforest._treeKeys[0].length;
    _cforest = cforest;
    _ncols = ncols;
    _rate = rate; _var = variable;
    _oob = true; _ntrees = cforest._treeKeys.length;
    _nclasses = nclasses;
    _classification = (nclasses>1);
    _threshold = threshold;
    _st = st;
  }

  @Override
  protected void setupLocal() {
    _forest = _cforest.fetch();
  }

  public static class ShuffleTask extends MRTask {

    @Override public void map(Chunk ic, Chunk oc) {
      if (ic._len==0) return;
      // Each vector is shuffled in the same way
      Random rng = getRNG(seed(ic.cidx()));
      oc.set(0,ic.atd(0));
      for (int row=1; row
        // Arghhh: expand the vector into double
        if (j!=row) oc.set(row, oc.atd(j));
        oc.set(j, ic.atd(row));
      }
    }

    public static long seed(int cidx) { return (0xe031e74f321f7e29L + ((long)cidx << 32L)); }

    public static Vec shuffle(Vec ivec) {
      Vec ovec = ivec.makeZero();
      new ShuffleTask().doAll(ivec, ovec);
      return ovec;
    }
  }

  @Override public void map(Chunk[] chks) {
    double[] data = new double[_ncols];
    double[] preds = new double[_nclasses+1];
    Chunk cresp = _st.chk_resp(chks);
    Chunk weights = _st.hasWeightCol() ? _st.chk_weight(chks) : new C0DChunk(1, chks[0]._len);
    int   nrows = cresp._len;
    int   [] oob = new int[2+Math.round((1f-_rate)*nrows*1.2f+0.5f)]; // preallocate
    int   [] soob = null;

    // Prepare output data
    _nrows      = new double[_ntrees];
    _votes      = _classification ? new double[_ntrees] : null;
    _sse        = _classification ? null : new float[_ntrees];
    long seedForOob = ShuffleTask.seed(cresp.cidx()); // seed for shuffling oob samples
    // Start iteration
    for( int tidx=0; tidx<_ntrees; tidx++) { // tree
      // OOB RNG for this tree
      Random rng = rngForTree(_forest._trees[tidx], cresp.cidx());
      // Collect oob rows and permutate them
      oob = ModelUtils.sampleOOBRows(nrows, _rate, rng, oob); // reuse use the same array for sampling
      int oobcnt = oob[0]; // Get number of sample rows
      if (_var>=0) {
        if (soob==null || soob.length < oobcnt) soob = new int[oobcnt];
        ArrayUtils.shuffleArray(oob, oobcnt, soob, seedForOob, 1); // Shuffle array and copy results into soob
      }
      for(int j = 1; j < 1+oobcnt; j++) {
        int row = oob[j];
        double w = weights.atd(row);
        if (cresp.isNA(row)) continue; // we cannot deal with this row anyhow
        if (w==0) continue;
        // Do scoring:
        // - prepare a row data
        for (int i=0;i<_ncols;i++) data[i] = chks[i].atd(row); // 1+i - one free is expected by prediction
        // - permute variable
        if (_var>=0) data[_var] = chks[_var].atd(soob[j-1]);
        else assert soob==null;
        // - score data
        Arrays.fill(preds, 0);
        // - score only the tree
        _forest.scoreTree(data, preds, tidx);
        // - derive a prediction
        if (_classification) {
          int pred = getPrediction(preds, null /*FIXME: should use model's _priorClassDistribution*/, data, _threshold);
          int actu = (int) cresp.at8(row);
          // assert preds[pred] > 0 : "There should be a vote for at least one class.";
          // - collect only correct votes
          if (pred == actu) _votes[tidx]+=w;
        } else { /* regression */
          double pred = preds[0]; // Important!
          double actu = cresp.atd(row);
          _sse[tidx] += (actu-pred)*(actu-pred);
        }
        // - collect rows which were used for voting
        _nrows[tidx]+=w;
        //if (_var<0) System.err.println("VARIMP OOB row: " + (cresp._start+row) + " : " + Arrays.toString(data) + " tree/actu: " + pred + "/" + actu);
      }
    }
  }

  @Override public void reduce( TreeMeasuresCollector t ) { ArrayUtils.add(_votes,t._votes); ArrayUtils.add(_nrows, t._nrows); ArrayUtils.add(_sse, t._sse); }

  public TreeVotes resultVotes() { return new TreeVotes(_votes, _nrows, _ntrees); }
  public TreeSSE   resultSSE  () { return new TreeSSE  (_sse,   _nrows, _ntrees); }

  private Random rngForTree(CompressedTree[] ts, int cidx) {
    return _oob ? ts[0].rngForChunk(cidx) : new DummyRandom(); // k-class set of trees shares the same random number
  }

  /* For bulk scoring
  public static TreeVotes collect(TreeModel tmodel, Frame f, int ncols, float rate, int variable) {
    CompressedTree[][] trees = new CompressedTree[tmodel.ntrees()][];
    for (int tidx = 0; tidx < tmodel.ntrees(); tidx++) trees[tidx] = tmodel.ctree(tidx);
    return new TreeVotesCollector(trees, tmodel.nclasses(), ncols, rate, variable).doAll(f).result();
  }*/

  private static final class DummyRandom extends Random {
    @Override public final float nextFloat() { return 1.0f; }
  }

  /** A simple holder for set of different tree measurements. */
  public static abstract class TreeMeasures extends Iced {
    /** Actual number of trees which votes are stored in this object */
    protected int _ntrees;
    /** Number of processed row per tree. */
    protected double[/*ntrees*/]   _nrows;

    public TreeMeasures(int initialCapacity) { _nrows = new double[initialCapacity]; }
    public TreeMeasures(double[] nrows, int ntrees) { _nrows = nrows; _ntrees = ntrees;}
    /** Returns number of rows which were used during voting per individual tree. */
    public final double[] nrows() { return _nrows; }
    /** Returns number of voting predictors */
    public final int    npredictors() { return _ntrees; }
    /** Returns a list of accuracies per tree. */
    public abstract double accuracy(int tidx);
    public final double[] accuracy() {
      double[] r = new double[_ntrees];
      // Average of all trees
      for (int tidx=0; tidx<_ntrees; tidx++) r[tidx] = accuracy(tidx);
      return r;
    }
    /** Compute variable importance with respect to given votes.
     * The given {@link T} object represents correct votes.
     * This object represents votes over shuffled data.
     *
     * @param right individual tree measurements performed over not shuffled data.
     * @return computed importance and standard deviation
     */
    public abstract double[/*2*/] imp(T right);

    public abstract T append(T t);
  }

  /** A class holding tree votes. */
  public static class TreeVotes extends TreeMeasures {
    /** Number of correct votes per tree */
    private double[/*ntrees*/]   _votes;

    public TreeVotes(int initialCapacity) {
      super(initialCapacity);
      _votes = new double[initialCapacity];
    }
    public TreeVotes(double[] votes, double[] nrows, int ntrees) {
      super(nrows, ntrees);
      _votes = votes;
    }
    /** Returns number of positive votes per tree. */
    public final double[] votes() { return _votes; }

    /** Returns accuracy per individual trees. */
    @Override public final double accuracy(int tidx)  {
      assert tidx < _nrows.length && tidx < _votes.length;
      return (_votes[tidx]) / _nrows[tidx];
    }

    /** Compute variable importance with respect to given votes.
     * The given {@link TreeVotes} object represents correct votes.
     * This object represents votes over shuffled data.
     *
     * @param right individual tree voters performed over not shuffled data.
     * @return computed importance and standard deviation
     */
    @Override public final double[/*2*/] imp(TreeVotes right) {
      assert npredictors() == right.npredictors();
      int ntrees = npredictors();
      double imp = 0;
      double sd  = 0;
      // Over all trees
      for (int tidx = 0; tidx < ntrees; tidx++) {
        assert right.nrows()[tidx] == nrows()[tidx];
        double delta = ((double) (right.votes()[tidx] - votes()[tidx])) / nrows()[tidx];
        imp += delta;
        sd  += delta * delta;
      }
      double av = imp / ntrees;
      double csd = Math.sqrt( (sd/ntrees - av*av) / ntrees );
      return new double[] { av, csd};
    }

    /** Append a tree votes to a list of trees. */
    public TreeVotes append(double rightVotes, double allRows) {
      assert _votes.length > _ntrees && _votes.length == _nrows.length : "TreeVotes inconsistency!";
      _votes[_ntrees] = rightVotes;
      _nrows[_ntrees] = allRows;
      _ntrees++;
      return this;
    }

    @Override public TreeVotes append(final TreeVotes tv) {
      for (int i=0; i {
    /** SSE per tree */
    private float[/*ntrees*/]   _sse;

    public TreeSSE(int initialCapacity) {
      super(initialCapacity);
      _sse = new float[initialCapacity];
    }
    public TreeSSE(float[] sse, double[] nrows, int ntrees) {
      super(nrows, ntrees);
      _sse = sse;
    }
    @Override public double accuracy(int tidx) {
      return _sse[tidx] / _nrows[tidx];
    }
    @Override public double[] imp(TreeSSE right) {
      assert npredictors() == right.npredictors();
      int ntrees = npredictors();
      double imp = 0;
      double sd  = 0;
      // Over all trees
      for (int tidx = 0; tidx < ntrees; tidx++) {
        assert right.nrows()[tidx] == nrows()[tidx]; // check that we iterate over same OOB rows
        double delta = ((double) (_sse[tidx] - right._sse[tidx])) / nrows()[tidx];
        imp += delta;
        sd  += delta * delta;
      }
      double av = imp / ntrees;
      double csd = Math.sqrt( (sd/ntrees - av*av) / ntrees );
      return new double[] { av, csd };
    }
    @Override public TreeSSE append(TreeSSE t) {
      for (int i=0; i _ntrees && _sse.length == _nrows.length : "TreeVotes inconsistency!";
      _sse  [_ntrees] = sse;
      _nrows[_ntrees] = allRows;
      _ntrees++;
      return this;
    }
  }

  public static TreeVotes asVotes(TreeMeasures tm) { return (TreeVotes) tm; }
  public static TreeSSE   asSSE  (TreeMeasures tm) { return (TreeSSE)   tm; }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy