All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.linalg.BLAS.scala Maven / Gradle / Ivy

There is a newer version: 3.5.3
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml.linalg

import dev.ludovic.netlib.blas.{BLAS => NetlibBLAS, JavaBLAS => NetlibJavaBLAS, NativeBLAS => NetlibNativeBLAS}

/**
 * BLAS routines for MLlib's vectors and matrices.
 */
private[spark] object BLAS extends Serializable {

  @transient private var _javaBLAS: NetlibBLAS = _
  @transient private var _nativeBLAS: NetlibBLAS = _
  private val nativeL1Threshold: Int = 256

  // For level-1 function dspmv, use javaBLAS for better performance.
  private[spark] def javaBLAS: NetlibBLAS = {
    if (_javaBLAS == null) {
      _javaBLAS = NetlibJavaBLAS.getInstance
    }
    _javaBLAS
  }

  // For level-3 routines, we use the native BLAS.
  private[spark] def nativeBLAS: NetlibBLAS = {
    if (_nativeBLAS == null) {
      _nativeBLAS =
        try { NetlibNativeBLAS.getInstance } catch { case _: Throwable => javaBLAS }
    }
    _nativeBLAS
  }

  private[spark] def getBLAS(vectorSize: Int): NetlibBLAS = {
    if (vectorSize < nativeL1Threshold) {
      javaBLAS
    } else {
      nativeBLAS
    }
  }

  /**
   * y += a * x
   */
  def axpy(a: Double, x: Vector, y: Vector): Unit = {
    require(x.size == y.size)
    y match {
      case dy: DenseVector =>
        x match {
          case sx: SparseVector =>
            axpy(a, sx, dy)
          case dx: DenseVector =>
            axpy(a, dx, dy)
          case _ =>
            throw new UnsupportedOperationException(
              s"axpy doesn't support x type ${x.getClass}.")
        }
      case _ =>
        throw new IllegalArgumentException(
          s"axpy only supports adding to a dense vector but got type ${y.getClass}.")
    }
  }

  /**
   * y += a * x
   */
  private def axpy(a: Double, x: DenseVector, y: DenseVector): Unit = {
    val n = x.size
    getBLAS(n).daxpy(n, a, x.values, 1, y.values, 1)
  }

  /**
   * y += a * x
   */
  private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = {
    val xValues = x.values
    val xIndices = x.indices
    val yValues = y.values
    val nnz = xIndices.length

    if (a == 1.0) {
      var k = 0
      while (k < nnz) {
        yValues(xIndices(k)) += xValues(k)
        k += 1
      }
    } else {
      var k = 0
      while (k < nnz) {
        yValues(xIndices(k)) += a * xValues(k)
        k += 1
      }
    }
  }

  /** Y += a * x */
  private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = {
    require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " +
      s"size(X) = ${(X.numRows, X.numCols)} but size(Y) = ${(Y.numRows, Y.numCols)}.")
    getBLAS(X.values.length).daxpy(X.numRows * X.numCols, a, X.values, 1, Y.values, 1)
  }

  /**
   * dot(x, y)
   */
  def dot(x: Vector, y: Vector): Double = {
    require(x.size == y.size,
      "BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" +
      " x.size = " + x.size + ", y.size = " + y.size)
    (x, y) match {
      case (dx: DenseVector, dy: DenseVector) =>
        dot(dx, dy)
      case (sx: SparseVector, dy: DenseVector) =>
        dot(sx, dy)
      case (dx: DenseVector, sy: SparseVector) =>
        dot(sy, dx)
      case (sx: SparseVector, sy: SparseVector) =>
        dot(sx, sy)
      case _ =>
        throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).")
    }
  }

  /**
   * dot(x, y)
   */
  private def dot(x: DenseVector, y: DenseVector): Double = {
    val n = x.size
    getBLAS(n).ddot(n, x.values, 1, y.values, 1)
  }

  /**
   * dot(x, y)
   */
  private def dot(x: SparseVector, y: DenseVector): Double = {
    val xValues = x.values
    val xIndices = x.indices
    val yValues = y.values
    val nnz = xIndices.length

    var sum = 0.0
    var k = 0
    while (k < nnz) {
      sum += xValues(k) * yValues(xIndices(k))
      k += 1
    }
    sum
  }

  /**
   * dot(x, y)
   */
  private def dot(x: SparseVector, y: SparseVector): Double = {
    val xValues = x.values
    val xIndices = x.indices
    val yValues = y.values
    val yIndices = y.indices
    val nnzx = xIndices.length
    val nnzy = yIndices.length

    var kx = 0
    var ky = 0
    var sum = 0.0
    // y catching x
    while (kx < nnzx && ky < nnzy) {
      val ix = xIndices(kx)
      while (ky < nnzy && yIndices(ky) < ix) {
        ky += 1
      }
      if (ky < nnzy && yIndices(ky) == ix) {
        sum += xValues(kx) * yValues(ky)
        ky += 1
      }
      kx += 1
    }
    sum
  }

  /**
   * y = x
   */
  def copy(x: Vector, y: Vector): Unit = {
    val n = y.size
    require(x.size == n)
    y match {
      case dy: DenseVector =>
        x match {
          case sx: SparseVector =>
            val sxIndices = sx.indices
            val sxValues = sx.values
            val dyValues = dy.values
            val nnz = sxIndices.length

            var i = 0
            var k = 0
            while (k < nnz) {
              val j = sxIndices(k)
              while (i < j) {
                dyValues(i) = 0.0
                i += 1
              }
              dyValues(i) = sxValues(k)
              i += 1
              k += 1
            }
            while (i < n) {
              dyValues(i) = 0.0
              i += 1
            }
          case dx: DenseVector =>
            Array.copy(dx.values, 0, dy.values, 0, n)
        }
      case _ =>
        throw new IllegalArgumentException(s"y must be dense in copy but got ${y.getClass}")
    }
  }

  /**
   * x = a * x
   */
  def scal(a: Double, x: Vector): Unit = {
    x match {
      case sx: SparseVector =>
        getBLAS(sx.values.length).dscal(sx.values.length, a, sx.values, 1)
      case dx: DenseVector =>
        getBLAS(dx.size).dscal(dx.values.length, a, dx.values, 1)
      case _ =>
        throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.")
    }
  }

  /**
   * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
   *
   * @param U the upper triangular part of the matrix in a [[DenseVector]](column major)
   */
  def spr(alpha: Double, v: Vector, U: DenseVector): Unit = {
    spr(alpha, v, U.values)
  }

  /**
   * y := alpha*A*x + beta*y
   *
   * @param n The order of the n by n matrix A.
   * @param A The upper triangular part of A in a [[DenseVector]] (column major).
   * @param x The [[DenseVector]] transformed by A.
   * @param y The [[DenseVector]] to be modified in place.
   */
  def dspmv(
      n: Int,
      alpha: Double,
      A: DenseVector,
      x: DenseVector,
      beta: Double,
      y: DenseVector): Unit = {
    javaBLAS.dspmv("U", n, alpha, A.values, x.values, 1, beta, y.values, 1)
  }

  /**
   * Adds alpha * v * v.t to a matrix in-place. This is the same as BLAS's ?SPR.
   *
   * @param U the upper triangular part of the matrix packed in an array (column major)
   */
  def spr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
    val n = v.size
    v match {
      case DenseVector(values) =>
        nativeBLAS.dspr("U", n, alpha, values, 1, U)
      case SparseVector(size, indices, values) =>
        val nnz = indices.length
        var colStartIdx = 0
        var prevCol = 0
        var col = 0
        var j = 0
        var i = 0
        var av = 0.0
        while (j < nnz) {
          col = indices(j)
          // Skip empty columns.
          colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2
          col = indices(j)
          av = alpha * values(j)
          i = 0
          while (i <= j) {
            U(colStartIdx + indices(i)) += av * values(i)
            i += 1
          }
          j += 1
          prevCol = col
        }
      case _ =>
        throw new IllegalArgumentException(s"spr doesn't support vector type ${v.getClass}.")
    }
  }

  /**
   * A := alpha * x * x^T^ + A
   * @param alpha a real scalar that will be multiplied to x * x^T^.
   * @param x the vector x that contains the n elements.
   * @param A the symmetric matrix A. Size of n x n.
   */
  def syr(alpha: Double, x: Vector, A: DenseMatrix): Unit = {
    val mA = A.numRows
    val nA = A.numCols
    require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA")
    require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}")

    x match {
      case dv: DenseVector => syr(alpha, dv, A)
      case sv: SparseVector => syr(alpha, sv, A)
      case _ =>
        throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.")
    }
  }

  private def syr(alpha: Double, x: DenseVector, A: DenseMatrix): Unit = {
    val nA = A.numRows
    val mA = A.numCols

    nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA)

    // Fill lower triangular part of A
    var i = 0
    while (i < mA) {
      var j = i + 1
      while (j < nA) {
        A(j, i) = A(i, j)
        j += 1
      }
      i += 1
    }
  }

  private def syr(alpha: Double, x: SparseVector, A: DenseMatrix): Unit = {
    val mA = A.numCols
    val xIndices = x.indices
    val xValues = x.values
    val nnz = xValues.length
    val Avalues = A.values

    var i = 0
    while (i < nnz) {
      val multiplier = alpha * xValues(i)
      val offset = xIndices(i) * mA
      var j = 0
      while (j < nnz) {
        Avalues(xIndices(j) + offset) += multiplier * xValues(j)
        j += 1
      }
      i += 1
    }
  }

  /**
   * C := alpha * A * B + beta * C
   * @param alpha a scalar to scale the multiplication A * B.
   * @param A the matrix A that will be left multiplied to B. Size of m x k.
   * @param B the matrix B that will be left multiplied by A. Size of k x n.
   * @param beta a scalar that can be used to scale matrix C.
   * @param C the resulting matrix C. Size of m x n. C.isTransposed must be false.
   */
  def gemm(
      alpha: Double,
      A: Matrix,
      B: DenseMatrix,
      beta: Double,
      C: DenseMatrix): Unit = {
    require(!C.isTransposed,
      "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
    if (alpha == 0.0 && beta == 1.0) {
      // gemm: alpha is equal to 0 and beta is equal to 1. Returning C.
    } else if (alpha == 0.0) {
      getBLAS(C.values.length).dscal(C.values.length, beta, C.values, 1)
    } else {
      A match {
        case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C)
        case dense: DenseMatrix => gemm(alpha, dense, B, beta, C)
        case _ =>
          throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.")
      }
    }
  }

  /**
   * CValues[0: A.numRows * B.numCols] := alpha * A * B + beta * CValues[0: A.numRows * B.numCols]
   * @param alpha a scalar to scale the multiplication A * B.
   * @param A the matrix A that will be left multiplied to B. Size of m x k.
   * @param B the matrix B that will be left multiplied by A. Size of k x n.
   * @param beta a scalar that can be used to scale matrix C.
   * @param CValues the values of matrix C. C.isTransposed is supposed to be false.
   */
  def gemm(
      alpha: Double,
      A: Matrix,
      B: DenseMatrix,
      beta: Double,
      CValues: Array[Double]): Unit = {
    require(A.numRows * B.numCols <= CValues.length,
      s"the length of CValues must be no less than ${A.numRows} X ${B.numCols}")
    if (alpha == 0.0 && beta == 1.0) {
      // gemm: alpha is equal to 0 and beta is equal to 1. Returning C.
    } else if (alpha == 0.0) {
      val n = A.numRows * B.numCols
      getBLAS(n).dscal(n, beta, CValues, 1)
    } else {
      A match {
        case sparse: SparseMatrix => gemmImpl(alpha, sparse, B, beta, CValues)
        case dense: DenseMatrix => gemmImpl(alpha, dense, B, beta, CValues)
        case _ =>
          throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.")
      }
    }
  }

  /**
   * C := alpha * A * B + beta * C
   * For `DenseMatrix` A.
   */
  private def gemm(
      alpha: Double,
      A: DenseMatrix,
      B: DenseMatrix,
      beta: Double,
      C: DenseMatrix): Unit = {
    require(A.numCols == B.numRows,
      s"The columns of A don't match the rows of B. A: ${A.numCols}, B: ${B.numRows}")
    require(A.numRows == C.numRows,
      s"The rows of C don't match the rows of A. C: ${C.numRows}, A: ${A.numRows}")
    require(B.numCols == C.numCols,
      s"The columns of C don't match the columns of B. C: ${C.numCols}, A: ${B.numCols}")
    gemmImpl(alpha, A, B, beta, C.values)
  }

  /**
   * CValues[0: A.numRows * B.numCols] := alpha * A * B + beta * CValues[0: A.numRows * B.numCols]
   * For `DenseMatrix` A.
   */
  private def gemmImpl(
      alpha: Double,
      A: DenseMatrix,
      B: DenseMatrix,
      beta: Double,
      CValues: Array[Double]): Unit = {
    val tAstr = if (A.isTransposed) "T" else "N"
    val tBstr = if (B.isTransposed) "T" else "N"
    val lda = if (!A.isTransposed) A.numRows else A.numCols
    val ldb = if (!B.isTransposed) B.numRows else B.numCols

    require(A.numCols == B.numRows,
      s"The columns of A don't match the rows of B. A: ${A.numCols}, B: ${B.numRows}")

    nativeBLAS.dgemm(tAstr, tBstr, A.numRows, B.numCols, A.numCols, alpha, A.values, lda,
      B.values, ldb, beta, CValues, A.numRows)
  }

  /**
   * C := alpha * A * B + beta * C
   * For `SparseMatrix` A.
   */
  private def gemm(
      alpha: Double,
      A: SparseMatrix,
      B: DenseMatrix,
      beta: Double,
      C: DenseMatrix): Unit = {
    val mA: Int = A.numRows
    val nB: Int = B.numCols
    val kA: Int = A.numCols
    val kB: Int = B.numRows

    require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB")
    require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA")
    require(nB == C.numCols,
      s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB")

    gemmImpl(alpha, A, B, beta, C.values)
  }

  /**
   * CValues[0: A.numRows * B.numCols] := alpha * A * B + beta * CValues[0: A.numRows * B.numCols]
   * For `SparseMatrix` A.
   */
  private def gemmImpl(
      alpha: Double,
      A: SparseMatrix,
      B: DenseMatrix,
      beta: Double,
      CValues: Array[Double]): Unit = {
    val mA: Int = A.numRows
    val nB: Int = B.numCols
    val kA: Int = A.numCols
    val kB: Int = B.numRows

    require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB")

    val Avals = A.values
    val Bvals = B.values
    val Cvals = CValues
    val ArowIndices = A.rowIndices
    val AcolPtrs = A.colPtrs

    // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
    if (A.isTransposed) {
      var colCounterForB = 0
      if (!B.isTransposed) { // Expensive to put the check inside the loop
        while (colCounterForB < nB) {
          var rowCounterForA = 0
          val Cstart = colCounterForB * mA
          val Bstart = colCounterForB * kA
          while (rowCounterForA < mA) {
            var i = AcolPtrs(rowCounterForA)
            val indEnd = AcolPtrs(rowCounterForA + 1)
            var sum = 0.0
            while (i < indEnd) {
              sum += Avals(i) * Bvals(Bstart + ArowIndices(i))
              i += 1
            }
            val Cindex = Cstart + rowCounterForA
            Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
            rowCounterForA += 1
          }
          colCounterForB += 1
        }
      } else {
        while (colCounterForB < nB) {
          var rowCounterForA = 0
          val Cstart = colCounterForB * mA
          while (rowCounterForA < mA) {
            var i = AcolPtrs(rowCounterForA)
            val indEnd = AcolPtrs(rowCounterForA + 1)
            var sum = 0.0
            while (i < indEnd) {
              sum += Avals(i) * Bvals(colCounterForB + nB * ArowIndices(i))
              i += 1
            }
            val Cindex = Cstart + rowCounterForA
            Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
            rowCounterForA += 1
          }
          colCounterForB += 1
        }
      }
    } else {
      // Scale matrix first if `beta` is not equal to 1.0
      if (beta != 1.0) {
        val n = A.numRows * B.numCols
        getBLAS(n).dscal(n, beta, Cvals, 1)
      }
      // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
      // B, and added to C.
      var colCounterForB = 0 // the column to be updated in C
      if (!B.isTransposed) { // Expensive to put the check inside the loop
        while (colCounterForB < nB) {
          var colCounterForA = 0 // The column of A to multiply with the row of B
          val Bstart = colCounterForB * kB
          val Cstart = colCounterForB * mA
          while (colCounterForA < kA) {
            var i = AcolPtrs(colCounterForA)
            val indEnd = AcolPtrs(colCounterForA + 1)
            val Bval = Bvals(Bstart + colCounterForA) * alpha
            while (i < indEnd) {
              Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
              i += 1
            }
            colCounterForA += 1
          }
          colCounterForB += 1
        }
      } else {
        while (colCounterForB < nB) {
          var colCounterForA = 0 // The column of A to multiply with the row of B
          val Cstart = colCounterForB * mA
          while (colCounterForA < kA) {
            var i = AcolPtrs(colCounterForA)
            val indEnd = AcolPtrs(colCounterForA + 1)
            val Bval = Bvals(colCounterForB + nB * colCounterForA) * alpha
            while (i < indEnd) {
              Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
              i += 1
            }
            colCounterForA += 1
          }
          colCounterForB += 1
        }
      }
    }
  }

  /**
   * y[0: A.numRows] := alpha * A * x[0: A.numCols] + beta * y[0: A.numRows]
   */
  def gemv(
      alpha: Double,
      A: Matrix,
      x: Array[Double],
      beta: Double,
      y: Array[Double]): Unit = {
    require(A.numCols <= x.length,
      s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.length}")
    require(A.numRows <= y.length,
      s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.length}")
    if (alpha == 0.0 && beta == 1.0) {
      // gemv: alpha is equal to 0 and beta is equal to 1. Returning y.
    } else if (alpha == 0.0) {
      getBLAS(A.numRows).dscal(A.numRows, beta, y, 1)
    } else {
      A match {
        case smA: SparseMatrix => gemvImpl(alpha, smA, x, beta, y)
        case dmA: DenseMatrix => gemvImpl(alpha, dmA, x, beta, y)
      }
    }
  }

  /**
   * y := alpha * A * x + beta * y
   * @param alpha a scalar to scale the multiplication A * x.
   * @param A the matrix A that will be left multiplied to x. Size of m x n.
   * @param x the vector x that will be left multiplied by A. Size of n x 1.
   * @param beta a scalar that can be used to scale vector y.
   * @param y the resulting vector y. Size of m x 1.
   */
  def gemv(
      alpha: Double,
      A: Matrix,
      x: Vector,
      beta: Double,
      y: DenseVector): Unit = {
    require(A.numCols == x.size,
      s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}")
    require(A.numRows == y.size,
      s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}")
    if (alpha == 0.0 && beta == 1.0) {
      // gemv: alpha is equal to 0 and beta is equal to 1. Returning y.
    } else if (alpha == 0.0) {
      scal(beta, y)
    } else {
      (A, x) match {
        case (smA: SparseMatrix, dvx: DenseVector) =>
          gemv(alpha, smA, dvx, beta, y)
        case (smA: SparseMatrix, svx: SparseVector) =>
          gemv(alpha, smA, svx, beta, y)
        case (dmA: DenseMatrix, dvx: DenseVector) =>
          gemv(alpha, dmA, dvx, beta, y)
        case (dmA: DenseMatrix, svx: SparseVector) =>
          gemv(alpha, dmA, svx, beta, y)
        case _ =>
          throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " +
            s"${A.getClass} and vector type ${x.getClass}.")
      }
    }
  }

  /**
   * y := alpha * A * x + beta * y
   * For `DenseMatrix` A and `DenseVector` x.
   */
  private def gemv(
      alpha: Double,
      A: DenseMatrix,
      x: DenseVector,
      beta: Double,
      y: DenseVector): Unit = {
    gemvImpl(alpha, A, x.values, beta, y.values)
  }

  /**
   * y[0: A.numRows] := alpha * A * x[0: A.numCols] + beta * y[0: A.numRows]
   * For `DenseMatrix` A.
   */
  private def gemvImpl(
      alpha: Double,
      A: DenseMatrix,
      xValues: Array[Double],
      beta: Double,
      yValues: Array[Double]): Unit = {
    val tStrA = if (A.isTransposed) "T" else "N"
    val mA = if (!A.isTransposed) A.numRows else A.numCols
    val nA = if (!A.isTransposed) A.numCols else A.numRows
    nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, xValues, 1, beta,
      yValues, 1)
  }

  /**
   * y := alpha * A * x + beta * y
   * For `DenseMatrix` A and `SparseVector` x.
   */
  private def gemv(
      alpha: Double,
      A: DenseMatrix,
      x: SparseVector,
      beta: Double,
      y: DenseVector): Unit = {
    val mA: Int = A.numRows
    val nA: Int = A.numCols

    val Avals = A.values

    val xIndices = x.indices
    val xNnz = xIndices.length
    val xValues = x.values
    val yValues = y.values

    if (A.isTransposed) {
      var rowCounterForA = 0
      while (rowCounterForA < mA) {
        var sum = 0.0
        var k = 0
        while (k < xNnz) {
          sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA)
          k += 1
        }
        yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA)
        rowCounterForA += 1
      }
    } else {
      var rowCounterForA = 0
      while (rowCounterForA < mA) {
        var sum = 0.0
        var k = 0
        while (k < xNnz) {
          sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA)
          k += 1
        }
        yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA)
        rowCounterForA += 1
      }
    }
  }

  /**
   * y := alpha * A * x + beta * y
   * For `SparseMatrix` A and `SparseVector` x.
   */
  private def gemv(
      alpha: Double,
      A: SparseMatrix,
      x: SparseVector,
      beta: Double,
      y: DenseVector): Unit = {
    val xValues = x.values
    val xIndices = x.indices
    val xNnz = xIndices.length

    val yValues = y.values

    val mA: Int = A.numRows
    val nA: Int = A.numCols

    val Avals = A.values
    val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs
    val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices

    if (A.isTransposed) {
      var rowCounter = 0
      while (rowCounter < mA) {
        var i = Arows(rowCounter)
        val indEnd = Arows(rowCounter + 1)
        var sum = 0.0
        var k = 0
        while (i < indEnd && k < xNnz) {
          if (xIndices(k) == Acols(i)) {
            sum += Avals(i) * xValues(k)
            k += 1
            i += 1
          } else if (xIndices(k) < Acols(i)) {
            k += 1
          } else {
            i += 1
          }
        }
        yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
        rowCounter += 1
      }
    } else {
      if (beta != 1.0) scal(beta, y)

      var colCounterForA = 0
      var k = 0
      while (colCounterForA < nA && k < xNnz) {
        if (xIndices(k) == colCounterForA) {
          var i = Acols(colCounterForA)
          val indEnd = Acols(colCounterForA + 1)

          val xTemp = xValues(k) * alpha
          while (i < indEnd) {
            yValues(Arows(i)) += Avals(i) * xTemp
            i += 1
          }
          k += 1
        }
        colCounterForA += 1
      }
    }
  }

  /**
   * y := alpha * A * x + beta * y
   * For `SparseMatrix` A and `DenseVector` x.
   */
  private def gemv(
      alpha: Double,
      A: SparseMatrix,
      x: DenseVector,
      beta: Double,
      y: DenseVector): Unit = {
    gemvImpl(alpha, A, x.values, beta, y.values)
  }

  /**
   * y[0: A.numRows] := alpha * A * x[0: A.numCols] + beta * y[0: A.numRows]
   * For `SparseMatrix` A.
   */
  private def gemvImpl(
      alpha: Double,
      A: SparseMatrix,
      xValues: Array[Double],
      beta: Double,
      yValues: Array[Double]): Unit = {
    val mA: Int = A.numRows
    val nA: Int = A.numCols

    val Avals = A.values
    val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs
    val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices
    // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
    if (A.isTransposed) {
      var rowCounter = 0
      while (rowCounter < mA) {
        var i = Arows(rowCounter)
        val indEnd = Arows(rowCounter + 1)
        var sum = 0.0
        while (i < indEnd) {
          sum += Avals(i) * xValues(Acols(i))
          i += 1
        }
        yValues(rowCounter) = beta * yValues(rowCounter) + sum * alpha
        rowCounter += 1
      }
    } else {
      if (beta != 1.0) getBLAS(mA).dscal(mA, beta, yValues, 1)
      // Perform matrix-vector multiplication and add to y
      var colCounterForA = 0
      while (colCounterForA < nA) {
        var i = Acols(colCounterForA)
        val indEnd = Acols(colCounterForA + 1)
        val xVal = xValues(colCounterForA) * alpha
        while (i < indEnd) {
          yValues(Arows(i)) += Avals(i) * xVal
          i += 1
        }
        colCounterForA += 1
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy