jacobi.core.op.Mul Maven / Gradle / Ivy
/*
* The MIT License
*
* Copyright 2017 Y.K. Chan
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
package jacobi.core.op;
import jacobi.api.Matrices;
import jacobi.api.Matrix;
import jacobi.api.annotations.Immutate;
import jacobi.core.impl.ColumnVector;
import jacobi.core.util.ParallelSupplier;
import jacobi.core.util.Throw;
import java.util.function.IntConsumer;
import java.util.stream.IntStream;
/**
* Matrix Multiplication operator, i.e. computes C = A * B.
*
* @author Y.K. Chan
*/
@Immutate
public class Mul {
/**
* Default partition size for cache utilization.
*/
public static final int DEFAULT_STRIDE_LENGTH = 8;
/**
* Constructor.
*/
public Mul() {
this(DEFAULT_STRIDE_LENGTH);
}
/**
* Constructor.
* @param stride Partition size
*/
public Mul(int stride) {
this.stride = stride;
this.mulT = new MulT();
}
/**
* Compute matrix C where C = A * B.
* @param a Input matrix A
* @param b Input matrix B
* @return Resultant matrix C
*/
public Matrix compute(Matrix a, Matrix b) {
Throw.when()
.isNull(() -> a, () -> "First operand is missing.")
.isNull(() -> b, () -> "Second operand is missing.")
.isTrue(
() -> a.getColCount() != b.getRowCount(),
() -> "Dimension mismatch. Unable to multiply a "
+ a.getRowCount()+ "x" + a.getColCount()
+ " matrix with a "
+ b.getRowCount()+ "x" + b.getColCount()
+ " matrix.");
if(b instanceof ColumnVector){
return a.getRowCount() == 1
? Matrices.scalar(this.dot(a.getRow(0), ((ColumnVector) b).getVector()))
: this.mulVector(a, (ColumnVector) b);
}
Matrix ans = Matrices.zeros(a.getRowCount(), b.getColCount());
this.compute(a, this.copy(b), ans);
return ans;
}
/**
* Fill the entries of resultant matrix C where C = A * B.
* @param a Input matrix A
* @param b Input matrix B
* @param ans Resultant matrix C
*/
protected void compute(Matrix a, Matrix b, Matrix ans) {
long numFlop = ((long) ans.getRowCount() * ans.getColCount()) * a.getColCount();
if(a.getRowCount() < this.stride || numFlop < ParallelSupplier.DEFAULT_FLOP_THRESHOLD ){
this.serial(a, b, ans);
}else{
this.parallel(a, b, ans);
}
}
/**
* Fill the entries of resultant matrix C where C = A * B in serial.
* @param a Input matrix A
* @param b Input matrix B
* @param ans Resultant matrix C
*/
protected void serial(Matrix a, Matrix b, Matrix ans) {
for(int i = 0; i < ans.getRowCount(); i++){
double[] u = a.getRow(i);
double[] v = ans.getRow(i);
this.computeRow(u, b, v);
ans.setRow(i, v);
}
}
/**
* Fill the entries of resultant matrix C where C = A * B in parallel.
* @param a Input matrix A
* @param b Input matrix B
* @param ans Resultant matrix C
*/
protected void parallel(Matrix a, Matrix b, Matrix ans) {
int numThreads = Math.min(this.stride, ParallelSupplier.DEFAULT_NUM_THREADS);
IntConsumer task = (i) -> ans.getAndSet(i, (r) -> this.computeRow(a.getRow(i), b, r));
ParallelSupplier.cyclic(task, 0, a.getRowCount(), numThreads);
}
/**
* Compute v = u * B, where u and v are vectors and B is a matrix.
* @param u Input vector u
* @param b Input matrix B
* @param v Output vector v
*/
protected void computeRow(double[] u, Matrix b, double[] v) {
int numRows = u.length % this.stride == 0 ? u.length : (1 + u.length / this.stride) * this.stride;
int numCols = v.length % this.stride == 0 ? v.length : (1 + v.length / this.stride) * this.stride;
for(int i = 0; i < numRows; i += this.stride){
int rowSpan = Math.min(this.stride, u.length - i);
for(int j = 0; j < numCols; j += this.stride){
int colSpan = Math.min(this.stride, v.length - j);
this.computeBlock(u, b, i, i + rowSpan, j, j + colSpan, v);
}
}
}
/**
* Compute v[p:q] = v[p:q] + u[a:b] * B[a:b, p:q], where B[a:b, p:q] is a selected sub-matrix of B.
* @param u Input vector u
* @param b Input matrix B
* @param rowBegin Begin of rows selected
* @param rowEnd End of rows selected
* @param colBegin Begin of columns selected
* @param colEnd End of columns selected
* @param v Output vector v
*/
protected void computeBlock(double[] u, Matrix b, int rowBegin, int rowEnd, int colBegin, int colEnd, double[] v) {
for(int i = rowBegin; i < rowEnd; i++){
double elem = u[i];
for(int j = colBegin; j < colEnd; j++){
double[] r = b.getRow(i);
v[j] += elem * r[j];
}
}
}
/**
* Compute A * b where b is a column vector.
* @param a Input matrix A
* @param b Input column vector b
* @return A * b
*/
protected Matrix mulVector(Matrix a, ColumnVector b) {
return this.mulT.compute(a, Matrices.wrap(new double[][]{ b.getVector() }));
}
/**
* Compute the dot product of vector u and v.
* @param u Input vector u
* @param v Input vector v
* @return Dot product
*/
protected double dot(double[] u, double[] v) {
double ans = 0.0;
for(int i = 0; i < u.length; i++){
ans += u[i] * v[i];
}
return ans;
}
/**
* Shallow copy of a matrix, i.e. only row references are copied.
* The copy of the matrix should not be mutated.
* @param mat Matrix to be copied
* @return Shallow copy of a matrix
*/
protected Matrix copy(Matrix mat) {
return Matrices.wrap(IntStream.range(0, mat.getRowCount())
.mapToObj((i) -> mat.getRow(i))
.toArray((n) -> new double[n][]));
}
private int stride;
private MulT mulT;
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy