All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.fabric.executor.CallInTransactionsExecutor Maven / Gradle / Ivy

There is a newer version: 5.25.1
Show newest version
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [https://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.fabric.executor;

import static scala.jdk.javaapi.CollectionConverters.asJava;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import org.neo4j.bolt.protocol.common.message.AccessMode;
import org.neo4j.cypher.internal.FullyParsedQuery;
import org.neo4j.cypher.internal.ast.SubqueryCall;
import org.neo4j.cypher.internal.expressions.ExplicitParameter;
import org.neo4j.cypher.internal.expressions.SignedDecimalIntegerLiteral;
import org.neo4j.cypher.internal.logical.plans.TransactionForeach$;
import org.neo4j.exceptions.ParameterNotFoundException;
import org.neo4j.fabric.eval.Catalog;
import org.neo4j.fabric.eval.UseEvaluation;
import org.neo4j.fabric.planning.FabricPlan;
import org.neo4j.fabric.planning.FabricPlanner;
import org.neo4j.fabric.planning.Fragment;
import org.neo4j.fabric.stream.Prefetcher;
import org.neo4j.fabric.stream.Record;
import org.neo4j.fabric.stream.Records;
import org.neo4j.fabric.stream.StatementResult;
import org.neo4j.fabric.stream.summary.MergedQueryStatistics;
import org.neo4j.fabric.transaction.FabricTransaction;
import org.neo4j.fabric.transaction.TransactionMode;
import org.neo4j.graphdb.GqlStatusObject;
import org.neo4j.graphdb.Notification;
import org.neo4j.kernel.api.exceptions.Status;
import org.neo4j.kernel.impl.query.QueryRoutingMonitor;
import org.neo4j.values.AnyValue;
import org.neo4j.values.storable.BooleanValue;
import org.neo4j.values.storable.IntegralValue;
import org.neo4j.values.storable.LongValue;
import org.neo4j.values.storable.NoValue;
import org.neo4j.values.storable.Values;
import org.neo4j.values.virtual.ListValueBuilder;
import org.neo4j.values.virtual.MapValue;
import org.neo4j.values.virtual.MapValueBuilder;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

class CallInTransactionsExecutor extends SingleQueryFragmentExecutor {

    private final Fragment.Apply callInTransactions;
    private final Fragment.Exec innerFragment;
    private final int batchSize;
    private final List inputRowsBuffer;
    private Catalog.Graph batchGraph;
    private TransactionMode batchTransactionMode;
    private OnErrorBreakContext onErrorBreakContext;

    CallInTransactionsExecutor(
            Fragment.Apply callInTransactions,
            FabricPlanner.PlannerInstance plannerInstance,
            Executor fabricWorkerExecutor,
            FabricTransaction.FabricExecutionContext ctx,
            UseEvaluation.Instance useEvaluator,
            FabricPlan plan,
            MapValue queryParams,
            AccessMode accessMode,
            Set notifications,
            Set gqlStatusObjects,
            QueryStatementLifecycles.StatementLifecycle lifecycle,
            Prefetcher prefetcher,
            QueryRoutingMonitor queryRoutingMonitor,
            MergedQueryStatistics statistics,
            Tracer tracer,
            FragmentExecutor fragmentExecutor) {
        super(
                plannerInstance,
                fabricWorkerExecutor,
                ctx,
                useEvaluator,
                plan,
                queryParams,
                accessMode,
                notifications,
                gqlStatusObjects,
                lifecycle,
                prefetcher,
                queryRoutingMonitor,
                statistics,
                tracer,
                fragmentExecutor);
        this.callInTransactions = callInTransactions;
        this.innerFragment = (Fragment.Exec) callInTransactions.inner();
        this.batchSize = batchSize();
        inputRowsBuffer = new ArrayList<>(batchSize);
        this.onErrorBreakContext = onErrorBreakContext();
    }

    private OnErrorBreakContext onErrorBreakContext() {
        var parameters = callInTransactions.inTransactionsParameters().get();
        if (!CallInTransactionsExecutorUtil.isOnErrorBreak(parameters)) {
            return null;
        }

        int variableOffset = extractBreakReportVariableOffset(parameters);
        return new OnErrorBreakContext(variableOffset, parameters.reportParams().isEmpty(), false);
    }

    private int extractBreakReportVariableOffset(SubqueryCall.InTransactionsParameters parameters) {
        var variableName = parameters
                .reportParams()
                .map(reportParameters -> reportParameters.reportAs().name())
                .getOrElse(Fragment.Apply$.MODULE$::REPORT_VARIABLE);

        List columns = asJava(innerFragment.outputColumns());
        for (int i = 0; i < columns.size(); i++) {
            if (columns.get(i).equals(variableName)) {
                return i;
            }
        }

        throw new IllegalStateException("Report variable not found among columns: " + columns);
    }

    Flux run(Record argument) {
        Flux input =
                fragmentExecutor().run(callInTransactions.input(), argument).records();
        Flux resultPipe = input.flatMap(this::processInputRecord, 1, 1);
        return Flux.concat(resultPipe, Flux.defer(this::processBufferedInputRows));
    }

    private int batchSize() {
        return callInTransactions
                .inTransactionsParameters()
                .flatMap(SubqueryCall.InTransactionsParameters::batchParams)
                .map(SubqueryCall.InTransactionsBatchParameters::batchSize)
                .map(expression -> {
                    if (expression instanceof SignedDecimalIntegerLiteral literal) {
                        return literal.value().intValue();
                    }

                    if (expression instanceof ExplicitParameter parameter) {
                        return batchSizeFromParam(parameter);
                    }

                    throw new IllegalArgumentException("Unexpected batch size expression: " + expression);
                })
                .getOrElse(() -> (int) TransactionForeach$.MODULE$.defaultBatchSize());
    }

    private int batchSizeFromParam(ExplicitParameter parameter) {
        AnyValue paramValue = queryParams().get(parameter.name());
        if (paramValue instanceof LongValue longValue) {
            return (int) longValue.value();
        }

        if (paramValue instanceof IntegralValue integralValue) {
            return integralValue.intValue();
        }

        if (paramValue instanceof NoValue) {
            throw new ParameterNotFoundException("Expected parameter(s): " + parameter.name());
        }

        // Parameter types are checked by semantic analysis, so this is here
        // only in case a wrong param can somehow sneak past the semantic analysis.
        throw new FabricException(
                // The semantic analysis check uses this status code,
                // so let's do the same even thought SyntaxError is a bit weird
                // for this case.
                Status.Statement.SyntaxError,
                "Type mismatch for parameter '%s': expected Integer but was %s"
                        .formatted(parameter.name(), paramValue.getTypeName()));
    }

    private Flux processInputRecord(Record argument) {
        if (onErrorBreakContext != null && onErrorBreakContext.breakExecution) {
            return produceBreakOutput(argument);
        }

        PrepareResult prepareResult = prepare(innerFragment, argument);

        if (batchGraph == null) {
            batchGraph = prepareResult.graph();
            batchTransactionMode = prepareResult.transactionMode();
        }

        if (!batchGraph.equals(prepareResult.graph())) {
            Flux result = processBufferedInputRows();
            batchGraph = prepareResult.graph();
            batchTransactionMode = prepareResult.transactionMode();
            inputRowsBuffer.add(new BufferedInputRow(prepareResult.argumentValues(), argument));
            return result;
        }

        inputRowsBuffer.add(new BufferedInputRow(prepareResult.argumentValues(), argument));
        if (inputRowsBuffer.size() == batchSize) {
            return processBufferedInputRows();
        }

        return Flux.empty();
    }

    private Flux produceBreakOutput(Record argument) {
        List columns = asJava(innerFragment.outputColumns());
        int columnCount = columns.size() - addedColumnsCount();
        List values = new ArrayList<>(columnCount);
        for (int i = 0; i < columnCount; i++) {
            if (i == onErrorBreakContext.reportVariableOffset) {
                MapValueBuilder builder = new MapValueBuilder(4);
                builder.add("started", BooleanValue.FALSE);
                builder.add("committed", BooleanValue.FALSE);
                builder.add("transactionId", NoValue.NO_VALUE);
                builder.add("errorMessage", NoValue.NO_VALUE);
                values.add(builder.build());
            } else {
                values.add(NoValue.NO_VALUE);
            }
        }

        return Flux.just(Records.join(argument, Records.of(values)));
    }

    private Flux processBufferedInputRows() {
        if (inputRowsBuffer.isEmpty()) {
            return Flux.empty();
        }

        MapValue params = addParamsFromInputRows();

        var result = doExecuteFragment(
                innerFragment,
                params,
                batchGraph,
                batchTransactionMode,
                // Inner part of CALL IN TRANSACTIONS does not have a child plan node
                // Unlike standard query execution with which most logic is shared
                () -> new FragmentResult(Flux.just(Records.empty()), Mono.empty(), Mono.empty()));
        var inputRecords = new ArrayList<>(inputRowsBuffer);
        Flux resultStream = result.records();
        if (onErrorBreakContext != null) {
            resultStream = resultStream.map(this::checkBreakCondition);
        }

        if (callInTransactions.outputColumns().isEmpty()) {
            resultStream = resultStream.map(outputRecord -> getMatchingInputRecord(outputRecord, inputRecords));
        } else {
            resultStream = resultStream.map(outputRecord ->
                    Records.join(getMatchingInputRecord(outputRecord, inputRecords), stripAddedColumns(outputRecord)));
        }
        batchGraph = null;
        batchTransactionMode = null;
        inputRowsBuffer.clear();
        return resultStream;
    }

    private Record getMatchingInputRecord(Record outputRecord, List inputRecords) {
        // We are using the knowledge that Stitcher adds this always as the last column.
        var rowIdColumn = innerFragment.outputColumns().size() - 1;
        var rowId = (IntegralValue) outputRecord.getValue(rowIdColumn);
        var rowIdAsInt = rowId instanceof LongValue ? (int) ((LongValue) rowId).value() : rowId.intValue();
        return inputRecords.get(rowIdAsInt).record;
    }

    private Record stripAddedColumns(Record record) {
        // Stitcher adds 1-2 columns always as the last columns
        int columnCount = record.size() - addedColumnsCount();
        AnyValue[] values = new AnyValue[columnCount];
        // We are using the knowledge that Stitcher adds Row ID column always as the last column.
        for (int i = 0; i < columnCount; i++) {
            values[i] = record.getValue(i);
        }

        return Records.of(values);
    }

    private int addedColumnsCount() {
        int addedColumnsCount = 1;
        if (onErrorBreakContext != null && onErrorBreakContext.reportVariableAdded) {
            addedColumnsCount++;
        }

        return addedColumnsCount;
    }

    private Record checkBreakCondition(Record outputRecord) {
        var value = outputRecord.getValue(onErrorBreakContext.reportVariableOffset);
        var mapValue = (MapValue) value;
        if (mapValue.get("errorMessage") != NoValue.NO_VALUE) {
            onErrorBreakContext = new OnErrorBreakContext(
                    onErrorBreakContext.reportVariableOffset, onErrorBreakContext.reportVariableAdded, true);
        }

        return outputRecord;
    }

    private MapValue addParamsFromInputRows() {
        List bindings = asJava(innerFragment.argumentColumns());

        var rowListBuilder = ListValueBuilder.newListBuilder(inputRowsBuffer.size());
        for (int i = 0; i < inputRowsBuffer.size(); i++) {
            MapValue rowParams = rowToParams(inputRowsBuffer.get(i), bindings, i);
            rowListBuilder.add(rowParams);
        }
        var rows = rowListBuilder.build();

        MapValueBuilder builder = new MapValueBuilder(queryParams().size() + 1);
        queryParams().foreach(builder::add);
        builder.add(Fragment.Apply$.MODULE$.CALL_IN_TX_ROWS(), rows);
        return builder.build();
    }

    private MapValue rowToParams(BufferedInputRow inputRow, List bindings, int rowId) {
        MapValueBuilder builder = new MapValueBuilder(bindings.size() + 1);
        bindings.forEach(
                var -> builder.add(var, validateValue(inputRow.argumentValues().get(var))));
        builder.add(Fragment.Apply$.MODULE$.CALL_IN_TX_ROW_ID(), Values.intValue(rowId));
        return builder.build();
    }

    @Override
    Mono runRemote(
            Location.Remote location,
            ExecutionOptions options,
            String query,
            TransactionMode transactionMode,
            MapValue params) {
        return Mono.just(
                ctx().getRemote().runInAutocommitTransaction(location, options, query, transactionMode, params));
    }

    @Override
    StatementResult runLocal(
            Location.Local location,
            TransactionMode transactionMode,
            QueryStatementLifecycles.StatementLifecycle parentLifecycle,
            FullyParsedQuery query,
            MapValue params,
            Flux input,
            ExecutionOptions executionOptions,
            Boolean targetsComposite) {
        return ctx().getLocal()
                .runInAutocommitTransaction(location, parentLifecycle, query, params, input, executionOptions);
    }

    private record BufferedInputRow(Map argumentValues, Record record) {}

    private record OnErrorBreakContext(int reportVariableOffset, boolean reportVariableAdded, boolean breakExecution) {}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy