org.apache.flink.ml.math.BLAS.scala Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.math
import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS}
import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}
/**
* BLAS routines for vectors and matrices.
*
* Original code from the Apache Spark project:
* http://git.io/vfZUe
*/
object BLAS extends Serializable {
@transient private var _f2jBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _
// For level-1 routines, we use Java implementation.
private def f2jBLAS: NetlibBLAS = {
if (_f2jBLAS == null) {
_f2jBLAS = new F2jBLAS
}
_f2jBLAS
}
/**
* y += a * x
*/
def axpy(a: Double, x: Vector, y: Vector): Unit = {
require(x.size == y.size)
y match {
case dy: DenseVector =>
x match {
case sx: SparseVector =>
axpy(a, sx, dy)
case dx: DenseVector =>
axpy(a, dx, dy)
case _ =>
throw new UnsupportedOperationException(
s"axpy doesn't support x type ${x.getClass}.")
}
case _ =>
throw new IllegalArgumentException(
s"axpy only supports adding to a dense vector but got type ${y.getClass}.")
}
}
/**
* y += a * x
*/
private def axpy(a: Double, x: DenseVector, y: DenseVector): Unit = {
val n = x.size
f2jBLAS.daxpy(n, a, x.data, 1, y.data, 1)
}
/**
* y += a * x
*/
private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = {
val xValues = x.data
val xIndices = x.indices
val yValues = y.data
val nnz = xIndices.size
if (a == 1.0) {
var k = 0
while (k < nnz) {
yValues(xIndices(k)) += xValues(k)
k += 1
}
} else {
var k = 0
while (k < nnz) {
yValues(xIndices(k)) += a * xValues(k)
k += 1
}
}
}
/**
* dot(x, y)
*/
def dot(x: Vector, y: Vector): Double = {
require(x.size == y.size,
"BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" +
" x.size = " + x.size + ", y.size = " + y.size)
(x, y) match {
case (dx: DenseVector, dy: DenseVector) =>
dot(dx, dy)
case (sx: SparseVector, dy: DenseVector) =>
dot(sx, dy)
case (dx: DenseVector, sy: SparseVector) =>
dot(sy, dx)
case (sx: SparseVector, sy: SparseVector) =>
dot(sx, sy)
case _ =>
throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).")
}
}
/**
* dot(x, y)
*/
private def dot(x: DenseVector, y: DenseVector): Double = {
val n = x.size
f2jBLAS.ddot(n, x.data, 1, y.data, 1)
}
/**
* dot(x, y)
*/
private def dot(x: SparseVector, y: DenseVector): Double = {
val xValues = x.data
val xIndices = x.indices
val yValues = y.data
val nnz = xIndices.size
var sum = 0.0
var k = 0
while (k < nnz) {
sum += xValues(k) * yValues(xIndices(k))
k += 1
}
sum
}
/**
* dot(x, y)
*/
private def dot(x: SparseVector, y: SparseVector): Double = {
val xValues = x.data
val xIndices = x.indices
val yValues = y.data
val yIndices = y.indices
val nnzx = xIndices.size
val nnzy = yIndices.size
var kx = 0
var ky = 0
var sum = 0.0
// y catching x
while (kx < nnzx && ky < nnzy) {
val ix = xIndices(kx)
while (ky < nnzy && yIndices(ky) < ix) {
ky += 1
}
if (ky < nnzy && yIndices(ky) == ix) {
sum += xValues(kx) * yValues(ky)
ky += 1
}
kx += 1
}
sum
}
/**
* y = x
*/
def copy(x: Vector, y: Vector): Unit = {
val n = y.size
require(x.size == n)
y match {
case dy: DenseVector =>
x match {
case sx: SparseVector =>
val sxIndices = sx.indices
val sxValues = sx.data
val dyValues = dy.data
val nnz = sxIndices.size
var i = 0
var k = 0
while (k < nnz) {
val j = sxIndices(k)
while (i < j) {
dyValues(i) = 0.0
i += 1
}
dyValues(i) = sxValues(k)
i += 1
k += 1
}
while (i < n) {
dyValues(i) = 0.0
i += 1
}
case dx: DenseVector =>
Array.copy(dx.data, 0, dy.data, 0, n)
}
case _ =>
throw new IllegalArgumentException(s"y must be dense in copy but got ${y.getClass}")
}
}
/**
* x = a * x
*/
def scal(a: Double, x: Vector): Unit = {
x match {
case sx: SparseVector =>
f2jBLAS.dscal(sx.data.size, a, sx.data, 1)
case dx: DenseVector =>
f2jBLAS.dscal(dx.data.size, a, dx.data, 1)
case _ =>
throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.")
}
}
// For level-3 routines, we use the native BLAS.
private def nativeBLAS: NetlibBLAS = {
if (_nativeBLAS == null) {
_nativeBLAS = NativeBLAS
}
_nativeBLAS
}
/**
* A := alpha * x * x^T^ + A
* @param alpha a real scalar that will be multiplied to x * x^T^.
* @param x the vector x that contains the n elements.
* @param A the symmetric matrix A. Size of n x n.
*/
def syr(alpha: Double, x: Vector, A: DenseMatrix) {
val mA = A.numRows
val nA = A.numCols
require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA")
require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}")
x match {
case dv: DenseVector => syr(alpha, dv, A)
case sv: SparseVector => syr(alpha, sv, A)
case _ =>
throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.")
}
}
private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
val nA = A.numRows
val mA = A.numCols
nativeBLAS.dsyr("U", x.size, alpha, x.data, 1, A.data, nA)
// Fill lower triangular part of A
var i = 0
while (i < mA) {
var j = i + 1
while (j < nA) {
A(j, i) = A(i, j)
j += 1
}
i += 1
}
}
private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) {
val mA = A.numCols
val xIndices = x.indices
val xValues = x.data
val nnz = xValues.length
val Avalues = A.data
var i = 0
while (i < nnz) {
val multiplier = alpha * xValues(i)
val offset = xIndices(i) * mA
var j = 0
while (j < nnz) {
Avalues(xIndices(j) + offset) += multiplier * xValues(j)
j += 1
}
i += 1
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy