edu.pitt.csb.mgm.Mgm Maven / Gradle / Ivy
The newest version!
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.pitt.csb.mgm;
import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IGraphSearch;
import edu.cmu.tetrad.sem.GeneralizedSemIm;
import edu.cmu.tetrad.sem.GeneralizedSemPm;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.StatUtils;
import org.apache.commons.math3.util.FastMath;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
/**
* Implementation of Lee and Hastie's (2012) pseudolikelihood method for learning Mixed Gaussian-Categorical Graphical
* Models Created by ajsedgewick on 7/15/15.
*
* @author josephramsey
* @version $Id: $Id
*/
public class Mgm extends ConvexProximal implements IGraphSearch {
/**
* factory2D.
*/
private final DoubleFactory2D factory2D = DoubleFactory2D.dense;
/**
* factory1D.
*/
private final DoubleFactory1D factory1D = DoubleFactory1D.dense;
/**
* Continuous Data
*/
private final DoubleMatrix2D xDat;
/**
* Discrete Data coded as integers, no IntMatrix2D apparently...
*/
private final DoubleMatrix2D yDat;
/**
* lambda.
*/
private final DoubleMatrix1D lambda;
/**
* alg.
*/
private final Algebra alg = new Algebra();
/**
* Levels of Discrete variables
*/
private final int[] l;
/**
* p.
*/
int p;
/**
* q.
*/
int q;
/**
* n.
*/
int n;
/**
* variables.
*/
private List variables;
/**
* initVariables.
*/
private List initVariables;
/**
* Discrete Data coded as dummy variables
*/
private DoubleMatrix2D dDat;
/**
* private long elapsedTime.
*/
private long elapsedTime;
/**
* private int lsum.
*/
private int lsum;
/**
* private int[] lcumsum.
*/
private int[] lcumsum;
/**
* parameter weights
*/
private DoubleMatrix1D weights;
/**
* private MGMParams params.
*/
private MGMParams params;
/**
* Constructor for Mgm.
*
* @param x a {@link cern.colt.matrix.DoubleMatrix2D} object
* @param y a {@link cern.colt.matrix.DoubleMatrix2D} object
* @param variables a {@link java.util.List} object
* @param l an array of {@link int} objects
* @param lambda an array of {@link double} objects
*/
public Mgm(DoubleMatrix2D x, DoubleMatrix2D y, List variables, int[] l, double[] lambda) {
if (l.length != y.columns())
throw new IllegalArgumentException("length of l doesn't match number of variables in Y");
if (y.rows() != x.rows())
throw new IllegalArgumentException("different number of samples for x and y");
//lambda should have 3 values corresponding to cc, cd, and dd
if (lambda.length != 3)
throw new IllegalArgumentException("Lambda should have three values for cc, cd, and dd edges respectively");
this.xDat = x;
this.yDat = y;
this.l = l;
this.p = x.columns();
this.q = y.columns();
this.n = x.rows();
this.variables = variables;
this.lambda = this.factory1D.make(lambda);
fixData();
initParameters();
calcWeights();
makeDummy();
}
/**
* Constructor for Mgm.
*
* @param ds a {@link edu.cmu.tetrad.data.DataSet} object
* @param lambda an array of {@link double} objects
*/
public Mgm(DataSet ds, double[] lambda) {
this.variables = ds.getVariables();
// Notify the user that you need at least one continuous and one discrete variable to run MGM
boolean hasContinuous = false;
boolean hasDiscrete = false;
for (Node node : this.variables) {
if (node instanceof ContinuousVariable) {
hasContinuous = true;
}
if (node instanceof DiscreteVariable) {
hasDiscrete = true;
}
}
if (!hasContinuous || !hasDiscrete) {
throw new IllegalArgumentException("Please give data with at least one discrete and one continuous variable to run MGM.");
}
DataSet dsCont = MixedUtils.getContinousData(ds);
DataSet dsDisc = MixedUtils.getDiscreteData(ds);
this.xDat = this.factory2D.make(dsCont.getDoubleData().toArray());
this.yDat = this.factory2D.make(dsDisc.getDoubleData().toArray());
this.l = MixedUtils.getDiscLevels(ds);
this.p = this.xDat.columns();
this.q = this.yDat.columns();
this.n = this.xDat.rows();
//the variables are now ordered continuous first then discrete
this.variables = new ArrayList<>();
this.variables.addAll(dsCont.getVariables());
this.variables.addAll(dsDisc.getVariables());
this.initVariables = ds.getVariables();
this.lambda = this.factory1D.make(lambda);
//Data is checked for 0 or 1 indexing and fore missing levels
fixData();
initParameters();
calcWeights();
makeDummy();
}
//create column major vector from matrix (i.e. concatenate columns)
/**
* flatten.
*
* @param m a {@link cern.colt.matrix.DoubleMatrix2D} object
* @return a {@link cern.colt.matrix.DoubleMatrix1D} object
*/
public static DoubleMatrix1D flatten(DoubleMatrix2D m) {
DoubleMatrix1D[] colArray = new DoubleMatrix1D[m.columns()];
for (int i = 0; i < m.columns(); i++) {
colArray[i] = m.viewColumn(i);
}
return DoubleFactory1D.dense.make(colArray);
}
/*
* PRIVATE UTILS
*/
//Utils
//sum rows together if marg == 1 and cols together if marg == 2
//Using row-major speeds up marg=1 5x
private static DoubleMatrix1D margSum(DoubleMatrix2D mat, int marg) {
int n = 0;
DoubleMatrix1D vec = null;
DoubleFactory1D fac = DoubleFactory1D.dense;
if (marg == 1) {
n = mat.columns();
vec = fac.make(n);
for (int j = 0; j < mat.rows(); j++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
for (int i = 0; i < n; i++) {
vec.setQuick(i, vec.getQuick(i) + mat.getQuick(j, i));
}
}
} else if (marg == 2) {
n = mat.rows();
vec = fac.make(n);
for (int i = 0; i < n; i++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
vec.setQuick(i, mat.viewRow(i).zSum());
}
}
return vec;
}
//zeros out everthing below di-th diagonal
/**
* upperTri.
*
* @param mat a {@link cern.colt.matrix.DoubleMatrix2D} object
* @param di a int
* @return a {@link cern.colt.matrix.DoubleMatrix2D} object
*/
public static DoubleMatrix2D upperTri(DoubleMatrix2D mat, int di) {
for (int i = FastMath.max(-di + 1, 0); i < mat.rows(); i++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
for (int j = 0; j < FastMath.min(i + di, mat.rows()); j++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
mat.set(i, j, 0);
}
}
return mat;
}
//zeros out everthing above di-th diagonal
private static DoubleMatrix2D lowerTri(DoubleMatrix2D mat, int di) {
for (int i = 0; i < mat.rows() - FastMath.max(di + 1, 0); i++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
for (int j = FastMath.max(i + di + 1, 0); j < mat.rows(); j++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
mat.set(i, j, 0);
}
}
return mat;
}
// should move somewhere else...
private static double norm2(DoubleMatrix2D mat) {
//return FastMath.sqrt(mat.copy().assign(Functions.pow(2)).zSum());
Algebra al = new Algebra();
//norm found by svd so we need rows >= cols
if (mat.rows() < mat.columns()) {
return al.norm2(al.transpose(mat));
}
return al.norm2(mat);
}
private static double norm2(DoubleMatrix1D vec) {
//return FastMath.sqrt(vec.copy().assign(Functions.pow(2)).zSum());
return FastMath.sqrt(new Algebra().norm2(vec));
}
private static void runTests1() {
try {
final String path = "/Users/ajsedgewick/tetrad_master/tetrad/tetrad-lib/src/main/java/edu/pitt/csb/mgm/test_data";
System.out.println(path);
DoubleMatrix2D xIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim(path, "med_test_C.txt").getDoubleData().toArray());
DoubleMatrix2D yIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim(path, "med_test_D.txt").getDoubleData().toArray());
int[] L = new int[24];
Node[] vars = new Node[48];
for (int i = 0; i < 24; i++) {
L[i] = 2;
vars[i] = new ContinuousVariable("X" + i);
vars[i + 24] = new DiscreteVariable("Y" + i);
}
final double lam = .2;
Mgm model = new Mgm(xIn, yIn, new ArrayList<>(Arrays.asList(vars)), L, new double[]{lam, lam, lam});
Mgm model2 = new Mgm(xIn, yIn, new ArrayList<>(Arrays.asList(vars)), L, new double[]{lam, lam, lam});
System.out.println("Weights: " + Arrays.toString(model.weights.toArray()));
DoubleMatrix2D test = xIn.copy();
DoubleMatrix2D test2 = xIn.copy();
long t = MillisecondTimes.timeMillis();
for (int i = 0; i < 50000; i++) {
test2 = xIn.copy();
test.assign(test2);
}
System.out.println("assign Time: " + (MillisecondTimes.timeMillis() - t));
t = MillisecondTimes.timeMillis();
double[][] xArr = xIn.toArray();
for (int i = 0; i < 50000; i++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
//test = DoubleFactory2D.dense.make(xArr);
test2 = xIn.copy();
test = test2;
}
System.out.println("equals Time: " + (MillisecondTimes.timeMillis() - t));
System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
t = MillisecondTimes.timeMillis();
model.learnEdges(700);
//model.learn(1e-7, 700);
System.out.println("Orig Time: " + (MillisecondTimes.timeMillis() - t));
System.out.println("nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
System.out.println("params:\n" + model.params);
System.out.println("adjMat:\n" + model.adjMatFromMGM());
} catch (IOException ex) {
ex.printStackTrace();
}
}
/**
* test non penalty use cases
*/
private static void runTests2() {
Graph g = GraphUtils.convert("X1-->X2,X3-->X2,X4-->X5");
//simple graph pm im gen example
HashMap nd = new HashMap<>();
nd.put("X1", 0);
nd.put("X2", 0);
nd.put("X3", 4);
nd.put("X4", 4);
nd.put("X5", 4);
g = MixedUtils.makeMixedGraph(g, nd);
GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(g, "Split(-1.5,-.5,.5,1.5)");
System.out.println(pm);
GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
System.out.println(im);
final int samps = 1000;
DataSet ds = im.simulateDataFisher(samps);
ds = MixedUtils.makeMixedData(ds, nd);
//System.out.println(ds);
final double lambda = 0;
Mgm model = new Mgm(ds, new double[]{lambda, lambda, lambda});
System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
model.learn(1e-8, 1000);
System.out.println("Learned nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("Learned reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
System.out.println("params:\n" + model.params);
System.out.println("adjMat:\n" + model.adjMatFromMGM());
}
/**
* main.
*
* @param args an array of {@link java.lang.String} objects
*/
public static void main(String[] args) {
Mgm.runTests1();
}
/**
* Setter for the field params
.
*
* @param newParams a {@link edu.pitt.csb.mgm.Mgm.MGMParams} object
*/
public void setParams(MGMParams newParams) {
this.params = newParams;
}
//init all parameters to zeros except for betad which is set to 1s
private void initParameters() {
this.lcumsum = new int[this.l.length + 1];
this.lcumsum[0] = 0;
for (int i = 0; i < this.l.length; i++) {
this.lcumsum[i + 1] = this.lcumsum[i] + this.l[i];
}
this.lsum = this.lcumsum[this.l.length];
//LH init to zeros, maybe should be random init?
DoubleMatrix2D beta = this.factory2D.make(this.xDat.columns(), this.xDat.columns()); //continuous-continuous
DoubleMatrix1D betad = this.factory1D.make(this.xDat.columns(), 1.0); //cont squared node pot
DoubleMatrix2D theta = this.factory2D.make(this.lsum, this.xDat.columns());
//continuous-discrete
DoubleMatrix2D phi = this.factory2D.make(this.lsum, this.lsum); //discrete-discrete
DoubleMatrix1D alpha1 = this.factory1D.make(this.xDat.columns()); //cont linear node pot
DoubleMatrix1D alpha2 = this.factory1D.make(this.lsum); //disc node potbeta =
this.params = new MGMParams(beta, betad, theta, phi, alpha1, alpha2);
//separate lambda for each type of edge, [cc, cd, dd]
//lambda = factory1D.make(3);
}
// avoid underflow in log(sum(exp(x))) calculation
private double logsumexp(DoubleMatrix1D x) {
DoubleMatrix1D myX = x.copy();
double maxX = StatUtils.max(myX.toArray());
return FastMath.log(myX.assign(Functions.minus(maxX)).assign(Functions.exp).zSum()) + maxX;
}
//calculate parameter weights as in Lee and Hastie
private void calcWeights() {
this.weights = this.factory1D.make(this.p + this.q);
for (int i = 0; i < this.p; i++) {
this.weights.set(i, StatUtils.sd(this.xDat.viewColumn(i).toArray()));
}
for (int j = 0; j < this.q; j++) {
double curWeight = 0;
for (int k = 0; k < this.l[j]; k++) {
double curp = this.yDat.viewColumn(j).copy().assign(Functions.equals(k + 1)).zSum() / (double) this.n;
curWeight += curp * (1 - curp);
}
this.weights.set(this.p + j, FastMath.sqrt(curWeight));
}
}
/**
* Convert discrete data (in yDat) to a matrix of dummy variables (stored in dDat)
*/
private void makeDummy() {
this.dDat = this.factory2D.make(this.n, this.lsum);
for (int i = 0; i < this.q; i++) {
for (int j = 0; j < this.l[i]; j++) {
DoubleMatrix1D curCol = this.yDat.viewColumn(i).copy().assign(Functions.equals(j + 1));
if (curCol.zSum() == 0)
throw new IllegalArgumentException("Discrete data is missing a level: variable " + i + " level " + j);
this.dDat.viewColumn(this.lcumsum[i] + j).assign(curCol);
}
}
}
/**
* checks if yDat is zero indexed and converts to 1 index. zscores x
*/
private void fixData() {
double ymin = StatUtils.min(Mgm.flatten(this.yDat).toArray());
if (ymin < 0 || ymin > 1)
throw new IllegalArgumentException("Discrete data must be either zero or one indexed. Found min index: " + ymin);
if (ymin == 0) {
this.yDat.assign(Functions.plus(1.0));
}
//z-score columns of X
for (int i = 0; i < this.p; i++) {
this.xDat.viewColumn(i).assign(StatUtils.standardizeData(this.xDat.viewColumn(i).toArray()));
}
}
/**
* Calculate the smooth value of the given input vector.
*
* @param parIn The input vector.
* @return The smooth value.
*/
public double smoothValue(DoubleMatrix1D parIn) {
//work with copy
MGMParams par = new MGMParams(parIn, this.p, this.lsum);
for (int i = 0; i < par.betad.size(); i++) {
if (par.betad.get(i) < 0)
return Double.POSITIVE_INFINITY;
}
//double nll = 0;
//int n = xDat.rows();
//beta=beta+beta';
//phi=phi+phi';
Mgm.upperTri(par.beta, 1);
par.beta.assign(this.alg.transpose(par.beta), Functions.plus);
for (int i = 0; i < this.q; i++) {
par.phi.viewPart(this.lcumsum[i], this.lcumsum[i], this.l[i], this.l[i]).assign(0);
}
// ensure mats are upper triangular
Mgm.upperTri(par.phi, 0);
par.phi.assign(this.alg.transpose(par.phi), Functions.plus);
//Xbeta=X*beta*diag(1./betad);
DoubleMatrix2D divBetaD = this.factory2D.diagonal(this.factory1D.make(this.p, 1.0).assign(par.betad, Functions.div));
DoubleMatrix2D xBeta = this.alg.mult(this.xDat, this.alg.mult(par.beta, divBetaD));
//Dtheta=D*theta*diag(1./betad);
DoubleMatrix2D dTheta = this.alg.mult(this.alg.mult(this.dDat, par.theta), divBetaD);
// Squared loss
//sqloss=-n/2*sum(log(betad))+...
//.5*norm((X-e*alpha1'-Xbeta-Dtheta)*diag(sqrt(betad)),'fro')^2;
DoubleMatrix2D tempLoss = this.factory2D.make(this.n, this.xDat.columns());
//wxprod=X*(theta')+D*phi+e*alpha2';
DoubleMatrix2D wxProd = this.alg.mult(this.xDat, this.alg.transpose(par.theta));
wxProd.assign(this.alg.mult(this.dDat, par.phi), Functions.plus);
for (int i = 0; i < this.n; i++) {
for (int j = 0; j < this.xDat.columns(); j++) {
tempLoss.set(i, j, this.xDat.get(i, j) - par.alpha1.get(j) - xBeta.get(i, j) - dTheta.get(i, j));
}
for (int j = 0; j < this.dDat.columns(); j++) {
wxProd.set(i, j, wxProd.get(i, j) + par.alpha2.get(j));
}
}
double sqloss = -this.n / 2.0 * par.betad.copy().assign(Functions.log).zSum() +
.5 * FastMath.pow(this.alg.normF(this.alg.mult(tempLoss, this.factory2D.diagonal(par.betad.copy().assign(Functions.sqrt)))), 2);
// categorical loss
/*catloss=0;
wxprod=X*(theta')+D*phi+e*alpha2'; %this is n by Ltot
for r=1:q
wxtemp=wxprod(:,Lsum(r)+1:Lsum(r)+L(r));
denom= logsumexp(wxtemp,2); %this is n by 1
catloss=catloss-sum(wxtemp(sub2ind([n L(r)],(1:n)',Y(:,r))));
catloss=catloss+sum(denom);
end
*/
double catloss = 0;
for (int i = 0; i < this.yDat.columns(); i++) {
DoubleMatrix2D wxTemp = wxProd.viewPart(0, this.lcumsum[i], this.n, this.l[i]);
for (int k = 0; k < this.n; k++) {
DoubleMatrix1D curRow = wxTemp.viewRow(k);
catloss -= curRow.get((int) this.yDat.get(k, i) - 1);
catloss += logsumexp(curRow);
}
}
return (sqloss + catloss) / ((double) this.n);
}
/**
* Smooth method calculates the smooth loss and gradient given input parameters.
*
* @param parIn input Vector
* @param gradOutVec gradient of g(X)
* @return the smooth loss
*/
public double smooth(DoubleMatrix1D parIn, DoubleMatrix1D gradOutVec) {
//work with copy
MGMParams par = new MGMParams(parIn, this.p, this.lsum);
MGMParams gradOut = new MGMParams();
for (int i = 0; i < par.betad.size(); i++) {
if (par.betad.get(i) < 0)
return Double.POSITIVE_INFINITY;
}
//beta=beta-diag(diag(beta));
//for r=1:q
// phi(Lsum(r)+1:Lsum(r+1),Lsum(r)+1:Lsum(r+1))=0;
//end
//beta=triu(beta); phi=triu(phi);
//beta=beta+beta';
//phi=phi+phi';
Mgm.upperTri(par.beta, 1);
par.beta.assign(this.alg.transpose(par.beta), Functions.plus);
for (int i = 0; i < this.q; i++) {
par.phi.viewPart(this.lcumsum[i], this.lcumsum[i], this.l[i], this.l[i]).assign(0);
}
//ensure matrix is upper triangular
Mgm.upperTri(par.phi, 0);
par.phi.assign(this.alg.transpose(par.phi), Functions.plus);
//Xbeta=X*beta*diag(1./betad);
DoubleMatrix2D divBetaD = this.factory2D.diagonal(this.factory1D.make(this.p, 1.0).assign(par.betad, Functions.div));
DoubleMatrix2D xBeta = this.alg.mult(this.xDat, this.alg.mult(par.beta, divBetaD));
//Dtheta=D*theta*diag(1./betad);
DoubleMatrix2D dTheta = this.alg.mult(this.alg.mult(this.dDat, par.theta), divBetaD);
// Squared loss
//tempLoss = (X-e*alpha1'-Xbeta-Dtheta) = -res (in gradient code)
DoubleMatrix2D tempLoss = this.factory2D.make(this.n, this.xDat.columns());
//wxprod=X*(theta')+D*phi+e*alpha2';
DoubleMatrix2D wxProd = this.alg.mult(this.xDat, this.alg.transpose(par.theta));
wxProd.assign(this.alg.mult(this.dDat, par.phi), Functions.plus);
for (int i = 0; i < this.n; i++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
for (int j = 0; j < this.xDat.columns(); j++) {
tempLoss.set(i, j, this.xDat.get(i, j) - par.alpha1.get(j) - xBeta.get(i, j) - dTheta.get(i, j));
}
for (int j = 0; j < this.dDat.columns(); j++) {
wxProd.set(i, j, wxProd.get(i, j) + par.alpha2.get(j));
}
}
//sqloss=-n/2*sum(log(betad))+...
//.5*norm((X-e*alpha1'-Xbeta-Dtheta)*diag(sqrt(betad)),'fro')^2;
double sqloss = -this.n / 2.0 * par.betad.copy().assign(Functions.log).zSum() +
.5 * FastMath.pow(this.alg.normF(this.alg.mult(tempLoss, this.factory2D.diagonal(par.betad.copy().assign(Functions.sqrt)))), 2);
//ok now tempLoss = res
tempLoss.assign(Functions.mult(-1));
//gradbeta=X'*(res);
gradOut.beta = this.alg.mult(this.alg.transpose(this.xDat), tempLoss);
//gradbeta=gradbeta-diag(diag(gradbeta)); % zero out diag
//gradbeta=tril(gradbeta)'+triu(gradbeta);
DoubleMatrix2D lowerBeta = this.alg.transpose(Mgm.lowerTri(gradOut.beta.copy(), -1));
Mgm.upperTri(gradOut.beta, 1).assign(lowerBeta, Functions.plus);
//gradalpha1=diag(betad)*sum(res,1)';
gradOut.alpha1 = this.alg.mult(this.factory2D.diagonal(par.betad), Mgm.margSum(tempLoss, 1));
//gradtheta=D'*(res);
gradOut.theta = this.alg.mult(this.alg.transpose(this.dDat), tempLoss);
// categorical loss
/*catloss=0;
wxprod=X*(theta')+D*phi+e*alpha2'; %this is n by Ltot
for r=1:q
wxtemp=wxprod(:,Lsum(r)+1:Lsum(r)+L(r));
denom= logsumexp(wxtemp,2); %this is n by 1
catloss=catloss-sum(wxtemp(sub2ind([n L(r)],(1:n)',Y(:,r))));
catloss=catloss+sum(denom);
end
*/
double catloss = 0;
for (int i = 0; i < this.yDat.columns(); i++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
DoubleMatrix2D wxTemp = wxProd.viewPart(0, this.lcumsum[i], this.n, this.l[i]);
//need to copy init values for calculating nll
DoubleMatrix2D wxTemp0 = wxTemp.copy();
// does this need to be done in log space??
wxTemp.assign(Functions.exp);
DoubleMatrix1D invDenom = this.factory1D.make(this.n, 1.0).assign(Mgm.margSum(wxTemp, 2), Functions.div);
wxTemp.assign(this.alg.mult(this.factory2D.diagonal(invDenom), wxTemp));
for (int k = 0; k < this.n; k++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
DoubleMatrix1D curRow = wxTemp.viewRow(k);
DoubleMatrix1D curRow0 = wxTemp0.viewRow(k);
catloss -= curRow0.get((int) this.yDat.get(k, i) - 1);
catloss += logsumexp(curRow0);
//wxtemp(sub2ind(size(wxtemp),(1:n)',Y(:,r)))=wxtemp(sub2ind(size(wxtemp),(1:n)',Y(:,r)))-1;
curRow.set((int) this.yDat.get(k, i) - 1, curRow.get((int) this.yDat.get(k, i) - 1) - 1);
}
}
//gradalpha2=sum(wxprod,1)';
gradOut.alpha2 = Mgm.margSum(wxProd, 1);
//gradw=X'*wxprod;
DoubleMatrix2D gradW = this.alg.mult(this.alg.transpose(this.xDat), wxProd);
//gradtheta=gradtheta+gradw';
gradOut.theta.assign(this.alg.transpose(gradW), Functions.plus);
//gradphi=D'*wxprod;
gradOut.phi = this.alg.mult(this.alg.transpose(this.dDat), wxProd);
//zero out gradphi diagonal
//for r=1:q
//gradphi(Lsum(r)+1:Lsum(r+1),Lsum(r)+1:Lsum(r+1))=0;
//end
for (int i = 0; i < this.q; i++) {
gradOut.phi.viewPart(this.lcumsum[i], this.lcumsum[i], this.l[i], this.l[i]).assign(0);
}
//gradphi=tril(gradphi)'+triu(gradphi);
DoubleMatrix2D lowerPhi = this.alg.transpose(Mgm.lowerTri(gradOut.phi.copy(), 0));
Mgm.upperTri(gradOut.phi, 0).assign(lowerPhi, Functions.plus);
/*
for s=1:p
gradbetad(s)=-n/(2*betad(s))+1/2*norm(res(:,s))^2-res(:,s)'*(Xbeta(:,s)+Dtheta(:,s));
end
*/
gradOut.betad = this.factory1D.make(this.xDat.columns());
for (int i = 0; i < this.p; i++) {
gradOut.betad.set(i, -this.n / (2.0 * par.betad.get(i)) + this.alg.norm2(tempLoss.viewColumn(i)) / 2.0 -
this.alg.mult(tempLoss.viewColumn(i), xBeta.viewColumn(i).copy().assign(dTheta.viewColumn(i), Functions.plus)));
}
gradOut.alpha1.assign(Functions.div(this.n));
gradOut.alpha2.assign(Functions.div(this.n));
gradOut.betad.assign(Functions.div(this.n));
gradOut.beta.assign(Functions.div(this.n));
gradOut.theta.assign(Functions.div(this.n));
gradOut.phi.assign(Functions.div(this.n));
gradOutVec.assign(gradOut.toMatrix1D());
return (sqloss + catloss) / ((double) this.n);
}
/**
* Calculates the non-smooth value for the given input vector.
*
* @param parIn the input vector
* @return the non-smooth value
*/
public double nonSmoothValue(DoubleMatrix1D parIn) {
//DoubleMatrix1D tlam = lambda.copy().assign(Functions.mult(t));
//Dimension checked in constructor
//par is a copy so we can update it
MGMParams par = new MGMParams(parIn, this.p, this.lsum);
//penbeta = t(1).*(wv(1:p)'*wv(1:p));
//betascale=zeros(size(beta));
//betascale=max(0,1-penbeta./abs(beta));
DoubleMatrix2D weightMat = this.alg.multOuter(this.weights,
this.weights, null);
//int p = xDat.columns();
//weight beta
//betaw = (wv(1:p)'*wv(1:p)).*abs(beta);
//betanorms=sum(betaw(:));
DoubleMatrix2D betaWeight = weightMat.viewPart(0, 0, this.p, this.p);
DoubleMatrix2D absBeta = par.beta.copy().assign(Functions.abs);
double betaNorms = absBeta.assign(betaWeight, Functions.mult).zSum();
/*
thetanorms=0;
for s=1:p
for j=1:q
tempvec=theta(Lsums(j)+1:Lsums(j+1),s);
thetanorms=thetanorms+(wv(s)*wv(p+j))*norm(tempvec);
end
end
*/
double thetaNorms = 0;
for (int i = 0; i < this.p; i++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
for (int j = 0; j < this.lcumsum.length - 1; j++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
DoubleMatrix1D tempVec = par.theta.viewColumn(i).viewPart(this.lcumsum[j], this.l[j]);
thetaNorms += weightMat.get(i, this.p + j) * FastMath.sqrt(this.alg.norm2(tempVec));
}
}
/*
for r=1:q
for j=1:q
if r