hex.tree.ReconstructTreeState Maven / Gradle / Ivy
package hex.tree;
import java.util.Arrays;
import java.util.Random;
import water.fvec.C0DChunk;
import water.fvec.Chunk;
/**
* Computing oob scores over all trees and rows
* and reconstructing ntree_id, oobt
fields in given frame.
*
* It prepares voter per tree and also marks
* rows which were consider out-of-bag.
*/
/* package */ public class ReconstructTreeState extends DTreeScorer {
/* @IN */ final protected double _rate;
/* @IN */ final protected boolean _OOBEnabled;
public ReconstructTreeState(int ncols, int nclass, SharedTree st, double rate, CompressedForest cforest, boolean oob) {
super(ncols,nclass,st,cforest);
_rate = rate;
_OOBEnabled = oob;
}
@Override public void map(Chunk[] chks) {
double[] data = new double[_ncols];
double [] preds = new double[_nclass+1];
int ntrees = ntrees();
Chunk weight = _st.hasWeightCol() ? _st.chk_weight(chks) : new C0DChunk(1, chks[0]._len);
Chunk oobt = _st.chk_oobt(chks);
Chunk resp = _st.chk_resp(chks);
for( int tidx=0; tidx= _rate;
if( !_OOBEnabled || rowIsOOB) {
// Make a prediction
for (int i=0;i<_ncols;i++) data[i] = chks[i].atd(row);
Arrays.fill(preds, 0);
score0(data, preds, tidx);
if (_nclass==1) preds[1]=preds[0]; // Only for regression, keep consistency
// Write tree predictions
for (int c=0;c<_nclass;c++) { // over all class
double prediction = preds[1+c];
if (preds[1+c] != 0) {
Chunk ctree = _st.chk_tree(chks, c);
double wcount = oobt.atd(row);
if (_OOBEnabled && _nclass >= 2)
ctree.set(row, (float) (ctree.atd(row)*wcount + prediction)/(wcount+w)); //store avg prediction
else
ctree.set(row, (float) (ctree.atd(row) + prediction));
}
}
// Mark oob row and store number of trees voting for this row
if (rowIsOOB)
oobt.set(row, oobt.atd(row)+w);
}
}
}
_st = null;
}
private Random rngForTree(CompressedTree[] ts, int cidx) {
return ts[0].rngForChunk(cidx); // k-class set of trees shares the same random number
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy