All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chen0040.glm.solvers.GlmAlgorithmIrls Maven / Gradle / Ivy

There is a newer version: 1.0.6
Show newest version
package com.github.chen0040.glm.solvers;

import Jama.Matrix;
import com.github.chen0040.glm.enums.GlmDistributionFamily;
import com.github.chen0040.glm.links.LinkFunction;
import com.github.chen0040.glm.maths.Mean;
import com.github.chen0040.glm.maths.StdDev;
import com.github.chen0040.glm.maths.Variance;
import com.github.chen0040.glm.metrics.GlmStatistics;


/**
 * Created by xschen on 15/8/15.
 */
/// 
/// In regressions, we tried to find a set of model coefficient such for
/// A * x = b + e
///
/// A * x is known as the model matrix, b as the response vector, e is the error terms.
///
/// In OLS (Ordinary Least Square), we assumes that the variance-covariance matrix V(e) = sigma^2 * W, where:
///   W is a symmetric positive definite matrix, and is a diagonal matrix
///   sigma is the standard error of e
///
/// In OLS (Ordinary Least Square), the objective is to find x_bar such that e.transpose * W * e is minimized (Note that since W is positive definite, e * W * e is alway positive)
/// In other words, we are looking for x_bar such as (A * x_bar - b).transpose * W * (A * x_bar - b) is minimized
///
/// Let y = (A * x - b).transpose * W * (A * x - b)
/// Now differentiating y with respect to x, we have
/// dy / dx = A.transpose * W * (A * x - b) * 2
///
/// To find min y, set dy / dx = 0 at x = x_bar, we have
/// A.transpose * W * (A * x_bar - b) = 0
///
/// Transform this, we have
/// A.transpose * W * A * x_bar = A.transpose * W * b
///
/// Multiply both side by (A.transpose * W * A).inverse, we have
/// x_bar = (A.transpose * W * A).inverse * A.transpose * W * b
/// This is commonly solved using IRLS
/// The implementation of Glm based on iteratively re-weighted least squares estimation (IRLS)
///
/// Discussion:
///
/// As inversion is performed for A.transpose * W * A, since A.transpose * W * A may not be directly invertible, the IRLS in this implementation is potentially numerically
/// unstable and not generally advised, use the QR or SVD variant of IRLS instead
///
/// 
public class GlmAlgorithmIrls extends GlmAlgorithm {
    private final static double EPSILON = 1e-20;
    private Matrix A;
    private Matrix b;
    private Matrix At;

    @Override
    public void copy(GlmAlgorithm rhs){
        super.copy(rhs);

        GlmAlgorithmIrls rhs2 = (GlmAlgorithmIrls)rhs;
        A = rhs2.A == null ? null : (Matrix)rhs2.A.clone();
        b = rhs2.b == null ? null : (Matrix)rhs2.b.clone();
        At = rhs2.At == null ? null : (Matrix)rhs2.At.clone();
    }

    @Override
    public GlmAlgorithm makeCopy(){
        GlmAlgorithmIrls clone = new GlmAlgorithmIrls();
        clone.copy(this);

        return clone;
    }

    public GlmAlgorithmIrls(){
        super();
    }

    public GlmAlgorithmIrls(GlmDistributionFamily distribution, LinkFunction linkFunc, double[][] A, double[] b)

    {
        super(distribution, linkFunc, null, null, null);
        this.A = toMatrix(A);
        this.b = columnVector(b);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }

    public GlmAlgorithmIrls(GlmDistributionFamily distribution, double[][] A, double[] b)

    {
        super(distribution);
        this.A = toMatrix(A);
        this.b = columnVector(b);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }

    private static Matrix toMatrix(double[][] A) {
        int m = A.length;
        int n = A[0].length;

        Matrix Am = new Matrix(m, n);
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                Am.set(i, j, (float)A[i][j]);
            }
        }

        return Am;
    }

    private static Matrix columnVector(double[] b) {
        int m = b.length;
        Matrix B = new Matrix(m, 1);
        for (int i = 0; i < m; ++i) {
            B.set(i, 0, b[i]);
        }
        return B;
    }

    private static Matrix columnVector(int n) {
        return new Matrix(n, 1);
    }

    private static Matrix identity(int m) {
        Matrix A = new Matrix(m, m);
        for (int i = 0; i < m; ++i) {
            A.set(i, i, 1);
        }
        return A;
    }

    @Override
    public double[] solve() {
        int m = A.getRowDimension();
        int n = A.getColumnDimension();

        Matrix x = columnVector(n);

        Matrix W = null;
        Matrix AtWAInv = null;

        for (int j = 0; j < maxIters; ++j) {

            Matrix eta = A.times(x); // eta: m x 1
            Matrix z = columnVector(m); //z: m x 1
            double[] g = new double[m];
            double[] gprime = new double[m];

            for (int k = 0; k < m; ++k) {
                g[k] = linkFunc.GetInvLink(eta.get(k, 0));
                gprime[k] = linkFunc.GetInvLinkDerivative(eta.get(k, 0));

                z.set(k, 0, eta.get(k, 0) + (b.get(k, 0) - g[k]) / gprime[k]);
            }

            W = identity(m); // W: m x m
            for (int k = 0; k < m; ++k) {
                double g_variance = getVariance(g[k]);
                if (g_variance == 0) {
                    g_variance = EPSILON;
                }
                W.set(k, k, gprime[k] * gprime[k] / g_variance);
            }

            Matrix x_old = x;

            Matrix AtW = At.times(W); // AtW: n x m

            // solve x for At * W * A * x = At * W * z
            Matrix AtWA = AtW.times(A); // AtWA: n x n
            AtWAInv = AtWA.inverse();
            x = AtWAInv.times(AtW).times(z);

            if ((x.minus(x_old)).norm2() < mTol) {
                break;
            }
        }

        glmCoefficients = new double[n];
        for (int i = 0; i < n; ++i) {
            glmCoefficients[i] = x.get(i, 0);
        }

        updateStatistics(AtWAInv, W);

        return glmCoefficients;
    }

    /// 
    ///
    /// 
    /// variance-covariance matrix for the model coefficients
    private void updateStatistics(Matrix vcovmat, Matrix W) {
        int n = vcovmat.getRowDimension();
        int m = b.getRowDimension();

        double[] stdErr = mStats.getStandardErrors();
        double[][] VCovMatrix = mStats.getVCovMatrix();
        double[] residuals = mStats.getResiduals();

        for (int i = 0; i < n; ++i) {
            stdErr[i] = Math.sqrt(vcovmat.get(i, i));
            for (int j = 0; j < n; ++j) {
                VCovMatrix[i][j] = vcovmat.get(i, j);
            }
        }

        double[] outcomes = new double[m];
        for (int i = 0; i < m; i++) {
            double cross_prod = 0;
            for (int j = 0; j < n; ++j) {
                cross_prod += A.get(i, j) * glmCoefficients[j];
            }
            residuals[i] = b.get(i, 0) - linkFunc.GetInvLink(cross_prod);
            outcomes[i] = b.get(i, 0);
        }

        mStats.setResidualStdDev(StdDev.apply(mStats.getResiduals(), 0));
        mStats.setResponseMean(Mean.apply(outcomes));
        mStats.setResponseVariance(Variance.apply(outcomes, mStats.getResponseMean()));

        mStats.setR2(1 - mStats.getResidualStdDev() * mStats.getResidualStdDev() / mStats.getResponseVariance());
        mStats.setAdjustedR2(1 - mStats.getResidualStdDev() * mStats.getResidualStdDev() / mStats.getResponseVariance() * (n - 1) / (n - glmCoefficients.length - 1));
    }
}






© 2015 - 2025 Weber Informatics LLC | Privacy Policy