All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.util.PointerUtil Maven / Gradle / Ivy

There is a newer version: 0.4-rc3.7
Show newest version
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 *
 */

package org.nd4j.linalg.jcublas.util;

import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.cuComplex;
import jcuda.cuDoubleComplex;
import jcuda.driver.CUdeviceptr;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DoubleBuffer;
import org.nd4j.linalg.api.ops.ScalarOp;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;

import static jcuda.driver.JCudaDriver.cuMemAlloc;

/**
 * Various methods for pointer based methods (mainly for the jcuda executioner)
 *
 * @author Adam Gibson
 */
public class PointerUtil {


    //convert an object array to doubles
    public static double[] toDoubles(Object[] extraArgs) {
        double[] ret = new double[extraArgs.length];
        for (int i = 0; i < extraArgs.length; i++) {
            ret[i] = Double.valueOf(extraArgs[i].toString());
        }

        return ret;
    }

    /**
     * Get the pointer for a single complex float
     * @param x the number ot get the pointer for
     * @return the pointer for the given complex number
     */
    public static Pointer getPointer(cuDoubleComplex x) {
        ByteBuffer byteBufferx = ByteBuffer.allocateDirect(8 * 2);
        byteBufferx.order(ByteOrder.nativeOrder());
        java.nio.DoubleBuffer floatBufferx = byteBufferx.asDoubleBuffer();
        floatBufferx.put(0,x.x);
        floatBufferx.put(1,x.y);
        return Pointer.to(floatBufferx);
    }
    /**
     * Get the pointer for a single complex float
     * @param x the number ot get the pointer for
     * @return the pointer for the given complex number
     */
    public static Pointer getPointer(cuComplex x) {
        ByteBuffer byteBufferx = ByteBuffer.allocateDirect(4 * 2);
        byteBufferx.order(ByteOrder.nativeOrder());
        FloatBuffer floatBufferx = byteBufferx.asFloatBuffer();
        floatBufferx.put(0,x.x);
        floatBufferx.put(1,x.y);
        return Pointer.to(floatBufferx);
    }


    //convert a float array to floats
    public static float[] toFloats(Object[] extraArgs) {
        float[] ret = new float[extraArgs.length];
        for (int i = 0; i < extraArgs.length; i++) {
            ret[i] = Float.valueOf(extraArgs[i].toString());
        }

        return ret;
    }


    /**
     * Compute the number of blocks that should be used for the
     * given input size and limits
     *
     * @param n          The input size
     * @param maxBlocks  The maximum number of blocks
     * @param maxThreads The maximum number of threads
     * @return The number of blocks
     */
    public static int getNumBlocks(int n, int maxBlocks, int maxThreads) {
        int blocks;
        int threads = getNumThreads(n, maxThreads);
        blocks = (n + (threads * 2 - 1)) / (threads * 2);
        blocks = Math.min(maxBlocks, blocks);
        return blocks;
    }

    /**
     * Compute the number of threads that should be used for the
     * given input size and limits
     *
     * @param n          The input size
     * @param maxThreads The maximum number of threads
     * @return The number of threads
     */
    public static int getNumThreads(int n, int maxThreads) {
        return (n < maxThreads * 2) ? nextPow2((n + 1) / 2) : maxThreads;
    }

    /**
     * Returns the power of 2 that is equal to or greater than x
     *
     * @param x The input
     * @return The next power of 2
     */
    public static int nextPow2(int x) {
        --x;
        x |= x >> 1;
        x |= x >> 2;
        x |= x >> 4;
        x |= x >> 8;
        x |= x >> 16;
        return ++x;
    }

    /**
     * Construct and allocate a device pointer
     *
     * @param length the length of the pointer
     * @param dType  the data type to choose
     * @return the new pointer
     */
    public static CUdeviceptr constructAndAlloc(int length, DataBuffer.Type dType) {
        // Allocate device output memory
        CUdeviceptr deviceOutput = new CUdeviceptr();
        cuMemAlloc(deviceOutput, length * (dType == DataBuffer.Type.FLOAT ? Sizeof.FLOAT : Sizeof.DOUBLE));
        return deviceOutput;
    }

    public static int sizeFor(DataBuffer.Type dataType) {
        return dataType == DataBuffer.Type.DOUBLE ? Sizeof.DOUBLE : Sizeof.FLOAT;
    }


    public static Object getPointer(ScalarOp scalarOp) {
        if (scalarOp.scalar() != null) {
            if (scalarOp.x().data().dataType() == DataBuffer.Type.FLOAT)
                return new float[]{scalarOp.scalar().floatValue()};
            else if (scalarOp.x().data().dataType() == DataBuffer.Type.DOUBLE)
                return new double[]{scalarOp.scalar().doubleValue()};
        }

        throw new IllegalStateException("Unable to get pointer for scalar operation " + scalarOp);
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy