
org.elasticsearch.tasks.TaskCancellationService Maven / Gradle / Ivy
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.tasks;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ResultDeduplicator;
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.EmptyTransportResponseHandler;
import org.elasticsearch.transport.NodeDisconnectedException;
import org.elasticsearch.transport.NodeNotConnectedException;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
public class TaskCancellationService {
public static final String BAN_PARENT_ACTION_NAME = "internal:admin/tasks/ban";
private static final Logger logger = LogManager.getLogger(TaskCancellationService.class);
private final TransportService transportService;
private final TaskManager taskManager;
private final ResultDeduplicator deduplicator;
public TaskCancellationService(TransportService transportService) {
this.transportService = transportService;
this.taskManager = transportService.getTaskManager();
this.deduplicator = new ResultDeduplicator<>(transportService.getThreadPool().getThreadContext());
transportService.registerRequestHandler(
BAN_PARENT_ACTION_NAME,
ThreadPool.Names.SAME,
BanParentTaskRequest::new,
new BanParentRequestHandler()
);
}
private String localNodeId() {
return transportService.getLocalNode().getId();
}
private static class CancelRequest {
final CancellableTask task;
final boolean waitForCompletion;
CancelRequest(CancellableTask task, boolean waitForCompletion) {
this.task = task;
this.waitForCompletion = waitForCompletion;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final CancelRequest that = (CancelRequest) o;
return waitForCompletion == that.waitForCompletion && Objects.equals(task, that.task);
}
@Override
public int hashCode() {
return Objects.hash(task, waitForCompletion);
}
}
void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener finalListener) {
deduplicator.executeOnce(
new CancelRequest(task, waitForCompletion),
finalListener,
(r, listener) -> doCancelTaskAndDescendants(task, reason, waitForCompletion, listener)
);
}
void doCancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener listener) {
final TaskId taskId = task.taskInfo(localNodeId(), false).taskId();
if (task.shouldCancelChildrenOnCancellation()) {
logger.trace("cancelling task [{}] and its descendants", taskId);
StepListener completedListener = new StepListener<>();
GroupedActionListener groupedListener = new GroupedActionListener<>(completedListener.map(r -> null), 3);
Collection childConnections = taskManager.startBanOnChildTasks(task.getId(), reason, () -> {
logger.trace("child tasks of parent [{}] are completed", taskId);
groupedListener.onResponse(null);
});
taskManager.cancel(task, reason, () -> {
logger.trace("task [{}] is cancelled", taskId);
groupedListener.onResponse(null);
});
StepListener setBanListener = new StepListener<>();
setBanOnChildConnections(reason, waitForCompletion, task, childConnections, setBanListener);
setBanListener.addListener(groupedListener);
// If we start unbanning when the last child task completed and that child task executed with a specific user, then unban
// requests are denied because internal requests can't run with a user. We need to remove bans with the current thread context.
final Runnable removeBansRunnable = transportService.getThreadPool()
.getThreadContext()
.preserveContext(() -> removeBanOnChildConnections(task, childConnections));
// We remove bans after all child tasks are completed although in theory we can do it on a per-connection basis.
completedListener.whenComplete(r -> removeBansRunnable.run(), e -> removeBansRunnable.run());
// if wait_for_completion is true, then only return when (1) bans are placed on child connections, (2) child tasks are
// completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child connections.
if (waitForCompletion) {
completedListener.addListener(listener);
} else {
setBanListener.addListener(listener);
}
} else {
logger.trace("task [{}] doesn't have any children that should be cancelled", taskId);
if (waitForCompletion) {
taskManager.cancel(task, reason, () -> listener.onResponse(null));
} else {
taskManager.cancel(task, reason, () -> {});
listener.onResponse(null);
}
}
}
private void setBanOnChildConnections(
String reason,
boolean waitForCompletion,
CancellableTask task,
Collection childConnections,
ActionListener listener
) {
if (childConnections.isEmpty()) {
listener.onResponse(null);
return;
}
final TaskId taskId = new TaskId(localNodeId(), task.getId());
logger.trace("cancelling child tasks of [{}] on child connections {}", taskId, childConnections);
GroupedActionListener groupedListener = new GroupedActionListener<>(listener.map(r -> null), childConnections.size());
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(taskId, reason, waitForCompletion);
for (Transport.Connection connection : childConnections) {
assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
transportService.sendRequest(
connection,
BAN_PARENT_ACTION_NAME,
banRequest,
TransportRequestOptions.EMPTY,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleResponse(TransportResponse.Empty response) {
logger.trace("sent ban for tasks with the parent [{}] for connection [{}]", taskId, connection);
groupedListener.onResponse(null);
}
@Override
public void handleException(TransportException exp) {
final Throwable cause = ExceptionsHelper.unwrapCause(exp);
assert cause instanceof ElasticsearchSecurityException == false;
if (isUnimportantBanFailure(cause)) {
logger.debug(
new ParameterizedMessage(
"cannot send ban for tasks with the parent [{}] on connection [{}]",
taskId,
connection
),
exp
);
} else if (logger.isDebugEnabled()) {
logger.warn(
new ParameterizedMessage(
"cannot send ban for tasks with the parent [{}] on connection [{}]",
taskId,
connection
),
exp
);
} else {
logger.warn(
"cannot send ban for tasks with the parent [{}] on connection [{}]: {}",
taskId,
connection,
exp.getMessage()
);
}
groupedListener.onFailure(exp);
}
}
);
}
}
private void removeBanOnChildConnections(CancellableTask task, Collection childConnections) {
final BanParentTaskRequest request = BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(localNodeId(), task.getId()));
for (Transport.Connection connection : childConnections) {
assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
logger.trace("Sending remove ban for tasks with the parent [{}] for connection [{}]", request.parentTaskId, connection);
transportService.sendRequest(
connection,
BAN_PARENT_ACTION_NAME,
request,
TransportRequestOptions.EMPTY,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleException(TransportException exp) {
final Throwable cause = ExceptionsHelper.unwrapCause(exp);
assert cause instanceof ElasticsearchSecurityException == false;
if (isUnimportantBanFailure(cause)) {
logger.debug(
new ParameterizedMessage(
"failed to remove ban for tasks with the parent [{}] on connection [{}]",
request.parentTaskId,
connection
),
exp
);
} else if (logger.isDebugEnabled()) {
logger.warn(
new ParameterizedMessage(
"failed to remove ban for tasks with the parent [{}] on connection [{}]",
request.parentTaskId,
connection
),
exp
);
} else {
logger.warn(
"failed to remove ban for tasks with the parent [{}] on connection [{}]: {}",
request.parentTaskId,
connection,
exp.getMessage()
);
}
}
}
);
}
}
private static boolean isUnimportantBanFailure(Throwable cause) {
return cause instanceof NodeDisconnectedException || cause instanceof NodeNotConnectedException;
}
private static class BanParentTaskRequest extends TransportRequest {
private final TaskId parentTaskId;
private final boolean ban;
private final boolean waitForCompletion;
private final String reason;
static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
return new BanParentTaskRequest(parentTaskId, reason, waitForCompletion);
}
static BanParentTaskRequest createRemoveBanParentTaskRequest(TaskId parentTaskId) {
return new BanParentTaskRequest(parentTaskId);
}
private BanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
this.parentTaskId = parentTaskId;
this.ban = true;
this.reason = reason;
this.waitForCompletion = waitForCompletion;
}
private BanParentTaskRequest(TaskId parentTaskId) {
this.parentTaskId = parentTaskId;
this.ban = false;
this.reason = null;
this.waitForCompletion = false;
}
private BanParentTaskRequest(StreamInput in) throws IOException {
super(in);
parentTaskId = TaskId.readFromStream(in);
ban = in.readBoolean();
reason = ban ? in.readString() : null;
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
waitForCompletion = in.readBoolean();
} else {
waitForCompletion = false;
}
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
parentTaskId.writeTo(out);
out.writeBoolean(ban);
if (ban) {
out.writeString(reason);
}
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeBoolean(waitForCompletion);
}
}
}
private class BanParentRequestHandler implements TransportRequestHandler {
@Override
public void messageReceived(final BanParentTaskRequest request, final TransportChannel channel, Task task) throws Exception {
if (request.ban) {
logger.debug(
"Received ban for the parent [{}] on the node [{}], reason: [{}]",
request.parentTaskId,
localNodeId(),
request.reason
);
final List childTasks = taskManager.setBan(request.parentTaskId, request.reason, channel);
final GroupedActionListener listener = new GroupedActionListener<>(
new ChannelActionListener<>(channel, BAN_PARENT_ACTION_NAME, request).map(r -> TransportResponse.Empty.INSTANCE),
childTasks.size() + 1
);
for (CancellableTask childTask : childTasks) {
cancelTaskAndDescendants(childTask, request.reason, request.waitForCompletion, listener);
}
listener.onResponse(null);
} else {
logger.debug("Removing ban for the parent [{}] on the node [{}]", request.parentTaskId, localNodeId());
taskManager.removeBan(request.parentTaskId);
channel.sendResponse(TransportResponse.Empty.INSTANCE);
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy