All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.pitt.csb.mgm.Mgm Maven / Gradle / Ivy

The newest version!
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard        //
// Scheines, Joseph Ramsey, and Clark Glymour.                               //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////

package edu.pitt.csb.mgm;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IGraphSearch;
import edu.cmu.tetrad.sem.GeneralizedSemIm;
import edu.cmu.tetrad.sem.GeneralizedSemPm;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.StatUtils;
import org.apache.commons.math3.util.FastMath;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;

/**
 * Implementation of Lee and Hastie's (2012) pseudolikelihood method for learning Mixed Gaussian-Categorical Graphical
 * Models Created by ajsedgewick on 7/15/15.
 *
 * @author josephramsey
 * @version $Id: $Id
 */
public class Mgm extends ConvexProximal implements IGraphSearch {

    /**
     * factory2D.
     */
    private final DoubleFactory2D factory2D = DoubleFactory2D.dense;

    /**
     * factory1D.
     */
    private final DoubleFactory1D factory1D = DoubleFactory1D.dense;

    /**
     * Continuous Data
     */
    private final DoubleMatrix2D xDat;

    /**
     * Discrete Data coded as integers, no IntMatrix2D apparently...
     */
    private final DoubleMatrix2D yDat;

    /**
     * lambda.
     */
    private final DoubleMatrix1D lambda;

    /**
     * alg.
     */
    private final Algebra alg = new Algebra();

    /**
     * Levels of Discrete variables
     */
    private final int[] l;

    /**
     * p.
     */
    int p;

    /**
     * q.
     */
    int q;

    /**
     * n.
     */
    int n;

    /**
     * variables.
     */
    private List variables;

    /**
     * initVariables.
     */
    private List initVariables;
    /**
     * Discrete Data coded as dummy variables
     */
    private DoubleMatrix2D dDat;

    /**
     * private long elapsedTime.
     */
    private long elapsedTime;

    /**
     * private int lsum.
     */
    private int lsum;

    /**
     * private int[] lcumsum.
     */
    private int[] lcumsum;

    /**
     * parameter weights
     */
    private DoubleMatrix1D weights;

    /**
     * private MGMParams params.
     */
    private MGMParams params;

    /**
     * 

Constructor for Mgm.

* * @param x a {@link cern.colt.matrix.DoubleMatrix2D} object * @param y a {@link cern.colt.matrix.DoubleMatrix2D} object * @param variables a {@link java.util.List} object * @param l an array of {@link int} objects * @param lambda an array of {@link double} objects */ public Mgm(DoubleMatrix2D x, DoubleMatrix2D y, List variables, int[] l, double[] lambda) { if (l.length != y.columns()) throw new IllegalArgumentException("length of l doesn't match number of variables in Y"); if (y.rows() != x.rows()) throw new IllegalArgumentException("different number of samples for x and y"); //lambda should have 3 values corresponding to cc, cd, and dd if (lambda.length != 3) throw new IllegalArgumentException("Lambda should have three values for cc, cd, and dd edges respectively"); this.xDat = x; this.yDat = y; this.l = l; this.p = x.columns(); this.q = y.columns(); this.n = x.rows(); this.variables = variables; this.lambda = this.factory1D.make(lambda); fixData(); initParameters(); calcWeights(); makeDummy(); } /** *

Constructor for Mgm.

* * @param ds a {@link edu.cmu.tetrad.data.DataSet} object * @param lambda an array of {@link double} objects */ public Mgm(DataSet ds, double[] lambda) { this.variables = ds.getVariables(); // Notify the user that you need at least one continuous and one discrete variable to run MGM boolean hasContinuous = false; boolean hasDiscrete = false; for (Node node : this.variables) { if (node instanceof ContinuousVariable) { hasContinuous = true; } if (node instanceof DiscreteVariable) { hasDiscrete = true; } } if (!hasContinuous || !hasDiscrete) { throw new IllegalArgumentException("Please give data with at least one discrete and one continuous variable to run MGM."); } DataSet dsCont = MixedUtils.getContinousData(ds); DataSet dsDisc = MixedUtils.getDiscreteData(ds); this.xDat = this.factory2D.make(dsCont.getDoubleData().toArray()); this.yDat = this.factory2D.make(dsDisc.getDoubleData().toArray()); this.l = MixedUtils.getDiscLevels(ds); this.p = this.xDat.columns(); this.q = this.yDat.columns(); this.n = this.xDat.rows(); //the variables are now ordered continuous first then discrete this.variables = new ArrayList<>(); this.variables.addAll(dsCont.getVariables()); this.variables.addAll(dsDisc.getVariables()); this.initVariables = ds.getVariables(); this.lambda = this.factory1D.make(lambda); //Data is checked for 0 or 1 indexing and fore missing levels fixData(); initParameters(); calcWeights(); makeDummy(); } //create column major vector from matrix (i.e. concatenate columns) /** *

flatten.

* * @param m a {@link cern.colt.matrix.DoubleMatrix2D} object * @return a {@link cern.colt.matrix.DoubleMatrix1D} object */ public static DoubleMatrix1D flatten(DoubleMatrix2D m) { DoubleMatrix1D[] colArray = new DoubleMatrix1D[m.columns()]; for (int i = 0; i < m.columns(); i++) { colArray[i] = m.viewColumn(i); } return DoubleFactory1D.dense.make(colArray); } /* * PRIVATE UTILS */ //Utils //sum rows together if marg == 1 and cols together if marg == 2 //Using row-major speeds up marg=1 5x private static DoubleMatrix1D margSum(DoubleMatrix2D mat, int marg) { int n = 0; DoubleMatrix1D vec = null; DoubleFactory1D fac = DoubleFactory1D.dense; if (marg == 1) { n = mat.columns(); vec = fac.make(n); for (int j = 0; j < mat.rows(); j++) { if (Thread.currentThread().isInterrupted()) { break; } for (int i = 0; i < n; i++) { vec.setQuick(i, vec.getQuick(i) + mat.getQuick(j, i)); } } } else if (marg == 2) { n = mat.rows(); vec = fac.make(n); for (int i = 0; i < n; i++) { if (Thread.currentThread().isInterrupted()) { break; } vec.setQuick(i, mat.viewRow(i).zSum()); } } return vec; } //zeros out everthing below di-th diagonal /** *

upperTri.

* * @param mat a {@link cern.colt.matrix.DoubleMatrix2D} object * @param di a int * @return a {@link cern.colt.matrix.DoubleMatrix2D} object */ public static DoubleMatrix2D upperTri(DoubleMatrix2D mat, int di) { for (int i = FastMath.max(-di + 1, 0); i < mat.rows(); i++) { if (Thread.currentThread().isInterrupted()) { break; } for (int j = 0; j < FastMath.min(i + di, mat.rows()); j++) { if (Thread.currentThread().isInterrupted()) { break; } mat.set(i, j, 0); } } return mat; } //zeros out everthing above di-th diagonal private static DoubleMatrix2D lowerTri(DoubleMatrix2D mat, int di) { for (int i = 0; i < mat.rows() - FastMath.max(di + 1, 0); i++) { if (Thread.currentThread().isInterrupted()) { break; } for (int j = FastMath.max(i + di + 1, 0); j < mat.rows(); j++) { if (Thread.currentThread().isInterrupted()) { break; } mat.set(i, j, 0); } } return mat; } // should move somewhere else... private static double norm2(DoubleMatrix2D mat) { //return FastMath.sqrt(mat.copy().assign(Functions.pow(2)).zSum()); Algebra al = new Algebra(); //norm found by svd so we need rows >= cols if (mat.rows() < mat.columns()) { return al.norm2(al.transpose(mat)); } return al.norm2(mat); } private static double norm2(DoubleMatrix1D vec) { //return FastMath.sqrt(vec.copy().assign(Functions.pow(2)).zSum()); return FastMath.sqrt(new Algebra().norm2(vec)); } private static void runTests1() { try { final String path = "/Users/ajsedgewick/tetrad_master/tetrad/tetrad-lib/src/main/java/edu/pitt/csb/mgm/test_data"; System.out.println(path); DoubleMatrix2D xIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim(path, "med_test_C.txt").getDoubleData().toArray()); DoubleMatrix2D yIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim(path, "med_test_D.txt").getDoubleData().toArray()); int[] L = new int[24]; Node[] vars = new Node[48]; for (int i = 0; i < 24; i++) { L[i] = 2; vars[i] = new ContinuousVariable("X" + i); vars[i + 24] = new DiscreteVariable("Y" + i); } final double lam = .2; Mgm model = new Mgm(xIn, yIn, new ArrayList<>(Arrays.asList(vars)), L, new double[]{lam, lam, lam}); Mgm model2 = new Mgm(xIn, yIn, new ArrayList<>(Arrays.asList(vars)), L, new double[]{lam, lam, lam}); System.out.println("Weights: " + Arrays.toString(model.weights.toArray())); DoubleMatrix2D test = xIn.copy(); DoubleMatrix2D test2 = xIn.copy(); long t = MillisecondTimes.timeMillis(); for (int i = 0; i < 50000; i++) { test2 = xIn.copy(); test.assign(test2); } System.out.println("assign Time: " + (MillisecondTimes.timeMillis() - t)); t = MillisecondTimes.timeMillis(); double[][] xArr = xIn.toArray(); for (int i = 0; i < 50000; i++) { if (Thread.currentThread().isInterrupted()) { break; } //test = DoubleFactory2D.dense.make(xArr); test2 = xIn.copy(); test = test2; } System.out.println("equals Time: " + (MillisecondTimes.timeMillis() - t)); System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D())); System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D())); t = MillisecondTimes.timeMillis(); model.learnEdges(700); //model.learn(1e-7, 700); System.out.println("Orig Time: " + (MillisecondTimes.timeMillis() - t)); System.out.println("nll: " + model.smoothValue(model.params.toMatrix1D())); System.out.println("reg term: " + model.nonSmoothValue(model.params.toMatrix1D())); System.out.println("params:\n" + model.params); System.out.println("adjMat:\n" + model.adjMatFromMGM()); } catch (IOException ex) { ex.printStackTrace(); } } /** * test non penalty use cases */ private static void runTests2() { Graph g = GraphUtils.convert("X1-->X2,X3-->X2,X4-->X5"); //simple graph pm im gen example HashMap nd = new HashMap<>(); nd.put("X1", 0); nd.put("X2", 0); nd.put("X3", 4); nd.put("X4", 4); nd.put("X5", 4); g = MixedUtils.makeMixedGraph(g, nd); GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(g, "Split(-1.5,-.5,.5,1.5)"); System.out.println(pm); GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm); System.out.println(im); final int samps = 1000; DataSet ds = im.simulateDataFisher(samps); ds = MixedUtils.makeMixedData(ds, nd); //System.out.println(ds); final double lambda = 0; Mgm model = new Mgm(ds, new double[]{lambda, lambda, lambda}); System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D())); System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D())); model.learn(1e-8, 1000); System.out.println("Learned nll: " + model.smoothValue(model.params.toMatrix1D())); System.out.println("Learned reg term: " + model.nonSmoothValue(model.params.toMatrix1D())); System.out.println("params:\n" + model.params); System.out.println("adjMat:\n" + model.adjMatFromMGM()); } /** *

main.

* * @param args an array of {@link java.lang.String} objects */ public static void main(String[] args) { Mgm.runTests1(); } /** *

Setter for the field params.

* * @param newParams a {@link edu.pitt.csb.mgm.Mgm.MGMParams} object */ public void setParams(MGMParams newParams) { this.params = newParams; } //init all parameters to zeros except for betad which is set to 1s private void initParameters() { this.lcumsum = new int[this.l.length + 1]; this.lcumsum[0] = 0; for (int i = 0; i < this.l.length; i++) { this.lcumsum[i + 1] = this.lcumsum[i] + this.l[i]; } this.lsum = this.lcumsum[this.l.length]; //LH init to zeros, maybe should be random init? DoubleMatrix2D beta = this.factory2D.make(this.xDat.columns(), this.xDat.columns()); //continuous-continuous DoubleMatrix1D betad = this.factory1D.make(this.xDat.columns(), 1.0); //cont squared node pot DoubleMatrix2D theta = this.factory2D.make(this.lsum, this.xDat.columns()); //continuous-discrete DoubleMatrix2D phi = this.factory2D.make(this.lsum, this.lsum); //discrete-discrete DoubleMatrix1D alpha1 = this.factory1D.make(this.xDat.columns()); //cont linear node pot DoubleMatrix1D alpha2 = this.factory1D.make(this.lsum); //disc node potbeta = this.params = new MGMParams(beta, betad, theta, phi, alpha1, alpha2); //separate lambda for each type of edge, [cc, cd, dd] //lambda = factory1D.make(3); } // avoid underflow in log(sum(exp(x))) calculation private double logsumexp(DoubleMatrix1D x) { DoubleMatrix1D myX = x.copy(); double maxX = StatUtils.max(myX.toArray()); return FastMath.log(myX.assign(Functions.minus(maxX)).assign(Functions.exp).zSum()) + maxX; } //calculate parameter weights as in Lee and Hastie private void calcWeights() { this.weights = this.factory1D.make(this.p + this.q); for (int i = 0; i < this.p; i++) { this.weights.set(i, StatUtils.sd(this.xDat.viewColumn(i).toArray())); } for (int j = 0; j < this.q; j++) { double curWeight = 0; for (int k = 0; k < this.l[j]; k++) { double curp = this.yDat.viewColumn(j).copy().assign(Functions.equals(k + 1)).zSum() / (double) this.n; curWeight += curp * (1 - curp); } this.weights.set(this.p + j, FastMath.sqrt(curWeight)); } } /** * Convert discrete data (in yDat) to a matrix of dummy variables (stored in dDat) */ private void makeDummy() { this.dDat = this.factory2D.make(this.n, this.lsum); for (int i = 0; i < this.q; i++) { for (int j = 0; j < this.l[i]; j++) { DoubleMatrix1D curCol = this.yDat.viewColumn(i).copy().assign(Functions.equals(j + 1)); if (curCol.zSum() == 0) throw new IllegalArgumentException("Discrete data is missing a level: variable " + i + " level " + j); this.dDat.viewColumn(this.lcumsum[i] + j).assign(curCol); } } } /** * checks if yDat is zero indexed and converts to 1 index. zscores x */ private void fixData() { double ymin = StatUtils.min(Mgm.flatten(this.yDat).toArray()); if (ymin < 0 || ymin > 1) throw new IllegalArgumentException("Discrete data must be either zero or one indexed. Found min index: " + ymin); if (ymin == 0) { this.yDat.assign(Functions.plus(1.0)); } //z-score columns of X for (int i = 0; i < this.p; i++) { this.xDat.viewColumn(i).assign(StatUtils.standardizeData(this.xDat.viewColumn(i).toArray())); } } /** * Calculate the smooth value of the given input vector. * * @param parIn The input vector. * @return The smooth value. */ public double smoothValue(DoubleMatrix1D parIn) { //work with copy MGMParams par = new MGMParams(parIn, this.p, this.lsum); for (int i = 0; i < par.betad.size(); i++) { if (par.betad.get(i) < 0) return Double.POSITIVE_INFINITY; } //double nll = 0; //int n = xDat.rows(); //beta=beta+beta'; //phi=phi+phi'; Mgm.upperTri(par.beta, 1); par.beta.assign(this.alg.transpose(par.beta), Functions.plus); for (int i = 0; i < this.q; i++) { par.phi.viewPart(this.lcumsum[i], this.lcumsum[i], this.l[i], this.l[i]).assign(0); } // ensure mats are upper triangular Mgm.upperTri(par.phi, 0); par.phi.assign(this.alg.transpose(par.phi), Functions.plus); //Xbeta=X*beta*diag(1./betad); DoubleMatrix2D divBetaD = this.factory2D.diagonal(this.factory1D.make(this.p, 1.0).assign(par.betad, Functions.div)); DoubleMatrix2D xBeta = this.alg.mult(this.xDat, this.alg.mult(par.beta, divBetaD)); //Dtheta=D*theta*diag(1./betad); DoubleMatrix2D dTheta = this.alg.mult(this.alg.mult(this.dDat, par.theta), divBetaD); // Squared loss //sqloss=-n/2*sum(log(betad))+... //.5*norm((X-e*alpha1'-Xbeta-Dtheta)*diag(sqrt(betad)),'fro')^2; DoubleMatrix2D tempLoss = this.factory2D.make(this.n, this.xDat.columns()); //wxprod=X*(theta')+D*phi+e*alpha2'; DoubleMatrix2D wxProd = this.alg.mult(this.xDat, this.alg.transpose(par.theta)); wxProd.assign(this.alg.mult(this.dDat, par.phi), Functions.plus); for (int i = 0; i < this.n; i++) { for (int j = 0; j < this.xDat.columns(); j++) { tempLoss.set(i, j, this.xDat.get(i, j) - par.alpha1.get(j) - xBeta.get(i, j) - dTheta.get(i, j)); } for (int j = 0; j < this.dDat.columns(); j++) { wxProd.set(i, j, wxProd.get(i, j) + par.alpha2.get(j)); } } double sqloss = -this.n / 2.0 * par.betad.copy().assign(Functions.log).zSum() + .5 * FastMath.pow(this.alg.normF(this.alg.mult(tempLoss, this.factory2D.diagonal(par.betad.copy().assign(Functions.sqrt)))), 2); // categorical loss /*catloss=0; wxprod=X*(theta')+D*phi+e*alpha2'; %this is n by Ltot for r=1:q wxtemp=wxprod(:,Lsum(r)+1:Lsum(r)+L(r)); denom= logsumexp(wxtemp,2); %this is n by 1 catloss=catloss-sum(wxtemp(sub2ind([n L(r)],(1:n)',Y(:,r)))); catloss=catloss+sum(denom); end */ double catloss = 0; for (int i = 0; i < this.yDat.columns(); i++) { DoubleMatrix2D wxTemp = wxProd.viewPart(0, this.lcumsum[i], this.n, this.l[i]); for (int k = 0; k < this.n; k++) { DoubleMatrix1D curRow = wxTemp.viewRow(k); catloss -= curRow.get((int) this.yDat.get(k, i) - 1); catloss += logsumexp(curRow); } } return (sqloss + catloss) / ((double) this.n); } /** * Smooth method calculates the smooth loss and gradient given input parameters. * * @param parIn input Vector * @param gradOutVec gradient of g(X) * @return the smooth loss */ public double smooth(DoubleMatrix1D parIn, DoubleMatrix1D gradOutVec) { //work with copy MGMParams par = new MGMParams(parIn, this.p, this.lsum); MGMParams gradOut = new MGMParams(); for (int i = 0; i < par.betad.size(); i++) { if (par.betad.get(i) < 0) return Double.POSITIVE_INFINITY; } //beta=beta-diag(diag(beta)); //for r=1:q // phi(Lsum(r)+1:Lsum(r+1),Lsum(r)+1:Lsum(r+1))=0; //end //beta=triu(beta); phi=triu(phi); //beta=beta+beta'; //phi=phi+phi'; Mgm.upperTri(par.beta, 1); par.beta.assign(this.alg.transpose(par.beta), Functions.plus); for (int i = 0; i < this.q; i++) { par.phi.viewPart(this.lcumsum[i], this.lcumsum[i], this.l[i], this.l[i]).assign(0); } //ensure matrix is upper triangular Mgm.upperTri(par.phi, 0); par.phi.assign(this.alg.transpose(par.phi), Functions.plus); //Xbeta=X*beta*diag(1./betad); DoubleMatrix2D divBetaD = this.factory2D.diagonal(this.factory1D.make(this.p, 1.0).assign(par.betad, Functions.div)); DoubleMatrix2D xBeta = this.alg.mult(this.xDat, this.alg.mult(par.beta, divBetaD)); //Dtheta=D*theta*diag(1./betad); DoubleMatrix2D dTheta = this.alg.mult(this.alg.mult(this.dDat, par.theta), divBetaD); // Squared loss //tempLoss = (X-e*alpha1'-Xbeta-Dtheta) = -res (in gradient code) DoubleMatrix2D tempLoss = this.factory2D.make(this.n, this.xDat.columns()); //wxprod=X*(theta')+D*phi+e*alpha2'; DoubleMatrix2D wxProd = this.alg.mult(this.xDat, this.alg.transpose(par.theta)); wxProd.assign(this.alg.mult(this.dDat, par.phi), Functions.plus); for (int i = 0; i < this.n; i++) { if (Thread.currentThread().isInterrupted()) { break; } for (int j = 0; j < this.xDat.columns(); j++) { tempLoss.set(i, j, this.xDat.get(i, j) - par.alpha1.get(j) - xBeta.get(i, j) - dTheta.get(i, j)); } for (int j = 0; j < this.dDat.columns(); j++) { wxProd.set(i, j, wxProd.get(i, j) + par.alpha2.get(j)); } } //sqloss=-n/2*sum(log(betad))+... //.5*norm((X-e*alpha1'-Xbeta-Dtheta)*diag(sqrt(betad)),'fro')^2; double sqloss = -this.n / 2.0 * par.betad.copy().assign(Functions.log).zSum() + .5 * FastMath.pow(this.alg.normF(this.alg.mult(tempLoss, this.factory2D.diagonal(par.betad.copy().assign(Functions.sqrt)))), 2); //ok now tempLoss = res tempLoss.assign(Functions.mult(-1)); //gradbeta=X'*(res); gradOut.beta = this.alg.mult(this.alg.transpose(this.xDat), tempLoss); //gradbeta=gradbeta-diag(diag(gradbeta)); % zero out diag //gradbeta=tril(gradbeta)'+triu(gradbeta); DoubleMatrix2D lowerBeta = this.alg.transpose(Mgm.lowerTri(gradOut.beta.copy(), -1)); Mgm.upperTri(gradOut.beta, 1).assign(lowerBeta, Functions.plus); //gradalpha1=diag(betad)*sum(res,1)'; gradOut.alpha1 = this.alg.mult(this.factory2D.diagonal(par.betad), Mgm.margSum(tempLoss, 1)); //gradtheta=D'*(res); gradOut.theta = this.alg.mult(this.alg.transpose(this.dDat), tempLoss); // categorical loss /*catloss=0; wxprod=X*(theta')+D*phi+e*alpha2'; %this is n by Ltot for r=1:q wxtemp=wxprod(:,Lsum(r)+1:Lsum(r)+L(r)); denom= logsumexp(wxtemp,2); %this is n by 1 catloss=catloss-sum(wxtemp(sub2ind([n L(r)],(1:n)',Y(:,r)))); catloss=catloss+sum(denom); end */ double catloss = 0; for (int i = 0; i < this.yDat.columns(); i++) { if (Thread.currentThread().isInterrupted()) { break; } DoubleMatrix2D wxTemp = wxProd.viewPart(0, this.lcumsum[i], this.n, this.l[i]); //need to copy init values for calculating nll DoubleMatrix2D wxTemp0 = wxTemp.copy(); // does this need to be done in log space?? wxTemp.assign(Functions.exp); DoubleMatrix1D invDenom = this.factory1D.make(this.n, 1.0).assign(Mgm.margSum(wxTemp, 2), Functions.div); wxTemp.assign(this.alg.mult(this.factory2D.diagonal(invDenom), wxTemp)); for (int k = 0; k < this.n; k++) { if (Thread.currentThread().isInterrupted()) { break; } DoubleMatrix1D curRow = wxTemp.viewRow(k); DoubleMatrix1D curRow0 = wxTemp0.viewRow(k); catloss -= curRow0.get((int) this.yDat.get(k, i) - 1); catloss += logsumexp(curRow0); //wxtemp(sub2ind(size(wxtemp),(1:n)',Y(:,r)))=wxtemp(sub2ind(size(wxtemp),(1:n)',Y(:,r)))-1; curRow.set((int) this.yDat.get(k, i) - 1, curRow.get((int) this.yDat.get(k, i) - 1) - 1); } } //gradalpha2=sum(wxprod,1)'; gradOut.alpha2 = Mgm.margSum(wxProd, 1); //gradw=X'*wxprod; DoubleMatrix2D gradW = this.alg.mult(this.alg.transpose(this.xDat), wxProd); //gradtheta=gradtheta+gradw'; gradOut.theta.assign(this.alg.transpose(gradW), Functions.plus); //gradphi=D'*wxprod; gradOut.phi = this.alg.mult(this.alg.transpose(this.dDat), wxProd); //zero out gradphi diagonal //for r=1:q //gradphi(Lsum(r)+1:Lsum(r+1),Lsum(r)+1:Lsum(r+1))=0; //end for (int i = 0; i < this.q; i++) { gradOut.phi.viewPart(this.lcumsum[i], this.lcumsum[i], this.l[i], this.l[i]).assign(0); } //gradphi=tril(gradphi)'+triu(gradphi); DoubleMatrix2D lowerPhi = this.alg.transpose(Mgm.lowerTri(gradOut.phi.copy(), 0)); Mgm.upperTri(gradOut.phi, 0).assign(lowerPhi, Functions.plus); /* for s=1:p gradbetad(s)=-n/(2*betad(s))+1/2*norm(res(:,s))^2-res(:,s)'*(Xbeta(:,s)+Dtheta(:,s)); end */ gradOut.betad = this.factory1D.make(this.xDat.columns()); for (int i = 0; i < this.p; i++) { gradOut.betad.set(i, -this.n / (2.0 * par.betad.get(i)) + this.alg.norm2(tempLoss.viewColumn(i)) / 2.0 - this.alg.mult(tempLoss.viewColumn(i), xBeta.viewColumn(i).copy().assign(dTheta.viewColumn(i), Functions.plus))); } gradOut.alpha1.assign(Functions.div(this.n)); gradOut.alpha2.assign(Functions.div(this.n)); gradOut.betad.assign(Functions.div(this.n)); gradOut.beta.assign(Functions.div(this.n)); gradOut.theta.assign(Functions.div(this.n)); gradOut.phi.assign(Functions.div(this.n)); gradOutVec.assign(gradOut.toMatrix1D()); return (sqloss + catloss) / ((double) this.n); } /** * Calculates the non-smooth value for the given input vector. * * @param parIn the input vector * @return the non-smooth value */ public double nonSmoothValue(DoubleMatrix1D parIn) { //DoubleMatrix1D tlam = lambda.copy().assign(Functions.mult(t)); //Dimension checked in constructor //par is a copy so we can update it MGMParams par = new MGMParams(parIn, this.p, this.lsum); //penbeta = t(1).*(wv(1:p)'*wv(1:p)); //betascale=zeros(size(beta)); //betascale=max(0,1-penbeta./abs(beta)); DoubleMatrix2D weightMat = this.alg.multOuter(this.weights, this.weights, null); //int p = xDat.columns(); //weight beta //betaw = (wv(1:p)'*wv(1:p)).*abs(beta); //betanorms=sum(betaw(:)); DoubleMatrix2D betaWeight = weightMat.viewPart(0, 0, this.p, this.p); DoubleMatrix2D absBeta = par.beta.copy().assign(Functions.abs); double betaNorms = absBeta.assign(betaWeight, Functions.mult).zSum(); /* thetanorms=0; for s=1:p for j=1:q tempvec=theta(Lsums(j)+1:Lsums(j+1),s); thetanorms=thetanorms+(wv(s)*wv(p+j))*norm(tempvec); end end */ double thetaNorms = 0; for (int i = 0; i < this.p; i++) { if (Thread.currentThread().isInterrupted()) { break; } for (int j = 0; j < this.lcumsum.length - 1; j++) { if (Thread.currentThread().isInterrupted()) { break; } DoubleMatrix1D tempVec = par.theta.viewColumn(i).viewPart(this.lcumsum[j], this.l[j]); thetaNorms += weightMat.get(i, this.p + j) * FastMath.sqrt(this.alg.norm2(tempVec)); } } /* for r=1:q for j=1:q if r




© 2015 - 2024 Weber Informatics LLC | Privacy Policy