All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.romualdrousseau.shuju.math.Blas Maven / Gradle / Ivy

Go to download

Collection of various algorithms implemented in Java covering domains such as mathematics and data science.

There is a newer version: 1.29.7
Show newest version
package com.github.romualdrousseau.shuju.math;

import java.util.stream.IntStream;

public class Blas {

    private static int ncpu = 1;
    static {
        Blas.ncpu = Runtime.getRuntime().availableProcessors();
    }

    public static void fgemv(final boolean tranA, final float[][] a, final float alpha, final float[] b,
            final float beta, final float[] c) {
        if (alpha == 0.0f && beta == 1.0f) {
            return;
        }

        final int M = tranA ? a[0].length : a.length;
        final int K = b.length;
        final int K_ = tranA ? a.length : a[0].length;

        // Check dimensionalities

        assert (K == K_) : "Illegal Dimension";
        assert (c.length == M) : "Illegal Dimension";

        // bC = C

        if (beta == 0.0f) {
            for (int i = 0; i < M; i++) {
                c[i] = 0.0f;
            }
        } else {
            for (int i = 0; i < M; i++) {
                c[i] *= beta;
            }
        }

        if (alpha == 0.0f) {
            return;
        }

        if (tranA) {

            // aA*@B + bC = C

            for (int j = 0; j < K; j++) {
                final float[] a_j = a[j];
                final float b_j = alpha * b[j];
                for (int i = 0; i < M; i++) {
                    c[i] += a_j[i] * b_j;
                }
            }
        } else {

            // aA@B + bC = C

            for (int i = 0; i < M; i++) {
                final float[] a_i = a[i];
                for (int j = 0; j < K; j++) {
                    c[i] += alpha * a_i[j] * b[j];
                }
            }
        }
    }

    public static void fgemm(final boolean tranA, final boolean tranB, final float[][] a, final float alpha,
            final float[][] b, final float beta, final float[][] c) {
        if (alpha == 0.0f && beta == 1.0f) {
            return;
        }

        // bC = C

        if (beta == 0.0f) {
            for (int i = 0; i < c.length; i++) {
                final float[] c_i = c[i];
                for (int j = 0; j < c_i.length; j++) {
                    c_i[j] = 0.0f;
                }
            }
        } else if (beta != 1.0f) {
            for (int i = 0; i < c.length; i++) {
                final float[] c_i = c[i];
                for (int j = 0; j < c_i.length; j++) {
                    c_i[j] *= beta;
                }
            }
        }

        if (alpha == 0.0f) {
            return;
        }

        final int M = tranA ? a[0].length : a.length;
        final int K = tranA ? a.length : a[0].length;

        if (b == null) {
            assert (c.length == M && c[0].length == K) : "Illegal Dimension";

            if (tranA) {

                // aA* + bC = C

                for (int j = 0; j < K; j++) {
                    final float[] a_j = a[j];
                    for (int i = 0; i < M; i++) {
                        c[i][j] += alpha * a_j[i];
                    }
                }
            } else {

                // aA + bC = C

                for (int i = 0; i < M; i++) {
                    final float[] a_i = a[i];
                    final float[] c_i = c[i];
                    for (int j = 0; j < K; j++) {
                        c_i[j] += alpha * a_i[j];
                    }
                }
            }
        } else {
            final int N = tranB ? b.length : b[0].length;
            final int K_ = tranB ? b[0].length : b.length;

            // Check dimensionalities

            assert (K == K_) : "Illegal Dimension";
            assert (c.length == M && c[0].length == N) : "Illegal Dimension";

            final int count = Math.min(1 + (M * N) / (64 * 64), Blas.ncpu);

            if (count <= 1) {

                // Single core calculation

                Blas.fgemm_kernel_cpu(tranA, tranB, M, N, K, a, 0, alpha, b, 0, beta, c, 0);
            } else {
                final int stride = M / count + 1;

                // Muti core calculation

                IntStream.rangeClosed(0, count).map(i -> i * stride).parallel().forEach(i -> {
                    final int remaining = Math.min(M - i, stride);
                    Blas.fgemm_kernel_cpu(tranA, tranB, remaining, N, K, a, i, alpha, b, 0, beta, c, i);
                });
            }
        }
    }

    private static void fgemm_kernel_cpu(final boolean tranA, final boolean tranB, final int M, final int N, final int K,
            final float[][] a, int offA, final float alpha, final float[][] b, int offB, final float beta,
            final float[][] c, int offC) {

        if (tranA && tranB) {
            final float[] cT_i = new float[N];

            // aA*@B* = C, but faster to compute a(B@A)* = C

            for (int i = 0; i < M; i++) {

                // Transpose one column of C_i => C_i*

                for (int j = 0; j < N; j++) {
                    cT_i[j] = c[j][offC + i];
                }

                // Compute one line B_i@A => C_i*

                final float[] b_i = b[offA + i];
                for (int k = 0; k < K; k++) {
                    final float b_ik = alpha * b_i[k];
                    final float[] a_k = a[offB + k];
                    for (int j = 0; j < N; j++) {
                        cT_i[j] += a_k[j] * b_ik;
                    }
                }

                // Transpose back one row of C_i* => C_i

                for (int j = 0; j < N; j++) {
                    c[j][offC + i] = cT_i[j] + c[j][offC + i];
                }
            }
        } else if (tranA && !tranB) {
            final float[] aT_i = new float[K];

            // aA@B* = C

            for (int i = 0; i < M; i++) {

                // Transpose one column of A_i => A_i*

                for (int k = 0; k < K; k++) {
                    aT_i[k] = a[k][offA + i];
                }

                // Compute one line A_i*@B = C_i

                final float[] c_i = c[offC + i];
                for (int k = 0; k < K; k++) {
                    final float aT_ik = alpha * aT_i[k];
                    final float[] b_k = b[offB + k];
                    for (int j = 0; j < N; j++) {
                        c_i[j] += aT_ik * b_k[j];
                    }
                }
            }
        } else if (!tranA && tranB) {
            final float[] bT_k = new float[N];

            // aA*@B = C

            for (int k = 0; k < K; k++) {

                // Transpose one column of B_i => B_i*

                for (int j = 0; j < N; j++) {
                    bT_k[j] = b[j][offB + k];
                }

                // Compute one line A@B_i* = C_i

                for (int i = 0; i < M; i++) {
                    final float a_ik = alpha * a[offA + i][k];
                    final float[] c_i = c[offC + i];
                    for (int j = 0; j < N; j++) {
                        c_i[j] += a_ik * bT_k[j];
                    }
                }
            }
        } else {

            // aA@B = C

            for (int i = 0; i < M; i++) {

                // Compute one line A_i@B = C_i

                final float[] a_i = a[offA + i];
                final float[] c_i = c[offC + i];
                for (int k = 0; k < K; k++) {
                    final float a_ik = alpha * a_i[k];
                    final float[] b_k = b[offB + k];
                    for (int j = 0; j < N; j++) {
                        c_i[j] += a_ik * b_k[j];
                    }
                }
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy