All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.procedure.impl.ProcedureRegistry Maven / Gradle / Ivy

/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [https://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.procedure.impl;

import static java.lang.String.format;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Stream;
import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList;
import org.neo4j.collection.ResourceRawIterator;
import org.neo4j.graphdb.security.AuthorizationViolationException;
import org.neo4j.internal.kernel.api.exceptions.ProcedureException;
import org.neo4j.internal.kernel.api.procs.FieldSignature;
import org.neo4j.internal.kernel.api.procs.ProcedureHandle;
import org.neo4j.internal.kernel.api.procs.ProcedureSignature;
import org.neo4j.internal.kernel.api.procs.QualifiedName;
import org.neo4j.internal.kernel.api.procs.UserAggregationReducer;
import org.neo4j.internal.kernel.api.procs.UserFunctionHandle;
import org.neo4j.internal.kernel.api.procs.UserFunctionSignature;
import org.neo4j.internal.kernel.api.security.AbstractSecurityLog;
import org.neo4j.internal.kernel.api.security.PermissionState;
import org.neo4j.kernel.api.QueryLanguage;
import org.neo4j.kernel.api.ResourceMonitor;
import org.neo4j.kernel.api.exceptions.Status;
import org.neo4j.kernel.api.procedure.CallableProcedure;
import org.neo4j.kernel.api.procedure.CallableUserAggregationFunction;
import org.neo4j.kernel.api.procedure.CallableUserFunction;
import org.neo4j.kernel.api.procedure.Context;
import org.neo4j.procedure.UnsupportedDatabaseTypes;
import org.neo4j.util.VisibleForTesting;
import org.neo4j.values.AnyValue;

public class ProcedureRegistry {

    private final ProcedureHolder procedures;
    private final ProcedureHolder functions;
    private final ProcedureHolder aggregationFunctions;

    public ProcedureRegistry() {
        this(new ProcedureHolder<>(), new ProcedureHolder<>(), new ProcedureHolder<>());
    }

    private ProcedureRegistry(
            ProcedureHolder procedures,
            ProcedureHolder functions,
            ProcedureHolder aggregationFunctions) {
        this.procedures = procedures;
        this.functions = functions;
        this.aggregationFunctions = aggregationFunctions;
    }

    /**
     * Register a new procedure.
     *
     * @param proc the procedure.
     */
    public void register(CallableProcedure proc) throws ProcedureException {
        ProcedureSignature signature = proc.signature();
        QualifiedName name = signature.name();

        String descriptiveName = signature.toString();
        validateSignature(descriptiveName, signature.inputSignature(), "input");
        validateSignature(descriptiveName, signature.outputSignature(), "output");

        if (!signature.isVoid() && signature.outputSignature().isEmpty()) {
            throw ProcedureException.classNotVoid(descriptiveName);
        }

        var supportedScopes = signature.supportedQueryLanguages();
        for (var scope : supportedScopes) {
            if (procedures.contains(name, scope)) {
                throw ProcedureException.procedureNameAlreadyInUse(name.toString());
            }
        }

        procedures.put(name, supportedScopes, proc, signature.caseInsensitive());
    }

    /**
     * Register a new function.
     *
     * @param function the function.
     */
    public void register(CallableUserFunction function) throws ProcedureException {
        UserFunctionSignature signature = function.signature();
        QualifiedName name = signature.name();
        var supportedScopes = signature.supportedQueryLanguages();

        for (var scope : supportedScopes) {
            if (aggregationFunctions.contains(name, scope)) {
                throw ProcedureException.aggregationFunctionNameAlreadyInUseAsAggregationFunction(name.toString());
            }

            if (functions.contains(name, scope)) {
                throw ProcedureException.functionNameAlreadyInUse(name.toString());
            }
        }

        functions.put(name, supportedScopes, function, signature.caseInsensitive());
    }

    /**
     * Register a new function.
     *
     * @param function the function.
     */
    public void register(CallableUserAggregationFunction function) throws ProcedureException {
        UserFunctionSignature signature = function.signature();
        QualifiedName name = signature.name();
        var supportedScopes = signature.supportedQueryLanguages();

        for (var scope : supportedScopes) {
            if (functions.contains(name, scope)) {
                throw ProcedureException.aggregationFunctionNameAlreadyInUseAsFunction(name.toString());
            }

            if (aggregationFunctions.contains(name, scope)) {
                throw ProcedureException.aggregationFunctionNameAlreadyInUse(name.toString());
            }
        }

        aggregationFunctions.put(name, supportedScopes, function, signature.caseInsensitive());
    }

    private void validateSignature(String descriptiveName, List fields, String fieldType)
            throws ProcedureException {
        Set names = new HashSet<>();
        for (FieldSignature field : fields) {
            if (!names.add(field.name())) {
                throw ProcedureException.duplicateFieldName(descriptiveName, fieldType, field.name());
            }
        }
    }

    public ProcedureHandle procedure(QualifiedName name, QueryLanguage scope) throws ProcedureException {
        CallableProcedure proc = procedures.getByKey(name, scope);
        if (proc == null) {
            throw ProcedureException.noSuchProcedure(name);
        }
        return new ProcedureHandle(proc.signature(), procedures.idOfKey(name, scope));
    }

    public UserFunctionHandle function(QualifiedName name, QueryLanguage scope) {
        CallableUserFunction func = functions.getByKey(name, scope);
        if (func == null) {
            return null;
        }
        return new UserFunctionHandle(func.signature(), functions.idOfKey(name, scope));
    }

    public UserFunctionHandle aggregationFunction(QualifiedName name, QueryLanguage scope) {
        CallableUserAggregationFunction func = aggregationFunctions.getByKey(name, scope);
        if (func == null) {
            return null;
        }
        return new UserFunctionHandle(func.signature(), aggregationFunctions.idOfKey(name, scope));
    }

    public ResourceRawIterator callProcedure(
            Context ctx, int id, AnyValue[] input, ResourceMonitor resourceMonitor) throws ProcedureException {
        CallableProcedure proc;
        try {
            proc = procedures.getById(id);
            var permission = ctx.securityContext().allowExecuteAdminProcedure(id);
            if (proc.signature().admin() && !permission.allowsAccess()) {
                String errorDescriptor = (permission == PermissionState.EXPLICIT_DENY)
                        ? "is not allowed"
                        : "permission has not been granted";
                String message = format(
                        "Executing admin procedure '%s' %s for %s.",
                        proc.signature().name(),
                        errorDescriptor,
                        ctx.securityContext().description());
                ctx.dependencyResolver()
                        .resolveDependency(AbstractSecurityLog.class)
                        .error(ctx.securityContext(), message);
                throw AuthorizationViolationException.authorizationViolation(message);
            }
            verifyDBType(ctx, proc);
        } catch (IndexOutOfBoundsException e) {
            throw ProcedureException.noSuchProcedure(id);
        }
        return proc.apply(ctx, input, resourceMonitor);
    }

    private void verifyDBType(Context ctx, CallableProcedure proc) throws ProcedureException {
        if (ctx.kernelTransaction().isSPDTransaction()
                && Arrays.stream(proc.signature().unsupportedDbTypes())
                        .anyMatch(t -> t.equals(UnsupportedDatabaseTypes.DatabaseType.SPD))) {
            throw new ProcedureException(
                    Status.Statement.SyntaxError,
                    "Procedure '" + proc.signature().name() + "' is not supported in SPD.");
        }
    }

    public AnyValue callFunction(Context ctx, int functionId, AnyValue[] input) throws ProcedureException {
        CallableUserFunction func;
        try {
            func = functions.getById(functionId);
        } catch (IndexOutOfBoundsException e) {
            throw ProcedureException.noSuchFunction(functionId);
        }
        return func.apply(ctx, input);
    }

    public UserAggregationReducer createAggregationFunction(Context ctx, int id) throws ProcedureException {
        try {
            CallableUserAggregationFunction func = aggregationFunctions.getById(id);
            return func.createReducer(ctx);
        } catch (IndexOutOfBoundsException e) {
            throw ProcedureException.noSuchFunction(id);
        }
    }

    public Stream getAllProcedures(QueryLanguage scope) {
        return stream(procedures, CallableProcedure::signature, (signature) -> signature
                .supportedQueryLanguages()
                .contains(scope));
    }

    int[] getIdsOfProceduresMatching(Predicate predicate) {
        return getIdsOf(procedures, predicate);
    }

    public Stream getAllNonAggregatingFunctions(QueryLanguage scope) {
        return stream(functions, CallableUserFunction::signature, (signature) -> signature
                .supportedQueryLanguages()
                .contains(scope));
    }

    int[] getIdsOfFunctionsMatching(Predicate predicate) {
        return getIdsOf(functions, predicate);
    }

    public Stream getAllAggregatingFunctions(QueryLanguage scope) {
        return stream(aggregationFunctions, CallableUserAggregationFunction::signature, (signature) -> signature
                .supportedQueryLanguages()
                .contains(scope));
    }

    int[] getIdsOfAggregatingFunctionsMatching(Predicate predicate) {
        return getIdsOf(aggregationFunctions, predicate);
    }

    @VisibleForTesting
    public void unregister(QualifiedName name) {
        procedures.unregister(name);
        functions.unregister(name);
        aggregationFunctions.unregister(name);
    }

    /**
     * Create an immutable copy of the ProcedureRegistry
     *
     * @param ref The source {@link ProcedureRegistry} to copy.
     *
     * @return an immutable copy of the source
     **/
    public static ProcedureRegistry copyOf(ProcedureRegistry ref) {
        return new ProcedureRegistry(
                ProcedureHolder.copyOf(ref.procedures),
                ProcedureHolder.copyOf(ref.functions),
                ProcedureHolder.copyOf(ref.aggregationFunctions));
    }

    /**
     * Create an tomestoned copy of the ProcedureRegistry
     *
     * @param ref The source {@link ProcedureRegistry} to tombstone and copy.
     * @param which Which QualifiedNames should be filtered.
     *
     * @return a tombstoned copy.
     **/
    public static ProcedureRegistry tombstone(ProcedureRegistry ref, Predicate which) {
        return new ProcedureRegistry(
                ProcedureHolder.tombstone(ref.procedures, which),
                ProcedureHolder.tombstone(ref.functions, which),
                ProcedureHolder.tombstone(ref.aggregationFunctions, which));
    }

    private static  int[] getIdsOf(ProcedureHolder holder, Predicate predicate) {
        var lst = new IntArrayList();
        holder.forEach((i, v) -> {
            if (predicate.test(v)) {
                lst.add(i);
            }
        });
        return lst.toArray();
    }

    private static  Stream stream(
            ProcedureHolder holder, Function transform, Predicate condition) {
        Stream.Builder builder = Stream.builder();
        holder.forEach((id, callable) -> {
            var value = transform.apply(callable);
            if (condition.test(value)) {
                builder.add(value);
            }
        });
        return builder.build();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy