org.apache.cassandra.cql3.functions.JavaBasedUDFunction Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cassandra-all Show documentation
Show all versions of cassandra-all Show documentation
The Apache Cassandra Project develops a highly scalable second-generation distributed database, bringing together Dynamo's fully distributed design and Bigtable's ColumnFamily-based data model.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.cql3.functions;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.*;
import java.nio.ByteBuffer;
import java.security.*;
import java.security.cert.Certificate;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.ByteStreams;
import com.google.common.reflect.TypeToken;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.cassandra.concurrent.NamedThreadFactory;
import org.apache.cassandra.cql3.ColumnIdentifier;
import org.apache.cassandra.cql3.functions.types.TypeCodec;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.transport.ProtocolVersion;
import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.security.SecurityThreadGroup;
import org.apache.cassandra.security.ThreadAwareSecurityManager;
import org.eclipse.jdt.core.compiler.IProblem;
import org.eclipse.jdt.internal.compiler.*;
import org.eclipse.jdt.internal.compiler.Compiler;
import org.eclipse.jdt.internal.compiler.classfmt.ClassFileReader;
import org.eclipse.jdt.internal.compiler.classfmt.ClassFormatException;
import org.eclipse.jdt.internal.compiler.env.ICompilationUnit;
import org.eclipse.jdt.internal.compiler.env.INameEnvironment;
import org.eclipse.jdt.internal.compiler.env.NameEnvironmentAnswer;
import org.eclipse.jdt.internal.compiler.impl.CompilerOptions;
import org.eclipse.jdt.internal.compiler.problem.DefaultProblemFactory;
public final class JavaBasedUDFunction extends UDFunction
{
private static final String BASE_PACKAGE = "org.apache.cassandra.cql3.udf.gen";
private static final Pattern JAVA_LANG_PREFIX = Pattern.compile("\\bjava\\.lang\\.");
private static final Logger logger = LoggerFactory.getLogger(JavaBasedUDFunction.class);
private static final AtomicInteger classSequence = new AtomicInteger();
// use a JVM standard ExecutorService as ExecutorPlus references internal
// classes, which triggers AccessControlException from the UDF sandbox
private static final UDFExecutorService executor =
new UDFExecutorService(new NamedThreadFactory("UserDefinedFunctions",
Thread.MIN_PRIORITY,
udfClassLoader,
new SecurityThreadGroup("UserDefinedFunctions", null, UDFunction::initializeThread)),
"userfunction");
private static final EcjTargetClassLoader targetClassLoader = new EcjTargetClassLoader();
private static final UDFByteCodeVerifier udfByteCodeVerifier = new UDFByteCodeVerifier();
private static final ProtectionDomain protectionDomain;
private static final IErrorHandlingPolicy errorHandlingPolicy = DefaultErrorHandlingPolicies.proceedWithAllProblems();
private static final IProblemFactory problemFactory = new DefaultProblemFactory(Locale.ENGLISH);
private static final CompilerOptions compilerOptions;
/**
* Poor man's template - just a text file splitted at '#' chars.
* Each string at an even index is a constant string (just copied),
* each string at an odd index is an 'instruction'.
*/
private static final String[] javaSourceTemplate;
static
{
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/Class", "forName");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/Class", "getClassLoader");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/Class", "getResource");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/Class", "getResourceAsStream");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "clearAssertionStatus");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getResource");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getResourceAsStream");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getResources");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getSystemClassLoader");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getSystemResource");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getSystemResourceAsStream");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getSystemResources");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "loadClass");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "setClassAssertionStatus");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "setDefaultAssertionStatus");
udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "setPackageAssertionStatus");
udfByteCodeVerifier.addDisallowedMethodCall("java/nio/ByteBuffer", "allocateDirect");
for (String ia : new String[]{"java/net/InetAddress", "java/net/Inet4Address", "java/net/Inet6Address"})
{
// static method, probably performing DNS lookups (despite SecurityManager)
udfByteCodeVerifier.addDisallowedMethodCall(ia, "getByAddress");
udfByteCodeVerifier.addDisallowedMethodCall(ia, "getAllByName");
udfByteCodeVerifier.addDisallowedMethodCall(ia, "getByName");
udfByteCodeVerifier.addDisallowedMethodCall(ia, "getLocalHost");
// instance methods, probably performing DNS lookups (despite SecurityManager)
udfByteCodeVerifier.addDisallowedMethodCall(ia, "getHostName");
udfByteCodeVerifier.addDisallowedMethodCall(ia, "getCanonicalHostName");
// ICMP PING
udfByteCodeVerifier.addDisallowedMethodCall(ia, "isReachable");
}
udfByteCodeVerifier.addDisallowedClass("java/net/NetworkInterface");
udfByteCodeVerifier.addDisallowedClass("java/net/SocketException");
Map settings = new HashMap<>();
settings.put(CompilerOptions.OPTION_LineNumberAttribute,
CompilerOptions.GENERATE);
settings.put(CompilerOptions.OPTION_SourceFileAttribute,
CompilerOptions.DISABLED);
settings.put(CompilerOptions.OPTION_ReportDeprecation,
CompilerOptions.IGNORE);
settings.put(CompilerOptions.OPTION_Source,
CompilerOptions.VERSION_1_8);
settings.put(CompilerOptions.OPTION_TargetPlatform,
CompilerOptions.VERSION_1_8);
compilerOptions = new CompilerOptions(settings);
compilerOptions.parseLiteralExpressionsAsConstants = true;
try (InputStream input = JavaBasedUDFunction.class.getResource("JavaSourceUDF.txt").openConnection().getInputStream())
{
ByteArrayOutputStream output = new ByteArrayOutputStream();
FBUtilities.copy(input, output, Long.MAX_VALUE);
String template = output.toString();
StringTokenizer st = new StringTokenizer(template, "#");
javaSourceTemplate = new String[st.countTokens()];
for (int i = 0; st.hasMoreElements(); i++)
javaSourceTemplate[i] = st.nextToken();
}
catch (IOException e)
{
throw new RuntimeException(e);
}
CodeSource codeSource;
try
{
codeSource = new CodeSource(new URL("udf", "localhost", 0, "/java", new URLStreamHandler()
{
protected URLConnection openConnection(URL u)
{
return null;
}
}), (Certificate[])null);
}
catch (MalformedURLException e)
{
throw new RuntimeException(e);
}
protectionDomain = new ProtectionDomain(codeSource, ThreadAwareSecurityManager.noPermissions, targetClassLoader, null);
}
private final JavaUDF javaUDF;
private static final Pattern patternJavaDriver = Pattern.compile("com\\.datastax\\.driver\\.core\\.");
JavaBasedUDFunction(FunctionName name, List argNames, List> argTypes,
AbstractType> returnType, boolean calledOnNullInput, String body)
{
super(name, argNames, argTypes, UDHelper.driverTypes(argTypes),
returnType, UDHelper.driverType(returnType), calledOnNullInput, "java", body);
// javaParamTypes is just the Java representation for argTypes resp. argDataTypes
TypeToken>[] javaParamTypes = UDHelper.typeTokens(argCodecs, calledOnNullInput);
// javaReturnType is just the Java representation for returnType resp. returnTypeCodec
TypeToken> javaReturnType = returnCodec.getJavaType();
// put each UDF in a separate package to prevent cross-UDF code access
String pkgName = BASE_PACKAGE + '.' + generateClassName(name, 'p');
String clsName = generateClassName(name, 'C');
String executeInternalName = generateClassName(name, 'x');
StringBuilder javaSourceBuilder = new StringBuilder();
int lineOffset = 1;
for (int i = 0; i < javaSourceTemplate.length; i++)
{
String s = javaSourceTemplate[i];
// strings at odd indexes are 'instructions'
if ((i & 1) == 1)
{
switch (s)
{
case "package_name":
s = pkgName;
break;
case "class_name":
s = clsName;
break;
case "body":
lineOffset = countNewlines(javaSourceBuilder);
s = patternJavaDriver.matcher(body).replaceAll("org.apache.cassandra.cql3.functions.types.");
break;
case "arguments":
s = generateArguments(javaParamTypes, argNames, false);
break;
case "arguments_aggregate":
s = generateArguments(javaParamTypes, argNames, true);
break;
case "argument_list":
s = generateArgumentList(javaParamTypes, argNames);
break;
case "return_type":
s = javaSourceName(javaReturnType);
break;
case "execute_internal_name":
s = executeInternalName;
break;
}
}
javaSourceBuilder.append(s);
}
String targetClassName = pkgName + '.' + clsName;
String javaSource = javaSourceBuilder.toString();
logger.trace("Compiling Java source UDF '{}' as class '{}' using source:\n{}", name, targetClassName, javaSource);
try
{
EcjCompilationUnit compilationUnit = new EcjCompilationUnit(javaSource, targetClassName);
Compiler compiler = new Compiler(compilationUnit,
errorHandlingPolicy,
compilerOptions,
compilationUnit,
problemFactory);
compiler.compile(new ICompilationUnit[]{ compilationUnit });
if (compilationUnit.problemList != null && !compilationUnit.problemList.isEmpty())
{
boolean fullSource = false;
StringBuilder problems = new StringBuilder();
for (IProblem problem : compilationUnit.problemList)
{
long ln = problem.getSourceLineNumber() - lineOffset;
if (ln < 1L)
{
if (problem.isError())
{
// if generated source around UDF source provided by the user is buggy,
// this code is appended.
problems.append("GENERATED SOURCE ERROR: line ")
.append(problem.getSourceLineNumber())
.append(" (in generated source): ")
.append(problem.getMessage())
.append('\n');
fullSource = true;
}
}
else
{
problems.append("Line ")
.append(Long.toString(ln))
.append(": ")
.append(problem.getMessage())
.append('\n');
}
}
if (fullSource)
throw new InvalidRequestException("Java source compilation failed:\n" + problems + "\n generated source:\n" + javaSource);
else
throw new InvalidRequestException("Java source compilation failed:\n" + problems);
}
// Verify the UDF bytecode against use of probably dangerous code
Set errors = udfByteCodeVerifier.verify(targetClassName, targetClassLoader.classData(targetClassName));
String validDeclare = "not allowed method declared: " + executeInternalName + '(';
for (Iterator i = errors.iterator(); i.hasNext();)
{
String error = i.next();
// we generate a random name of the private, internal execute method, which is detected by the byte-code verifier
if (error.startsWith(validDeclare))
i.remove();
}
if (!errors.isEmpty())
throw new InvalidRequestException("Java UDF validation failed: " + errors);
// Load the class and create a new instance of it
Thread thread = Thread.currentThread();
ClassLoader orig = thread.getContextClassLoader();
try
{
thread.setContextClassLoader(UDFunction.udfClassLoader);
// Execute UDF intiialization from UDF class loader
Class cls = Class.forName(targetClassName, false, targetClassLoader);
// Count only non-synthetic methods, so code coverage instrumentation doesn't cause a miscount
int nonSyntheticMethodCount = 0;
for (Method m : cls.getDeclaredMethods())
{
if (!m.isSynthetic())
{
nonSyntheticMethodCount += 1;
}
}
if (nonSyntheticMethodCount != 3 || cls.getDeclaredConstructors().length != 1)
throw new InvalidRequestException("Check your source to not define additional Java methods or constructors");
MethodType methodType = MethodType.methodType(void.class)
.appendParameterTypes(TypeCodec.class, TypeCodec[].class, UDFContext.class);
MethodHandle ctor = MethodHandles.lookup().findConstructor(cls, methodType);
this.javaUDF = (JavaUDF) ctor.invokeWithArguments(returnCodec, argCodecs, udfContext);
}
finally
{
thread.setContextClassLoader(orig);
}
}
catch (InvocationTargetException e)
{
// in case of an ITE, use the cause
logger.error(String.format("Could not compile function '%s' from Java source:%n%s", name, javaSource), e);
throw new InvalidRequestException(String.format("Could not compile function '%s' from Java source: %s", name, e.getCause()));
}
catch (InvalidRequestException | VirtualMachineError e)
{
throw e;
}
catch (Throwable e)
{
logger.error(String.format("Could not compile function '%s' from Java source:%n%s", name, javaSource), e);
throw new InvalidRequestException(String.format("Could not compile function '%s' from Java source: %s", name, e));
}
}
protected ExecutorService executor()
{
return executor;
}
protected ByteBuffer executeUserDefined(ProtocolVersion protocolVersion, List params)
{
return javaUDF.executeImpl(protocolVersion, params);
}
protected Object executeAggregateUserDefined(ProtocolVersion protocolVersion, Object firstParam, List params)
{
return javaUDF.executeAggregateImpl(protocolVersion, firstParam, params);
}
private static int countNewlines(StringBuilder javaSource)
{
int ln = 0;
for (int i = 0; i < javaSource.length(); i++)
if (javaSource.charAt(i) == '\n')
ln++;
return ln;
}
private static String generateClassName(FunctionName name, char prefix)
{
String qualifiedName = name.toString();
StringBuilder sb = new StringBuilder(qualifiedName.length() + 10);
sb.append(prefix);
for (int i = 0; i < qualifiedName.length(); i++)
{
char c = qualifiedName.charAt(i);
if (Character.isJavaIdentifierPart(c))
sb.append(c);
else
sb.append(Integer.toHexString(((short)c)&0xffff));
}
sb.append('_')
.append(ThreadLocalRandom.current().nextInt() & 0xffffff)
.append('_')
.append(classSequence.incrementAndGet());
return sb.toString();
}
@VisibleForTesting
public static String javaSourceName(TypeToken> type)
{
String n = type.toString();
return JAVA_LANG_PREFIX.matcher(n).replaceAll("");
}
private static String generateArgumentList(TypeToken>[] paramTypes, List argNames)
{
// initial builder size can just be a guess (prevent temp object allocations)
StringBuilder code = new StringBuilder(32 * paramTypes.length);
for (int i = 0; i < paramTypes.length; i++)
{
if (i > 0)
code.append(", ");
code.append(javaSourceName(paramTypes[i]))
.append(' ')
.append(argNames.get(i));
}
return code.toString();
}
/**
* Generate Java source code snippet for the arguments part to call the UDF implementation function -
* i.e. the {@code private #return_type# #execute_internal_name#(#argument_list#)} function
* (see {@code JavaSourceUDF.txt} template file for details).
*
* This method generates the arguments code snippet for both {@code executeImpl} and
* {@code executeAggregateImpl}. General signature for both is the {@code protocolVersion} and
* then all UDF arguments. For aggregation UDF calls the first argument is always unserialized as
* that is the state variable.
*
*
* An example output for {@code executeImpl}:
* {@code (double) super.compose_double(protocolVersion, 0, params.get(0)), (double) super.compose_double(protocolVersion, 1, params.get(1))}
*
*
* Similar output for {@code executeAggregateImpl}:
* {@code firstParam, (double) super.compose_double(protocolVersion, 1, params.get(1))}
*
*/
private static String generateArguments(TypeToken>[] paramTypes, List argNames, boolean forAggregate)
{
StringBuilder code = new StringBuilder(64 * paramTypes.length);
for (int i = 0; i < paramTypes.length; i++)
{
if (i > 0)
// add separator, if not the first argument
code.append(",\n");
// add comment only if trace is enabled
if (logger.isTraceEnabled())
code.append(" /* parameter '").append(argNames.get(i)).append("' */\n");
// cast to Java type
code.append(" (").append(javaSourceName(paramTypes[i])).append(") ");
if (forAggregate && i == 0)
// special case for aggregations where the state variable (1st arg to state + final function and
// return value from state function) is not re-serialized
code.append("firstParam");
else
// generate object representation of input parameter (call UDFunction.compose)
code.append(composeMethod(paramTypes[i])).append("(protocolVersion, ").append(i).append(", params.get(").append(forAggregate ? i - 1 : i).append("))");
}
return code.toString();
}
private static String composeMethod(TypeToken> type)
{
return (type.isPrimitive()) ? ("super.compose_" + type.getRawType().getName()) : "super.compose";
}
// Java source UDFs are a very simple compilation task, which allows us to let one class implement
// all interfaces required by ECJ.
static final class EcjCompilationUnit implements ICompilationUnit, ICompilerRequestor, INameEnvironment
{
List problemList;
private final String className;
private final char[] sourceCode;
EcjCompilationUnit(String sourceCode, String className)
{
this.className = className;
this.sourceCode = sourceCode.toCharArray();
}
// ICompilationUnit
@Override
public char[] getFileName()
{
return sourceCode;
}
@Override
public char[] getContents()
{
return sourceCode;
}
@Override
public char[] getMainTypeName()
{
int dot = className.lastIndexOf('.');
return ((dot > 0) ? className.substring(dot + 1) : className).toCharArray();
}
@Override
public char[][] getPackageName()
{
StringTokenizer izer = new StringTokenizer(className, ".");
char[][] result = new char[izer.countTokens() - 1][];
for (int i = 0; i < result.length; i++)
result[i] = izer.nextToken().toCharArray();
return result;
}
@Override
public boolean ignoreOptionalProblems()
{
return false;
}
// ICompilerRequestor
@Override
public void acceptResult(CompilationResult result)
{
if (result.hasErrors())
{
IProblem[] problems = result.getProblems();
if (problemList == null)
problemList = new ArrayList<>(problems.length);
Collections.addAll(problemList, problems);
}
else
{
ClassFile[] classFiles = result.getClassFiles();
for (ClassFile classFile : classFiles)
targetClassLoader.addClass(className, classFile.getBytes());
}
}
// INameEnvironment
@Override
public NameEnvironmentAnswer findType(char[][] compoundTypeName)
{
StringBuilder result = new StringBuilder();
for (int i = 0; i < compoundTypeName.length; i++)
{
if (i > 0)
result.append('.');
result.append(compoundTypeName[i]);
}
return findType(result.toString());
}
@Override
public NameEnvironmentAnswer findType(char[] typeName, char[][] packageName)
{
StringBuilder result = new StringBuilder();
int i = 0;
for (; i < packageName.length; i++)
{
if (i > 0)
result.append('.');
result.append(packageName[i]);
}
if (i > 0)
result.append('.');
result.append(typeName);
return findType(result.toString());
}
@SuppressWarnings("resource")
private NameEnvironmentAnswer findType(String className)
{
if (className.equals(this.className))
{
return new NameEnvironmentAnswer(this, null);
}
String resourceName = className.replace('.', '/') + ".class";
try (InputStream is = UDFunction.udfClassLoader.getResourceAsStream(resourceName))
{
if (is != null)
{
byte[] classBytes = ByteStreams.toByteArray(is);
char[] fileName = className.toCharArray();
ClassFileReader classFileReader = new ClassFileReader(classBytes, fileName, true);
return new NameEnvironmentAnswer(classFileReader, null);
}
}
catch (IOException | ClassFormatException exc)
{
throw new RuntimeException(exc);
}
return null;
}
private boolean isPackage(String result)
{
if (result.equals(this.className))
return false;
String resourceName = result.replace('.', '/') + ".class";
try (InputStream is = UDFunction.udfClassLoader.getResourceAsStream(resourceName))
{
return is == null;
}
catch (IOException e)
{
// we are here, since close on is failed. That means it was not null
return false;
}
}
@Override
public boolean isPackage(char[][] parentPackageName, char[] packageName)
{
StringBuilder result = new StringBuilder();
int i = 0;
if (parentPackageName != null)
for (; i < parentPackageName.length; i++)
{
if (i > 0)
result.append('.');
result.append(parentPackageName[i]);
}
if (Character.isUpperCase(packageName[0]) && !isPackage(result.toString()))
return false;
if (i > 0)
result.append('.');
result.append(packageName);
return isPackage(result.toString());
}
@Override
public void cleanup()
{
}
}
static final class EcjTargetClassLoader extends SecureClassLoader
{
EcjTargetClassLoader()
{
super(UDFunction.udfClassLoader);
}
// This map is usually empty.
// It only contains data *during* UDF compilation but not during runtime.
//
// addClass() is invoked by ECJ after successful compilation of the generated Java source.
// loadClass(targetClassName) is invoked by buildUDF() after ECJ returned from successful compilation.
//
private final Map classes = new ConcurrentHashMap<>();
void addClass(String className, byte[] classData)
{
classes.put(className, classData);
}
byte[] classData(String className)
{
return classes.get(className);
}
protected Class> findClass(String name) throws ClassNotFoundException
{
// remove the class binary - it's only used once - so it's wasting heap
byte[] classData = classes.remove(name);
if (classData != null)
return defineClass(name, classData, 0, classData.length, protectionDomain);
return getParent().loadClass(name);
}
protected PermissionCollection getPermissions(CodeSource codesource)
{
return ThreadAwareSecurityManager.noPermissions;
}
}}