All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.GlobalQuantilesCalc Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree;

import hex.quantile.Quantile;
import hex.quantile.QuantileModel;
import water.DKV;
import water.Job;
import water.Key;
import water.fvec.Frame;
import water.util.ArrayUtils;

/**
 * Helper class for calculating split points used when histogram type is "QuantilesGlobal"
 */
class GlobalQuantilesCalc {

    /**
     * Calculates split points for histogram type = QuantilesGlobal.
     * 
     * @param trainFr (adapted) training frame
     * @param weightsColumn name of column containing observation weights (optional)
     * @param priorSplitPoints optional pre-existing split points for some columns
     * @param N number of bins
     * @param nbins_top_level number of top-level bins
     * @return array of split points for each feature column of the input training frame
     */
    static double[][] splitPoints(Frame trainFr, String weightsColumn, 
                                  double[][] priorSplitPoints, final int N, int nbins_top_level) {
        final int[] frToTrain = new int[trainFr.numCols()];
        final Frame fr = collectColumnsForQuantile(trainFr, weightsColumn, priorSplitPoints, frToTrain);
        final double[][] splitPoints = new double[trainFr.numCols()][];
        if (fr.numCols() == 0 || weightsColumn != null && fr.numCols() == 1 && weightsColumn.equals(fr.name(0))) {
            return splitPoints;
        }
        Key tmpFrameKey = Key.make();
        DKV.put(tmpFrameKey, fr);
        QuantileModel qm = null;
        try {
            QuantileModel.QuantileParameters p = new QuantileModel.QuantileParameters();
            p._train = tmpFrameKey;
            p._weights_column = weightsColumn;
            p._combine_method = QuantileModel.CombineMethod.INTERPOLATE;
            p._probs = new double[N];
            for (int i = 0; i < N; ++i) //compute quantiles such that they span from (inclusive) min...maxEx (exclusive)
                p._probs[i] = i * 1. / N;
            Job job = new Quantile(p).trainModel();
            qm = job.get();
            job.remove();
            double[][] origQuantiles = qm._output._quantiles;
            //pad the quantiles until we have nbins_top_level bins
            for (int q = 0; q < origQuantiles.length; q++) {
                if (origQuantiles[q].length <= 1) {
                    continue;
                }
                final int i = frToTrain[q];
                // make the quantiles split points unique
                splitPoints[i] = ArrayUtils.makeUniqueAndLimitToRange(origQuantiles[q], fr.vec(q).min(), fr.vec(q).max());
                if (splitPoints[i].length <= 1) //not enough split points left - fall back to regular binning
                    splitPoints[i] = null;
                else
                    splitPoints[i] = ArrayUtils.padUniformly(splitPoints[i], nbins_top_level);
                assert splitPoints[i] == null || splitPoints[i].length > 1;
            }
            return splitPoints;
        } finally {
            DKV.remove(tmpFrameKey);
            if (qm != null) {
                qm.delete();
            }
        }
    }

    static Frame collectColumnsForQuantile(Frame trainFr, String weightsColumn, double[][] priorSplitPoints,
                                           int[] frToTrainMap) {
        final Frame fr = new Frame();
        final int weightsIdx = trainFr.find(weightsColumn);
        for (int i = 0; i < trainFr.numCols(); ++i) {
            if (i != weightsIdx) {
                if (priorSplitPoints != null && priorSplitPoints[i] != null) {
                    continue;
                }
                if (!trainFr.vec(i).isNumeric() || trainFr.vec(i).isCategorical() ||
                        trainFr.vec(i).isBinary() || trainFr.vec(i).isConst()) {
                    continue;
                }
            }
            frToTrainMap[fr.numCols()] = i;
            fr.add(trainFr.name(i), trainFr.vec(i));
        }
        return fr;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy