org.elasticsearch.common.util.concurrent.ThreadContext Maven / Gradle / Ivy
Show all versions of elasticsearch Show documentation
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.common.util.concurrent;
import org.apache.lucene.util.CloseableThreadLocal;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Setting.Property;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.store.Store;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with
* a thread. It allows to store and retrieve header information across method calls, network calls as well as threads spawned from a
* thread that has a {@link ThreadContext} associated with. Threads spawned from a {@link org.elasticsearch.threadpool.ThreadPool} have out of the box
* support for {@link ThreadContext} and all threads spawned will inherit the {@link ThreadContext} from the thread that it is forking from.".
* Network calls will also preserve the senders headers automatically.
*
* Consumers of ThreadContext usually don't need to interact with adding or stashing contexts. Every elasticsearch thread is managed by a thread pool or executor
* being responsible for stashing and restoring the threads context. For instance if a network request is received, all headers are deserialized from the network
* and directly added as the headers of the threads {@link ThreadContext} (see {@link #readHeaders(StreamInput)}. In order to not modify the context that is currently
* active on this thread the network code uses a try/with pattern to stash it's current context, read headers into a fresh one and once the request is handled or a handler thread
* is forked (which in turn inherits the context) it restores the previous context. For instance:
*
*
* // current context is stashed and replaced with a default context
* try (StoredContext context = threadContext.stashContext()) {
* threadContext.readHeaders(in); // read headers into current context
* if (fork) {
* threadPool.execute(() -> request.handle()); // inherits context
* } else {
* request.handle();
* }
* }
* // previous context is restored on StoredContext#close()
*
*
*/
public final class ThreadContext implements Closeable, Writeable {
public static final String PREFIX = "request.headers";
public static final Setting DEFAULT_HEADERS_SETTING = Setting.groupSetting(PREFIX + ".", Property.NodeScope);
private static final ThreadContextStruct DEFAULT_CONTEXT = new ThreadContextStruct();
private final Map defaultHeader;
private final ContextThreadLocal threadLocal;
/**
* Creates a new ThreadContext instance
* @param settings the settings to read the default request headers from
*/
public ThreadContext(Settings settings) {
Settings headers = DEFAULT_HEADERS_SETTING.get(settings);
if (headers == null) {
this.defaultHeader = Collections.emptyMap();
} else {
Map defaultHeader = new HashMap<>();
for (String key : headers.names()) {
defaultHeader.put(key, headers.get(key));
}
this.defaultHeader = Collections.unmodifiableMap(defaultHeader);
}
threadLocal = new ContextThreadLocal();
}
@Override
public void close() throws IOException {
threadLocal.close();
}
/**
* Removes the current context and resets a default context. The removed context can be
* restored when closing the returned {@link StoredContext}
*/
public StoredContext stashContext() {
final ThreadContextStruct context = threadLocal.get();
threadLocal.set(null);
return () -> threadLocal.set(context);
}
/**
* Removes the current context and resets a new context that contains a merge of the current headers and the given headers. The removed context can be
* restored when closing the returned {@link StoredContext}. The merge strategy is that headers that are already existing are preserved unless they are defaults.
*/
public StoredContext stashAndMergeHeaders(Map headers) {
final ThreadContextStruct context = threadLocal.get();
Map newHeader = new HashMap<>(headers);
newHeader.putAll(context.requestHeaders);
threadLocal.set(DEFAULT_CONTEXT.putHeaders(newHeader));
return () -> threadLocal.set(context);
}
/**
* Just like {@link #stashContext()} but no default context is set.
* @param preserveResponseHeaders if set to true
the response headers of the restore thread will be preserved.
*/
public StoredContext newStoredContext(boolean preserveResponseHeaders) {
final ThreadContextStruct context = threadLocal.get();
return () -> {
if (preserveResponseHeaders && threadLocal.get() != context) {
threadLocal.set(context.putResponseHeaders(threadLocal.get().responseHeaders));
} else {
threadLocal.set(context);
}
};
}
/**
* Returns a supplier that gathers a {@link #newStoredContext(boolean)} and restores it once the
* returned supplier is invoked. The context returned from the supplier is a stored version of the
* suppliers callers context that should be restored once the originally gathered context is not needed anymore.
* For instance this method should be used like this:
*
*
* Supplier<ThreadContext.StoredContext> restorable = context.newRestorableContext(true);
* new Thread() {
* public void run() {
* try (ThreadContext.StoredContext ctx = restorable.get()) {
* // execute with the parents context and restore the threads context afterwards
* }
* }
*
* }.start();
*
*
* @param preserveResponseHeaders if set to true
the response headers of the restore thread will be preserved.
* @return a restorable context supplier
*/
public Supplier newRestorableContext(boolean preserveResponseHeaders) {
return wrapRestorable(newStoredContext(preserveResponseHeaders));
}
/**
* Same as {@link #newRestorableContext(boolean)} but wraps an existing context to restore.
* @param storedContext the context to restore
*/
public Supplier wrapRestorable(StoredContext storedContext) {
return () -> {
StoredContext context = newStoredContext(false);
storedContext.restore();
return context;
};
}
@Override
public void writeTo(StreamOutput out) throws IOException {
threadLocal.get().writeTo(out, defaultHeader);
}
/**
* Reads the headers from the stream into the current context
*/
public void readHeaders(StreamInput in) throws IOException {
threadLocal.set(new ThreadContext.ThreadContextStruct(in));
}
/**
* Returns the header for the given key or null
if not present
*/
public String getHeader(String key) {
String value = threadLocal.get().requestHeaders.get(key);
if (value == null) {
return defaultHeader.get(key);
}
return value;
}
/**
* Returns all of the request contexts headers
*/
public Map getHeaders() {
HashMap map = new HashMap<>(defaultHeader);
map.putAll(threadLocal.get().requestHeaders);
return Collections.unmodifiableMap(map);
}
/**
* Get a copy of all response headers.
*
* @return Never {@code null}.
*/
public Map> getResponseHeaders() {
Map> responseHeaders = threadLocal.get().responseHeaders;
HashMap> map = new HashMap<>(responseHeaders.size());
for (Map.Entry> entry : responseHeaders.entrySet()) {
map.put(entry.getKey(), Collections.unmodifiableList(entry.getValue()));
}
return Collections.unmodifiableMap(map);
}
/**
* Copies all header key, value pairs into the current context
*/
public void copyHeaders(Iterable> headers) {
threadLocal.set(threadLocal.get().copyHeaders(headers));
}
/**
* Puts a header into the context
*/
public void putHeader(String key, String value) {
threadLocal.set(threadLocal.get().putRequest(key, value));
}
/**
* Puts all of the given headers into this context
*/
public void putHeader(Map header) {
threadLocal.set(threadLocal.get().putHeaders(header));
}
/**
* Puts a transient header object into this context
*/
public void putTransient(String key, Object value) {
threadLocal.set(threadLocal.get().putTransient(key, value));
}
/**
* Returns a transient header object or null
if there is no header for the given key
*/
@SuppressWarnings("unchecked") // (T)object
public T getTransient(String key) {
return (T) threadLocal.get().transientHeaders.get(key);
}
/**
* Add the unique response {@code value} for the specified {@code key}.
*
* Any duplicate {@code value} is ignored.
*/
public void addResponseHeader(String key, String value) {
threadLocal.set(threadLocal.get().putResponse(key, value));
}
/**
* Saves the current thread context and wraps command in a Runnable that restores that context before running command. If
* command
has already been passed through this method then it is returned unaltered rather than wrapped twice.
*/
public Runnable preserveContext(Runnable command) {
if (command instanceof ContextPreservingAbstractRunnable) {
return command;
}
if (command instanceof ContextPreservingRunnable) {
return command;
}
if (command instanceof AbstractRunnable) {
return new ContextPreservingAbstractRunnable((AbstractRunnable) command);
}
return new ContextPreservingRunnable(command);
}
/**
* Unwraps a command that was previously wrapped by {@link #preserveContext(Runnable)}.
*/
public Runnable unwrap(Runnable command) {
if (command instanceof ContextPreservingAbstractRunnable) {
return ((ContextPreservingAbstractRunnable) command).unwrap();
}
if (command instanceof ContextPreservingRunnable) {
return ((ContextPreservingRunnable) command).unwrap();
}
return command;
}
/**
* Returns true if the current context is the default context.
*/
boolean isDefaultContext() {
return threadLocal.get() == DEFAULT_CONTEXT;
}
/**
* Returns true
if the context is closed, otherwise true
*/
boolean isClosed() {
return threadLocal.closed.get();
}
@FunctionalInterface
public interface StoredContext extends AutoCloseable {
@Override
void close();
default void restore() {
close();
}
}
private static final class ThreadContextStruct {
private final Map requestHeaders;
private final Map transientHeaders;
private final Map> responseHeaders;
private ThreadContextStruct(StreamInput in) throws IOException {
final int numRequest = in.readVInt();
Map requestHeaders = numRequest == 0 ? Collections.emptyMap() : new HashMap<>(numRequest);
for (int i = 0; i < numRequest; i++) {
requestHeaders.put(in.readString(), in.readString());
}
this.requestHeaders = requestHeaders;
this.responseHeaders = in.readMapOfLists(StreamInput::readString, StreamInput::readString);
this.transientHeaders = Collections.emptyMap();
}
private ThreadContextStruct(Map requestHeaders,
Map> responseHeaders,
Map transientHeaders) {
this.requestHeaders = requestHeaders;
this.responseHeaders = responseHeaders;
this.transientHeaders = transientHeaders;
}
/**
* This represents the default context and it should only ever be called by {@link #DEFAULT_CONTEXT}.
*/
private ThreadContextStruct() {
this(Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap());
}
private ThreadContextStruct putRequest(String key, String value) {
Map newRequestHeaders = new HashMap<>(this.requestHeaders);
putSingleHeader(key, value, newRequestHeaders);
return new ThreadContextStruct(newRequestHeaders, responseHeaders, transientHeaders);
}
private void putSingleHeader(String key, String value, Map newHeaders) {
if (newHeaders.putIfAbsent(key, value) != null) {
throw new IllegalArgumentException("value for key [" + key + "] already present");
}
}
private ThreadContextStruct putHeaders(Map headers) {
if (headers.isEmpty()) {
return this;
} else {
final Map newHeaders = new HashMap<>();
for (Map.Entry entry : headers.entrySet()) {
putSingleHeader(entry.getKey(), entry.getValue(), newHeaders);
}
newHeaders.putAll(this.requestHeaders);
return new ThreadContextStruct(newHeaders, responseHeaders, transientHeaders);
}
}
private ThreadContextStruct putResponseHeaders(Map> headers) {
assert headers != null;
if (headers.isEmpty()) {
return this;
}
final Map> newResponseHeaders = new HashMap<>(this.responseHeaders);
for (Map.Entry> entry : headers.entrySet()) {
String key = entry.getKey();
final List existingValues = newResponseHeaders.get(key);
if (existingValues != null) {
List newValues = Stream.concat(entry.getValue().stream(),
existingValues.stream()).distinct().collect(Collectors.toList());
newResponseHeaders.put(key, Collections.unmodifiableList(newValues));
} else {
newResponseHeaders.put(key, entry.getValue());
}
}
return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders);
}
private ThreadContextStruct putResponse(String key, String value) {
assert value != null;
final Map> newResponseHeaders = new HashMap<>(this.responseHeaders);
final List existingValues = newResponseHeaders.get(key);
if (existingValues != null) {
if (existingValues.contains(value)) {
return this;
}
final List newValues = new ArrayList<>(existingValues);
newValues.add(value);
newResponseHeaders.put(key, Collections.unmodifiableList(newValues));
} else {
newResponseHeaders.put(key, Collections.singletonList(value));
}
return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders);
}
private ThreadContextStruct putTransient(String key, Object value) {
Map newTransient = new HashMap<>(this.transientHeaders);
if (newTransient.putIfAbsent(key, value) != null) {
throw new IllegalArgumentException("value for key [" + key + "] already present");
}
return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient);
}
boolean isEmpty() {
return requestHeaders.isEmpty() && responseHeaders.isEmpty() && transientHeaders.isEmpty();
}
private ThreadContextStruct copyHeaders(Iterable> headers) {
Map newHeaders = new HashMap<>();
for (Map.Entry header : headers) {
newHeaders.put(header.getKey(), header.getValue());
}
return putHeaders(newHeaders);
}
private void writeTo(StreamOutput out, Map defaultHeaders) throws IOException {
final Map requestHeaders;
if (defaultHeaders.isEmpty()) {
requestHeaders = this.requestHeaders;
} else {
requestHeaders = new HashMap<>(defaultHeaders);
requestHeaders.putAll(this.requestHeaders);
}
out.writeVInt(requestHeaders.size());
for (Map.Entry entry : requestHeaders.entrySet()) {
out.writeString(entry.getKey());
out.writeString(entry.getValue());
}
out.writeMapOfLists(responseHeaders, StreamOutput::writeString, StreamOutput::writeString);
}
}
private static class ContextThreadLocal extends CloseableThreadLocal {
private final AtomicBoolean closed = new AtomicBoolean(false);
@Override
public void set(ThreadContextStruct object) {
try {
if (object == DEFAULT_CONTEXT) {
super.set(null);
} else {
super.set(object);
}
} catch (NullPointerException ex) {
/* This is odd but CloseableThreadLocal throws a NPE if it was closed but still accessed.
to get a real exception we call ensureOpen() to tell the user we are already closed.*/
ensureOpen();
throw ex;
}
}
@Override
public ThreadContextStruct get() {
try {
ThreadContextStruct threadContextStruct = super.get();
if (threadContextStruct != null) {
return threadContextStruct;
}
return DEFAULT_CONTEXT;
} catch (NullPointerException ex) {
/* This is odd but CloseableThreadLocal throws a NPE if it was closed but still accessed.
to get a real exception we call ensureOpen() to tell the user we are already closed.*/
ensureOpen();
throw ex;
}
}
private void ensureOpen() {
if (closed.get()) {
throw new IllegalStateException("threadcontext is already closed");
}
}
@Override
public void close() {
if (closed.compareAndSet(false, true)) {
super.close();
}
}
}
/**
* Wraps a Runnable to preserve the thread context.
*/
private class ContextPreservingRunnable implements Runnable {
private final Runnable in;
private final ThreadContext.StoredContext ctx;
private ContextPreservingRunnable(Runnable in) {
ctx = newStoredContext(false);
this.in = in;
}
@Override
public void run() {
boolean whileRunning = false;
try (ThreadContext.StoredContext ignore = stashContext()){
ctx.restore();
whileRunning = true;
in.run();
whileRunning = false;
} catch (IllegalStateException ex) {
if (whileRunning || threadLocal.closed.get() == false) {
throw ex;
}
// if we hit an ISE here we have been shutting down
// this comes from the threadcontext and barfs if
// our threadpool has been shutting down
}
}
@Override
public String toString() {
return in.toString();
}
public Runnable unwrap() {
return in;
}
}
/**
* Wraps an AbstractRunnable to preserve the thread context.
*/
private class ContextPreservingAbstractRunnable extends AbstractRunnable {
private final AbstractRunnable in;
private final ThreadContext.StoredContext creatorsContext;
private ThreadContext.StoredContext threadsOriginalContext = null;
private ContextPreservingAbstractRunnable(AbstractRunnable in) {
creatorsContext = newStoredContext(false);
this.in = in;
}
@Override
public boolean isForceExecution() {
return in.isForceExecution();
}
@Override
public void onAfter() {
try {
in.onAfter();
} finally {
if (threadsOriginalContext != null) {
threadsOriginalContext.restore();
}
}
}
@Override
public void onFailure(Exception e) {
in.onFailure(e);
}
@Override
public void onRejection(Exception e) {
in.onRejection(e);
}
@Override
protected void doRun() throws Exception {
boolean whileRunning = false;
threadsOriginalContext = stashContext();
try {
creatorsContext.restore();
whileRunning = true;
in.doRun();
whileRunning = false;
} catch (IllegalStateException ex) {
if (whileRunning || threadLocal.closed.get() == false) {
throw ex;
}
// if we hit an ISE here we have been shutting down
// this comes from the threadcontext and barfs if
// our threadpool has been shutting down
}
}
@Override
public String toString() {
return in.toString();
}
public AbstractRunnable unwrap() {
return in;
}
}
}