All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.dmlc.xgboost4j.java.XGBoostScoreTask Maven / Gradle / Ivy

There is a newer version: 3.22.0.3
Show newest version
package ml.dmlc.xgboost4j.java;

import com.google.common.collect.ObjectArrays;
import hex.*;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.xgboost.XGBoost;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostUtils;
import water.*;
import water.fvec.*;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import static water.fvec.Vec.T_CAT;
import static water.fvec.Vec.T_NUM;

public class XGBoostScoreTask extends MRTask {

    private final XGBoostModelInfo _sharedmodel;
    private final XGBoostOutput _output;
    private final XGBoostModel.XGBoostParameters _parms;

    private byte[] rawBooster;

    public static class XGBoostScoreTaskResult {
        public Frame preds;
        public ModelMetrics mm;
    }

    public static XGBoostScoreTaskResult runScoreTask(XGBoostModelInfo sharedmodel,
                                               XGBoostOutput output,
                                               XGBoostModel.XGBoostParameters parms,
                                               Booster booster,
                                               Key destinationKey,
                                               Frame data,
                                               boolean computeMetrics){
        XGBoostScoreTask task = new XGBoostScoreTask(sharedmodel,
                output,
                parms,
                booster).doAll(outputTypes(output), data);
        String[] names = ObjectArrays.concat(Model.makeScoringNames(output), new String[] {"label"}, String.class);
        Frame preds = task.outputFrame(destinationKey, names, makeDomains(output, names));

        XGBoostScoreTaskResult res = new XGBoostScoreTaskResult();

        Vec resp = preds.lastVec();
        preds.remove(preds.vecs().length - 1);
        if (output.nclasses() == 1) {
            Vec pred = preds.vec(0);
            if (computeMetrics) {
                res.mm = ModelMetricsRegression.make(pred, resp, DistributionFamily.gaussian);
            }
        } else if (output.nclasses() == 2) {
            Vec p1 = preds.vec(2);
            if (computeMetrics) {
                resp.setDomain(output.classNames());
                res.mm = ModelMetricsBinomial.make(p1, resp);
            }
        } else {
            if (computeMetrics) {
                resp.setDomain(output.classNames());
                Frame pp = new Frame(preds);
                pp.remove(0);
                Scope.enter();
                res.mm = ModelMetricsMultinomial.make(pp, resp, resp.toCategoricalVec().domain());
                Scope.exit();
            }
        }

        res.preds = preds;

        if (resp != null) {
            resp.remove();
        }

        if (computeMetrics) {
            assert res.mm != null;
        }
        assert "predict".equals(preds.name(0));

        return res;
    }

    private static byte[] outputTypes(XGBoostOutput output) {
        // Last output is the response, which eventually will be removed before returning the preds Frame but is needed to build metrics
        if(output.nclasses() == 1) {
            return new byte[]{T_NUM, T_NUM};
        } else if(output.nclasses() == 2) {
            return new byte[]{T_CAT, T_NUM, T_NUM, T_NUM};
        } else{
            byte[] types = new byte[output.nclasses() + 2];
            Arrays.fill(types, T_NUM);
            return types;
        }
    }

    private static String[][] makeDomains(XGBoostOutput output, String[] names) {
        if(output.nclasses() == 1) {
            return null;
        } else if(output.nclasses() == 2) {
            String[][] domains = new String[4][];
            domains[0] = new String[]{"N", "Y"};
            domains[3] = new String[]{"N", "Y"};
            return domains;
        } else{
            String[][] domains = new String[names.length][];
            domains[0] = output.classNames();
            return domains;
        }
    }

    private XGBoostScoreTask(XGBoostModelInfo sharedmodel,
                             XGBoostOutput output,
                             XGBoostModel.XGBoostParameters parms,
                             Booster booster) {
        _sharedmodel = sharedmodel;
        _output = output;
        _parms = parms;
        this.rawBooster = XGBoost.getRawArray(booster);
    }

    @Override
    public void map(Chunk[] cs, NewChunk[] ncs) {
        try {
            HashMap params = XGBoostModel.createParams(_parms, _output);

            Map rabitEnv = new HashMap<>();
            // Rabit has to be initialized as parts of booster.predict() are using Rabit
            // This might be fixed in future versions of XGBoost
            Rabit.init(rabitEnv);

            DMatrix data = XGBoostUtils.convertChunksToDMatrix(
                    _sharedmodel._dataInfoKey,
                    cs,
                    _fr.find(_parms._response_column),
                    -1, // not used for preds
                    _fr.find(_parms._fold_column),
                    _output._sparse);

            // No local chunks for this frame
            if (data.rowNum() == 0) {
                return;
            }

            Booster booster = null;
            try {
                booster = Booster.loadModel(new ByteArrayInputStream(rawBooster));
                booster.setParams(params);
            } catch (IOException e) {
                throw new IllegalStateException("Failed to load the booster.", e);
            }
            final float[][] preds = booster.predict(data);

            float[] labels = data.getLabel();

            float[] weights = data.getWeight();

            if (_output.nclasses() == 1) {
                double[] dpreds = new double[preds.length];
                for (int j = 0; j < dpreds.length; ++j)
                    dpreds[j] = preds[j][0];
                for (int i = 0; i < cs[0]._len; ++i) {
                    ncs[0].addNum(dpreds[i]);
                    ncs[1].addNum(labels[i]);
                }
            } else if (_output.nclasses() == 2) {
                double[] dpreds = new double[preds.length];

                for (int j = 0; j < dpreds.length; ++j)
                    dpreds[j] = preds[j][0];

                if (weights.length > 0)
                    for (int j = 0; j < dpreds.length; ++j)
                        assert weights[j] == 1.0;

                for (int i = 0; i < cs[0]._len; ++i) {
                    double p = dpreds[i];
                    ncs[1].addNum(1.0d - p);
                    ncs[2].addNum(p);
                    double[] row = new double[]{0, 1 - p, p};
                    double predLab = hex.genmodel.GenModel.getPrediction(row, _output._priorClassDist, null, Model.defaultThreshold(_output));
                    ncs[0].addNum(predLab);

                    ncs[3].addNum(labels[i]);
                }
            } else {
                for (int i = 0; i < cs[0]._len; ++i) {
                    double[] row = new double[ncs.length - 1];
                    for (int j = 1; j < row.length; ++j) {
                        double val = preds[i][j - 1];
                        ncs[j].addNum(val);
                        row[j] = val;
                    }
                    ncs[0].addNum(hex.genmodel.GenModel.getPrediction(row, _output._priorClassDist, null, Model.defaultThreshold(_output)));
                    ncs[ncs.length - 1].addNum(labels[i]);
                }
            }
        } catch (XGBoostError xgBoostError) {
            throw new IllegalStateException("Failed to score with XGBoost.", xgBoostError);
        } finally {
            try {
                Rabit.shutdown();
            } catch (XGBoostError xgBoostError) {
                throw new IllegalStateException("Failed Rabit shutdown. A hanging RabitTracker task might be present on the driver node.", xgBoostError);
            }
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy