All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chen0040.glm.solvers.GlmAlgorithm Maven / Gradle / Ivy

There is a newer version: 1.0.6
Show newest version
package com.github.chen0040.glm.solvers;


import com.github.chen0040.glm.links.*;
import com.github.chen0040.glm.metrics.GlmStatistics;
import com.github.chen0040.ls.GradientEvaluationMethod;
import com.github.chen0040.ls.LocalSearch;
import com.github.chen0040.ls.TerminationEvaluationMethod;
import com.github.chen0040.ls.methods.cgs.NonlinearCGSearch;
import com.github.chen0040.glm.enums.GlmDistributionFamily;
import com.github.chen0040.glm.maths.MatrixOp;
import com.github.chen0040.ls.CostEvaluationMethod;
import com.github.chen0040.ls.solutions.NumericSolution;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;

import java.util.Random;


/**
 * Created by xschen on 14/8/15.
 */
/// 
/// Link: http://bwlewis.github.io/GLM/
/// GLM is generalized linear model for exponential family of distribution model b = g(a).
/// g(a) is the inverse link function.
///
/// Therefore, for a regressions characterized by inverse link function g(a), the regressions problem be formulated
/// as we are looking for model coefficient set x in/**/
/// g(A * x) = b + e
/// And the objective is to find x such for the following objective:
/// min (g(A * x) - b).transpose * W * (g(A * x) - b)
///
/// Suppose we assumes that e consist of uncorrelated naive variables with identical variance, then W = sigma^(-2) * I,
/// and The objective min (g(A * x) - b) * W * (g(A * x) - b).transpose is reduced to the OLS form:
/// min || g(A * x) - b ||^2
/// 
public class GlmAlgorithm implements Cloneable {
    private static Random random = new Random();
    protected LinkFunction linkFunc;
    protected int maxIters = 25;
    protected double mTol = 0.000001;
    protected double mRegularizationLambda = 0;
    protected GlmDistributionFamily mDistributionFamily;
    protected GlmStatistics mStats = new GlmStatistics();
    protected TerminationEvaluationMethod shouldTerminate = (state, iteration) -> {
        if (!state.improved() || state.improvement() < mTol) {
            return false;
        }
        return iteration >= maxIters;
    };
    protected double[] glmCoefficients;
    private LocalSearch solver;
    private double[][] A; //first column of A corresponds to x_0 = 1
    private double[] b;
    protected CostEvaluationMethod evaluateCost = new CostEvaluationMethod() {
        public double apply(double[] x, double[] lowerBounds, double[] upperBounds, Object constraint) {
            int m = b.length;
            int n = x.length;

            double[] c = MatrixOp.Multiply(A, x);
            double crossprod = 0;
            for (int i = 0; i < m; ++i) {
                double g = linkFunc.GetInvLink(c[i]);
                double gprime = linkFunc.GetInvLinkDerivative(c[i]);

                double d = g - b[i];
                crossprod += d * d;
            }

            double J = crossprod / (2 * m);

            for (int j = 1; j < n; ++j) {
                J += (mRegularizationLambda * x[j] * x[j]) / (2 * m);
            }

            return J;
        }
    };

    protected GradientEvaluationMethod evaluateGradient = new GradientEvaluationMethod() {
        public void apply(double[] x, double[] gradx, double[] lowerBounds, double[] upperBounds, Object constraint) {
            int m = b.length;
            int n = A[0].length;

            double[] c = MatrixOp.Multiply(A, x);

            double[] g = new double[m];
            double[] gprime = new double[m];
            for (int j = 0; j < m; ++j) {
                g[j] = linkFunc.GetInvLink(c[j]);
                gprime[j] = linkFunc.GetInvLinkDerivative(c[j]);
            }

            for (int i = 0; i < n; ++i) {
                double crossprod = 0;
                for (int j = 0; j < m; ++j) {
                    double cb = g[j] - b[j];
                    crossprod += cb * gprime[j] * A[j][i];
                }

                gradx[i] = crossprod / m;

                if (i != 0) {
                    gradx[i] += (mRegularizationLambda * x[i]) / m;
                }
            }
            /*
            GradientEstimation.CalcGradient(x, gradx, (x2, constraints2) =>
                {
                    return EvaluateCost(x2, lower_bounds, upper_bounds, constraints2);
                });*/
        }
    };


    public GlmAlgorithm(){

    }

    public GlmAlgorithm(GlmDistributionFamily distribution, LinkFunction linkFunc, double[][] A, double[] b, LocalSearch solver) {
        this.mDistributionFamily = distribution;
        this.solver = solver;
        this.linkFunc = linkFunc;
        this.A = A;
        this.b = b;
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }

    public GlmAlgorithm(GlmDistributionFamily distribution, double[][] A, double[] b, LocalSearch solver) {
        this.solver = solver;
        this.mDistributionFamily = distribution;
        this.linkFunc = getLinkFunction(distribution);
        this.A = A;
        this.b = b;
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }


    public GlmAlgorithm(GlmDistributionFamily distribution, double[][] A, double[] b) {
        this.solver = new NonlinearCGSearch();
        this.mDistributionFamily = distribution;
        this.linkFunc = getLinkFunction(distribution);
        this.A = A;
        this.b = b;
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }

    public GlmAlgorithm(GlmDistributionFamily distribution) {
        this.linkFunc = getLinkFunction(distribution);
        this.mDistributionFamily = distribution;
    }

    public GlmAlgorithm(GlmDistributionFamily distribution, double[][] A, double[] b, LocalSearch solver, int maxIters) {
        this.solver = solver;
        this.mDistributionFamily = distribution;
        this.linkFunc = getLinkFunction(distribution);

        int m = A.length;
        int n = A[0].length;

        this.A = new double[m][];
        for (int i = 0; i < m; i++) {
            this.A[i] = new double[n];
            for (int j = 0; j < n; j++) {
                this.A[i][j] = A[i][j];
            }
        }
        this.b = b;
        if (maxIters > 0) {
            this.maxIters = maxIters;
        }
        this.mStats = new GlmStatistics(m, b.length);
    }

    public static LinkFunction getLinkFunction(GlmDistributionFamily distribution) {
        switch (distribution) {
            case Bernouli:
            case Binomial:
            case Categorical:
            case Multinomial:
                return new LogitLinkFunction();
            case Exponential:
            case Gamma:
                return new InverseLinkFunction();
            case InverseGaussian:
                return new InverseSquaredLinkFunction();
            case Normal:
                return new IdentityLinkFunction();
            case Poisson:
                return new LogLinkFunction();
            default:
                throw new NotImplementedException();
        }
    }

    private LinkFunction clone(LinkFunction rhs){
        if(rhs==null) return null;
        AbstractLinkFunction rhs2 = (AbstractLinkFunction)rhs;
        return rhs2.makeCopy();
    }

    public void copy(GlmAlgorithm rhs){
        linkFunc = (rhs.linkFunc);
        maxIters = rhs.maxIters;
        mTol = rhs.mTol;
        mRegularizationLambda = rhs.mRegularizationLambda;
        mDistributionFamily = rhs.mDistributionFamily;
        mStats = rhs.mStats == null ? null : (GlmStatistics)rhs.mStats.clone();

        mDistributionFamily = rhs.mDistributionFamily;

        shouldTerminate = rhs.shouldTerminate;

        glmCoefficients = rhs.glmCoefficients == null ? null : rhs.glmCoefficients.clone();

        solver = rhs.solver== null ? null : rhs.solver.makeCopy();

        A = rhs.A == null ? null : rhs.A.clone(); //first column of A corresponds to x_0 = 1
        b = rhs.b == null ? null : rhs.b.clone();

        evaluateCost = rhs.evaluateCost;
        evaluateGradient = rhs.evaluateGradient;
    }

    public GlmAlgorithm makeCopy(){
        GlmAlgorithm clone = new GlmAlgorithm();
        clone.copy(this);
        return clone;
    }

    public double getTol() {
        return mTol;
    }

    public void setTol(double value) {
        mTol = value;
    }

    public GlmDistributionFamily getDistributionFamily() {
        return mDistributionFamily;
    }

    public double predict(double[] input_0) {

        if(glmCoefficients == null){
            return Double.NaN;
        }

        int n = input_0.length;

        double linear_predictor = 0;
        for (int i = 0; i < n; ++i) {
            linear_predictor += glmCoefficients[i] * input_0[i];
        }
        return linkFunc.GetInvLink(linear_predictor);
    }

    protected double getVariance(double g) {
        switch (mDistributionFamily) {
            case Bernouli:
            case Binomial:
            case Categorical:
            case Multinomial:
                return g * (1 - g);
            case Exponential:
            case Gamma:
                return g * g;
            case InverseGaussian:
                return g * g * g;
            case Normal:
                return 1;
            case Poisson:
                return g;
            default:
                throw new NotImplementedException();
        }
    }

    public int getMaxIters() {
        return maxIters;
    }

    public void setMaxIters(int value) {
        maxIters = value;
    }

    public double[] getCoefficients() {
        return glmCoefficients;
    }

    public GlmStatistics getStatistics() {
        return mStats;
    }

    public double[] solve() {
        int n = A[0].length;


        double[] x_0 = new double[n];
        for (int i = 0; i < n; ++i) {
            x_0[i] = random.nextDouble();
        }

        NumericSolution s = solver.minimize(x_0, evaluateCost, evaluateGradient, shouldTerminate, null);

        glmCoefficients = s.values();

        updateStatistics();

        return getCoefficients();
    }

    private void updateStatistics() {
        mStats = new GlmStatistics(A, b, glmCoefficients);
    }

}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy