All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.PlattScalingHelper Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.glm.GLM;
import hex.glm.GLMModel;
import water.*;
import water.fvec.Frame;
import water.fvec.Vec;

import static hex.ModelCategory.Binomial;

public class PlattScalingHelper {
    
    public interface ModelBuilderWithCalibration, P extends Model.Parameters, O extends Model.Output> {
        ModelBuilder getModelBuilder();
        Frame getCalibrationFrame();
        void setCalibrationFrame(Frame f);
    }

    public interface ParamsWithCalibration {
        Model.Parameters getParams();
        Frame getCalibrationFrame();
        boolean calibrateModel();
    }

    public interface OutputWithCalibration {
        ModelCategory getModelCategory();
        GLMModel calibrationModel();
    }
    
    public static void initCalibration(ModelBuilderWithCalibration builder, ParamsWithCalibration parms, boolean expensive) {
        // Calibration
        Frame cf = parms.getCalibrationFrame();  // User-given calibration set
        if (cf != null) {
            if (! parms.calibrateModel())
                builder.getModelBuilder().warn("_calibration_frame", "Calibration frame was specified but calibration was not requested.");
            Frame adaptedCf = builder.getModelBuilder().init_adaptFrameToTrain(cf, "Calibration Frame", "_calibration_frame", expensive);
            builder.setCalibrationFrame(adaptedCf);
        }
        if (parms.calibrateModel()) {
            if (builder.getModelBuilder().nclasses() != 2)
                builder.getModelBuilder().error("_calibrate_model", "Model calibration is only currently supported for binomial models.");
            if (cf == null)
                builder.getModelBuilder().error("_calibrate_model", "Calibration frame was not specified.");
        }
    }
    
    public static , P extends Model.Parameters, O extends Model.Output> GLMModel buildCalibrationModel(
        ModelBuilderWithCalibration builder, ParamsWithCalibration parms, Job job, M model
    ) {
        Key calibInputKey = Key.make();
        try {
            Scope.enter();
            job.update(0, "Calibrating probabilities");
            Frame calib = builder.getCalibrationFrame();
            Vec calibWeights = parms.getParams()._weights_column != null ? calib.vec(parms.getParams()._weights_column) : null;
            Frame calibPredict = Scope.track(model.score(calib, null, job, false));
            Frame calibInput = new Frame(calibInputKey,
                new String[]{"p", "response"}, new Vec[]{calibPredict.vec(1), calib.vec(parms.getParams()._response_column)});
            if (calibWeights != null) {
                calibInput.add("weights", calibWeights);
            }
            DKV.put(calibInput);

            Key calibModelKey = Key.make();
            Job calibJob = new Job<>(calibModelKey, ModelBuilder.javaName("glm"), "Platt Scaling (GLM)");
            GLM calibBuilder = ModelBuilder.make("GLM", calibJob, calibModelKey);
            calibBuilder._parms._intercept = true;
            calibBuilder._parms._response_column = "response";
            calibBuilder._parms._train = calibInput._key;
            calibBuilder._parms._family = GLMModel.GLMParameters.Family.binomial;
            calibBuilder._parms._lambda = new double[] {0.0};
            if (calibWeights != null) {
                calibBuilder._parms._weights_column = "weights";
            }

            return calibBuilder.trainModel().get();
        } finally {
            Scope.exit();
            DKV.remove(calibInputKey);
        }
    }

    public static Frame postProcessPredictions(Frame predictFr, Job j, OutputWithCalibration output) {
        if (output.calibrationModel() == null) {
            return predictFr;
        } else if (output.getModelCategory() == Binomial) {
            Key jobKey = j != null ? j._key : null;
            Key calibInputKey = Key.make();
            Frame calibOutput = null;
            try {
                Frame calibInput = new Frame(calibInputKey, new String[]{"p"}, new Vec[]{predictFr.vec(1)});
                calibOutput = output.calibrationModel().score(calibInput);
                assert calibOutput._names.length == 3;
                Vec[] calPredictions = calibOutput.remove(new int[]{1, 2});
                // append calibrated probabilities to the prediction frame
                predictFr.write_lock(jobKey);
                for (int i = 0; i < calPredictions.length; i++)
                    predictFr.add("cal_" + predictFr.name(1 + i), calPredictions[i]);
                return predictFr.update(jobKey);
            } finally {
                predictFr.unlock(jobKey);
                DKV.remove(calibInputKey);
                if (calibOutput != null)
                    calibOutput.remove();
            }
        } else {
            throw H2O.unimpl("Calibration is only supported for binomial models");
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy