All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.tasks.TaskCancellationService Maven / Gradle / Ivy

There is a newer version: 8.14.0
Show newest version
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0 and the Server Side Public License, v 1; you may not use this file except
 * in compliance with, at your election, the Elastic License 2.0 or the Server
 * Side Public License, v 1.
 */

package org.elasticsearch.tasks;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ResultDeduplicator;
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.EmptyTransportResponseHandler;
import org.elasticsearch.transport.NodeDisconnectedException;
import org.elasticsearch.transport.NodeNotConnectedException;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Objects;

public class TaskCancellationService {
    public static final String BAN_PARENT_ACTION_NAME = "internal:admin/tasks/ban";
    private static final Logger logger = LogManager.getLogger(TaskCancellationService.class);
    private final TransportService transportService;
    private final TaskManager taskManager;
    private final ResultDeduplicator deduplicator;

    public TaskCancellationService(TransportService transportService) {
        this.transportService = transportService;
        this.taskManager = transportService.getTaskManager();
        this.deduplicator = new ResultDeduplicator<>(transportService.getThreadPool().getThreadContext());
        transportService.registerRequestHandler(
            BAN_PARENT_ACTION_NAME,
            ThreadPool.Names.SAME,
            BanParentTaskRequest::new,
            new BanParentRequestHandler()
        );
    }

    private String localNodeId() {
        return transportService.getLocalNode().getId();
    }

    private static class CancelRequest {
        final CancellableTask task;
        final boolean waitForCompletion;

        CancelRequest(CancellableTask task, boolean waitForCompletion) {
            this.task = task;
            this.waitForCompletion = waitForCompletion;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            final CancelRequest that = (CancelRequest) o;
            return waitForCompletion == that.waitForCompletion && Objects.equals(task, that.task);
        }

        @Override
        public int hashCode() {
            return Objects.hash(task, waitForCompletion);
        }
    }

    void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener finalListener) {
        deduplicator.executeOnce(
            new CancelRequest(task, waitForCompletion),
            finalListener,
            (r, listener) -> doCancelTaskAndDescendants(task, reason, waitForCompletion, listener)
        );
    }

    void doCancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener listener) {
        final TaskId taskId = task.taskInfo(localNodeId(), false).taskId();
        if (task.shouldCancelChildrenOnCancellation()) {
            logger.trace("cancelling task [{}] and its descendants", taskId);
            StepListener completedListener = new StepListener<>();
            GroupedActionListener groupedListener = new GroupedActionListener<>(completedListener.map(r -> null), 3);
            Collection childConnections = taskManager.startBanOnChildTasks(task.getId(), reason, () -> {
                logger.trace("child tasks of parent [{}] are completed", taskId);
                groupedListener.onResponse(null);
            });
            taskManager.cancel(task, reason, () -> {
                logger.trace("task [{}] is cancelled", taskId);
                groupedListener.onResponse(null);
            });
            StepListener setBanListener = new StepListener<>();
            setBanOnChildConnections(reason, waitForCompletion, task, childConnections, setBanListener);
            setBanListener.addListener(groupedListener);
            // If we start unbanning when the last child task completed and that child task executed with a specific user, then unban
            // requests are denied because internal requests can't run with a user. We need to remove bans with the current thread context.
            final Runnable removeBansRunnable = transportService.getThreadPool()
                .getThreadContext()
                .preserveContext(() -> removeBanOnChildConnections(task, childConnections));
            // We remove bans after all child tasks are completed although in theory we can do it on a per-connection basis.
            completedListener.whenComplete(r -> removeBansRunnable.run(), e -> removeBansRunnable.run());
            // if wait_for_completion is true, then only return when (1) bans are placed on child connections, (2) child tasks are
            // completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child connections.
            if (waitForCompletion) {
                completedListener.addListener(listener);
            } else {
                setBanListener.addListener(listener);
            }
        } else {
            logger.trace("task [{}] doesn't have any children that should be cancelled", taskId);
            if (waitForCompletion) {
                taskManager.cancel(task, reason, () -> listener.onResponse(null));
            } else {
                taskManager.cancel(task, reason, () -> {});
                listener.onResponse(null);
            }
        }
    }

    private void setBanOnChildConnections(
        String reason,
        boolean waitForCompletion,
        CancellableTask task,
        Collection childConnections,
        ActionListener listener
    ) {
        if (childConnections.isEmpty()) {
            listener.onResponse(null);
            return;
        }
        final TaskId taskId = new TaskId(localNodeId(), task.getId());
        logger.trace("cancelling child tasks of [{}] on child connections {}", taskId, childConnections);
        GroupedActionListener groupedListener = new GroupedActionListener<>(listener.map(r -> null), childConnections.size());
        final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(taskId, reason, waitForCompletion);
        for (Transport.Connection connection : childConnections) {
            assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
            transportService.sendRequest(
                connection,
                BAN_PARENT_ACTION_NAME,
                banRequest,
                TransportRequestOptions.EMPTY,
                new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
                    @Override
                    public void handleResponse(TransportResponse.Empty response) {
                        logger.trace("sent ban for tasks with the parent [{}] for connection [{}]", taskId, connection);
                        groupedListener.onResponse(null);
                    }

                    @Override
                    public void handleException(TransportException exp) {
                        final Throwable cause = ExceptionsHelper.unwrapCause(exp);
                        assert cause instanceof ElasticsearchSecurityException == false;
                        if (isUnimportantBanFailure(cause)) {
                            logger.debug(
                                new ParameterizedMessage(
                                    "cannot send ban for tasks with the parent [{}] on connection [{}]",
                                    taskId,
                                    connection
                                ),
                                exp
                            );
                        } else if (logger.isDebugEnabled()) {
                            logger.warn(
                                new ParameterizedMessage(
                                    "cannot send ban for tasks with the parent [{}] on connection [{}]",
                                    taskId,
                                    connection
                                ),
                                exp
                            );
                        } else {
                            logger.warn(
                                "cannot send ban for tasks with the parent [{}] on connection [{}]: {}",
                                taskId,
                                connection,
                                exp.getMessage()
                            );
                        }

                        groupedListener.onFailure(exp);
                    }
                }
            );
        }
    }

    private void removeBanOnChildConnections(CancellableTask task, Collection childConnections) {
        final BanParentTaskRequest request = BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(localNodeId(), task.getId()));
        for (Transport.Connection connection : childConnections) {
            assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
            logger.trace("Sending remove ban for tasks with the parent [{}] for connection [{}]", request.parentTaskId, connection);
            transportService.sendRequest(
                connection,
                BAN_PARENT_ACTION_NAME,
                request,
                TransportRequestOptions.EMPTY,
                new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
                    @Override
                    public void handleException(TransportException exp) {
                        final Throwable cause = ExceptionsHelper.unwrapCause(exp);
                        assert cause instanceof ElasticsearchSecurityException == false;
                        if (isUnimportantBanFailure(cause)) {
                            logger.debug(
                                new ParameterizedMessage(
                                    "failed to remove ban for tasks with the parent [{}] on connection [{}]",
                                    request.parentTaskId,
                                    connection
                                ),
                                exp
                            );
                        } else if (logger.isDebugEnabled()) {
                            logger.warn(
                                new ParameterizedMessage(
                                    "failed to remove ban for tasks with the parent [{}] on connection [{}]",
                                    request.parentTaskId,
                                    connection
                                ),
                                exp
                            );
                        } else {
                            logger.warn(
                                "failed to remove ban for tasks with the parent [{}] on connection [{}]: {}",
                                request.parentTaskId,
                                connection,
                                exp.getMessage()
                            );
                        }
                    }
                }
            );
        }
    }

    private static boolean isUnimportantBanFailure(Throwable cause) {
        return cause instanceof NodeDisconnectedException || cause instanceof NodeNotConnectedException;
    }

    private static class BanParentTaskRequest extends TransportRequest {

        private final TaskId parentTaskId;
        private final boolean ban;
        private final boolean waitForCompletion;
        private final String reason;

        static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
            return new BanParentTaskRequest(parentTaskId, reason, waitForCompletion);
        }

        static BanParentTaskRequest createRemoveBanParentTaskRequest(TaskId parentTaskId) {
            return new BanParentTaskRequest(parentTaskId);
        }

        private BanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
            this.parentTaskId = parentTaskId;
            this.ban = true;
            this.reason = reason;
            this.waitForCompletion = waitForCompletion;
        }

        private BanParentTaskRequest(TaskId parentTaskId) {
            this.parentTaskId = parentTaskId;
            this.ban = false;
            this.reason = null;
            this.waitForCompletion = false;
        }

        private BanParentTaskRequest(StreamInput in) throws IOException {
            super(in);
            parentTaskId = TaskId.readFromStream(in);
            ban = in.readBoolean();
            reason = ban ? in.readString() : null;
            if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
                waitForCompletion = in.readBoolean();
            } else {
                waitForCompletion = false;
            }
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            super.writeTo(out);
            parentTaskId.writeTo(out);
            out.writeBoolean(ban);
            if (ban) {
                out.writeString(reason);
            }
            if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
                out.writeBoolean(waitForCompletion);
            }
        }
    }

    private class BanParentRequestHandler implements TransportRequestHandler {
        @Override
        public void messageReceived(final BanParentTaskRequest request, final TransportChannel channel, Task task) throws Exception {
            if (request.ban) {
                logger.debug(
                    "Received ban for the parent [{}] on the node [{}], reason: [{}]",
                    request.parentTaskId,
                    localNodeId(),
                    request.reason
                );
                final List childTasks = taskManager.setBan(request.parentTaskId, request.reason, channel);
                final GroupedActionListener listener = new GroupedActionListener<>(
                    new ChannelActionListener<>(channel, BAN_PARENT_ACTION_NAME, request).map(r -> TransportResponse.Empty.INSTANCE),
                    childTasks.size() + 1
                );
                for (CancellableTask childTask : childTasks) {
                    cancelTaskAndDescendants(childTask, request.reason, request.waitForCompletion, listener);
                }
                listener.onResponse(null);
            } else {
                logger.debug("Removing ban for the parent [{}] on the node [{}]", request.parentTaskId, localNodeId());
                taskManager.removeBan(request.parentTaskId);
                channel.sendResponse(TransportResponse.Empty.INSTANCE);
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy