All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.gbm.GbmPojoWriter Maven / Gradle / Ivy

package hex.tree.gbm;

import hex.Distribution;
import hex.DistributionFactory;
import hex.LinkFunction;
import hex.Model;
import hex.genmodel.CategoricalEncoding;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.CompressedTree;
import hex.tree.SharedTreePojoWriter;
import water.util.SBPrintStream;

class GbmPojoWriter extends SharedTreePojoWriter {

    private final double _init_f;
    private final boolean _balance_classes;
    private final DistributionFamily _distribution_family;
    private final LinkFunction _link_function;

    GbmPojoWriter(GBMModel model, CompressedTree[][] trees) {
        super(model._key, model._output, model.getGenModelEncoding(), model.binomialOpt(),
                trees, model._output._treeStats);
        _init_f = model._output._init_f;
        _balance_classes = model._parms._balance_classes;
        Distribution distribution = DistributionFactory.getDistribution(model._parms);
        _distribution_family = distribution._family;
        _link_function = distribution._linkFunction;
    }

    GbmPojoWriter(Model model, CategoricalEncoding encoding,
                  boolean binomialOpt, CompressedTree[][] trees,
                  double initF, boolean balanceClasses,
                  DistributionFamily distributionFamily, LinkFunction linkFunction) {
        super(model._key, model._output, encoding, binomialOpt, trees, null);
        _init_f = initF;
        _balance_classes = balanceClasses;
        _distribution_family = distributionFamily;
        _link_function = linkFunction;
    }

    // Note: POJO scoring code doesn't support per-row offsets (the scoring API would need to be changed to pass in offsets)
    @Override
    protected void toJavaUnifyPreds(SBPrintStream body) {
        // Preds are filled in from the trees, but need to be adjusted according to
        // the loss function.
        if (_distribution_family == DistributionFamily.bernoulli
                || _distribution_family == DistributionFamily.quasibinomial
                || _distribution_family == DistributionFamily.modified_huber
        ) {
            body.ip("preds[2] = preds[1] + ").p(_init_f).p(";").nl();
            body.ip("preds[2] = " + _link_function.linkInvString("preds[2]") + ";").nl();
            body.ip("preds[1] = 1.0-preds[2];").nl();
            if (_balance_classes)
                body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
            body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + _output.defaultThreshold() + ");").nl();
            return;
        }
        if (_output.nclasses() == 1) { // Regression
            body.ip("preds[0] += ").p(_init_f).p(";").nl();
            body.ip("preds[0] = " + _link_function.linkInvString("preds[0]") + ";").nl();
            return;
        }
        if (_output.nclasses() == 2) { // Kept the initial prediction for binomial
            body.ip("preds[1] += ").p(_init_f).p(";").nl();
            body.ip("preds[2] = - preds[1];").nl();
        }
        body.ip("hex.genmodel.GenModel.GBM_rescale(preds);").nl();
        if (_balance_classes)
            body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
        body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + _output.defaultThreshold() + ");").nl();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy