hex.tree.drf.TreeMeasuresCollector Maven / Gradle / Ivy
package hex.tree.drf;
import java.util.Arrays;
import java.util.Random;
import static hex.genmodel.GenModel.getPrediction;
import hex.tree.CompressedForest;
import hex.tree.CompressedTree;
import hex.tree.SharedTree;
import water.Iced;
import water.MRTask;
import water.fvec.C0DChunk;
import water.fvec.Chunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.ModelUtils;
import static water.util.RandomUtils.getRNG;
/** Score given tree model and preserve errors per tree in form of votes (for classification)
* or MSE (for regression).
*
* This is different from Model.score() function since the MR task
* uses inverse loop: first over all trees and over all rows in chunk.
*/
public class TreeMeasuresCollector extends MRTask {
/* @IN */ final private CompressedForest _cforest;
/* @IN */ final private float _rate;
/* @IN */ final private int _var;
/* @IN */ final private boolean _oob;
/* @IN */ final private int _ncols;
/* @IN */ final private int _nclasses;
/* @IN */ final private boolean _classification;
/* @IN */ final private double _threshold;
final private SharedTree _st;
/* @INOUT */ private final int _ntrees;
/* @OUT */ private double [/*ntrees*/] _votes; // Number of correct votes per tree (for classification only)
/* @OUT */ private double [/*ntrees*/] _nrows; // Number of scored row per tree (for classification/regression)
/* @OUT */ private float[/*ntrees*/] _sse; // Sum of squared errors per tree (for regression only)
/* Intermediate */
private transient CompressedForest.LocalCompressedForest _forest;
private TreeMeasuresCollector(CompressedForest cforest, int nclasses, int ncols, float rate, int variable, double threshold, SharedTree st) {
assert cforest._treeKeys.length > 0;
assert nclasses == cforest._treeKeys[0].length;
_cforest = cforest;
_ncols = ncols;
_rate = rate; _var = variable;
_oob = true; _ntrees = cforest._treeKeys.length;
_nclasses = nclasses;
_classification = (nclasses>1);
_threshold = threshold;
_st = st;
}
@Override
protected void setupLocal() {
_forest = _cforest.fetch();
}
public static class ShuffleTask extends MRTask {
@Override public void map(Chunk ic, Chunk oc) {
if (ic._len==0) return;
// Each vector is shuffled in the same way
Random rng = getRNG(seed(ic.cidx()));
oc.set(0,ic.atd(0));
for (int row=1; row
// Arghhh: expand the vector into double
if (j!=row) oc.set(row, oc.atd(j));
oc.set(j, ic.atd(row));
}
}
public static long seed(int cidx) { return (0xe031e74f321f7e29L + ((long)cidx << 32L)); }
public static Vec shuffle(Vec ivec) {
Vec ovec = ivec.makeZero();
new ShuffleTask().doAll(ivec, ovec);
return ovec;
}
}
@Override public void map(Chunk[] chks) {
double[] data = new double[_ncols];
double[] preds = new double[_nclasses+1];
Chunk cresp = _st.chk_resp(chks);
Chunk weights = _st.hasWeightCol() ? _st.chk_weight(chks) : new C0DChunk(1, chks[0]._len);
int nrows = cresp._len;
int [] oob = new int[2+Math.round((1f-_rate)*nrows*1.2f+0.5f)]; // preallocate
int [] soob = null;
// Prepare output data
_nrows = new double[_ntrees];
_votes = _classification ? new double[_ntrees] : null;
_sse = _classification ? null : new float[_ntrees];
long seedForOob = ShuffleTask.seed(cresp.cidx()); // seed for shuffling oob samples
// Start iteration
for( int tidx=0; tidx<_ntrees; tidx++) { // tree
// OOB RNG for this tree
Random rng = rngForTree(_forest._trees[tidx], cresp.cidx());
// Collect oob rows and permutate them
oob = ModelUtils.sampleOOBRows(nrows, _rate, rng, oob); // reuse use the same array for sampling
int oobcnt = oob[0]; // Get number of sample rows
if (_var>=0) {
if (soob==null || soob.length < oobcnt) soob = new int[oobcnt];
ArrayUtils.shuffleArray(oob, oobcnt, soob, seedForOob, 1); // Shuffle array and copy results into soob
}
for(int j = 1; j < 1+oobcnt; j++) {
int row = oob[j];
double w = weights.atd(row);
if (cresp.isNA(row)) continue; // we cannot deal with this row anyhow
if (w==0) continue;
// Do scoring:
// - prepare a row data
for (int i=0;i<_ncols;i++) data[i] = chks[i].atd(row); // 1+i - one free is expected by prediction
// - permute variable
if (_var>=0) data[_var] = chks[_var].atd(soob[j-1]);
else assert soob==null;
// - score data
Arrays.fill(preds, 0);
// - score only the tree
_forest.scoreTree(data, preds, tidx);
// - derive a prediction
if (_classification) {
int pred = getPrediction(preds, null /*FIXME: should use model's _priorClassDistribution*/, data, _threshold);
int actu = (int) cresp.at8(row);
// assert preds[pred] > 0 : "There should be a vote for at least one class.";
// - collect only correct votes
if (pred == actu) _votes[tidx]+=w;
} else { /* regression */
double pred = preds[0]; // Important!
double actu = cresp.atd(row);
_sse[tidx] += (actu-pred)*(actu-pred);
}
// - collect rows which were used for voting
_nrows[tidx]+=w;
//if (_var<0) System.err.println("VARIMP OOB row: " + (cresp._start+row) + " : " + Arrays.toString(data) + " tree/actu: " + pred + "/" + actu);
}
}
}
@Override public void reduce( TreeMeasuresCollector t ) { ArrayUtils.add(_votes,t._votes); ArrayUtils.add(_nrows, t._nrows); ArrayUtils.add(_sse, t._sse); }
public TreeVotes resultVotes() { return new TreeVotes(_votes, _nrows, _ntrees); }
public TreeSSE resultSSE () { return new TreeSSE (_sse, _nrows, _ntrees); }
private Random rngForTree(CompressedTree[] ts, int cidx) {
return _oob ? ts[0].rngForChunk(cidx) : new DummyRandom(); // k-class set of trees shares the same random number
}
/* For bulk scoring
public static TreeVotes collect(TreeModel tmodel, Frame f, int ncols, float rate, int variable) {
CompressedTree[][] trees = new CompressedTree[tmodel.ntrees()][];
for (int tidx = 0; tidx < tmodel.ntrees(); tidx++) trees[tidx] = tmodel.ctree(tidx);
return new TreeVotesCollector(trees, tmodel.nclasses(), ncols, rate, variable).doAll(f).result();
}*/
private static final class DummyRandom extends Random {
@Override public final float nextFloat() { return 1.0f; }
}
/** A simple holder for set of different tree measurements. */
public static abstract class TreeMeasures extends Iced {
/** Actual number of trees which votes are stored in this object */
protected int _ntrees;
/** Number of processed row per tree. */
protected double[/*ntrees*/] _nrows;
public TreeMeasures(int initialCapacity) { _nrows = new double[initialCapacity]; }
public TreeMeasures(double[] nrows, int ntrees) { _nrows = nrows; _ntrees = ntrees;}
/** Returns number of rows which were used during voting per individual tree. */
public final double[] nrows() { return _nrows; }
/** Returns number of voting predictors */
public final int npredictors() { return _ntrees; }
/** Returns a list of accuracies per tree. */
public abstract double accuracy(int tidx);
public final double[] accuracy() {
double[] r = new double[_ntrees];
// Average of all trees
for (int tidx=0; tidx<_ntrees; tidx++) r[tidx] = accuracy(tidx);
return r;
}
/** Compute variable importance with respect to given votes.
* The given {@link T} object represents correct votes.
* This object represents votes over shuffled data.
*
* @param right individual tree measurements performed over not shuffled data.
* @return computed importance and standard deviation
*/
public abstract double[/*2*/] imp(T right);
public abstract T append(T t);
}
/** A class holding tree votes. */
public static class TreeVotes extends TreeMeasures {
/** Number of correct votes per tree */
private double[/*ntrees*/] _votes;
public TreeVotes(int initialCapacity) {
super(initialCapacity);
_votes = new double[initialCapacity];
}
public TreeVotes(double[] votes, double[] nrows, int ntrees) {
super(nrows, ntrees);
_votes = votes;
}
/** Returns number of positive votes per tree. */
public final double[] votes() { return _votes; }
/** Returns accuracy per individual trees. */
@Override public final double accuracy(int tidx) {
assert tidx < _nrows.length && tidx < _votes.length;
return (_votes[tidx]) / _nrows[tidx];
}
/** Compute variable importance with respect to given votes.
* The given {@link TreeVotes} object represents correct votes.
* This object represents votes over shuffled data.
*
* @param right individual tree voters performed over not shuffled data.
* @return computed importance and standard deviation
*/
@Override public final double[/*2*/] imp(TreeVotes right) {
assert npredictors() == right.npredictors();
int ntrees = npredictors();
double imp = 0;
double sd = 0;
// Over all trees
for (int tidx = 0; tidx < ntrees; tidx++) {
assert right.nrows()[tidx] == nrows()[tidx];
double delta = ((double) (right.votes()[tidx] - votes()[tidx])) / nrows()[tidx];
imp += delta;
sd += delta * delta;
}
double av = imp / ntrees;
double csd = Math.sqrt( (sd/ntrees - av*av) / ntrees );
return new double[] { av, csd};
}
/** Append a tree votes to a list of trees. */
public TreeVotes append(double rightVotes, double allRows) {
assert _votes.length > _ntrees && _votes.length == _nrows.length : "TreeVotes inconsistency!";
_votes[_ntrees] = rightVotes;
_nrows[_ntrees] = allRows;
_ntrees++;
return this;
}
@Override public TreeVotes append(final TreeVotes tv) {
for (int i=0; i {
/** SSE per tree */
private float[/*ntrees*/] _sse;
public TreeSSE(int initialCapacity) {
super(initialCapacity);
_sse = new float[initialCapacity];
}
public TreeSSE(float[] sse, double[] nrows, int ntrees) {
super(nrows, ntrees);
_sse = sse;
}
@Override public double accuracy(int tidx) {
return _sse[tidx] / _nrows[tidx];
}
@Override public double[] imp(TreeSSE right) {
assert npredictors() == right.npredictors();
int ntrees = npredictors();
double imp = 0;
double sd = 0;
// Over all trees
for (int tidx = 0; tidx < ntrees; tidx++) {
assert right.nrows()[tidx] == nrows()[tidx]; // check that we iterate over same OOB rows
double delta = ((double) (_sse[tidx] - right._sse[tidx])) / nrows()[tidx];
imp += delta;
sd += delta * delta;
}
double av = imp / ntrees;
double csd = Math.sqrt( (sd/ntrees - av*av) / ntrees );
return new double[] { av, csd };
}
@Override public TreeSSE append(TreeSSE t) {
for (int i=0; i _ntrees && _sse.length == _nrows.length : "TreeVotes inconsistency!";
_sse [_ntrees] = sse;
_nrows[_ntrees] = allRows;
_ntrees++;
return this;
}
}
public static TreeVotes asVotes(TreeMeasures tm) { return (TreeVotes) tm; }
public static TreeSSE asSSE (TreeMeasures tm) { return (TreeSSE) tm; }
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy