All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.procedure.impl.ProcedureCompiler Maven / Gradle / Ivy

There is a newer version: 5.25.1
Show newest version
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [https://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.procedure.impl;

import static java.lang.reflect.Modifier.isPublic;
import static java.util.Collections.emptyList;
import static org.apache.commons.lang3.ArrayUtils.EMPTY_STRING_ARRAY;
import static org.neo4j.configuration.GraphDatabaseSettings.procedure_unrestricted;

import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;
import org.neo4j.internal.kernel.api.exceptions.ProcedureException;
import org.neo4j.internal.kernel.api.procs.FieldSignature;
import org.neo4j.internal.kernel.api.procs.ProcedureSignature;
import org.neo4j.internal.kernel.api.procs.QualifiedName;
import org.neo4j.internal.kernel.api.procs.UserFunctionSignature;
import org.neo4j.kernel.api.CypherScope;
import org.neo4j.kernel.api.exceptions.ComponentInjectionException;
import org.neo4j.kernel.api.exceptions.Status;
import org.neo4j.kernel.api.procedure.CallableProcedure;
import org.neo4j.kernel.api.procedure.CallableUserAggregationFunction;
import org.neo4j.kernel.api.procedure.CallableUserFunction;
import org.neo4j.kernel.api.procedure.CypherVersionScope;
import org.neo4j.kernel.api.procedure.FailedLoadAggregatedFunction;
import org.neo4j.kernel.api.procedure.FailedLoadFunction;
import org.neo4j.kernel.api.procedure.FailedLoadProcedure;
import org.neo4j.kernel.api.procedure.SystemProcedure;
import org.neo4j.logging.InternalLog;
import org.neo4j.procedure.Admin;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Internal;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.NotThreadSafe;
import org.neo4j.procedure.Procedure;
import org.neo4j.procedure.ThreadSafe;
import org.neo4j.procedure.UnsupportedDatabaseTypes;
import org.neo4j.procedure.UserAggregationFunction;
import org.neo4j.procedure.UserAggregationResult;
import org.neo4j.procedure.UserAggregationUpdate;
import org.neo4j.procedure.UserFunction;
import org.neo4j.string.Globbing;

/**
 * Handles converting a class into one or more callable {@link CallableProcedure}.
 */
class ProcedureCompiler {

    private final ProcedureOutputSignatureCompiler outputSignatureCompiler;
    private final MethodSignatureCompiler inputSignatureDeterminer;
    private final FieldInjections safeFieldInjections;
    private final FieldInjections allFieldInjections;
    private final InternalLog log;
    private final Cypher5TypeCheckers typeCheckers;
    private final ProcedureConfig config;
    private final NamingRestrictions functionRestrictions;
    private final NamingRestrictions procedureRestrictions;

    ProcedureCompiler(
            Cypher5TypeCheckers typeCheckers,
            ComponentRegistry safeComponents,
            ComponentRegistry allComponents,
            InternalLog log,
            ProcedureConfig config) {
        this(
                new MethodSignatureCompiler(typeCheckers),
                new ProcedureOutputSignatureCompiler(typeCheckers),
                new FieldInjections(safeComponents),
                new FieldInjections(allComponents),
                log,
                typeCheckers,
                config,
                NamingRestrictions.rejectEmptyNamespace(),
                NamingRestrictions.rejectNone());
    }

    private ProcedureCompiler(
            MethodSignatureCompiler inputSignatureCompiler,
            ProcedureOutputSignatureCompiler outputSignatureCompiler,
            FieldInjections safeFieldInjections,
            FieldInjections allFieldInjections,
            InternalLog log,
            Cypher5TypeCheckers typeCheckers,
            ProcedureConfig config,
            NamingRestrictions functionRestrictions,
            NamingRestrictions procedureRestrictions) {
        this.inputSignatureDeterminer = inputSignatureCompiler;
        this.outputSignatureCompiler = outputSignatureCompiler;
        this.safeFieldInjections = safeFieldInjections;
        this.allFieldInjections = allFieldInjections;
        this.log = log;
        this.typeCheckers = typeCheckers;
        this.config = config;
        this.functionRestrictions = functionRestrictions;
        this.procedureRestrictions = procedureRestrictions;
    }

    List compileFunction(
            Class fcnDefinition,
            boolean isBuiltin,
            ClassLoader parentClassLoader,
            Predicate methodNameFilter)
            throws ProcedureException {
        try {
            List functionMethods = Arrays.stream(fcnDefinition.getDeclaredMethods())
                    .filter(m -> m.isAnnotationPresent(UserFunction.class))
                    .toList();

            if (functionMethods.isEmpty()) {
                return emptyList();
            }

            // used for proper error handling
            assertValidConstructor(fcnDefinition);

            List out = new ArrayList<>(functionMethods.size());
            for (Method method : functionMethods) {
                String valueName = method.getAnnotation(UserFunction.class).value();
                String definedName = method.getAnnotation(UserFunction.class).name();
                QualifiedName funcName = extractName(fcnDefinition, method, valueName, definedName);

                if (!methodNameFilter.test(funcName.toString())) {
                    continue;
                }

                if (isBuiltin || config.isWhitelisted(funcName.toString())) {
                    out.add(compileFunction(fcnDefinition, method, funcName, parentClassLoader));
                } else {
                    log.warn(String.format("The function '%s' is not on the allowlist and won't be loaded.", funcName));
                }
            }
            out.sort(Comparator.comparing(a -> a.signature().name().toString()));
            return out;
        } catch (ProcedureException e) {
            throw e;
        } catch (Exception e) {
            throw new ProcedureException(
                    Status.Procedure.ProcedureRegistrationFailed,
                    e,
                    "Failed to compile function defined in `%s`: %s",
                    fcnDefinition.getSimpleName(),
                    e.getMessage());
        }
    }

    List compileAggregationFunction(
            Class fcnDefinition, ClassLoader parentClassLoader, Predicate methodNameFilter)
            throws ProcedureException {
        try {
            List methods = Arrays.stream(fcnDefinition.getDeclaredMethods())
                    .filter(m -> m.isAnnotationPresent(UserAggregationFunction.class))
                    .toList();

            if (methods.isEmpty()) {
                return emptyList();
            }

            assertValidConstructor(fcnDefinition);

            List out = new ArrayList<>(methods.size());
            for (Method method : methods) {
                String valueName =
                        method.getAnnotation(UserAggregationFunction.class).value();
                String definedName =
                        method.getAnnotation(UserAggregationFunction.class).name();
                QualifiedName funcName = extractName(fcnDefinition, method, valueName, definedName);

                if (!methodNameFilter.test(funcName.toString())) {
                    continue;
                }

                if (config.isWhitelisted(funcName.toString())) {
                    out.add(compileAggregationFunction(fcnDefinition, method, funcName, parentClassLoader));
                } else {
                    log.warn(String.format("The function '%s' is not on the allowlist and won't be loaded.", funcName));
                }
            }
            out.sort(Comparator.comparing(a -> a.signature().name().toString()));
            return out;
        } catch (ProcedureException e) {
            throw e;
        } catch (Exception e) {
            throw new ProcedureException(
                    Status.Procedure.ProcedureRegistrationFailed,
                    e,
                    "Failed to compile function defined in `%s`: %s",
                    fcnDefinition.getSimpleName(),
                    e.getMessage());
        }
    }

    List compileProcedure(
            Class procDefinition,
            boolean fullAccess,
            ClassLoader parentClassLoader,
            Predicate methodNameFilter)
            throws ProcedureException {
        try {
            List procedureMethods = Arrays.stream(procDefinition.getDeclaredMethods())
                    .filter(m -> m.isAnnotationPresent(Procedure.class))
                    .toList();

            if (procedureMethods.isEmpty()) {
                return emptyList();
            }

            assertValidConstructor(procDefinition);
            List out = new ArrayList<>(procedureMethods.size());
            for (Method method : procedureMethods) {
                String valueName = method.getAnnotation(Procedure.class).value();
                String definedName = method.getAnnotation(Procedure.class).name();
                QualifiedName procName = extractName(procDefinition, method, valueName, definedName);

                if (!methodNameFilter.test(procName.toString())) {
                    continue;
                }

                if (fullAccess || config.isWhitelisted(procName.toString())) {
                    out.add(compileProcedure(procDefinition, method, fullAccess, procName, parentClassLoader));
                } else {
                    log.warn(
                            String.format("The procedure '%s' is not on the allowlist and won't be loaded.", procName));
                }
            }
            out.sort(Comparator.comparing(a -> a.signature().name().toString()));
            return out;
        } catch (ProcedureException e) {
            throw e;
        } catch (Exception e) {
            throw new ProcedureException(
                    Status.Procedure.ProcedureRegistrationFailed,
                    e,
                    "Failed to compile procedure defined in `%s`: %s",
                    procDefinition.getSimpleName(),
                    e.getMessage());
        }
    }

    private CallableProcedure compileProcedure(
            Class procDefinition,
            Method method,
            boolean fullAccess,
            QualifiedName procName,
            ClassLoader parentClassLoader)
            throws ProcedureException {
        procedureRestrictions.verify(procName);

        List inputSignature = inputSignatureDeterminer.signatureFor(method);
        List outputSignature = outputSignatureCompiler.fieldSignatures(method);

        String description = description(method);
        Procedure procedure = method.getAnnotation(Procedure.class);
        Mode mode = procedure.mode();
        boolean admin = method.isAnnotationPresent(Admin.class);
        boolean systemProcedure = method.isAnnotationPresent(SystemProcedure.class);
        boolean allowExpiredCredentials =
                systemProcedure && method.getAnnotation(SystemProcedure.class).allowExpiredCredentials();
        boolean internal = method.isAnnotationPresent(Internal.class);
        var unsupportedDbTypes = method.isAnnotationPresent(UnsupportedDatabaseTypes.class)
                ? method.getAnnotation(UnsupportedDatabaseTypes.class).value()
                : new UnsupportedDatabaseTypes.DatabaseType[0];
        boolean threadSafe = !method.isAnnotationPresent(NotThreadSafe.class);
        boolean isDeprecated = method.isAnnotationPresent(Deprecated.class);
        String deprecated = deprecated(
                procedure::deprecatedBy,
                "Use of @Procedure(deprecatedBy) without @Deprecated in " + procName,
                isDeprecated);
        List setters = allFieldInjections.setters(procDefinition);
        if (!fullAccess && !config.fullAccessFor(procName.toString())) {
            try {
                setters = safeFieldInjections.setters(procDefinition);
            } catch (ComponentInjectionException e) {
                description = describeAndLogLoadFailure(procName);
                ProcedureSignature signature = new ProcedureSignature(
                        procName,
                        inputSignature,
                        outputSignature,
                        Mode.DEFAULT,
                        admin,
                        isDeprecated,
                        deprecated,
                        description,
                        null,
                        procedure.eager(),
                        false,
                        systemProcedure,
                        internal,
                        allowExpiredCredentials,
                        threadSafe,
                        getSupportedCypherVersions(method),
                        unsupportedDbTypes);
                return new FailedLoadProcedure(signature);
            }
        }

        ProcedureSignature signature = new ProcedureSignature(
                procName,
                inputSignature,
                outputSignature,
                mode,
                admin,
                isDeprecated,
                deprecated,
                description,
                null,
                procedure.eager(),
                false,
                systemProcedure,
                internal,
                allowExpiredCredentials,
                threadSafe,
                getSupportedCypherVersions(method),
                unsupportedDbTypes);

        return ProcedureCompilation.compileProcedure(signature, setters, method, parentClassLoader);
    }

    private static Set getSupportedCypherVersions(Method method) throws IllegalArgumentException {
        var annotation = method.getAnnotation(CypherVersionScope.class);
        if (annotation == null) {
            // If there is no annotation, then we assume that the method supports all cypher language versions
            return CypherScope.ALL_SCOPES;
        }

        var scope = annotation.scope();
        if (scope == null || scope.length == 0) {
            return CypherScope.ALL_SCOPES;
        }

        return EnumSet.copyOf(Arrays.asList(scope));
    }

    List compileProcedure(Class procDefinition, boolean fullAccess) throws ProcedureException {
        return compileProcedure(
                procDefinition, fullAccess, CallableUserFunction.class.getClassLoader(), Globbing.MATCH_ALL);
    }

    List compileAggregationFunction(Class fcnDefinition) throws ProcedureException {
        return compileAggregationFunction(
                fcnDefinition, CallableUserFunction.class.getClassLoader(), Globbing.MATCH_ALL);
    }

    List compileFunction(Class fcnDefinition, boolean isBuiltin) throws ProcedureException {
        return compileFunction(
                fcnDefinition, isBuiltin, CallableUserFunction.class.getClassLoader(), Globbing.MATCH_ALL);
    }

    private String describeAndLogLoadFailure(QualifiedName name) {
        String nameStr = name.toString();
        String description =
                nameStr + " is unavailable because it is sandboxed and has dependencies outside of the sandbox. "
                        + "Sandboxing is controlled by the "
                        + procedure_unrestricted.name() + " setting. "
                        + "Only unrestrict procedures you can trust with access to database internals.";
        log.warn(description);
        return description;
    }

    private CallableUserFunction compileFunction(
            Class procDefinition, Method method, QualifiedName procName, ClassLoader parentClassLoader)
            throws ProcedureException {
        functionRestrictions.verify(procName);

        List inputSignature = inputSignatureDeterminer.signatureFor(method);
        Class returnType = method.getReturnType();
        Cypher5TypeCheckers.TypeChecker typeChecker = typeCheckers.checkerFor(returnType);
        String description = description(method);
        UserFunction function = method.getAnnotation(UserFunction.class);
        boolean internal = method.isAnnotationPresent(Internal.class);
        boolean threadSafe = !method.isAnnotationPresent(NotThreadSafe.class);
        boolean isDeprecated = method.isAnnotationPresent(Deprecated.class);
        String deprecated = deprecated(
                function::deprecatedBy,
                "Use of @UserFunction(deprecatedBy) without @Deprecated in " + procName,
                isDeprecated);

        List setters = allFieldInjections.setters(procDefinition);
        if (!config.fullAccessFor(procName.toString())) {
            try {
                setters = safeFieldInjections.setters(procDefinition);
            } catch (ComponentInjectionException e) {
                description = describeAndLogLoadFailure(procName);
                UserFunctionSignature signature = new UserFunctionSignature(
                        procName,
                        inputSignature,
                        typeChecker.type(),
                        isDeprecated,
                        deprecated,
                        description,
                        null,
                        false,
                        false,
                        internal,
                        threadSafe,
                        getSupportedCypherVersions(method));
                return new FailedLoadFunction(signature);
            }
        }

        UserFunctionSignature signature = new UserFunctionSignature(
                procName,
                inputSignature,
                typeChecker.type(),
                isDeprecated,
                deprecated,
                description,
                null,
                false,
                false,
                internal,
                threadSafe,
                getSupportedCypherVersions(method));

        return ProcedureCompilation.compileFunction(signature, setters, method, parentClassLoader);
    }

    private CallableUserAggregationFunction compileAggregationFunction(
            Class definition, Method create, QualifiedName funcName, ClassLoader parentClassLoader)
            throws ProcedureException {
        functionRestrictions.verify(funcName);

        // find update and result method
        Method update = null;
        Method result = null;
        Class aggregator = create.getReturnType();
        for (Method m : aggregator.getDeclaredMethods()) {
            if (m.isAnnotationPresent(UserAggregationUpdate.class)) {
                if (update != null) {
                    throw new ProcedureException(
                            Status.Procedure.ProcedureRegistrationFailed,
                            "Class '%s' contains multiple methods annotated with '@%s'.",
                            aggregator.getSimpleName(),
                            UserAggregationUpdate.class.getSimpleName());
                }
                update = m;
            }
            if (m.isAnnotationPresent(UserAggregationResult.class)) {
                if (result != null) {
                    throw new ProcedureException(
                            Status.Procedure.ProcedureRegistrationFailed,
                            "Class '%s' contains multiple methods annotated with '@%s'.",
                            aggregator.getSimpleName(),
                            UserAggregationResult.class.getSimpleName());
                }
                result = m;
            }
        }
        if (result == null || update == null) {
            throw new ProcedureException(
                    Status.Procedure.ProcedureRegistrationFailed,
                    "Class '%s' must contain methods annotated with both '@%s' as well as '@%s'.",
                    aggregator.getSimpleName(),
                    UserAggregationResult.class.getSimpleName(),
                    UserAggregationUpdate.class.getSimpleName());
        }
        if (update.getReturnType() != void.class) {
            throw new ProcedureException(
                    Status.Procedure.ProcedureRegistrationFailed,
                    "Update method '%s' in %s has type '%s' but must have return type 'void'.",
                    update.getName(),
                    aggregator.getSimpleName(),
                    update.getReturnType().getSimpleName());
        }
        if (!isPublic(create.getModifiers())) {
            throw new ProcedureException(
                    Status.Procedure.ProcedureRegistrationFailed,
                    "Aggregation method '%s' in %s must be public.",
                    create.getName(),
                    definition.getSimpleName());
        }
        if (!isPublic(aggregator.getModifiers())) {
            throw new ProcedureException(
                    Status.Procedure.ProcedureRegistrationFailed,
                    "Aggregation class '%s' must be public.",
                    aggregator.getSimpleName());
        }
        if (!isPublic(update.getModifiers())) {
            throw new ProcedureException(
                    Status.Procedure.ProcedureRegistrationFailed,
                    "Aggregation update method '%s' in %s must be public.",
                    update.getName(),
                    aggregator.getSimpleName());
        }
        if (!isPublic(result.getModifiers())) {
            throw new ProcedureException(
                    Status.Procedure.ProcedureRegistrationFailed,
                    "Aggregation result method '%s' in %s must be public.",
                    result.getName(),
                    aggregator.getSimpleName());
        }

        List inputSignature = inputSignatureDeterminer.signatureFor(update);
        Class returnType = result.getReturnType();
        Cypher5TypeCheckers.TypeChecker valueConverter = typeCheckers.checkerFor(returnType);
        String description = description(create);
        UserAggregationFunction function = create.getAnnotation(UserAggregationFunction.class);
        boolean isDeprecated = create.isAnnotationPresent(Deprecated.class);
        String deprecated = deprecated(
                function::deprecatedBy,
                "Use of @UserAggregationFunction(deprecatedBy) without @Deprecated in " + funcName,
                isDeprecated);

        boolean internal = create.isAnnotationPresent(Internal.class);
        boolean threadSafe = create.isAnnotationPresent(ThreadSafe.class);

        List setters = allFieldInjections.setters(definition);
        if (!config.fullAccessFor(funcName.toString())) {
            try {
                setters = safeFieldInjections.setters(definition);
            } catch (ComponentInjectionException e) {
                description = describeAndLogLoadFailure(funcName);
                UserFunctionSignature signature = new UserFunctionSignature(
                        funcName,
                        inputSignature,
                        valueConverter.type(),
                        isDeprecated,
                        deprecated,
                        description,
                        null,
                        false,
                        false,
                        internal,
                        threadSafe,
                        getSupportedCypherVersions(create));

                return new FailedLoadAggregatedFunction(signature);
            }
        }

        UserFunctionSignature signature = new UserFunctionSignature(
                funcName,
                inputSignature,
                valueConverter.type(),
                isDeprecated,
                deprecated,
                description,
                null,
                false,
                false,
                internal,
                threadSafe,
                getSupportedCypherVersions(create));

        return ProcedureCompilation.compileAggregation(signature, setters, create, update, result, parentClassLoader);
    }

    private String deprecated(Supplier supplier, String warning, Boolean isDeprecated) {
        String deprecatedBy = supplier.get();
        String deprecated = null;
        if (isDeprecated) {
            deprecated = deprecatedBy.isEmpty() ? null : deprecatedBy;
        } else if (!deprecatedBy.isEmpty()) {
            log.warn(warning);
            deprecated = deprecatedBy;
        }

        return deprecated;
    }

    private String description(Method method) {
        if (method.isAnnotationPresent(Description.class)) {
            return method.getAnnotation(Description.class).value();
        } else {
            return null;
        }
    }

    private void assertValidConstructor(Class procDefinition) throws ProcedureException {
        boolean hasValidConstructor = false;
        for (Constructor constructor : procDefinition.getConstructors()) {
            if (isPublic(constructor.getModifiers()) && constructor.getParameterCount() == 0) {
                hasValidConstructor = true;
                break;
            }
        }
        if (!hasValidConstructor) {
            throw new ProcedureException(
                    Status.Procedure.ProcedureRegistrationFailed,
                    "Unable to find a usable public no-argument constructor in the class `%s`. "
                            + "Please add a valid, public constructor, recompile the class and try again.",
                    procDefinition.getSimpleName());
        }
    }

    private QualifiedName extractName(Class procDefinition, Method m, String valueName, String definedName) {
        String procName = definedName.isBlank() ? valueName : definedName;
        if (!procName.isBlank()) {
            String[] split = procName.split("\\.");
            if (split.length == 1) {
                return new QualifiedName(split[0]);
            } else {
                int lastElement = split.length - 1;
                return new QualifiedName(Arrays.copyOf(split, lastElement), split[lastElement]);
            }
        }
        Package pkg = procDefinition.getPackage();
        // Package is null if class is in root package
        String[] namespace = pkg == null ? EMPTY_STRING_ARRAY : pkg.getName().split("\\.");
        String name = m.getName();
        return new QualifiedName(namespace, name);
    }

    ProcedureCompiler withAdditionalProcedureRestrictions(NamingRestrictions additionalProcedureRestrictions) {
        return new ProcedureCompiler(
                inputSignatureDeterminer,
                outputSignatureCompiler,
                safeFieldInjections,
                allFieldInjections,
                log,
                typeCheckers,
                config,
                functionRestrictions,
                NamingRestrictions.allOf(procedureRestrictions, additionalProcedureRestrictions));
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy