All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.dmlc.xgboost4j.java.XGBoostScoreTask Maven / Gradle / Ivy

package ml.dmlc.xgboost4j.java;

import hex.*;
import hex.tree.xgboost.*;
import water.*;
import water.fvec.*;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import static water.fvec.Vec.T_CAT;
import static water.fvec.Vec.T_NUM;

public class XGBoostScoreTask extends MRTask {

    private final XGBoostModelInfo _sharedmodel;
    private final XGBoostOutput _output;
    private final XGBoostModel.XGBoostParameters _parms;
    private final BoosterParms _boosterParms;
    private final boolean _computeMetrics;
    private final int _weightsChunkId;
    private final Model _model;
    private final double _threshold;

    private ModelMetrics.MetricBuilder _metricBuilder;

    public static class XGBoostScoreTaskResult {
        public Frame preds;
        public ModelMetrics mm;
    }

    public static XGBoostScoreTaskResult runScoreTask(XGBoostModelInfo sharedmodel,
                                                      XGBoostOutput output,
                                                      XGBoostModel.XGBoostParameters parms,
                                                      Key destinationKey,
                                                      Frame data,
                                                      Frame originalData,
                                                      boolean computeMetrics,
                                                      Model m) {
        BoosterParms boosterParms = XGBoostModel.createParams(parms, output.nclasses(), sharedmodel.dataInfo().coefNames());
        XGBoostScoreTask task = new XGBoostScoreTask(sharedmodel,
                output,
                parms,
                boosterParms,
                computeMetrics,
                data.find(parms._weights_column),
                m).doAll(outputTypes(output), data);

        final String[] names = Model.makeScoringNames(output);
        final Frame preds = task.outputFrame(destinationKey, names, makeDomains(output, names));

        XGBoostScoreTaskResult res = new XGBoostScoreTaskResult();
        if (output.nclasses() == 1) {
            Vec pred = preds.vec(0);
            if (computeMetrics) {
                res.mm = task._metricBuilder.makeModelMetrics(m, originalData, data, new Frame(pred));
            }
        } else if (output.nclasses() == 2) {
            Vec p1 = preds.vec(2);
            if (computeMetrics) {
                res.mm = task._metricBuilder.makeModelMetrics(m, originalData, data, new Frame(p1));
            }
        } else {
            if (computeMetrics) {
                Frame pp = new Frame(preds);
                pp.remove(0);
                Scope.enter();
                res.mm = task._metricBuilder.makeModelMetrics(m, originalData, data, pp);
                Scope.exit();
            }
        }
        res.preds = preds;
        assert "predict".equals(preds.name(0));

        return res;
    }


    private static byte[] outputTypes(XGBoostOutput output) {
        // Last output is the response, which eventually will be removed before returning the preds Frame but is needed to build metrics
        if(output.nclasses() == 1) {
            return new byte[]{T_NUM};
        } else if(output.nclasses() == 2) {
            return new byte[]{T_CAT, T_NUM, T_NUM};
        } else{
            byte[] types = new byte[output.nclasses() + 1];
            Arrays.fill(types, T_NUM);
            return types;
        }
    }

    private static String[][] makeDomains(XGBoostOutput output, String[] names) {
        if(output.nclasses() == 1) {
            return null;
        } else {
            String[][] domains = new String[names.length][];
            domains[0] = output.classNames();
            return domains;
        }
    }

    private XGBoostScoreTask(final XGBoostModelInfo sharedmodel,
                             final XGBoostOutput output,
                             final XGBoostModel.XGBoostParameters parms,
                             final BoosterParms boosterParms,
                             final boolean computeMetrics,
                             final int weightsChunkId,
                             final Model model) {
        _sharedmodel = sharedmodel;
        _output = output;
        _parms = parms;
        _boosterParms = boosterParms;
        _computeMetrics = computeMetrics;
        _weightsChunkId = weightsChunkId;
        _model = model;
        _threshold = Model.defaultThreshold(_output);
    }

    /**
     * Constructs a MetricBuilder for this XGBoostScoreTask based on parameters of response variable
     *
     * @param responseClassesNum Number of classes found in response variable
     * @param responseDomain     Specific domains in response variable
     * @return An instance of {@link hex.ModelMetrics.MetricBuilder} corresponding to given response variable type
     */
    private ModelMetrics.MetricBuilder createMetricsBuilder(final int responseClassesNum, final String[] responseDomain) {
        switch (responseClassesNum) {
            case 1:
                return new ModelMetricsRegression.MetricBuilderRegression();
            case 2:
                return new ModelMetricsBinomial.MetricBuilderBinomial(responseDomain);
            default:
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(responseClassesNum, responseDomain);
        }
    }

    private static class ScoreResult {
        float[][] _preds;
        float[] _labels;
    }

    private static ScoreResult scoreChunkExt(final XGBoostModelInfo sharedmodel, final XGBoostModel.XGBoostParameters parms,
                                             final BoosterParms boosterParms, final XGBoostOutput output,
                                             final Frame fr, final Chunk[] cs) {
        DMatrix data = null;
        Booster booster = null;
        try {
            Map rabitEnv = new HashMap<>();
            // Rabit has to be initialized as parts of booster.predict() are using Rabit
            // This might be fixed in future versions of XGBoost
            Rabit.init(rabitEnv);

            data = XGBoostUtils.convertChunksToDMatrix(
                    sharedmodel._dataInfoKey,
                    cs,
                    fr.find(parms._response_column),
                    -1, // not used for preds
                    fr.find(parms._fold_column),
                    output._sparse);

            // No local chunks for this frame
            if (data.rowNum() == 0) {
                return null;
            }

            // Initialize Booster
            booster = sharedmodel.deserializeBooster();
            booster.setParams(boosterParms.get());

            // Predict
            ScoreResult result = new ScoreResult();
            result._preds = booster.predict(data);
            result._labels = data.getLabel();
            return result;
        } catch (XGBoostError xgBoostError) {
            throw new IllegalStateException("Failed to score with XGBoost.", xgBoostError);
        } finally {
            BoosterHelper.dispose(booster, data);
            try {
                Rabit.shutdown();
            } catch (XGBoostError xgBoostError) {
                throw new IllegalStateException("Failed Rabit shutdown. A hanging RabitTracker task might be present on the driver node.", xgBoostError);
            }
        }
    }

    public static float[][] scoreChunk(final XGBoostModelInfo sharedmodel, final XGBoostModel.XGBoostParameters parms,
                                       final BoosterParms boosterParms, final XGBoostOutput output,
                                       final Frame fr, final Chunk[] cs) {
        ScoreResult r = scoreChunkExt(sharedmodel, parms, boosterParms, output, fr, cs);
        return r == null ? new float[0][] : r._preds;
    }

    @Override
    public void map(Chunk[] cs, NewChunk[] ncs) {
        _metricBuilder = _computeMetrics ? createMetricsBuilder(_output.nclasses(), _output.classNames()) : null;

        final ScoreResult r = scoreChunkExt(_sharedmodel, _parms, _boosterParms, _output, _fr, cs);

        if (r == null)
            return;

        if (_output.nclasses() == 1) {
            double[] currentPred = new double[1];
            float[] yact = new float[1];
            for (int j = 0; j < r._preds.length; ++j) {
                currentPred[0] = r._preds[j][0];
                if (_computeMetrics) {
                    yact[0] = r._labels[j];
                    double weight = _weightsChunkId != -1 ? cs[_weightsChunkId].atd(j) : 1; // If there is no chunk with weights, the weight is considered to be 1
                    _metricBuilder.perRow(currentPred, yact, weight, 0, _model);
                }
            }
            for (int i = 0; i < cs[0]._len; ++i) {
                ncs[0].addNum(r._preds[i][0]);
            }
        } else if (_output.nclasses() == 2) {
            double[] row = new double[3];
            float[] yact = new float[1];
            for (int i = 0; i < cs[0]._len; ++i) {
                final double p = r._preds[i][0];
                row[1] = 1 - p;
                row[2] = p;
                row[0] = hex.genmodel.GenModel.getPrediction(row, _output._priorClassDist, null, _threshold);

                ncs[0].addNum(row[0]);
                ncs[1].addNum(row[1]);
                ncs[2].addNum(row[2]);

                if (_computeMetrics) {
                    double weight = _weightsChunkId != -1 ? cs[_weightsChunkId].atd(i) : 1; // If there is no chunk with weights, the weight is considered to be 1
                    yact[0] = r._labels[i];
                    _metricBuilder.perRow(row, yact, weight, 0, _model);
                }
            }
        } else {
            float[] yact = new float[1];
            double[] row = MemoryManager.malloc8d(ncs.length);
            for (int i = 0; i < cs[0]._len; ++i) {
                for (int j = 1; j < row.length; ++j) {
                    double val = r._preds[i][j - 1];
                    ncs[j].addNum(val);
                    row[j] = val;
                }
                row[0] = hex.genmodel.GenModel.getPrediction(row, _output._priorClassDist, null, _threshold);
                ncs[0].addNum(row[0]);
                if (_computeMetrics) {
                    yact[0] = r._labels[i];
                    double weight = _weightsChunkId != -1 ? cs[_weightsChunkId].atd(i) : 1; // If there is no chunk with weights, the weight is considered to be 1
                    _metricBuilder.perRow(row, yact, weight, 0, _model);
                }
            }
        }
    }

    @Override
    public void reduce(XGBoostScoreTask mrt) {
        super.reduce(mrt);
        if (_computeMetrics) {
            _metricBuilder.reduce(mrt._metricBuilder);
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy