All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.classification.PlattScaling Maven / Gradle / Ivy

There is a newer version: 4.3.0
Show newest version
/*******************************************************************************
 * Copyright (c) 2010 Haifeng Li
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/

package smile.classification;

import java.io.Serializable;
import static java.lang.Math.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Platt scaling or Platt calibration is a way of transforming the outputs
 * of a classification model into a probability distribution over classes.
 * The method was invented by John Platt in the context of support vector
 * machines, but can be applied to other classification models.
 * Platt scaling works by fitting a logistic regression model to
 * a classifier's scores.
 *
 * Platt suggested using the Levenberg–Marquardt algorithm to optimize
 * the parameters, but a Newton algorithm was later proposed that should
 * be more numerically stable, which is implemented in this class.
 *
 * @author Haifeng Li
 */
public class PlattScaling implements Serializable {
    private static final long serialVersionUID = 1L;

    /** The scalar parameters to be learned by the algorithm. */
    private double alpha;
    private double beta;
    private static final Logger logger = LoggerFactory.getLogger(PlattScaling.class);

    /**
     * Trains the Platt scaling.
     * @param scores The predicted scores.
     * @param y The training labels.
     */
    public PlattScaling(double[] scores, int[] y) {
        this(scores, y, 100);
    }

    /**
     * Trains the Platt scaling.
     * @param scores The predicted scores.
     * @param y The training labels.
     * @param maxIters The maximal number of iterations.
     */
    public PlattScaling(double[] scores, int[] y, int maxIters) {
        int l = scores.length;
        double prior1 = 0, prior0 = 0;
        int i;

        for (i = 0; i < l; i++) {
            if (y[i] > 0) prior1 += 1;
            else prior0 += 1;
        }

        double min_step = 1e-10;    // Minimal step taken in line search
        double sigma = 1e-12;    // For numerically strict PD of Hessian
        double eps = 1e-5;
        double hiTarget = (prior1 + 1.0) / (prior1 + 2.0);
        double loTarget = 1 / (prior0 + 2.0);
        double[] t = new double[l];

        // Initial Point and Initial Fun Value
        alpha = 0.0;
        beta = Math.log((prior0 + 1.0) / (prior1 + 1.0));
        double fval = 0.0;

        for (i = 0; i < l; i++) {
            if (y[i] > 0) t[i] = hiTarget;
            else t[i] = loTarget;

            double fApB = scores[i] * alpha + beta;

            if (fApB >= 0)
                fval += t[i] * fApB + log(1 + exp(-fApB));
            else
                fval += (t[i] - 1) * fApB + log(1 + exp(fApB));
        }

        int iter = 0;
        for (; iter < maxIters; iter++) {
            // Update Gradient and Hessian (use H' = H + sigma I)
            double h11 = sigma; // numerically ensures strict PD
            double h22 = sigma;
            double h21 = 0.0;
            double g1 = 0.0;
            double g2 = 0.0;
            for (i = 0; i < l; i++) {
                double fApB = scores[i] * alpha + beta;
                double p, q;
                if (fApB >= 0) {
                    p = exp(-fApB) / (1.0 + exp(-fApB));
                    q = 1.0 / (1.0 + exp(-fApB));
                } else {
                    p = 1.0 / (1.0 + exp(fApB));
                    q = exp(fApB) / (1.0 + exp(fApB));
                }
                double d2 = p * q;
                h11 += scores[i] * scores[i] * d2;
                h22 += d2;
                h21 += scores[i] * d2;
                double d1 = t[i] - p;
                g1 += scores[i] * d1;
                g2 += d1;
            }

            // Stopping Criteria
            if (abs(g1) < eps && abs(g2) < eps)
                break;

            // Finding Newton direction: -inv(H') * g
            double det = h11 * h22 - h21 * h21;
            double dA = -(h22 * g1 - h21 * g2) / det;
            double dB = -(-h21 * g1 + h11 * g2) / det;
            double gd = g1 * dA + g2 * dB;


            double stepsize = 1;        // Line Search
            while (stepsize >= min_step) {
                double newA = alpha + stepsize * dA;
                double newB = beta + stepsize * dB;

                // New function value
                double newf = 0.0;
                for (i = 0; i < l; i++) {
                    double fApB = scores[i] * newA + newB;
                    if (fApB >= 0)
                        newf += t[i] * fApB + log(1 + exp(-fApB));
                    else
                        newf += (t[i] - 1) * fApB + log(1 + exp(fApB));
                }
                // Check sufficient decrease
                if (newf < fval + 0.0001 * stepsize * gd) {
                    alpha = newA;
                    beta = newB;
                    fval = newf;
                    break;
                } else
                    stepsize = stepsize / 2.0;
            }

            if (stepsize < min_step) {
                logger.error("Line search fails.");
                break;
            }
        }

        if (iter >= maxIters) {
            logger.warn("Reaches maximal iterations");
        }
    }

    /**
     * Returns the posterior probability estimate P(y = 1 | x).
     *
     * @param y the binary classifier output score.
     * @return the estimated probability.
     */
    public double predict(double y) {
        double fApB = y * alpha + beta;

        if (fApB >= 0)
            return exp(-fApB) / (1.0 + exp(-fApB));
        else
            return 1.0 / (1 + exp(fApB));
    }

    /**
     * Estimates the multiclass probabilies.
     */
    public static void multiclass(int k, double[][] r, double[] p) {
        double[][] Q = new double[k][k];
        double[] Qp = new double[k];
        double pQp, eps = 0.005 / k;

        for (int t = 0; t < k; t++) {
            p[t] = 1.0 / k;  // Valid if k = 1
            Q[t][t] = 0;
            for (int j = 0; j < t; j++) {
                Q[t][t] += r[j][t] * r[j][t];
                Q[t][j] = Q[j][t];
            }
            for (int j = t + 1; j < k; j++) {
                Q[t][t] += r[j][t] * r[j][t];
                Q[t][j] = -r[j][t] * r[t][j];
            }
        }

        int iter = 0;
        int maxIter = max(100, k);
        for (; iter < maxIter; iter++) {
            // stopping condition, recalculate QP,pQP for numerical accuracy
            pQp = 0;
            for (int t = 0; t < k; t++) {
                Qp[t] = 0;
                for (int j = 0; j < k; j++)
                    Qp[t] += Q[t][j] * p[j];
                pQp += p[t] * Qp[t];
            }
            double max_error = 0;
            for (int t = 0; t < k; t++) {
                double error = abs(Qp[t] - pQp);
                if (error > max_error)
                    max_error = error;
            }
            if (max_error < eps) break;

            for (int t = 0; t < k; t++) {
                double diff = (-Qp[t] + pQp) / Q[t][t];
                p[t] += diff;
                pQp = (pQp + diff * (diff * Q[t][t] + 2 * Qp[t])) / (1 + diff) / (1 + diff);
                for (int j = 0; j < k; j++) {
                    Qp[j] = (Qp[j] + diff * Q[t][j]) / (1 + diff);
                    p[j] /= (1 + diff);
                }
            }
        }

        if (iter >= maxIter) {
            logger.warn("Reaches maximal iterations");
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy