org.nd4j.linalg.jcublas.util.CudaArgs Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.jcublas.util;
import org.nd4j.shade.guava.collect.ArrayListMultimap;
import org.nd4j.shade.guava.collect.Multimap;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
/**
* Handles conversion of
* arguments passed to jcuda
* to their proper primitives
* when invoked with pointers.
*
* @author Adam Gibson
*/
public class CudaArgs {
private CudaArgs() {}
/**
* For invoking a cuda kernel
* this returns the module opName for the given op
* @param op the op to get the module opName for
* @return the module opName for the given op
*/
public static String getModuleNameFor(Op op) {
//String functionName = op instanceof TransformOp || op instanceof ReduceOp || op instanceof IndexAccumulation ? op.opName() + "_strided" : op.opName();
String moduleName = null;
if (op instanceof ReduceOp) {
moduleName = "reduce";
// FIXME: special case for reduce3
if (op.opName().equals("cosinesimilarity")) {
moduleName = "reduce3";
} else if (op.opName().equals("euclidean")) {
moduleName = "reduce3";
} else if (op.opName().equals("manhattan")) {
moduleName = "reduce3";
}
} else if (op instanceof TransformOp) {
// FIXME: we need special case for pairwise transforms for now. Later we should make them separate kernel call
if (op.opName().equals("add")) {
moduleName = "pairWiseTransform";
} else if (op.opName().equals("copy")) {
moduleName = "pairWiseTransform";
} else if (op.opName().equals("div")) {
moduleName = "pairWiseTransform";
} else if (op.opName().equals("mul")) {
moduleName = "pairWiseTransform";
} else if (op.opName().equals("rdiv")) {
moduleName = "pairWiseTransform";
} else if (op.opName().equals("rsub")) {
moduleName = "pairWiseTransform";
} else if (op.opName().equals("sub")) {
moduleName = "pairWiseTransform";
} else {
moduleName = "transform";
}
} else if (op instanceof ScalarOp) {
moduleName = "scalar";
} else if (op instanceof BroadcastOp) {
moduleName = "broadcast";
} else if (op instanceof IndexAccumulation) {
moduleName = "indexReduce";
}
return moduleName;
}
public static int getOpCode(Op op) {
int code = -1;
String name = op.opName();
if (op instanceof ReduceOp) {
if (name.equals("mean")) {
code = 0;
} else if (name.equals("sum")) {
code = 1;
} else if (name.equals("bias")) {
code = 2;
} else if (name.equals("max")) {
code = 3;
} else if (name.equals("min")) {
code = 4;
} else if (name.equals("norm1")) {
code = 5;
} else if (name.equals("norm2")) {
code = 6;
} else if (name.equals("normmax")) {
code = 7;
} else if (name.equals("prod")) {
code = 8;
} else if (name.equals("std")) {
code = 9;
} else if (name.equals("var")) {
code = 10;
// FIXME: special case for reduce3
} else if (name.equals("manhattan")) {
code = 0;
} else if (name.equals("euclidean")) {
code = 1;
} else if (name.equals("cosinesimilarity")) {
code = 2;
}
} else if (op instanceof TransformOp) {
if (name.equals("abs")) {
code = 0;
} else if (name.equals("ceil")) {
code = 1;
} else if (name.equals("cos")) {
code = 2;
} else if (name.equals("exp")) {
code = 3;
} else if (name.equals("floor")) {
code = 4;
} else if (name.equals("log")) {
code = 5;
} else if (name.equals("neg")) {
code = 6;
} else if (name.equals("pow")) {
code = 7;
} else if (name.equals("round")) {
code = 8;
} else if (name.equals("setrange")) {
code = 9;
} else if (name.equals("sigmoid")) {
code = 10;
} else if (name.equals("sign")) {
code = 11;
} else if (name.equals("sin")) {
code = 12;
} else if (name.equals("softplus")) {
code = 13;
} else if (name.equals("sqrt")) {
code = 14;
} else if (name.equals("tanh")) {
code = 15;
} else if (name.equals("acos")) {
code = 16;
} else if (name.equals("asin")) {
code = 17;
} else if (name.equals("atan")) {
code = 18;
// FIXME: we need special case for pairwise transforms for now. Later we should make them separate kernel call
} else if (name.equals("add")) {
code = 0;
} else if (name.equals("copy")) {
code = 1;
} else if (name.equals("div")) {
code = 2;
} else if (name.equals("eq")) {
code = 3;
} else if (name.equals("gt")) {
code = 4;
} else if (name.equals("lt")) {
code = 5;
} else if (name.equals("mul")) {
code = 6;
} else if (name.equals("rdiv")) {
code = 7;
} else if (name.equals("rsub")) {
code = 8;
} else if (name.equals("sub")) {
code = 9;
} else if (name.equals("eps")) {
code = 10;
} else if (name.equals("gte")) {
code = 11;
} else if (name.equals("lte")) {
code = 12;
} else if (name.equals("max")) {
code = 13;
} else if (name.equals("min")) {
code = 14;
} else if (name.equals("neq")) {
code = 15;
}
} else if (op instanceof ScalarOp) {
if (name.startsWith("add")) {
code = 0;
} else if (name.startsWith("sub")) {
code = 1;
} else if (name.startsWith("mul")) {
code = 2;
} else if (name.startsWith("div")) {
code = 3;
} else if (name.startsWith("rdiv")) {
code = 4;
} else if (name.startsWith("rsub")) {
code = 5;
} else if (name.startsWith("max")) {
code = 6;
} else if (name.startsWith("lessthan")) {
code = 7;
} else if (name.startsWith("greaterthan")) {
code = 8;
} else if (name.startsWith("eq")) {
code = 9;
} else if (name.startsWith("lte")) {
code = 10;
} else if (name.startsWith("neq")) {
code = 11;
} else if (name.startsWith("min")) {
code = 12;
} else if (name.startsWith("set")) {
code = 13;
}
} else if (op instanceof BroadcastOp) {
if (name.equals("broadcastadd")) {
code = 0;
} else if (name.equals("broadcastsub")) {
code = 1;
} else if (name.equals("broadcastmul")) {
code = 2;
} else if (name.equals("broadcastdiv")) {
code = 3;
} else if (name.equals("broadcastrdiv")) {
code = 4;
} else if (name.equals("broadcastrsub")) {
code = 5;
} else if (name.equals("broadcastcopy")) {
code = 6;
}
} else if (op instanceof IndexAccumulation) {
if (name.equals("imax")) {
code = 0;
} else if (name.equals("imin")) {
code = 1;
}
}
// System.out.println("CALLING ["+getModuleNameFor(op)+"] -> ["+code+"]");
return code;
}
/**
* Returns number of SMs, based on device compute capability and number of processors.
*
* @param ccMajor
* @param ccMinor
* @return
*/
public static int convertMPtoCores(int ccMajor, int ccMinor, int numberOfProcessors) {
// Defines for GPU Architecture types (using the SM version to determine the # of cores per SM
if (ccMajor == 1)
return 8;
if (ccMajor == 2 && ccMinor == 1)
return 48;
if (ccMajor == 2)
return 32;
if (ccMajor == 3)
return 192;
if (ccMajor == 5)
return 128;
// return negative number if device is unknown
return -1;
}
/**
*
* @param context
* @param kernelParams
* @return
*/
public static ArgsAndReferences argsAndReference(CudaContext context, Object... kernelParams) {
// Map
© 2015 - 2025 Weber Informatics LLC | Privacy Policy