All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.drf.OOBScorer Maven / Gradle / Ivy

package hex.tree.drf;

import java.util.Arrays;
import java.util.Random;

import hex.tree.*;
import hex.tree.DTreeScorer;
import water.*;
import water.fvec.Chunk;

/**
 * Computing oob scores over all trees and rows
 * and reconstructing ntree_id, oobt fields in given frame.
 *
 * 

It prepares voter per tree and also marks * rows which were consider out-of-bag.

*/ /* package */ class OOBScorer extends DTreeScorer { /* @IN */ final protected float _rate; public OOBScorer(int ncols, int nclass, float rate, Key[][] treeKeys) { super(ncols,nclass,treeKeys); _rate = rate; } @Override public void map(Chunk[] chks) { double[] data = new double[_ncols]; double [] preds = new double[_nclass+1]; int ntrees = _trees.length; Chunk coobt = chk_oobt(chks); Chunk cys = chk_resp(chks); for( int tidx=0; tidx= _rate || Double.isNaN(cys.atd(row)) ) { // Make a prediction for (int i=0;i<_ncols;i++) data[i] = chks[i].atd(row); Arrays.fill(preds, 0); score0(data, preds, _trees[tidx]); if (_nclass==1) preds[1]=preds[0]; // Only for regression, keep consistency // Write tree predictions for (int c=0;c<_nclass;c++) { // over all class double prediction = preds[1+c]; if (preds[1+c] != 0) { Chunk ctree = chk_tree(chks, c); long count = coobt.at8(row); if (_nclass >= 2) ctree.set(row, (float) (ctree.atd(row)*count + prediction)/(count+1)); //store avg prediction else ctree.set(row, (float) (ctree.atd(row) + prediction)); } } // Mark oob row and store number of trees voting for this row (only for regression) coobt.set(row, coobt.atd(row)+1); } } } } private Random rngForTree(CompressedTree[] ts, int cidx) { return ts[0].rngForChunk(cidx); // k-class set of trees shares the same random number } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy