hex.tree.SharedTreeModel Maven / Gradle / Ivy
package hex.tree;
import hex.*;
import hex.genmodel.CategoricalEncoding;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.glm.GLMModel;
import hex.util.LinearAlgebraUtils;
import org.apache.log4j.Logger;
import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import static hex.genmodel.GenModel.createAuxKey;
import static hex.genmodel.algos.tree.SharedTreeMojoModel.__INTERNAL_MAX_TREE_DEPTH;
import static hex.tree.SharedTree.createModelSummaryTable;
public abstract class SharedTreeModel<
M extends SharedTreeModel,
P extends SharedTreeModel.SharedTreeParameters,
O extends SharedTreeModel.SharedTreeOutput
> extends Model implements Model.LeafNodeAssignment, Model.GetMostImportantFeatures, Model.FeatureFrequencies, Model.UpdateAuxTreeWeights {
private static final Logger LOG = Logger.getLogger(SharedTreeModel.class);
@Override
public String[] getMostImportantFeatures(int n) {
if (_output == null) return null;
TwoDimTable vi = _output._variable_importances;
if (vi==null) return null;
n = Math.min(n, vi.getRowHeaders().length);
String[] res = new String[n];
System.arraycopy(vi.getRowHeaders(), 0, res, 0, n);
return res;
}
@Override public ToEigenVec getToEigenVec() { return LinearAlgebraUtils.toEigen; }
public abstract static class SharedTreeParameters extends Model.Parameters implements Model.GetNTrees, CalibrationHelper.ParamsWithCalibration {
public int _ntrees=50; // Number of trees in the final model. Grid Search, comma sep values:50,100,150,200
public int _max_depth = 5; // Maximum tree depth. Grid Search, comma sep values:5,7
public double _min_rows = 10; // Fewest allowed observations in a leaf (in R called 'nodesize'). Grid Search, comma sep values
public int _nbins = 20; // Numerical (real/int) cols: Build a histogram of this many bins, then split at the best point
public int _nbins_cats = 1024; // Categorical (factor) cols: Build a histogram of this many bins, then split at the best point
public double _min_split_improvement = 1e-5; // Minimum relative improvement in squared error reduction for a split to happen
public enum HistogramType {
AUTO, UniformAdaptive, Random, QuantilesGlobal, RoundRobin, UniformRobust;
public static HistogramType[] ROUND_ROBIN_CANDIDATES = {
AUTO, // Note: the inclusion of AUTO means UniformAdaptive has effectively higher chance of being used
UniformAdaptive, Random, QuantilesGlobal
};
}
public HistogramType _histogram_type = HistogramType.AUTO; // What type of histogram to use for finding optimal split points
public double _r2_stopping = Double.MAX_VALUE; // Stop when the r^2 metric equals or exceeds this value
public int _nbins_top_level = 1<<10; //hardcoded maximum top-level number of bins for real-valued columns
public boolean _build_tree_one_node = false;
public int _score_tree_interval = 0; // score every so many trees (no matter what)
public int _initial_score_interval = 4000; //Adding this parameter to take away the hard coded value of 4000 for scoring the first 4 secs
public int _score_interval = 4000; //Adding this parameter to take away the hard coded value of 4000 for scoring each iteration every 4 secs
public double _sample_rate = 0.632; //fraction of rows to sample for each tree
public double[] _sample_rate_per_class; //fraction of rows to sample for each tree, per class
public boolean useRowSampling() {
return _sample_rate < 1 || _sample_rate_per_class != null;
}
// Platt scaling (by default)
public boolean _calibrate_model;
public Key _calibration_frame;
public CalibrationHelper.CalibrationMethod _calibration_method = CalibrationHelper.CalibrationMethod.AUTO;
@Override public long progressUnits() { return _ntrees + (_histogram_type==HistogramType.QuantilesGlobal || _histogram_type==HistogramType.RoundRobin ? 1 : 0); }
public double _col_sample_rate_change_per_level = 1.0f; //relative change of the column sampling rate for every level
public double _col_sample_rate_per_tree = 1.0f; //fraction of columns to sample for each tree
public boolean useColSampling() {
return _col_sample_rate_change_per_level != 1.0f || _col_sample_rate_per_tree != 1.0f;
}
public boolean isStochastic() {
return useRowSampling() || useColSampling();
}
public boolean _parallel_main_model_building = false;
public boolean _use_best_cv_iteration = true; // when early stopping is enabled, cv models will pick the iteration that produced the best score instead of the stopping iteration
public String _in_training_checkpoints_dir;
public int _in_training_checkpoints_tree_interval = 1; // save model checkpoint every so many trees (no matter what)
/** Fields which can NOT be modified if checkpoint is specified.
* FIXME: should be defined in Schema API annotation
*/
static final String[] CHECKPOINT_NON_MODIFIABLE_FIELDS = { "_build_tree_one_node", "_sample_rate", "_max_depth", "_min_rows", "_nbins", "_nbins_cats", "_nbins_top_level"};
@Override
public int getNTrees() {
return _ntrees;
}
@Override
public Frame getCalibrationFrame() {
return _calibration_frame == null ? null : _calibration_frame.get();
}
@Override
public boolean calibrateModel() {
return _calibrate_model;
}
@Override
public CalibrationHelper.CalibrationMethod getCalibrationMethod() {
return _calibration_method;
}
@Override
public void setCalibrationMethod(CalibrationHelper.CalibrationMethod calibrationMethod) {
_calibration_method = calibrationMethod;
}
@Override
public Parameters getParams() {
return this;
}
/**
* Do we need to enable strictly deterministic way of building histograms?
*
* Used eg. when monotonicity constraints in GBM are enabled, by default disabled
*
* @return true if histograms should be built in deterministic way
*/
public boolean forceStrictlyReproducibleHistograms() {
return false;
}
}
@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
switch(_output.getModelCategory()) {
case Binomial: return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain, _parms._auc_type);
case Regression: return new ModelMetricsRegression.MetricBuilderRegression();
default: throw H2O.unimpl();
}
}
public abstract static class SharedTreeOutput extends Model.Output implements Model.GetNTrees, CalibrationHelper.OutputWithCalibration {
/** InitF value (for zero trees)
* f0 = mean(yi) for gaussian
* f0 = log(yi/1-yi) for bernoulli
*
* For GBM bernoulli, the initial prediction for 0 trees is
* p = 1/(1+exp(-f0))
*
* From this, the mse for 0 trees (null model) can be computed as follows:
* mean((yi-p)^2)
* */
public double _init_f;
/** Number of trees actually in the model (as opposed to requested) */
public int _ntrees;
/** More indepth tree stats */
public final TreeStats _treeStats;
/** Trees get big, so store each one separately in the DKV. */
public Key[/*_ntrees*/][/*_nclass*/] _treeKeys;
public Key[/*_ntrees*/][/*_nclass*/] _treeKeysAux;
public ScoreKeeper[/*ntrees+1*/] _scored_train;
public ScoreKeeper[/*ntrees+1*/] _scored_valid;
public ScoreKeeper[] scoreKeepers() {
ArrayList skl = new ArrayList<>();
ScoreKeeper[] ska = _validation_metrics != null ? _scored_valid : _scored_train;
for( ScoreKeeper sk : ska )
if (!sk.isEmpty())
skl.add(sk);
return skl.toArray(new ScoreKeeper[skl.size()]);
}
/** Training time */
public long[/*ntrees+1*/] _training_time_ms = {System.currentTimeMillis()};
/**
* Variable importances computed during training
*/
public TwoDimTable _variable_importances;
public VarImp _varimp;
@Override
public TwoDimTable getVariableImportances() {
return _variable_importances;
}
public Model, ?, ?> _calib_model;
public SharedTreeOutput( SharedTree b) {
super(b);
_ntrees = 0; // No trees yet
_treeKeys = new Key[_ntrees][]; // No tree keys yet
_treeKeysAux = new Key[_ntrees][]; // No tree keys yet
_treeStats = new TreeStats();
_scored_train = new ScoreKeeper[]{new ScoreKeeper(Double.NaN)};
_scored_valid = new ScoreKeeper[]{new ScoreKeeper(Double.NaN)};
_modelClassDist = _priorClassDist;
}
@Override
public TwoDimTable createInputFramesInformationTable(ModelBuilder modelBuilder) {
SharedTreeParameters params = (SharedTreeParameters) modelBuilder._parms;
TwoDimTable table = super.createInputFramesInformationTable(modelBuilder);
table.set(2, 0, "calibration_frame");
table.set(2, 1, params.getCalibrationFrame() != null ? params.getCalibrationFrame().checksum() : -1);
table.set(2, 2, params.getCalibrationFrame() != null ? Arrays.toString(params.getCalibrationFrame().anyVec().espc()) : -1);
return table;
}
@Override
public int getInformationTableNumRows() {
return super.getInformationTableNumRows() + 1;// +1 row for calibration frame
}
// Append next set of K trees
public void addKTrees( DTree[] trees) {
// DEBUG: Print the generated K trees
//SharedTree.printGenerateTrees(trees);
assert nclasses()==trees.length;
// Compress trees and record tree-keys
_treeKeys = Arrays.copyOf(_treeKeys ,_ntrees+1);
_treeKeysAux = Arrays.copyOf(_treeKeysAux ,_ntrees+1);
Key[] keys = _treeKeys[_ntrees] = new Key[trees.length];
Key[] keysAux = _treeKeysAux[_ntrees] = new Key[trees.length];
Futures fs = new Futures();
for( int i=0; i calibrationModel() {
return _calib_model;
}
@Override
public void setCalibrationModel(Model, ?, ?> model) {
_calib_model = model;
}
public CompressedTree ctree(int tnum, int knum ) { return _treeKeys[tnum][knum].get(); }
public String toStringTree ( int tnum, int knum ) { return ctree(tnum,knum).toString(this); }
}
public SharedTreeModel(Key selfKey, P parms, O output) {
super(selfKey, parms, output);
}
protected String[] makeAllTreeColumnNames() {
int classTrees = 0;
for (int i = 0; i < _output._treeKeys[0].length; ++i) {
if (_output._treeKeys[0][i] != null) classTrees++;
}
final int outputcols = _output._treeKeys.length * classTrees;
final String[] names = new String[outputcols];
int col = 0;
for (int tidx = 0; tidx < _output._treeKeys.length; tidx++) {
Key[] keys = _output._treeKeys[tidx];
for (int c = 0; c < keys.length; c++) {
if (keys[c] != null) {
names[col++] = "T" + (tidx + 1) + (keys.length == 1 ? "" : (".C" + (c + 1)));
}
}
}
return names;
}
@Override
public Frame scoreLeafNodeAssignment(Frame frame, LeafNodeAssignmentType type, Key destination_key) {
Frame adaptFrm = new Frame(frame);
adaptTestForTrain(adaptFrm, true, false);
final String[] names = makeAllTreeColumnNames();
AssignLeafNodeTaskBase task = AssignLeafNodeTaskBase.make(_output, type);
return task.execute(adaptFrm, names, destination_key);
}
@Override
public UpdateAuxTreeWeightsReport updateAuxTreeWeights(Frame frame, String weightsColumn) {
if (weightsColumn == null) {
throw new IllegalArgumentException("Weights column name is not defined");
}
Frame adaptFrm = new Frame(frame);
Vec weights = adaptFrm.remove(weightsColumn);
if (weights == null) {
throw new IllegalArgumentException("Input frame doesn't contain weights column `" + weightsColumn + "`");
}
adaptTestForTrain(adaptFrm, true, false);
// keep features only and re-introduce weights column at the end of the frame
Frame featureFrm = new Frame(_output.features(), frame.vecs(_output.features()));
featureFrm.add(weightsColumn, weights);
UpdateAuxTreeWeightsTask t = new UpdateAuxTreeWeightsTask(_output).doAll(featureFrm);
UpdateAuxTreeWeights.UpdateAuxTreeWeightsReport report = new UpdateAuxTreeWeights.UpdateAuxTreeWeightsReport();
report._warn_trees = t._warnTrees;
report._warn_classes = t._warnClasses;
return report;
}
public static class BufStringDecisionPathTracker implements SharedTreeMojoModel.DecisionPathTracker {
private final byte[] _buf = new byte[__INTERNAL_MAX_TREE_DEPTH];
private final BufferedString _bs = new BufferedString(_buf, 0, 0);
private int _pos = 0;
@Override
public boolean go(int depth, boolean right) {
_buf[depth] = right ? (byte) 'R' : (byte) 'L';
if (right) _pos = depth;
return true;
}
@Override
public BufferedString terminate() {
_bs.setLen(_pos);
_pos = 0;
return _bs;
}
@Override
public BufferedString invalidPath() {
return null;
}
}
private static abstract class AssignLeafNodeTaskBase extends MRTask {
final Key[/*_ntrees*/][/*_nclass*/] _treeKeys;
final String[][] _domains;
AssignLeafNodeTaskBase(SharedTreeOutput output) {
_treeKeys = output._treeKeys;
_domains = output._domains;
}
protected abstract void initMap();
protected abstract void assignNode(final int tidx, final int cls, final CompressedTree tree, final double[] input,
final NewChunk out);
@Override
public void map(Chunk[] chks, NewChunk[] ncs) {
double[] input = new double[chks.length];
initMap();
for (int row = 0; row < chks[0]._len; row++) {
for (int i = 0; i < chks.length; i++)
input[i] = chks[i].atd(row);
int col = 0;
for (int tidx = 0; tidx < _treeKeys.length; tidx++) {
Key[] keys = _treeKeys[tidx];
for (int cls = 0; cls < keys.length; cls++) {
Key key = keys[cls];
if (key != null) {
CompressedTree tree = DKV.get(key).get();
assignNode(tidx, cls, tree, input, ncs[col++]);
}
}
}
assert (col == ncs.length);
}
}
protected abstract Frame execute(Frame adaptFrm, String[] names, Key destKey);
private static AssignLeafNodeTaskBase make(SharedTreeOutput modelOutput, LeafNodeAssignmentType type) {
switch (type) {
case Path:
return new AssignTreePathTask(modelOutput);
case Node_ID:
return new AssignLeafNodeIdTask(modelOutput);
default:
throw new UnsupportedOperationException("Unknown leaf node assignment type: " + type);
}
}
}
private static class AssignTreePathTask extends AssignLeafNodeTaskBase {
private transient BufStringDecisionPathTracker _tr;
private AssignTreePathTask(SharedTreeOutput output) {
super(output);
}
@Override
protected void initMap() {
_tr = new BufStringDecisionPathTracker();
}
@Override
protected void assignNode(int tidx, int cls, CompressedTree tree, double[] input,
NewChunk nc) {
BufferedString pred = tree.getDecisionPath(input, _domains, _tr);
nc.addStr(pred);
}
@Override
protected Frame execute(Frame adaptFrm, String[] names, Key destKey) {
Frame res = doAll(names.length, Vec.T_STR, adaptFrm).outputFrame(destKey, names, null);
// convert to categorical
Vec vv;
Vec[] nvecs = new Vec[res.vecs().length];
boolean hasInvalidPaths = false;
for(int c=0;c 0;
nvecs[c] = vv.toCategoricalVec();
} catch (Exception e) {
VecUtils.deleteVecs(nvecs, c);
throw e;
}
}
res.delete();
res = new Frame(destKey, names, nvecs);
if (destKey != null) {
DKV.put(res);
}
if (hasInvalidPaths) {
LOG.warn("Some of the leaf node assignments were skipped (NA), " +
"only tree-paths up to length 64 are supported.");
}
return res;
}
}
private static class AssignLeafNodeIdTask extends AssignLeafNodeTaskBase {
final Key[/*_ntrees*/][/*_nclass*/] _auxTreeKeys;
private AssignLeafNodeIdTask(SharedTreeOutput output) {
super(output);
_auxTreeKeys = output._treeKeysAux;
}
@Override
protected void initMap() {
}
@Override
protected void assignNode(int tidx, int cls, CompressedTree tree, double[] input, NewChunk nc) {
CompressedTree auxTree = _auxTreeKeys[tidx][cls].get();
assert auxTree != null;
final double d = SharedTreeMojoModel.scoreTree(tree._bits, input, true, _domains);
final int nodeId = SharedTreeMojoModel.getLeafNodeId(d, auxTree._bits);
nc.addNum(nodeId, 0);
}
@Override
protected Frame execute(Frame adaptFrm, String[] names, Key destKey) {
Frame result = doAll(names.length, Vec.T_NUM, adaptFrm).outputFrame(destKey, names, null);
if (result.vec(0).min() < 0) {
LOG.warn("Some of the observations were not assigned a Leaf Node ID (-1), " +
"only tree-paths up to length 64 are supported.");
}
return result;
}
}
private static class UpdateAuxTreeWeightsTask extends MRTask {
// IN
private final Key[/*_ntrees*/][/*_nclass*/] _treeKeys;
private final Key[/*_ntrees*/][/*_nclass*/] _auxTreeKeys;
private final String[][] _domains;
// WORKING
private transient int[/*treeId*/][/*classId*/] _maxNodeIds;
// OUT
private double[/*treeId*/][/*classId*/][/*leafNodeId*/] _leafNodeWeights;
private int[] _warnTrees;
private int[] _warnClasses;
private UpdateAuxTreeWeightsTask(SharedTreeOutput output) {
_treeKeys = output._treeKeys;
_auxTreeKeys = output._treeKeysAux;
_domains = output._domains;
}
@Override
protected void setupLocal() {
_maxNodeIds = new int[_auxTreeKeys.length][];
for (int treeId = 0; treeId < _auxTreeKeys.length; treeId++) {
Key[] classAuxTreeKeys = _auxTreeKeys[treeId];
_maxNodeIds[treeId] = new int[classAuxTreeKeys.length];
for (int classId = 0; classId < classAuxTreeKeys.length; classId++) {
if (classAuxTreeKeys[classId] == null) {
_maxNodeIds[treeId][classId] = -1;
continue;
}
CompressedTree tree = classAuxTreeKeys[classId].get();
assert tree != null;
_maxNodeIds[treeId][classId] = tree.findMaxNodeId();
}
}
}
protected void initMap() {
_leafNodeWeights = new double[_maxNodeIds.length][][];
for (int treeId = 0; treeId < _maxNodeIds.length; treeId++) {
int[] classMaxNodeIds = _maxNodeIds[treeId];
_leafNodeWeights[treeId] = new double[classMaxNodeIds.length][];
for (int classId = 0; classId < classMaxNodeIds.length; classId++) {
if (classMaxNodeIds[classId] < 0)
continue;
_leafNodeWeights[treeId][classId] = new double[classMaxNodeIds[classId] + 1];
}
}
}
@Override
public void map(Chunk[] chks) {
double[] input = new double[chks.length - 1];
initMap();
for (int row = 0; row < chks[0]._len; row++) {
double weight = chks[input.length].atd(row);
if (weight == 0 || Double.isNaN(weight))
continue;
for (int i = 0; i < input.length; i++)
input[i] = chks[i].atd(row);
for (int tidx = 0; tidx < _treeKeys.length; tidx++) {
Key[] keys = _treeKeys[tidx];
for (int cls = 0; cls < keys.length; cls++) {
Key key = keys[cls];
if (key != null) {
CompressedTree tree = DKV.get(key).get();
CompressedTree auxTree = _auxTreeKeys[tidx][cls].get();
assert auxTree != null;
final double d = SharedTreeMojoModel.scoreTree(tree._bits, input, true, _domains);
final int nodeId = SharedTreeMojoModel.getLeafNodeId(d, auxTree._bits);
_leafNodeWeights[tidx][cls][nodeId] += weight;
}
}
}
}
}
@Override
public void reduce(UpdateAuxTreeWeightsTask mrt) {
ArrayUtils.add(_leafNodeWeights, mrt._leafNodeWeights);
}
@Override
protected void postGlobal() {
_warnTrees = new int[0];
_warnClasses = new int[0];
Futures fs = new Futures();
for (int treeId = 0; treeId < _leafNodeWeights.length; treeId++) {
double[][] classWeights = _leafNodeWeights[treeId];
for (int classId = 0; classId < classWeights.length; classId++) {
double[] nodeWeights = classWeights[classId];
if (nodeWeights == null)
continue;
CompressedTree auxTree = _auxTreeKeys[treeId][classId].get();
assert auxTree != null;
CompressedTree updatedTree = auxTree.updateLeafNodeWeights(nodeWeights);
assert auxTree._key.equals(updatedTree._key);
DKV.put(updatedTree, fs);
if (updatedTree.hasZeroWeight()) {
_warnTrees = ArrayUtils.append(_warnTrees, treeId);
_warnClasses = ArrayUtils.append(_warnClasses, classId);
}
}
}
fs.blockForPending();
assert _warnTrees.length == _warnClasses.length;
}
}
@Override
public Frame scoreFeatureFrequencies(Frame frame, Key destination_key) {
Frame adaptFrm = new Frame(frame);
adaptTestForTrain(adaptFrm, true, false);
// remove non-feature columns
adaptFrm.remove(_parms._response_column);
adaptFrm.remove(_parms._fold_column);
adaptFrm.remove(_parms._weights_column);
adaptFrm.remove(_parms._offset_column);
if(_parms._treatment_column != null){
adaptFrm.remove(_parms._treatment_column);
}
assert adaptFrm.numCols() == _output.nfeatures();
return new ScoreFeatureFrequenciesTask(_output)
.doAll(adaptFrm.numCols(), Vec.T_NUM, adaptFrm)
.outputFrame(destination_key, adaptFrm.names(), null);
}
private static class ComputeSharedTreesFun extends MrFun {
final Key[/*_ntrees*/][/*_nclass*/] _treeKeys;
final Key[/*_ntrees*/][/*_nclass*/] _auxTreeKeys;
final String[] _names;
final String[][] _domains;
transient SharedTreeSubgraph[/*_ntrees*/][/*_nclass*/] _trees;
ComputeSharedTreesFun(SharedTreeSubgraph[][] trees,
Key[][] treeKeys, Key[][] auxTreeKeys,
String[] names, String[][] domains) {
_trees = trees;
_treeKeys = treeKeys;
_auxTreeKeys = auxTreeKeys;
_names = names;
_domains = domains;
}
@Override
protected void map(int t) {
for (int c = 0; c < _treeKeys[t].length; c++) {
if (_treeKeys[t][c] == null)
continue;
_trees[t][c] = SharedTreeMojoModel.computeTreeGraph(0, "T",
_treeKeys[t][c].get()._bits, _auxTreeKeys[t][c].get()._bits, _names, _domains);
}
}
}
private static class ScoreFeatureFrequenciesTask extends MRTask {
final Key[/*_ntrees*/][/*_nclass*/] _treeKeys;
final Key[/*_ntrees*/][/*_nclass*/] _auxTreeKeys;
final String _domains[][];
transient SharedTreeSubgraph[/*_ntrees*/][/*_nclass*/] _trees;
ScoreFeatureFrequenciesTask(SharedTreeOutput output) {
_treeKeys = output._treeKeys;
_auxTreeKeys = output._treeKeysAux;
_domains = output._domains;
}
@Override
protected void setupLocal() {
_trees = new SharedTreeSubgraph[_treeKeys.length][];
for (int t = 0; t < _treeKeys.length; t++) {
_trees[t] = new SharedTreeSubgraph[_treeKeys[t].length];
}
MrFun> getSharedTreesFun = new ComputeSharedTreesFun(_trees, _treeKeys, _auxTreeKeys, _fr.names(), _domains);
H2O.submitTask(new LocalMR(getSharedTreesFun, _trees.length)).join();
}
@Override
public void map(Chunk[] cs, NewChunk[] ncs) {
double[] input = new double[cs.length];
int[] output = new int[ncs.length];
for (int r = 0; r < cs[0]._len; r++) {
for (int i = 0; i < cs.length; i++)
input[i] = cs[i].atd(r);
Arrays.fill(output, 0);
for (int t = 0; t < _treeKeys.length; t++) {
for (int c = 0; c < _treeKeys[t].length; c++) {
if (_treeKeys[t][c] == null)
continue;
double d = SharedTreeMojoModel.scoreTree(_treeKeys[t][c].get()._bits, input, true, _domains);
String decisionPath = SharedTreeMojoModel.getDecisionPath(d);
SharedTreeNode n = _trees[t][c].walkNodes(decisionPath);
updateStats(n, output);
}
}
for (int i = 0; i < ncs.length; i++) {
ncs[i].addNum(output[i]);
}
}
}
private void updateStats(final SharedTreeNode leaf, int[] stats) {
SharedTreeNode n = leaf.getParent();
while (n != null) {
stats[n.getColId()]++;
n = n.getParent();
}
}
}
@Override
protected Frame postProcessPredictions(Frame adaptedFrame, Frame predictFr, Job j) {
return CalibrationHelper.postProcessPredictions(predictFr, j, _output);
}
protected double[] score0Incremental(Score.ScoreIncInfo sii, Chunk chks[], double offset, int row_in_chunk, double[] tmp, double[] preds) {
return score0(chks, offset, row_in_chunk, tmp, preds); // by default delegate to non-incremental implementation
}
@Override protected double[] score0(double[] data, double[] preds, double offset) {
return score0(data, preds, offset, _output._treeKeys.length);
}
@Override protected double[] score0(double[/*ncols*/] data, double[/*nclasses+1*/] preds) {
return score0(data, preds, 0.0);
}
protected double[] score0(double[] data, double[] preds, double offset, int ntrees) {
Arrays.fill(preds,0);
return score0(data, preds, offset, 0, ntrees);
}
protected double[] score0(double[] data, double[] preds, double offset, int startTree, int ntrees) {
// Prefetch trees into the local cache if it is necessary
// Invoke scoring
for( int tidx=startTree; tidxget().score(data,_output._domains);
assert (!Double.isInfinite(pred));
preds[keys.length == 1 ? 0 : c + 1] += pred;
}
}
}
/** Performs deep clone of given model. */
protected M deepClone(Key result) {
M newModel = IcedUtils.deepCopy(self());
newModel._key = result;
// Do not clone model metrics
newModel._output.clearModelMetrics(false);
newModel._output._training_metrics = null;
newModel._output._validation_metrics = null;
// Clone trees
Key[][] treeKeys = newModel._output._treeKeys;
for (int i = 0; i < treeKeys.length; i++) {
for (int j = 0; j < treeKeys[i].length; j++) {
if (treeKeys[i][j] == null) continue;
CompressedTree ct = DKV.get(treeKeys[i][j]).get();
CompressedTree newCt = IcedUtils.deepCopy(ct);
newCt._key = CompressedTree.makeTreeKey(i, j);
DKV.put(treeKeys[i][j] = newCt._key,newCt);
}
}
// Clone Aux info
Key[][] treeKeysAux = newModel._output._treeKeysAux;
if (treeKeysAux!=null) {
for (int i = 0; i < treeKeysAux.length; i++) {
for (int j = 0; j < treeKeysAux[i].length; j++) {
if (treeKeysAux[i][j] == null) continue;
CompressedTree ct = DKV.get(treeKeysAux[i][j]).get();
CompressedTree newCt = IcedUtils.deepCopy(ct);
newCt._key = Key.make(createAuxKey(treeKeys[i][j].toString()));
DKV.put(treeKeysAux[i][j] = newCt._key,newCt);
}
}
}
return newModel;
}
@Override protected Futures remove_impl(Futures fs, boolean cascade) {
for (Key[] ks : _output._treeKeys)
for (Key k : ks)
Keyed.remove(k, fs, true);
for (Key[] ks : _output._treeKeysAux)
for (Key k : ks)
Keyed.remove(k, fs, true);
if (_output._calib_model != null)
_output._calib_model.remove(fs);
return super.remove_impl(fs, cascade);
}
/** Write out K/V pairs */
@Override protected AutoBuffer writeAll_impl(AutoBuffer ab) {
for (Key[] ks : _output._treeKeys)
for (Key k : ks)
ab.putKey(k);
for (Key[] ks : _output._treeKeysAux)
for (Key k : ks)
ab.putKey(k);
return super.writeAll_impl(ab);
}
@Override protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
for (Key[] ks : _output._treeKeys)
for (Key k : ks)
ab.getKey(k,fs);
for (Key[] ks : _output._treeKeysAux)
for (Key k : ks)
ab.getKey(k,fs);
return super.readAll_impl(ab,fs);
}
@SuppressWarnings("unchecked") // `M` is really the type of `this`
private M self() { return (M)this; }
/**
* Converts a given tree of the ensemble to a user-understandable representation.
* @param tidx tree index
* @param cls tree class
* @return instance of SharedTreeSubgraph
*/
public SharedTreeSubgraph getSharedTreeSubgraph(final int tidx, final int cls) {
if (tidx < 0 || tidx >= _output._ntrees) {
throw new IllegalArgumentException("Invalid tree index: " + tidx +
". Tree index must be in range [0, " + (_output._ntrees -1) + "].");
}
Key treeKey = _output._treeKeysAux[tidx][cls];
if (treeKey == null)
return null;
final CompressedTree auxCompressedTree = treeKey.get();
return _output._treeKeys[tidx][cls].get().toSharedTreeSubgraph(auxCompressedTree, _output._names, _output._domains);
}
@Override
public boolean isFeatureUsedInPredict(String featureName) {
if (featureName.equals(_output.responseName())) return false;
int featureIdx = ArrayUtils.find(_output._varimp._names, featureName);
return featureIdx != -1 && (double) _output._varimp._varimp[featureIdx] != 0d;
}
//--------------------------------------------------------------------------------------------------------------------
// Serialization into a POJO
//--------------------------------------------------------------------------------------------------------------------
public boolean binomialOpt() {
return true;
}
@Override
public CategoricalEncoding getGenModelEncoding() {
switch (_parms._categorical_encoding) {
case AUTO:
case Enum:
case SortByResponse:
return CategoricalEncoding.AUTO;
case OneHotExplicit:
return CategoricalEncoding.OneHotExplicit;
case Binary:
return CategoricalEncoding.Binary;
case EnumLimited:
return CategoricalEncoding.EnumLimited;
case Eigen:
return CategoricalEncoding.Eigen;
case LabelEncoder:
return CategoricalEncoding.LabelEncoder;
default:
return null;
}
}
protected SharedTreePojoWriter makeTreePojoWriter() {
throw new UnsupportedOperationException("POJO is not supported for model " + _parms.algoName() + ".");
}
@Override
protected final PojoWriter makePojoWriter() {
CategoricalEncoding encoding = getGenModelEncoding();
if (encoding == null) {
throw new IllegalArgumentException("Only default, SortByResponse, EnumLimited and 1-hot explicit scheme is supported for POJO/MOJO");
}
return makeTreePojoWriter();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy