hex.tree.gbm.GBMModel Maven / Gradle / Ivy
package hex.tree.gbm;
import hex.*;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.*;
import hex.util.EffectiveParametersUtils;
import water.DKV;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.TwoDimTable;
import java.util.*;
public class GBMModel extends SharedTreeModelWithContributions
implements Model.StagedPredictions, FeatureInteractionsCollector, FriedmanPopescusHCollector {
public static class GBMParameters extends SharedTreeModel.SharedTreeParameters {
public double _learn_rate;
public double _learn_rate_annealing;
public double _col_sample_rate;
public double _max_abs_leafnode_pred;
public double _pred_noise_bandwidth;
public KeyValue[] _monotone_constraints;
public String[][] _interaction_constraints;
public GBMParameters() {
super();
_learn_rate = 0.1;
_learn_rate_annealing = 1.0;
_col_sample_rate = 1.0;
_sample_rate = 1.0;
_ntrees = 50;
_max_depth = 5;
_max_abs_leafnode_pred = Double.MAX_VALUE;
_pred_noise_bandwidth =0;
}
@Override
public boolean useColSampling() {
return super.useColSampling() || _col_sample_rate != 1.0;
}
public String algoName() { return "GBM"; }
public String fullName() { return "Gradient Boosting Machine"; }
public String javaName() { return GBMModel.class.getName(); }
@Override
public boolean forceStrictlyReproducibleHistograms() {
// if monotone constraints are enabled -> use strictly reproducible histograms (we calculate values that
// are not subject to reduce precision logic in DHistogram (the "float trick" cannot be applied)
return usesMonotoneConstraints();
}
private boolean usesMonotoneConstraints() {
if (areMonotoneConstraintsEmpty())
return emptyConstraints(0) != null;
return true;
}
private boolean areMonotoneConstraintsEmpty() {
return _monotone_constraints == null || _monotone_constraints.length == 0;
}
public Constraints constraints(Frame f) {
if (areMonotoneConstraintsEmpty()) {
return emptyConstraints(f.numCols());
}
int[] cs = new int[f.numCols()];
for (KeyValue spec : _monotone_constraints) {
if (spec.getValue() == 0)
continue;
int col = f.find(spec.getKey());
if (col < 0) {
throw new IllegalStateException("Invalid constraint specification, column '" + spec.getKey() + "' doesn't exist.");
}
cs[col] = spec.getValue() < 0 ? -1 : 1;
}
boolean useBounds = _distribution == DistributionFamily.gaussian ||
_distribution == DistributionFamily.bernoulli ||
_distribution == DistributionFamily.tweedie ||
_distribution == DistributionFamily.quasibinomial ||
_distribution == DistributionFamily.multinomial ||
_distribution == DistributionFamily.quantile;
return new Constraints(cs, DistributionFactory.getDistribution(this), useBounds);
}
// allows to override the behavior in tests (eg. create empty constraints and test execution as if constraints were used)
Constraints emptyConstraints(int nCols) {
return null;
}
public GlobalInteractionConstraints interactionConstraints(Frame frame){
return new GlobalInteractionConstraints(this._interaction_constraints, frame.names());
}
public BranchInteractionConstraints initialInteractionConstraints(GlobalInteractionConstraints ics){
return new BranchInteractionConstraints(ics.getAllAllowedColumnIndices());
}
}
public static class GBMOutput extends SharedTreeModel.SharedTreeOutput {
public String[] _quasibinomialDomains;
boolean _quasibinomial;
int _nclasses;
public int nclasses() {
return _nclasses;
}
public GBMOutput(GBM b) {
super(b);
_quasibinomial = b._parms._distribution == DistributionFamily.quasibinomial;
_nclasses = b.nclasses();
}
@Override
public String[] classNames() {
String [] res = super.classNames();
if(_quasibinomial){
return _quasibinomialDomains;
}
return res;
}
}
public GBMModel(Key selfKey, GBMParameters parms, GBMOutput output) {
super(selfKey,parms,output);
}
@Override
public void initActualParamValues() {
super.initActualParamValues();
EffectiveParametersUtils.initFoldAssignment(_parms);
EffectiveParametersUtils.initHistogramType(_parms);
EffectiveParametersUtils.initCategoricalEncoding(_parms, Parameters.CategoricalEncodingScheme.Enum);
}
public void initActualParamValuesAfterOutputSetup(int nclasses, boolean isClassifier) {
EffectiveParametersUtils.initStoppingMetric(_parms, isClassifier);
EffectiveParametersUtils.initDistribution(_parms, nclasses);
}
@Override
protected ScoreContributionsTask getScoreContributionsTask(SharedTreeModel model) {
return new ScoreContributionsTask(this);
}
@Override
protected ScoreContributionsTask getScoreContributionsSoringTask(SharedTreeModel model, ContributionsOptions options) {
return new ScoreContributionsSortingTask(model, options);
}
@Override
public Frame scoreStagedPredictions(Frame frame, Key destination_key) {
Frame adaptFrm = new Frame(frame);
adaptTestForTrain(adaptFrm, true, false);
final String[] names = makeAllTreeColumnNames();
final int outputcols = names.length;
return new StagedPredictionsTask(this)
.doAll(outputcols, Vec.T_NUM, adaptFrm)
.outputFrame(destination_key, names, null);
}
private static class StagedPredictionsTask extends MRTask {
private final Key _modelKey;
private transient GBMModel _model;
private StagedPredictionsTask(GBMModel model) {
_modelKey = model._key;
}
@Override
protected void setupLocal() {
_model = _modelKey.get();
assert _model != null;
}
@Override
public void map(Chunk chks[], NewChunk[] nc) {
double[] input = new double[chks.length];
int contribOffset = _model._output.nclasses() == 1 ? 0 : 1;
for (int row = 0; row < chks[0]._len; row++) {
for (int i = 0; i < chks.length; i++)
input[i] = chks[i].atd(row);
double[] contribs = new double[contribOffset + _model._output.nclasses()];
double[] preds = new double[contribs.length];
int col = 0;
for (int tidx = 0; tidx < _model._output._treeKeys.length; tidx++) {
Key[] keys = _model._output._treeKeys[tidx];
for (int i = 0; i < keys.length; i++) {
if (keys[i] != null)
contribs[contribOffset + i] += DKV.get(keys[i]).get().score(input, _model._output._domains);
preds[contribOffset + i] = contribs[contribOffset + i];
}
_model.score0Probabilities(preds, 0);
_model.score0PostProcessSupervised(preds, input);
for (int i = 0; i < keys.length; i++) {
if (keys[i] != null)
nc[col++].addNum(preds[contribOffset + i]);
}
}
assert (col == nc.length);
}
}
}
@Override
protected final double[] score0Incremental(Score.ScoreIncInfo sii, Chunk[] chks, double offset, int row_in_chunk, double[] tmp, double[] preds) {
assert _output.nfeatures() == tmp.length;
for (int i = 0; i < tmp.length; i++)
tmp[i] = chks[i].atd(row_in_chunk);
if (sii._startTree == 0)
Arrays.fill(preds,0);
else
for (int i = 0; i < sii._workspaceColCnt; i++)
preds[sii._predsAryOffset + i] = chks[sii._workspaceColIdx + i].atd(row_in_chunk);
score0(tmp, preds, offset, sii._startTree, _output._treeKeys.length);
for (int i = 0; i < sii._workspaceColCnt; i++)
chks[sii._workspaceColIdx + i].set(row_in_chunk, preds[sii._predsAryOffset + i]);
score0Probabilities(preds, offset);
score0PostProcessSupervised(preds, tmp);
return preds;
}
/** Bulk scoring API for one row. Chunks are all compatible with the model,
* and expect the last Chunks are for the final distribution and prediction.
* Default method is to just load the data into the tmp array, then call
* subclass scoring logic. */
@Override protected double[] score0(double data[/*ncols*/], double preds[/*nclasses+1*/], double offset, int ntrees) {
super.score0(data, preds, offset, ntrees); // These are f_k(x) in Algorithm 10.4
return score0Probabilities(preds, offset);
}
private double[] score0Probabilities(double preds[/*nclasses+1*/], double offset) {
if (_parms._distribution == DistributionFamily.bernoulli
|| _parms._distribution == DistributionFamily.quasibinomial
|| _parms._distribution == DistributionFamily.modified_huber
|| (_parms._distribution == DistributionFamily.custom && _output.nclasses() == 2)) { // custom distribution could be also binomial
double f = preds[1] + _output._init_f + offset; //Note: class 1 probability stored in preds[1] (since we have only one tree)
preds[2] = DistributionFactory.getDistribution(_parms).linkInv(f);
preds[1] = 1.0 - preds[2];
} else if (_parms._distribution == DistributionFamily.multinomial // Kept the initial prediction for binomial
|| (_parms._distribution == DistributionFamily.custom && _output.nclasses() > 2) ) { // custom distribution could be also multinomial
if (_output.nclasses() == 2) { //1-tree optimization for binomial
preds[1] += _output._init_f + offset; //offset is not yet allowed, but added here to be future-proof
preds[2] = -preds[1];
}
hex.genmodel.GenModel.GBM_rescale(preds);
} else { //Regression
double f = preds[0] + _output._init_f + offset;
preds[0] = DistributionFactory.getDistribution(_parms).linkInv(f);
}
return preds;
}
@Override
protected SharedTreePojoWriter makeTreePojoWriter() {
CompressedForest compressedForest = new CompressedForest(_output._treeKeys, _output._domains);
CompressedForest.LocalCompressedForest localCompressedForest = compressedForest.fetch();
return new GbmPojoWriter(this, localCompressedForest._trees);
}
@Override
public GbmMojoWriter getMojo() {
return new GbmMojoWriter(this);
}
public FeatureInteractions getFeatureInteractions(int maxInteractionDepth, int maxTreeDepth, int maxDeepening) {
FeatureInteractions featureInteractions = new FeatureInteractions();
int nclasses = this._output._nclasses > 2 ? this._output._nclasses : 1;
for (int i = 0; i < this._parms._ntrees; i++) {
for (int j = 0; j < nclasses; j++) {
FeatureInteractions currentTreeFeatureInteractions = new FeatureInteractions();
SharedTreeSubgraph tree = this.getSharedTreeSubgraph(i, j);
List interactionPath = new ArrayList<>();
Set memo = new HashSet<>();
FeatureInteractions.collectFeatureInteractions(tree.rootNode, interactionPath, 0, 0, 1, 0, 0, currentTreeFeatureInteractions,
memo, maxInteractionDepth, maxTreeDepth, maxDeepening, i, true);
featureInteractions.mergeWith(currentTreeFeatureInteractions);
}
}
return featureInteractions;
}
@Override
public TwoDimTable[][] getFeatureInteractionsTable(int maxInteractionDepth, int maxTreeDepth, int maxDeepening) {
return FeatureInteractions.getFeatureInteractionsTable(this.getFeatureInteractions(maxInteractionDepth,maxTreeDepth,maxDeepening));
}
@Override
public double getFriedmanPopescusH(Frame frame, String[] vars) {
int nclasses = this._output._nclasses > 2 ? this._output._nclasses : 1;
SharedTreeSubgraph[][] sharedTreeSubgraphs = new SharedTreeSubgraph[this._parms._ntrees][nclasses];
for (int i = 0; i < this._parms._ntrees; i++) {
for (int j = 0; j < nclasses; j++) {
sharedTreeSubgraphs[i][j] = this.getSharedTreeSubgraph(i, j);
}
}
return FriedmanPopescusH.h(frame, vars, this._parms._learn_rate, sharedTreeSubgraphs);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy