All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.JCublasWrapper Maven / Gradle / Ivy

There is a newer version: 0.4-rc3.7
Show newest version
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 *
 */

package org.nd4j.linalg.jcublas;

import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;


/**
 * Blas wrapper for JCUDA
 *
 * @author mjk
 * @author Adam Gibson
 */
public class JCublasWrapper implements org.nd4j.linalg.factory.BlasWrapper {
    @Override
    public INDArray swap(INDArray x, INDArray y) {
        SimpleJCublas.swap(x, y);
        return y;
    }

    @Override
    public INDArray scal(double alpha, INDArray x) {
        return SimpleJCublas.scal(alpha, x);

    }

    @Override
    public INDArray scal(float alpha, INDArray x) {
        SimpleJCublas.scal(alpha, x);
        return x;
    }


    @Override
    public IComplexNDArray scal(IComplexFloat alpha, IComplexNDArray x) {
        return SimpleJCublas.scal(alpha, x);

    }

    @Override
    public IComplexNDArray scal(IComplexDouble alpha, IComplexNDArray x) {
        return SimpleJCublas.scal(alpha, x);

    }

    @Override
    public INDArray copy(INDArray x, INDArray y) {
        SimpleJCublas.copy(x, y);
        return y;
    }

    @Override
    public IComplexNDArray copy(IComplexNDArray x, IComplexNDArray y) {
        SimpleJCublas.copy(x, y);
        return y;
    }

    @Override
    public INDArray axpy(double da, INDArray dx, INDArray dy) {
        SimpleJCublas.axpy(da, dx, dy);
        return dy;
    }

    @Override
    public INDArray axpy(float da, INDArray dx, INDArray dy) {
        SimpleJCublas.axpy(da, dx, dy);
        return dy;
    }


    @Override
    public IComplexNDArray axpy(IComplexNumber da, IComplexNDArray dx, IComplexNDArray dy) {
        if (da instanceof IComplexDouble) {
            SimpleJCublas.axpy((IComplexDouble) da, dx, dy);

        } else
            SimpleJCublas.axpy((IComplexFloat) da, dx, dy);

        return dy;
    }

    public double dot(INDArray x, INDArray y) {

        return SimpleJCublas.dot(x, y);
    }

    //@Override
    public double dotd(INDArray x, INDArray y) {
        return SimpleJCublas.dot(x, y);
    }

    @Override
    public IComplexNumber dotc(IComplexNDArray x, IComplexNDArray y) {
        return SimpleJCublas.dot(x, y);
    }

    @Override
    public IComplexNumber dotu(IComplexNDArray x, IComplexNDArray y) {
        return SimpleJCublas.dotu(x, y);
    }

    @Override
    public double nrm2(INDArray x) {
        return SimpleJCublas.nrm2(x);
    }

    @Override
    public double nrm2(IComplexNDArray x) {
        return SimpleJCublas.nrm2(x);
    }

    @Override
    public double asum(INDArray x) {
        return SimpleJCublas.asum(x);

    }

    @Override
    public double asum(IComplexNDArray x) {
        return SimpleJCublas.asum(x);
    }

    @Override
    public int iamax(INDArray x) {
        return SimpleJCublas.iamax(x);
    }

    @Override
    public int iamax(IComplexNDArray x) {
        return SimpleJCublas.iamax(x);
    }

    @Override
    public INDArray gemv(double alpha, INDArray a, INDArray x, double beta, INDArray y) {
        SimpleJCublas.gemv(a, x, y, alpha, beta);
        return y;
    }

    @Override
    public INDArray gemv(float alpha, INDArray a, INDArray x, float beta, INDArray y) {
        return SimpleJCublas.gemv(a, x, y, alpha, beta);
    }

    @Override
    public INDArray ger(double alpha, INDArray x, INDArray y, INDArray a) {
        return SimpleJCublas.ger(x, y, a, alpha);
    }

    @Override
    public INDArray ger(float alpha, INDArray x, INDArray y, INDArray a) {
        return SimpleJCublas.ger(x, y, a, alpha);
    }

    @Override
    public IComplexNDArray gemv(IComplexDouble alpha, IComplexNDArray a, IComplexNDArray x, IComplexDouble beta, IComplexNDArray y) {
        return SimpleJCublas.gemv(a, x, alpha, y, beta);

    }

    @Override
    public IComplexNDArray gemv(IComplexFloat alpha, IComplexNDArray a, IComplexNDArray x, IComplexFloat beta, IComplexNDArray y) {
        return SimpleJCublas.gemv(a, x, alpha, y, beta);
    }

    @Override
    public IComplexNDArray geru(IComplexDouble alpha, IComplexNDArray x, IComplexNDArray y, IComplexNDArray a) {
        throw new UnsupportedOperationException();
    }


    @Override
    public IComplexNDArray geru(IComplexFloat alpha, IComplexNDArray x, IComplexNDArray y, IComplexNDArray a) {
        return SimpleJCublas.geru(x, y, a, alpha.asDouble());
    }

    @Override
    public IComplexNDArray gerc(IComplexFloat alpha, IComplexNDArray x, IComplexNDArray y, IComplexNDArray a) {
        return SimpleJCublas.gerc(x, y, a, alpha.asDouble());
    }

    @Override
    public IComplexNDArray gerc(IComplexDouble alpha, IComplexNDArray x, IComplexNDArray y, IComplexNDArray a) {
        throw new UnsupportedOperationException();

    }

    @Override
    public INDArray gemm(double alpha, INDArray a, INDArray b, double beta, INDArray c) {
        return SimpleJCublas.gemm(a, b, c, alpha, beta);
    }

    @Override
    public INDArray gemm(float alpha, INDArray a, INDArray b, float beta, INDArray c) {
        return SimpleJCublas.gemm(a, b, c, alpha, beta);
    }


    @Override
    public IComplexNDArray gemm(IComplexNumber alpha, IComplexNDArray a, IComplexNDArray b, IComplexNumber beta, IComplexNDArray c) {
        if (beta instanceof IComplexDouble)
            SimpleJCublas.gemm(a, b, alpha.asDouble(), c, beta.asDouble());
        else
            SimpleJCublas.gemm(a, b, alpha.asFloat(), c, beta.asFloat());

        return c;
    }

    @Override
    public INDArray gesv(INDArray a, int[] ipiv, INDArray b) {
        throw new UnsupportedOperationException();

    }

    @Override
    public void checkInfo(String name, int info) {
        throw new UnsupportedOperationException();

    }

    @Override
    public INDArray sysv(char uplo, INDArray a, int[] ipiv, INDArray b) {
        throw new UnsupportedOperationException();

    }

    @Override
    public int syev(char jobz, char uplo, INDArray a, INDArray w) {
        return 0;
    }

    @Override
    public int syevx(char jobz, char range, char uplo, INDArray a, float vl, float vu, int il, int iu, float abstol, INDArray w, INDArray z) {
        throw new UnsupportedOperationException();

    }

    //@Override
    public int syevx(char jobz, char range, char uplo, INDArray a, double vl, double vu, int il, int iu, double abstol, INDArray w, INDArray z) {
        throw new UnsupportedOperationException();

    }

    @Override
    public int syevd(char jobz, char uplo, INDArray A, INDArray w) {
        return 0;
    }

    @Override
    public int syevr(char jobz, char range, char uplo, INDArray a, float vl, float vu, int il, int iu, float abstol, INDArray w, INDArray z, int[] isuppz) {
        throw new UnsupportedOperationException();
    }

    //@Override
    public int syevr(char jobz, char range, char uplo, INDArray a, double vl, double vu, int il, int iu, double abstol, INDArray w, INDArray z, int[] isuppz) {
        throw new UnsupportedOperationException();

    }

    @Override
    public void posv(char uplo, INDArray A, INDArray B) {
        throw new UnsupportedOperationException();

    }

    @Override
    public int geev(char jobvl, char jobvr, INDArray A, INDArray WR, INDArray WI, INDArray VL, INDArray VR) {
        throw new UnsupportedOperationException();

    }

    @Override
    public int sygvd(int itype, char jobz, char uplo, INDArray A, INDArray B, INDArray W) {
        throw new UnsupportedOperationException();

    }

    @Override
    public void gelsd(INDArray A, INDArray B) {
        throw new UnsupportedOperationException();

    }

    @Override
    public void geqrf(INDArray A, INDArray tau) {
        throw new UnsupportedOperationException();

    }

    @Override
    public void ormqr(char side, char trans, INDArray A, INDArray tau, INDArray C) {
        throw new UnsupportedOperationException();

    }

    @Override
    public void dcopy(int n, float[] dx, int dxIdx, int incx, float[] dy, int dyIdx, int incy) {
        throw new UnsupportedOperationException();

    }

    @Override
    public void saxpy(double alpha, INDArray x, INDArray y) {
        SimpleJCublas.axpy(alpha, x, y);
    }

    /**
     * Abstraction over saxpy
     *
     * @param alpha the alpha to scale by
     * @param x     the ndarray to use
     * @param y     the ndarray to use
     */
    @Override
    public void saxpy(float alpha, INDArray x, INDArray y) {
        SimpleJCublas.saxpy(alpha, x, y);
    }


    /*
    missing functions
        gesv
        sysv
        syev
        syevx
        syevd
        syevr
        posv
        geev
        sygvd
        gelsd
        geqrf
        ormqr

     */
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy