All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.genmodel.algos.tree;

import hex.genmodel.PredictContributions;
import hex.genmodel.PredictContributionsFactory;

import java.util.ArrayList;
import java.util.List;

public abstract class SharedTreeMojoModelWithContributions extends SharedTreeMojoModel implements PredictContributionsFactory {
    protected SharedTreeMojoModelWithContributions(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    public PredictContributions makeContributionsPredictor() {
        if (_nclasses > 2) {
            throw new UnsupportedOperationException("Predicting contributions for multinomial classification problems is not yet supported.");
        }
        SharedTreeGraph graph = computeGraph(-1);
        List> treeSHAPs = new ArrayList<>(graph.subgraphArray.size());
        for (SharedTreeSubgraph tree : graph.subgraphArray) {
            SharedTreeNode[] nodes = tree.getNodes();
            treeSHAPs.add(new TreeSHAP<>(nodes));
        }
        TreeSHAPPredictor predictor = new TreeSHAPEnsemble<>(treeSHAPs, (float) getInitF());
        
        return getContributionsPredictor(predictor);
    }

    public double getInitF() {
        return 0; // Set to zero by default, which is correct for DRF. However, need to override in GBMMojoModel with correct init_f.
    }
    
    protected abstract PredictContributions getContributionsPredictor(TreeSHAPPredictor treeSHAPPredictor);

    protected static class SharedTreeContributionsPredictor extends ContributionsPredictor {

        public SharedTreeContributionsPredictor(SharedTreeMojoModel model, TreeSHAPPredictor treeSHAPPredictor) {
            super(model._nfeatures + 1, model.features(), treeSHAPPredictor);
        }

        @Override
        protected final double[] toInputRow(double[] input) {
            return input;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy