
com.nativelibs4java.opencl.generator.JavaCLGenerator Maven / Gradle / Ivy
The newest version!
package com.nativelibs4java.opencl.generator;
import com.nativelibs4java.opencl.*;
import com.ochafik.io.IOUtils;
import com.ochafik.lang.jnaerator.*;
import com.ochafik.lang.jnaerator.PreprocessorUtils.MacroUseCallback;
import com.ochafik.lang.jnaerator.TypeConversion.JavaPrimitive;
import com.ochafik.lang.jnaerator.TypeConversion.TypeConversionMode;
import com.ochafik.lang.jnaerator.parser.*;
import com.ochafik.lang.jnaerator.runtime.NativeSize;
import com.ochafik.util.listenable.Adapter;
import com.ochafik.util.listenable.Pair;
import com.ochafik.util.string.RegexUtils;
import com.ochafik.util.string.StringUtils;
import java.io.File;
import java.io.*;
import java.util.*;
import static com.ochafik.lang.jnaerator.parser.ElementsHelper.*;
import java.util.regex.Pattern;
import org.anarres.cpp.LexerException;
public class JavaCLGenerator extends JNAerator {
static Pattern nameExtPatt = Pattern.compile("(.*?)\\.(\\w+)");
public JavaCLGenerator(JNAeratorConfig config) {
super(config);
config.noCPlusPlus = true;
config.genCPlusPlus = false;
config.gccLong = true;
config.putTopStructsInSeparateFiles = false;
config.runtime = JNAeratorConfig.Runtime.JNAerator;//NL4JStructs;
config.fileToLibrary = new Adapter() {
@Override
public String adapt(File value) {
String[] m = RegexUtils.match(value.getName(), nameExtPatt);
return m == null ? null : m[1];
}
};
config.functionsAccepter = new Adapter() {
@Override
public Boolean adapt(Function value) {
List mods = value.getModifiers();
if (ModifierType.__kernel.isContainedBy(mods))
return true;
if (value.getValueType() == null)
return null;
mods = value.getValueType().getModifiers();
return ModifierType.__kernel.isContainedBy(mods);
}
};
}
Map> macrosByFile = new HashMap>();
@Override
protected JNAeratorParser createJNAeratorParser() {
return new JNAeratorParser() {
@Override
protected com.ochafik.lang.jnaerator.parser.ObjCppParser newObjCppParser(TypeConversion typeConverter, String s, boolean verbose, PrintStream errorOut) throws IOException {
com.ochafik.lang.jnaerator.parser.ObjCppParser parser = super.newObjCppParser(typeConverter, s, verbose, errorOut);
parser.allowKinds(ModifierKind.OpenCL);
return parser;
}
};
}
static Set openclPrimitives = new HashSet();
static {
openclPrimitives.add("half");
openclPrimitives.add("image2d_t");
openclPrimitives.add("image3d_t");
openclPrimitives.add("sampler_t");
openclPrimitives.add("event_t");
}
@Override
public Result createResult(final ClassOutputter outputter, Feedback feedback) {
return new Result(config, outputter, feedback) {
@Override
public void init() {
typeConverter = new JNATypeConversion(this) {
@Override
public void initTypes() {
super.initTypes();
}
@Override
public boolean isObjCppPrimitive(String s) {
int len;
if (s == null || (len = s.length()) == 0)
return false;
if (super.isObjCppPrimitive(s))
return true;
// handle case of "(int|long|short|byte|double|float)\\d"
if (len > 1 && Character.isDigit(s.charAt(len - 1))) {
String ss = s.substring(0, len - 1);
if (ss.charAt(0) == 'u')
ss = ss.substring(1);
if (super.isObjCppPrimitive(ss))
return true;
}
return openclPrimitives.contains(s);
}
};
declarationsConverter = new JNADeclarationsConverter(this) {
@Override
public void convertFunction(Function function, Signatures signatures, boolean isCallback, DeclarationsHolder out, Identifier libraryClassName, int unused) {
if (isCallback)
return;
if (!result.config.functionsAccepter.adapt(function))
return;
List args = function.getArgs();
List convArgs = new ArrayList(args.size());
String queueName = "commandQueue";
convArgs.add(new Arg(queueName, typeRef(CLQueue.class)));
List convArgExpr = new ArrayList(args.size());
List extraStatements = new ArrayList();
int iArg = 1;
for (Arg arg : args) {
TypeRef tr = arg.createMutatedType();
if (tr == null)
return;
try {
tr = result.typeConverter.resolveTypeDef(tr, libraryClassName, true, false);
List mods = tr.getModifiers();
TypeRef convTr;
String argName = arg.getName() == null ? "arg" + iArg : arg.getName();
Expression argExpr;
if (ModifierType.__local.isContainedBy(mods)) {
argName += "LocalByteSize";
//convTr = typeRef(Long.TYPE);
//argExpr = new Expression.New(typeRef(CLKernel.LocalSize.class), varRef(argName));
convTr = typeRef(CLKernel.LocalSize.class);
argExpr = varRef(argName);//new Expression.New(typeRef(CLKernel.LocalSize.class), varRef(argName));
} else {
Conversion conv = convertTypeToJavaCL(result, argName, tr, TypeConversion.TypeConversionMode.PrimitiveOrBufferParameter, null);
convTr = conv.outerJavaTypeRef;
argExpr = conv.convertedExpr;
extraStatements.addAll(conv.extraStatements);
//String convTrStr = convTr.toString();
/*if (convTrStr.equals(NativeSize.class.getName()) || convTrStr.equals(NativeLong.class.getName()))
argExpr = new Expression.New(tr, varRef(conv.argName));
else
argExpr = varRef(ident(argName));*/
}
convArgs.add(new Arg(argName, convTr));
convArgExpr.add(argExpr);//varRef(argName));
} catch (UnsupportedConversionException ex) {
out.addDeclaration(skipDeclaration(function, ex.toString()));
}
iArg++;
}
String globalWSName = "globalWorkSizes", localWSName = "localWorkSizes", eventsName = "eventsToWaitFor";
convArgs.add(new Arg(globalWSName, typeRef(int[].class)));
convArgs.add(new Arg(localWSName, typeRef(int[].class)));
convArgs.add(new Arg(eventsName, typeRef(CLEvent.class)).setVarArg(true));
String functionName = function.getName().toString();
String kernelVarName = functionName + "_kernel";
if (signatures.addVariable(kernelVarName))
out.addDeclaration(new VariablesDeclaration(typeRef(CLKernel.class), new Declarator.DirectDeclarator(kernelVarName)));
Function method = new Function(Function.Type.JavaMethod, ident(functionName), typeRef(CLEvent.class));
method.addModifiers(ModifierType.Public, ModifierType.Synchronized);
method.addThrown(typeRef(CLBuildException.class));
method.setArgs(convArgs);
List statements = new ArrayList();
statements.add(
new Statement.If(
expr(varRef(kernelVarName), Expression.BinaryOperator.IsEqual, new Expression.NullExpression()),
stat(
expr(
varRef(kernelVarName), Expression.AssignmentOperator.Equal,
methodCall(
"createKernel",
new Expression.Constant(Expression.Constant.Type.String, functionName, null)
)
)
),
null
)
);
statements.addAll(extraStatements);
statements.add(
stat(methodCall(
varRef(kernelVarName),
Expression.MemberRefStyle.Dot,
"setArgs",
convArgExpr.toArray(new Expression[convArgExpr.size()])
))
);
statements.add(
new Statement.Return(methodCall(
varRef(kernelVarName),
Expression.MemberRefStyle.Dot,
"enqueueNDRange",
varRef(queueName),
varRef(globalWSName),
varRef(localWSName),
varRef(eventsName)
))
);
method.setBody(block(statements.toArray(new Statement[statements.size()])));
if (signatures.addMethod(method))
out.addDeclaration(method);
}
};
globalsGenerator = new JNAGlobalsGenerator(this);
objectiveCGenerator = new ObjectiveCGenerator(this);
universalReconciliator = new UniversalReconciliator();
}
};
}
static class CLPrim {
TypeConversion.JavaPrimitive javaPrim;
int arity;
boolean isLong, isShort;
Expression assertExpr;
Statement checkStatement;
Expression convertStatement;
Class> argClass;
public CLPrim(JavaPrimitive javaPrim, int arity) {
this.javaPrim = javaPrim;
this.arity = arity;
}
static Pattern patt = Pattern.compile("(?:(long|short)\\s+)?(float|double|u?(?:char|long|short|int))(\\d)");
public static CLPrim parse(Result result, TypeRef tr) {
String s = tr.toString();
if (s == null || s.length() == 0)
return null;
char c = s.charAt(s.length() - 1);
if (!Character.isDigit(c)) {
//JavaPrim prim = result.typeConverter.getPrimitive(
return null;
}
String[] m = RegexUtils.match(tr.toString(), patt);
if (m == null)
return null;
//boolean isShort = false,
//result.typeConverter
return null;
}
}
static class Conversion {
TypeRef outerJavaTypeRef;
Expression convertedExpr;
String argName;
List extraStatements = new ArrayList();
}
static Map>> buffersAndArityByType = new HashMap>>();
static Map>> arraysAndArityByType = new HashMap>>();
static {
Object[] data = new Object[] {
"char", Byte.TYPE, byte[].class, Byte.class,
"long", Long.TYPE, long[].class, Long.class,
"int", Integer.TYPE, int[].class, Integer.class,
"short", Short.TYPE, short[].class, Short.class,
"wchar_t", Character.TYPE, char[].class, Short.class,
"double", Double.TYPE, double[].class, Double.class,
"float", Float.TYPE, float[].class, Float.class,
"bool", Boolean.TYPE, boolean[].class, Boolean.class
};
for (int arity : new int[] { 1, 2, 4, 8, 16 }) {
String suffix = arity == 1 ? "" : arity +"";
for (int i = 0; i < data.length; i += 4) {
String rawType = (String)data[i];
Class> scalClass = (Class>)data[i + 1];
Class> arrClass = (Class>)data[i + 2];
Class> buffClass = (Class>)data[i + 3];
Pair>
buffPair = new Pair>(arity, buffClass),
arrPair = new Pair>(arity, arity == 1 ? scalClass : arrClass);
for (String type : new String[] { rawType + suffix, "u" + rawType + suffix}) {
buffersAndArityByType.put(type, buffPair);
arraysAndArityByType.put(type, arrPair);
}
}
}
}
private Conversion convertTypeToJavaCL(Result result, String argName, TypeRef valueType, TypeConversionMode typeConversionMode, Identifier libraryClassName) throws UnsupportedConversionException {
Conversion ret = new Conversion();
ret.argName = argName;
ret.convertedExpr = varRef(argName);
if (valueType instanceof TypeRef.Pointer) {
TypeRef target = ((TypeRef.Pointer)valueType).getTarget();
if (target instanceof TypeRef.SimpleTypeRef) {
TypeRef.SimpleTypeRef starget = (TypeRef.SimpleTypeRef)target;
Pair> pair = buffersAndArityByType.get(starget.getName().toString());
if (pair != null) {
ret.outerJavaTypeRef = typeRef(ident(CLBuffer.class, expr(typeRef(pair.getSecond()))));
return ret;
}
}
} else if (valueType instanceof TypeRef.SimpleTypeRef) {
TypeRef.SimpleTypeRef sr = (TypeRef.SimpleTypeRef)valueType;
String name = sr.getName() == null ? sr.toString() : sr.getName().toString();
if (name.equals("size_t")) {
ret.outerJavaTypeRef = typeRef(Long.TYPE);
ret.convertedExpr = new Expression.New(typeRef(NativeSize.class), ret.convertedExpr);
return ret;
} else {
Pair> pair = arraysAndArityByType.get(name);
if (pair != null) {
ret.outerJavaTypeRef = typeRef(pair.getSecond());
if (pair.getFirst().intValue() != 1) {
ret.extraStatements.add(
stat(
methodCall(
"checkArrayLength",
varRef(ret.argName),
expr(
Expression.Constant.Type.Int,
pair.getFirst()
),
expr(
Expression.Constant.Type.String,
ret.argName
)
)
)
);
}
return ret;
}
}
}
throw new UnsupportedConversionException(valueType, "Unhandled type : " + valueType);
}
@Override
protected void generateLibraryFiles(SourceFiles sourceFiles, Result result) throws IOException {
//super.generateLibraryFiles(sourceFiles, result);
for (SourceFile sourceFile : sourceFiles.getSourceFiles()) {
String rawSrcFilePath = new File(sourceFile.getElementFile()).getCanonicalPath();
String srcFilePath = result.config.relativizeFileForSourceComments(rawSrcFilePath);
File srcFile = new File(srcFilePath);
String srcParent = srcFile.getParent();
String srcFileName = srcFile.getName();
String[] nameExt = RegexUtils.match(srcFileName, nameExtPatt);
if (nameExt == null)
continue;
String name = nameExt[1], ext = nameExt[2];
if (!ext.equals("c") && !ext.equals("cl"))
continue;
String packageName = srcParent == null || srcParent.length() == 0 ? null : srcParent.replace('/', '.').replace('\\', '.');
Identifier packageIdent = ident(packageName);
String className = (packageName == null ? "" : packageName + ".") + name;
Struct interf = new Struct();
interf.addToCommentBefore("Wrapper around the OpenCL program " + name);
interf.addModifiers(ModifierType.Public);
interf.setTag(ident(name));
interf.addParent(ident(CLAbstractUserProgram.class));
interf.setType(Struct.Type.JavaClass);
String[] constrArgNames = new String[] { "context", "program" };
Class>[] constrArgTypes = new Class>[] { CLContext.class, CLProgram.class };
for (int i = 0; i < constrArgNames.length; i++) {
String argName = constrArgNames[i];
Function constr = new Function(Function.Type.JavaMethod, ident(name), null, new Arg(argName, typeRef(constrArgTypes[i])));
constr.addModifiers(ModifierType.Public);
constr.addThrown(typeRef(IOException.class));
constr.setBody(
block(
stat(
methodCall(
"super",
varRef(argName),
methodCall(
"readRawSourceForClass",
result.typeConverter.typeLiteral(typeRef(name))
)
)
)
)
);
interf.addDeclaration(constr);
}
//result.declarationsConverter.convertStructs(null, null, interf, null)
Signatures signatures = new Signatures();//result.getSignaturesForOutputClass(fullLibraryClassName);
result.typeConverter.allowFakePointers = true;
String library = name;
Identifier fullLibraryClassName = ident(className);
result.declarationsConverter.convertStructs(result.structsByLibrary.get(library), signatures, interf, fullLibraryClassName, library);
//result.declarationsConverter.convertCallbacks(result.callbacksByLibrary.get(library), signatures, interf, fullLibraryClassName);
int declCount = interf.getDeclarations().size();
result.declarationsConverter.convertFunctions(result.functionsByLibrary.get(library), signatures, interf, fullLibraryClassName);
result.declarationsConverter.convertEnums(result.enumsByLibrary.get(library), signatures, interf, fullLibraryClassName);
result.declarationsConverter.convertConstants(library, result.definesByLibrary.get(library), sourceFiles, signatures, interf, fullLibraryClassName);
boolean hasKernels = interf.getDeclarations().size() > declCount;
if (!hasKernels)
continue;
//for ()
/*
public SampleUserProgram(CLContext context) throws IOException {
super(context, readRawSourceForClass(SampleUserProgram.class));
}*/
for (Set set : macrosByFile.values()) {
for (String macroName : set) {
if (macroName.equals("__LINE__") ||
macroName.equals("__FILE__") ||
macroName.equals("__COUNTER__") ||
config.preprocessorConfig.macros.containsKey(macroName))
continue;
String[] parts = macroName.split("_+");
List newParts = new ArrayList(parts.length);
for (String part : parts) {
if (part == null || (part = part.trim()).length() == 0)
continue;
newParts.add(StringUtils.capitalize(part));
}
String functionName = "define" + StringUtils.implode(newParts, "");
Function macroDef = new Function(Function.Type.JavaMethod, ident(functionName), typeRef("void"));
String valueName = "value";
macroDef.addArg(new Arg(valueName, typeRef(String.class)));
macroDef.setBody(block(stat(methodCall("defineMacro", expr(Expression.Constant.Type.String, macroName), varRef(valueName)))));
interf.addDeclaration(macroDef);
}
}
PrintWriter out = result.classOutputter.getClassSourceWriter(className);
result.printJavaClass(packageIdent, interf, out);
//if (packageName != null)
// out.println("package " + packageName + ";");
//out.println(interf);
out.close();
}
}
@Override
protected void autoConfigure() {
super.autoConfigure();
/*
__OPENCL_VERSION__
__ENDIAN_LITTLE__
__IMAGE_SUPPORT__
__FAST_RELAXED_MATH__
*/
}
public static void main(String[] args) {
JNAerator.main(new JavaCLGenerator(new JNAeratorConfig()),
new String[] {
"-o", "target/generated-sources/main/java",
//"-o", "/Users/ochafik/Prog/Java/versionedSources/nativelibs4java/trunk/libraries/OpenCL/Demos/target/generated-sources/main/java",
"-noJar",
"-noComp",
"-v",
"-addRootDir", "src/main/opencl",
"src/main/opencl",
//"-addRootDir", "/Users/ochafik/Prog/Java/versionedSources/nativelibs4java/trunk/libraries/OpenCL/Blas/target/../src/main/opencl",
//"/Users/ochafik/Prog/Java/versionedSources/nativelibs4java/trunk/libraries/OpenCL/Blas/src/main/opencl/com/nativelibs4java/opencl/blas/LinearAlgebraKernels.c"
//"-addRootDir", "/Users/ochafik/Prog/Java/versionedSources/nativelibs4java/trunk/libraries/OpenCL/Demos/target/../src/main/opencl",
//"/Users/ochafik/Prog/Java/versionedSources/nativelibs4java/trunk/libraries/OpenCL/Demos/target/../src/main/opencl/com/nativelibs4java/opencl/demos/sobelfilter/SimpleSobel.cl"
}
);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy