net.maizegenetics.stats.EMMA.EMMAforDoubleMatrix Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of tassel Show documentation
Show all versions of tassel Show documentation
TASSEL is a software package to evaluate traits associations, evolutionary patterns, and linkage
disequilibrium.
The newest version!
package net.maizegenetics.stats.EMMA;
import net.maizegenetics.matrixalgebra.Matrix.DoubleMatrix;
import net.maizegenetics.matrixalgebra.Matrix.DoubleMatrixFactory;
import net.maizegenetics.matrixalgebra.decomposition.EigenvalueDecomposition;
import net.maizegenetics.stats.linearmodels.LinearModelUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.stream.IntStream;
/**
* Class for implementation of the EMMA algorithm often used in GWAS. Initially developer by
* Kang et al Genetics March 2008 178:1709-1723
*/
public class EMMAforDoubleMatrix {
private static final Logger myLogger = LogManager.getLogger(EMMAforDoubleMatrix.class);
protected DoubleMatrix y;
protected double[] lambda;
protected double[] eta2;
protected double c;
protected int N;
protected int q;
protected int Nran;
protected int dfMarker = 0;
protected DoubleMatrix Xoriginal = null;
// protected DoubleMatrix A;
protected DoubleMatrix X;
protected DoubleMatrix Zoriginal = null;
protected DoubleMatrix Z = null;
// protected DoubleMatrix transZ;
protected DoubleMatrix K;
protected EigenvalueDecomposition eig;
protected EigenvalueDecomposition eigA;
protected DoubleMatrix U;
protected DoubleMatrix invH;
protected DoubleMatrix invXHX;
protected DoubleMatrix beta;
protected DoubleMatrix Xbeta;
protected double ssModel;
protected double ssError;
protected double SST;
protected double Rsq;
protected int dfModel;
protected int dfError;
protected double delta;
protected double varResidual;
protected double varRandomEffect;
protected DoubleMatrix blup;
protected DoubleMatrix pred;
protected DoubleMatrix res;
protected DoubleMatrix pev;
protected double lnLikelihood;
protected boolean findDelta = true;
protected boolean calculatePEV = false;
protected double lowerlimit = 1e-5;
protected double upperlimit = 1e5;
protected int nregions = 100;
protected double convergence = 1e-10;
protected int maxiter = 50;
protected int subintervalCount = 0;
public EMMAforDoubleMatrix(DoubleMatrix y, DoubleMatrix fixed, DoubleMatrix kin, int nAlleles) {
this(y, fixed, kin, nAlleles, Double.NaN);
}
/**
* This constructor assumes that Z is the identity matrix for calculating blups, predicted values and residuals. If that is not true use
* the contstructor that explicity takes Z. This constructor treats A as ZKZ' so it can be used if blups and residuals are not needed.
* @param data
* @param fixed
* @param kin
* @param nAlleles
* @param delta
*/
public EMMAforDoubleMatrix(DoubleMatrix data, DoubleMatrix fixed, DoubleMatrix kin, int nAlleles, double delta) {
//throw an error if X is less than full column rank
dfModel = fixed.numberOfColumns();
int rank = fixed.columnRank();
if (rank < dfModel) throw new IllegalArgumentException("The fixed effect design matrix has less than full column rank. The analysis will not be run.");
if (!Double.isNaN(delta)) {
this.delta = delta;
findDelta = false;
}
y = data;
if (y.numberOfColumns() > 1 && y.numberOfRows() == 1) this.y = y.transpose();
N = y.numberOfRows();
X = fixed;
q = X.numberOfColumns();
// A = kin;
K = kin;
Nran = K.numberOfRows();
Z = DoubleMatrixFactory.DEFAULT.identity(Nran);
dfMarker = nAlleles - 1;
init();
}
/**
* This constructor should be used when Z is not the identity matrix. Z is needed to calculate blups and residuals.
* @param data
* @param fixed
* @param kin
* @param inZ
* @param nAlleles
* @param delta
*/
public EMMAforDoubleMatrix(DoubleMatrix data, DoubleMatrix fixed, DoubleMatrix kin, DoubleMatrix inZ, int nAlleles, double delta) {
dfModel = fixed.numberOfColumns();
int rank = fixed.columnRank();
if (rank < dfModel) throw new IllegalArgumentException("The fixed effect design matrix has less than full column rank. The analysis will not be run.");
if (!Double.isNaN(delta)) {
this.delta = delta;
findDelta = false;
}
y = data;
if (y.numberOfColumns() > 1 && y.numberOfRows() == 1) this.y = y.transpose();
N = y.numberOfRows();
X = fixed;
q = X.numberOfColumns();
Z = inZ;
K = kin;
// A = Z.mult(K).tcrossproduct(Z);
Nran = Z.numberOfRows();
dfMarker = nAlleles - 1;
init();
}
/**
* This constructor is designed for G-BLUP.
* Uses kinship matrix with phenotypes in same order.
* Phenotypes = NaN for prediction set.
* Assumes that Z is the identity matrix for calculating blups, predicted values and residuals. If that is not true use
* the constructor that explicity takes Z. This constructor treats A as ZKZ' so it can be used if blups and residuals are not needed.
* @param data A column double matrix of phenotypic values (number of individuals x 1 taxa ID column + number of traits)
* @param fixed
* @param kin
*/
public EMMAforDoubleMatrix(DoubleMatrix data, DoubleMatrix fixed, DoubleMatrix kin) {
//throw an error if X is less than full column rank
dfModel = fixed.numberOfColumns();
int rank = fixed.columnRank();
if (rank < dfModel) throw new IllegalArgumentException("The fixed effect design matrix has less than full column rank. The analysis will not be run.");
if (data.numberOfColumns() > 1 && data.numberOfRows() == 1)
throw new IllegalArgumentException("The phenotype data must be a column matrix.");
//remove rows in data with missing phenotypic values from Y and Z
K = kin;
Nran = K.numberOfRows();
//int nonmissingY = (int) Arrays.stream(data.to1DArray()).filter(d -> ! Double.isNaN(d)).count();
//int[] nonmissingIndex = new int[nonmissingY];
int[] nonmissingIndex = IntStream.range(0, Nran).filter(i -> ! Double.isNaN(data.get(i,0))).toArray();
y = data.getSelection(nonmissingIndex, null);
Zoriginal = DoubleMatrixFactory.DEFAULT.identity(Nran);
Z = Zoriginal.getSelection(nonmissingIndex, null);
N = y.numberOfRows();
Xoriginal = fixed;
X = fixed.getSelection(nonmissingIndex, null);
q = X.numberOfColumns();
init();
}
protected void init() {
int nreml = N - q;
c = nreml * Math.log(nreml / 2 / Math.PI) - nreml;
lambda = new double[nreml];
//find the eigenvalues of A
DoubleMatrix A = Z.mult(K).tcrossproduct(Z);
eigA = A.getEigenvalueDecomposition();
double[] eigenvalA = eigA.getEigenvalues();
int n = eigenvalA.length;
double min = eigenvalA[0];
for (int i = 1; i < n; i++) min = Math.min(min, eigenvalA[i]);
double bend = 0.0;
if (min < 0.01) bend = -1 * min + 0.5;
//S = I - X inv(X'X) X'
//X is assumed to be of full column rank, i.e. X'X is non-singular
DoubleMatrix[] XtXGM = X.getXtXGM();
DoubleMatrix XtX = XtXGM[0];
DoubleMatrix S = XtXGM[2];
DoubleMatrix G = XtXGM[1];
//determine the s
//add bend to the diagonal of A
//this is necessary to get correct decomposition of SAS
n = A.numberOfRows();
for (int i = 0; i < n; i++) A.set(i, i, A.get(i, i) + bend);
DoubleMatrix SAS = S.mult(A.mult(S));
//decompose SAS
eig = SAS.getEigenvalueDecomposition();
//which are the zero eigenvalues?
double[] eigenval = eig.getEigenvalues();
int[] ndx = getSortedIndexofAbsoluteValues(eigenval);
int[] eigndx = new int[nreml];
for (int i = 0; i < nreml; i++) eigndx[i] = ndx[i];
//sort V to get U
DoubleMatrix V = eig.getEigenvectors();
U = V.getSelection(null, ndx);
//derive lambda
for (int i = 0; i < nreml; i++) lambda[i] = eigenval[eigndx[i]] - bend;
}
private int[] getSortedIndexofAbsoluteValues(double[] values) {
int n = values.length;
int[] index = new int[n];
class Pair implements Comparable {
int order;
double absvalue;
Pair(int order, double value){
this.order = order;
this.absvalue = Math.abs(value);
}
@Override
public int compareTo(Pair other) {
if (absvalue < other.absvalue) return 1;
if (absvalue > other.absvalue) return -1;
return 0;
}
}
Pair[] valuePairs = new Pair[n];
for (int i = 0; i < n; i++) {
valuePairs[i] = new Pair(i, values[i]);
}
Arrays.sort(valuePairs);
for (int i = 0; i < n; i++) index[i] = valuePairs[i].order;
return index;
}
public void solve() {
//calculate eta squared
DoubleMatrix eta = U.crossproduct(y);
int nrows = eta.numberOfRows();
eta2 = new double[nrows];
for (int i = 0; i < nrows; i++) eta2[i] = eta.get(i, 0) * eta.get(i, 0);
if (findDelta) {
double[] interval = new double[]{lowerlimit, upperlimit};
delta = findDeltaInInterval(interval);
}
lnLikelihood = lnlk(delta);
invH = inverseH(delta);
beta = calculateBeta();
double genvar = getGenvar(beta);
dfModel = q - 1;
dfError = N - q;
varResidual = genvar * delta;
varRandomEffect = genvar;
}
public void calculateBlupsPredictedResiduals() {
calculateBLUP();
pred = calculatePred();
res = calculateRes();
}
public void calculateBlupsPredicted() {
calculateBLUP();
pred = calculatePred();
}
private double findDeltaInInterval(double[] interval) {
double[][] d = scanlnlk(interval[0], interval[1]);
double[][] sgnchange = findSignChanges(d);
int nchanges = sgnchange.length;
double[] bestd = new double[]{Double.NaN, Double.NaN, Double.NaN};
int n = d.length;
//find the element of d with maximum ln Likelihood (bestd)
for (int i = 0; i < n; i++) {
if (Double.isNaN(bestd[1])) bestd = d[i];
else if (!Double.isNaN(d[i][1]) && d[i][1] > bestd[1]) bestd = d[i];
}
double bestdelta = bestd[0];
double lkDelta = bestd[1];
for (int i = 0; i < nchanges; i++) {
double newdelta = findMaximum(sgnchange[i]);
if (!Double.isNaN(newdelta)) {
double newlk = lnlk(newdelta);
if (!Double.isNaN(newlk) && newlk > lkDelta) {
bestdelta = newdelta;
lkDelta = newlk;
}
}
}
return bestdelta;
}
private double lnlk(double delta) {
double term1 = 0;
double term2 = 0;
int n = N - q;
for (int i = 0; i < n; i++) {
double val = (lambda[i] + delta);
if (val < 0) return Double.NaN;
term1 += eta2[i] / val;
term2 += Math.log(val);
}
return (c - n * Math.log(term1) - term2) / 2;
}
private double d1lnlk(double delta) {
double term1 = 0;
double term2 = 0;
double term3 = 0;
int n = N - q;
for (int i = 0; i < n; i++) {
double val = 1 / (lambda[i] + delta);
double val2 = eta2[i] * val;
term1 += val2;
term2 += val2 * val;
term3 += val;
}
return n * term2 / term1 / 2 - term3 / 2;
}
private double[][] scanlnlk(double lower, double upper) {
double[][] result = new double[nregions][3];
upper = Math.log10(upper);
lower = Math.log10(lower);
double incr = (upper - lower) / (nregions - 1);
for (int i = 0; i < nregions; i++) {
double delta = Math.pow(10.0, lower + i * incr);
result[i][0] = delta;
result[i][1] = lnlk(delta);
result[i][2] = d1lnlk(delta);
}
return result;
}
private double[][] findSignChanges(double[][] scan) {
ArrayList changes = new ArrayList();
int n = scan.length;
for (int i = 0; i < n - 1; i++) {
if (scan[i][2] > 0 && scan[i+1][2] <= 0 && !Double.isNaN(scan[i][1])) changes.add(new Double[]{scan[i][0], scan[i+1][0]});
}
n = changes.size();
double[][] result = new double[n][2];
for (int i = 0; i < n; i++) {
result[i][0] = changes.get(i)[0];
result[i][1] = changes.get(i)[1];
}
return result;
}
private double findMaximum(double[] interval) {
//uses Newton Raphson to find local maximum
//updates delta using delta' = delta - d1/d2
//where d1 is the first derivative of lnlk at delta and
//d2 is the second derivative of lnlk at delta
//the local maximum is expected to fall between delta and max
//if the algorithm finds a new delta outside this interval
//then the likelihood is not well behaved in this interval
//subdivide the interval and try again
double delta = interval[0];
boolean end = false;
int n = N - q;
int nIterations = 0;
while (!end && nIterations < maxiter) {
//A = sum[eta2/(lambda + delta)]
//B = sum[eta2/(lambda + delta)^2]
//C = sum[eta2/(lambda + delta)^3]
//D = sum[1/(lambda + delta)]
//E = sum[1/(lambda + delta)^2]
double A = 0;
double B = 0;
double C = 0;
double D = 0;
double E = 0;
for (int i = 0; i < n; i++) {
double val = lambda[i] + delta;
double val2 = val * val;
double val3 = val2 * val;
A += eta2[i]/val;
B += eta2[i]/val2;
C += eta2[i]/val3;
D += 1/val;
E += 1/val2;
}
double d1 = n*B/A - D;
if (Math.abs(d1) < convergence) end = true;
else {
double d2 = E + n*(B*B - 2*A*C)/A/A;
delta = delta - d1/d2;
}
if (delta < interval[0] || delta > interval[1]) {
subintervalCount++;
if (subintervalCount > 3) {
subintervalCount = 0;
return Double.NaN;
}
delta = findDeltaInInterval(interval);
end = true;
}
nIterations++;
}
subintervalCount = 0;
return delta;
}
private DoubleMatrix inverseH(double delta) {
DoubleMatrix V = eigA.getEigenvectors();
DoubleMatrix D = eigA.getEigenvalueMatrix();
int n = D.numberOfRows();
for (int i = 0; i < n; i++) D.set(i, i, 1 / (D.get(i, i) + delta));
return V.mult(D.tcrossproduct(V));
}
private DoubleMatrix calculateBeta() {
DoubleMatrix XtH = X.crossproduct(invH);
invXHX = XtH.mult(X).inverse();
return invXHX.mult(XtH.mult(y));
}
public void calculateBLUP(){
Xbeta = X.mult(beta);
DoubleMatrix YminusXbeta = y.minus(Xbeta);
DoubleMatrix KtransZ = K.mult(Z.transpose());
DoubleMatrix KtransZinvH = KtransZ.mult(invH);
blup = KtransZinvH.mult(YminusXbeta);
if (calculatePEV) {
DoubleMatrix KZHX = KtransZinvH.mult(X);
DoubleMatrix pevMatrix = K.copy();
pevMatrix = pevMatrix.minus(KtransZinvH.tcrossproduct(KtransZ));
pevMatrix = pevMatrix.minus(KZHX.mult(invXHX).tcrossproduct(KZHX));
int size = pevMatrix.numberOfRows();
pev = DoubleMatrixFactory.DEFAULT.make(size, 1);
for (int i = 0; i < size; i++) pev.set(i, 0, pevMatrix.get(i, i));
pev.scalarMultEquals(varRandomEffect);
}
}
private DoubleMatrix calculatePred(){
if (Xoriginal == null){
Xbeta = X.mult(beta);
DoubleMatrix Zu = Z.mult(blup);
return Xbeta.plus(Zu);
}else{
Xbeta = Xoriginal.mult(beta);
//System.out.printf("Dimensions of Xbeta are %d,%d\n",Xbeta.numberOfRows(),Xbeta.numberOfColumns());
DoubleMatrix Zu = Zoriginal.mult(blup);
//System.out.printf("Dimensions of Zu are %d,%d\n",Zu.numberOfRows(),Zu.numberOfColumns());
return Xbeta.plus(Zu);
}
}
private DoubleMatrix calculateRes(){
return y.minus(pred);
}
private double getGenvar(DoubleMatrix beta) {
DoubleMatrix res = y.copy();
res.minusEquals(X.mult(beta));
return res.crossproduct(invH.mult(res)).get(0,0) / (N - q);
}
public int getDfMarker() {
return dfMarker;
}
public DoubleMatrix getBeta() {
return beta;
}
public int getDfModel() {
return dfModel;
}
public int getDfError() {
return dfError;
}
public double getDelta() {
return delta;
}
public DoubleMatrix getInvH() {
return invH;
}
public double getVarRes() {
return varResidual;
}
public double getVarRan() {
return varRandomEffect;
}
public DoubleMatrix getBlup() {
return blup;
}
public DoubleMatrix getPev() {
return pev;
}
public DoubleMatrix getPred() {
return pred;
}
public DoubleMatrix getRes() {
return res;
}
public double getLnLikelihood() {
return lnLikelihood;
}
/**
* @return For markers with 2 df, F.fullModel, p.fullModel, additive effect, Fadd, padd, dominance effect, Fdom, pdom
* For markers with other than 2 df, F and p for the full model only
*/
public double[] getMarkerFp() {
if (dfMarker < 1) return new double[]{Double.NaN, Double.NaN, Double.NaN};
int nparm = beta.numberOfRows();
int firstmarker = nparm - dfMarker;
DoubleMatrix M = DoubleMatrixFactory.DEFAULT.make(dfMarker, nparm);
for (int i = 0; i < dfMarker; i++) M.set(i, i + firstmarker, 1);
DoubleMatrix MB = M.mult(beta);
DoubleMatrix invMiM = M.mult(invXHX.tcrossproduct(M));
invMiM.invert();
double F = MB.crossproduct(invMiM.mult(MB)).get(0,0);
F /= varRandomEffect;
F /= dfMarker;
double p;
try {
p = LinearModelUtils.Ftest(F, dfMarker, N - q);
} catch (Exception e) {p = Double.NaN;}
if (dfMarker != 2) return new double[]{F,p};
//calculate add and dom effects and tests
//assumes the betas are for the two homozygous classes and that the het effect = 0;
//additive test
double Fadd, Fdom, padd, pdom;
M = DoubleMatrixFactory.DEFAULT.make(1, nparm, 0);
M.set(0, nparm - 2, 0.5);
M.set(0, nparm - 1, -0.5);
MB = M.mult(beta);
double addEffect = MB.get(0, 0);
try {
Fadd = addEffect * addEffect / (M.mult(invXHX.tcrossproduct(M))).get(0,0) / varRandomEffect;
} catch (Exception ex) {
Fadd = Double.NaN;
}
try {
padd = LinearModelUtils.Ftest(Fadd, 1, N - q);
} catch (Exception e) {
padd = Double.NaN;
}
//dominance test
M = DoubleMatrixFactory.DEFAULT.make(1, nparm, 0);
M.set(0, nparm - 2, -0.5);
M.set(0, nparm - 1, -0.5);
MB = M.mult(beta);
double domEffect = MB.get(0, 0);
try {
Fdom = domEffect * domEffect / (M.mult(invXHX.tcrossproduct(M))).get(0,0) / varRandomEffect;
} catch (Exception ex) {
Fdom = Double.NaN;
}
try {
pdom = LinearModelUtils.Ftest(Fdom, 1, N - q);
} catch (Exception e) {
pdom = Double.NaN;
}
return new double[]{F,p,addEffect,Fadd,padd,domEffect,Fdom,pdom};
}
public void solveWithNewData(DoubleMatrix y) {
this.y = y;
solve();
}
public void setCalculatePEV(boolean calculatePEV) {
this.calculatePEV = calculatePEV;
}
}