All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.projectglow.sql.expressions.FirthTest.scala Maven / Gradle / Ivy

/*
 * Copyright 2019 The Glow Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.projectglow.sql.expressions

import breeze.linalg.{*, diag, max, qr, sum, DenseMatrix, DenseVector, MatrixSingularException}
import breeze.numerics.{abs, log, sigmoid, sqrt}
import com.github.fommil.netlib.LAPACK
import org.apache.spark.ml.linalg.{DenseMatrix => SparkDenseMatrix}
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType
import org.netlib.util.intW

object FirthTest extends LogitTest {
  override def fitStatePerPhenotype: Boolean = false
  override type FitState = FirthFitState
  override val resultSchema: StructType = Encoders.product[LogitTestResults].schema

  override def init(
      phenotypes: Array[Double],
      covariates: SparkDenseMatrix,
      offsetOption: Option[Array[Double]]): FirthFitState = {

    val covariateX =
      new DenseMatrix[Double](covariates.numRows, covariates.numCols, covariates.values)
    FirthFitState(
      DenseMatrix.horzcat(covariateX, DenseMatrix.zeros[Double](covariateX.rows, 1)),
      new FirthNewtonArgs(covariates.numRows, covariates.numCols),
      new FirthNewtonArgs(covariates.numRows, covariates.numCols + 1)
    )
  }

  override def runTest(
      genotypes: DenseVector[Double],
      phenotypes: DenseVector[Double],
      offsetOption: Option[DenseVector[Double]],
      fitState: FirthFitState): InternalRow = {

    fitState.x(::, -1) := genotypes
    fitState.nullFitArgs.clear()
    val nullFit = fitFirth(fitState.x, phenotypes, offsetOption, fitState.nullFitArgs)

    if (!nullFit.converged) {
      return LogitTestResults.nanRow
    }

    fitState.fullFitArgs.clear()
    fitState.fullFitArgs.initFromNullFit(nullFit.fitState)
    val fullFit = fitFirth(fitState.x, phenotypes, offsetOption, fitState.fullFitArgs)

    if (!fullFit.converged) {
      return LogitTestResults.nanRow
    }

    val beta = fullFit.fitState.b(-1)
    val fisher = fitState
        .x
        .t * (fitState.x(::, *) *:* (fullFit.fitState.mu *:* (1d - fullFit.fitState.mu)))
    LogisticRegressionGwas.makeStats(beta, fisher, fullFit.logLkhd, nullFit.logLkhd)
  }

  // Adapted from Hail's `fitFirth` method.
  def fitFirth(
      x: DenseMatrix[Double],
      y: DenseVector[Double],
      offsetOption: Option[DenseVector[Double]],
      args: FirthNewtonArgs,
      maxIter: Int = 100,
      tol: Double = 1e-6): FirthFit = {

    val m0 = args.b.length
    var logLkhd = 0d
    var iter = 1
    var converged = false
    var exploded = false

    while (!converged && !exploded && iter <= maxIter) {
      try {
        val eta = offsetOption match {
          case Some(offset) => offset + x(::, 0 until m0) * args.b
          case None => x(::, 0 until m0) * args.b
        }
        args.mu := sigmoid(eta)
        args.sqrtW := sqrt(args.mu *:* (1d - args.mu))
        val QR = qr.reduced(x(::, *) *:* args.sqrtW)
        val h = QR.q(*, ::).map(r => r dot r)
        val deltaB = solveUpperTriangular(
          QR.r(0 until m0, 0 until m0),
          QR.q(::, 0 until m0).t * (((y - args.mu) + (h *:* (0.5 - args.mu))) /:/ args.sqrtW))

        if (deltaB(0).isNaN) {
          exploded = true
        } else if (max(abs(deltaB)) < tol && iter > 1) {
          converged = true
          logLkhd = sum(breeze.numerics.log((y *:* args.mu) + ((1d - y) *:* (1d - args.mu)))) + sum(
              log(abs(diag(QR.r))))
        } else {
          iter += 1
          args.b += deltaB
        }
      } catch {
        case e: breeze.linalg.MatrixSingularException => exploded = true
        case e: breeze.linalg.NotConvergedException => exploded = true
      }
    }

    FirthFit(args, logLkhd, converged, exploded)
  }

  // Solve an upper triangular system of equations using LAPACK.
  // Adapted from Hail's TriSolver
  def solveUpperTriangular(a: DenseMatrix[Double], b: DenseVector[Double]): DenseVector[Double] = {

    require(a.rows == a.cols)
    require(a.rows == b.length)

    val x = DenseVector(b.toArray)

    val info: Int = {
      val info = new intW(0)
      LAPACK
        .getInstance()
        .dtrtrs("U", "N", "N", a.rows, 1, a.toArray, a.rows, x.data, x.length, info) // x := A \ x
      info.`val`
    }

    if (info > 0) {
      throw new MatrixSingularException()
    } else if (info < 0) {
      throw new IllegalArgumentException()
    }

    x
  }
}

class FirthNewtonArgs(numRows: Int, numCols: Int) {
  val b = DenseVector.zeros[Double](numCols)
  val mu = DenseVector.zeros[Double](numRows)
  val sqrtW = DenseVector.zeros[Double](numRows)

  def clear(): Unit = {
    b := 0d
    mu := 0d
    sqrtW := 0d
  }

  def initFromNullFit(nullFit: FirthNewtonArgs): Unit = {
    b(0 until nullFit.b.length) := nullFit.b
  }
}

case class FirthFitState(
    // The last column of x will be rewritten with the genotypes for each new row
    x: DenseMatrix[Double],
    nullFitArgs: FirthNewtonArgs,
    fullFitArgs: FirthNewtonArgs
)

case class FirthFit(
    fitState: FirthNewtonArgs,
    logLkhd: Double,
    converged: Boolean,
    exploded: Boolean)




© 2015 - 2025 Weber Informatics LLC | Privacy Policy