All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.tasks.TaskCancellationService Maven / Gradle / Ivy

/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0 and the Server Side Public License, v 1; you may not use this file except
 * in compliance with, at your election, the Elastic License 2.0 or the Server
 * Side Public License, v 1.
 */

package org.elasticsearch.tasks;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ResultDeduplicator;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.concurrent.ListenableFuture;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.NodeDisconnectedException;
import org.elasticsearch.transport.NodeNotConnectedException;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Executor;

import static org.elasticsearch.core.Strings.format;

public class TaskCancellationService {
    public static final String BAN_PARENT_ACTION_NAME = "internal:admin/tasks/ban";
    public static final String CANCEL_CHILD_ACTION_NAME = "internal:admin/tasks/cancel_child";
    public static final TransportVersion VERSION_SUPPORTING_CANCEL_CHILD_ACTION = TransportVersion.V_8_8_0;
    private static final Logger logger = LogManager.getLogger(TaskCancellationService.class);
    private final TransportService transportService;
    private final TaskManager taskManager;
    private final ResultDeduplicator deduplicator;

    public TaskCancellationService(TransportService transportService) {
        this.transportService = transportService;
        this.taskManager = transportService.getTaskManager();
        this.deduplicator = new ResultDeduplicator<>(transportService.getThreadPool().getThreadContext());
        transportService.registerRequestHandler(
            BAN_PARENT_ACTION_NAME,
            ThreadPool.Names.SAME,
            BanParentTaskRequest::new,
            new BanParentRequestHandler()
        );
        transportService.registerRequestHandler(
            CANCEL_CHILD_ACTION_NAME,
            ThreadPool.Names.SAME,
            CancelChildRequest::new,
            new CancelChildRequestHandler()
        );
    }

    private String localNodeId() {
        return transportService.getLocalNode().getId();
    }

    private static class CancelRequest {
        final CancellableTask task;
        final boolean waitForCompletion;

        CancelRequest(CancellableTask task, boolean waitForCompletion) {
            this.task = task;
            this.waitForCompletion = waitForCompletion;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            final CancelRequest that = (CancelRequest) o;
            return waitForCompletion == that.waitForCompletion && Objects.equals(task, that.task);
        }

        @Override
        public int hashCode() {
            return Objects.hash(task, waitForCompletion);
        }
    }

    void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener finalListener) {
        deduplicator.executeOnce(
            new CancelRequest(task, waitForCompletion),
            finalListener,
            (r, listener) -> doCancelTaskAndDescendants(task, reason, waitForCompletion, listener)
        );
    }

    void doCancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener listener) {
        final TaskId taskId = task.taskInfo(localNodeId(), false).taskId();
        if (task.shouldCancelChildrenOnCancellation()) {
            logger.trace("cancelling task [{}] and its descendants", taskId);
            ListenableFuture completedListener = new ListenableFuture<>();
            ListenableFuture setBanListener = new ListenableFuture<>();

            Collection childConnections;
            try (var refs = new RefCountingRunnable(() -> setBanListener.addListener(completedListener))) {
                var banChildrenRef = refs.acquire();
                var cancelTaskRef = refs.acquire();

                childConnections = taskManager.startBanOnChildTasks(task.getId(), reason, () -> {
                    logger.trace("child tasks of parent [{}] are completed", taskId);
                    banChildrenRef.close();
                });

                taskManager.cancel(task, reason, () -> {
                    logger.trace("task [{}] is cancelled", taskId);
                    cancelTaskRef.close();
                });
            }
            setBanOnChildConnections(reason, waitForCompletion, task, childConnections, setBanListener);

            // We remove bans after all child tasks are completed although in theory we can do it on a per-connection basis.
            completedListener.addListener(
                ActionListener.running(
                    transportService.getThreadPool()
                        .getThreadContext()
                        // If we start unbanning when the last child task completed and that child task executed with a specific user, then
                        // unban requests are denied because internal requests can't run with a user. We need to remove bans with the
                        // current thread context.
                        .preserveContext(() -> removeBanOnChildConnections(task, childConnections))
                )
            );

            if (waitForCompletion) {
                // Wait until (1) bans are placed on child connections, (2) child tasks are completed or failed, (3) main task is cancelled.
                completedListener.addListener(listener);
            } else {
                // Only wait until bans are placed on child connections
                setBanListener.addListener(listener);
            }
        } else {
            logger.trace("task [{}] doesn't have any children that should be cancelled", taskId);
            if (waitForCompletion) {
                taskManager.cancel(task, reason, () -> listener.onResponse(null));
            } else {
                taskManager.cancel(task, reason, () -> {});
                listener.onResponse(null);
            }
        }
    }

    private void setBanOnChildConnections(
        String reason,
        boolean waitForCompletion,
        CancellableTask task,
        Collection childConnections,
        ActionListener listener
    ) {
        if (childConnections.isEmpty()) {
            listener.onResponse(null);
            return;
        }
        final TaskId taskId = new TaskId(localNodeId(), task.getId());
        logger.trace("cancelling child tasks of [{}] on child connections {}", taskId, childConnections);
        CountDownActionListener countDownListener = new CountDownActionListener(childConnections.size(), listener);
        final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(taskId, reason, waitForCompletion);
        for (Transport.Connection connection : childConnections) {
            assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
            transportService.sendRequest(
                connection,
                BAN_PARENT_ACTION_NAME,
                banRequest,
                TransportRequestOptions.EMPTY,
                new TransportResponseHandler.Empty() {
                    @Override
                    public Executor executor(ThreadPool threadPool) {
                        return TransportResponseHandler.TRANSPORT_WORKER;
                    }

                    @Override
                    public void handleResponse() {
                        logger.trace("sent ban for tasks with the parent [{}] for connection [{}]", taskId, connection);
                        countDownListener.onResponse(null);
                    }

                    @Override
                    public void handleException(TransportException exp) {
                        final Throwable cause = ExceptionsHelper.unwrapCause(exp);
                        assert cause instanceof ElasticsearchSecurityException == false;
                        if (isUnimportantBanFailure(cause)) {
                            logger.debug(
                                () -> format("cannot send ban for tasks with the parent [%s] on connection [%s]", taskId, connection),
                                exp
                            );
                        } else if (logger.isDebugEnabled()) {
                            logger.warn(
                                () -> format("cannot send ban for tasks with the parent [%s] on connection [%s]", taskId, connection),
                                exp
                            );
                        } else {
                            logger.warn(
                                "cannot send ban for tasks with the parent [{}] on connection [{}]: {}",
                                taskId,
                                connection,
                                exp.getMessage()
                            );
                        }

                        countDownListener.onFailure(exp);
                    }
                }
            );
        }
    }

    private void removeBanOnChildConnections(CancellableTask task, Collection childConnections) {
        final BanParentTaskRequest request = BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(localNodeId(), task.getId()));
        for (Transport.Connection connection : childConnections) {
            assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
            logger.trace("Sending remove ban for tasks with the parent [{}] for connection [{}]", request.parentTaskId, connection);
            transportService.sendRequest(
                connection,
                BAN_PARENT_ACTION_NAME,
                request,
                TransportRequestOptions.EMPTY,
                new TransportResponseHandler.Empty() {
                    @Override
                    public Executor executor(ThreadPool threadPool) {
                        return TransportResponseHandler.TRANSPORT_WORKER;
                    }

                    @Override
                    public void handleResponse() {}

                    @Override
                    public void handleException(TransportException exp) {
                        final Throwable cause = ExceptionsHelper.unwrapCause(exp);
                        assert cause instanceof ElasticsearchSecurityException == false;
                        if (isUnimportantBanFailure(cause)) {
                            logger.debug(
                                () -> format(
                                    "failed to remove ban for tasks with the parent [%s] on connection [%s]",
                                    request.parentTaskId,
                                    connection
                                ),
                                exp
                            );
                        } else if (logger.isDebugEnabled()) {
                            logger.warn(
                                () -> format(
                                    "failed to remove ban for tasks with the parent [%s] on connection [%s]",
                                    request.parentTaskId,
                                    connection
                                ),
                                exp
                            );
                        } else {
                            logger.warn(
                                "failed to remove ban for tasks with the parent [{}] on connection [{}]: {}",
                                request.parentTaskId,
                                connection,
                                exp.getMessage()
                            );
                        }
                    }
                }
            );
        }
    }

    private static boolean isUnimportantBanFailure(Throwable cause) {
        return cause instanceof NodeDisconnectedException || cause instanceof NodeNotConnectedException;
    }

    private static class BanParentTaskRequest extends TransportRequest {

        private final TaskId parentTaskId;
        private final boolean ban;
        private final boolean waitForCompletion;
        private final String reason;

        static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
            return new BanParentTaskRequest(parentTaskId, reason, waitForCompletion);
        }

        static BanParentTaskRequest createRemoveBanParentTaskRequest(TaskId parentTaskId) {
            return new BanParentTaskRequest(parentTaskId);
        }

        private BanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
            this.parentTaskId = parentTaskId;
            this.ban = true;
            this.reason = reason;
            this.waitForCompletion = waitForCompletion;
        }

        private BanParentTaskRequest(TaskId parentTaskId) {
            this.parentTaskId = parentTaskId;
            this.ban = false;
            this.reason = null;
            this.waitForCompletion = false;
        }

        private BanParentTaskRequest(StreamInput in) throws IOException {
            super(in);
            parentTaskId = TaskId.readFromStream(in);
            ban = in.readBoolean();
            reason = ban ? in.readString() : null;
            if (in.getTransportVersion().onOrAfter(TransportVersion.V_7_8_0)) {
                waitForCompletion = in.readBoolean();
            } else {
                waitForCompletion = false;
            }
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            super.writeTo(out);
            parentTaskId.writeTo(out);
            out.writeBoolean(ban);
            if (ban) {
                out.writeString(reason);
            }
            if (out.getTransportVersion().onOrAfter(TransportVersion.V_7_8_0)) {
                out.writeBoolean(waitForCompletion);
            }
        }
    }

    private class BanParentRequestHandler implements TransportRequestHandler {
        @Override
        public void messageReceived(final BanParentTaskRequest request, final TransportChannel channel, Task task) throws Exception {
            if (request.ban) {
                logger.debug(
                    "Received ban for the parent [{}] on the node [{}], reason: [{}]",
                    request.parentTaskId,
                    localNodeId(),
                    request.reason
                );
                final List childTasks = taskManager.setBan(request.parentTaskId, request.reason, channel);
                final GroupedActionListener listener = new GroupedActionListener<>(
                    childTasks.size() + 1,
                    new ChannelActionListener<>(channel).map(r -> TransportResponse.Empty.INSTANCE)
                );
                for (CancellableTask childTask : childTasks) {
                    cancelTaskAndDescendants(childTask, request.reason, request.waitForCompletion, listener);
                }
                listener.onResponse(null);
            } else {
                logger.debug("Removing ban for the parent [{}] on the node [{}]", request.parentTaskId, localNodeId());
                taskManager.removeBan(request.parentTaskId);
                channel.sendResponse(TransportResponse.Empty.INSTANCE);
            }
        }
    }

    private static class CancelChildRequest extends TransportRequest {

        private final TaskId parentTaskId;
        private final long childRequestId;
        private final String reason;

        static CancelChildRequest createCancelChildRequest(TaskId parentTaskId, long childRequestId, String reason) {
            return new CancelChildRequest(parentTaskId, childRequestId, reason);
        }

        private CancelChildRequest(TaskId parentTaskId, long childRequestId, String reason) {
            this.parentTaskId = parentTaskId;
            this.childRequestId = childRequestId;
            this.reason = reason;
        }

        private CancelChildRequest(StreamInput in) throws IOException {
            super(in);
            parentTaskId = TaskId.readFromStream(in);
            childRequestId = in.readLong();
            reason = in.readString();
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            super.writeTo(out);
            parentTaskId.writeTo(out);
            out.writeLong(childRequestId);
            out.writeString(reason);
        }
    }

    private class CancelChildRequestHandler implements TransportRequestHandler {
        @Override
        public void messageReceived(final CancelChildRequest request, final TransportChannel channel, Task task) throws Exception {
            taskManager.cancelChildLocal(request.parentTaskId, request.childRequestId, request.reason);
            channel.sendResponse(TransportResponse.Empty.INSTANCE);
        }
    }

    private static final TransportResponseHandler.Empty NOOP_HANDLER = TransportResponseHandler.empty(
        TransportResponseHandler.TRANSPORT_WORKER,
        ActionListener.noop()
    );

    /**
     * Sends an action to cancel a child task, associated with the given request ID and parent task.
     */
    public void cancelChildRemote(TaskId parentTask, long childRequestId, Transport.Connection childConnection, String reason) {
        if (childConnection.getTransportVersion().onOrAfter(VERSION_SUPPORTING_CANCEL_CHILD_ACTION)) {
            DiscoveryNode childNode = childConnection.getNode();
            logger.debug(
                "sending cancellation of child of parent task [{}] with request ID [{}] to node [{}] because of [{}]",
                parentTask,
                childRequestId,
                childNode,
                reason
            );
            final CancelChildRequest request = CancelChildRequest.createCancelChildRequest(parentTask, childRequestId, reason);
            transportService.sendRequest(childNode, CANCEL_CHILD_ACTION_NAME, request, TransportRequestOptions.EMPTY, NOOP_HANDLER);
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy