All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.solr.common.util.ExecutorUtil Maven / Gradle / Ivy

There is a newer version: 9.6.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.common.util;

import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;


public class ExecutorUtil {
  private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

  private static volatile List providers = new ArrayList<>();

  public synchronized static void addThreadLocalProvider(InheritableThreadLocalProvider provider) {
    for (InheritableThreadLocalProvider p : providers) {//this is to avoid accidental multiple addition of providers in tests
      if (p.getClass().equals(provider.getClass())) return;
    }
    List copy = new ArrayList<>(providers);
    copy.add(provider);
    providers = copy;
  }

  /** Any class which wants to carry forward the threadlocal values to the threads run
   * by threadpools must implement this interface and the implementation should be
   * registered here
   */
  public interface InheritableThreadLocalProvider {
    /**This is invoked in the parent thread which submitted a task.
     * copy the necessary Objects to the ctx. The object that is passed is same
     * across all three methods
     */
    public void store(AtomicReference ctx);

    /**This is invoked in the Threadpool thread. set the appropriate values in the threadlocal
     * of this thread.     */
    public void set(AtomicReference ctx);

    /**This method is invoked in the threadpool thread after the execution
     * clean all the variables set in the set method
     */
    public void clean(AtomicReference ctx);
  }

  // ** This will interrupt the threads! ** Lucene and Solr do not like this because it can close channels, so only use
  // this if you know what you are doing - you probably want shutdownAndAwaitTermination.
  // Marked as Deprecated to discourage use.
  @Deprecated
  public static void shutdownWithInterruptAndAwaitTermination(ExecutorService pool) {
    pool.shutdownNow(); // Cancel currently executing tasks - NOTE: this interrupts!
    boolean shutdown = false;
    while (!shutdown) {
      try {
        // Wait a while for existing tasks to terminate
        shutdown = pool.awaitTermination(60, TimeUnit.SECONDS);
      } catch (InterruptedException ie) {
        // Preserve interrupt status
        Thread.currentThread().interrupt();
      }
    }
  }
  
  // ** This will interrupt the threads! ** Lucene and Solr do not like this because it can close channels, so only use
  // this if you know what you are doing - you probably want shutdownAndAwaitTermination.
  // Marked as Deprecated to discourage use.
  @Deprecated
  public static void shutdownAndAwaitTerminationWithInterrupt(ExecutorService pool) {
    pool.shutdown(); // Disable new tasks from being submitted
    boolean shutdown = false;
    boolean interrupted = false;
    while (!shutdown) {
      try {
        // Wait a while for existing tasks to terminate
        shutdown = pool.awaitTermination(60, TimeUnit.SECONDS);
      } catch (InterruptedException ie) {
        // Preserve interrupt status
        Thread.currentThread().interrupt();
      }
      if (!shutdown && !interrupted) {
        pool.shutdownNow(); // Cancel currently executing tasks - NOTE: this interrupts!
        interrupted = true;
      }
    }
  }

  public static void shutdownAndAwaitTermination(ExecutorService pool) {
    pool.shutdown(); // Disable new tasks from being submitted
    boolean shutdown = false;
    while (!shutdown) {
      try {
        // Wait a while for existing tasks to terminate
        shutdown = pool.awaitTermination(60, TimeUnit.SECONDS);
      } catch (InterruptedException ie) {
        // Preserve interrupt status
        Thread.currentThread().interrupt();
      }
    }
  }

  /**
   * See {@link java.util.concurrent.Executors#newFixedThreadPool(int, ThreadFactory)}
   */
  public static ExecutorService newMDCAwareFixedThreadPool(int nThreads, ThreadFactory threadFactory) {
    return new MDCAwareThreadPoolExecutor(nThreads, nThreads,
        0L, TimeUnit.MILLISECONDS,
        new LinkedBlockingQueue(),
        threadFactory);
  }

  /**
   * See {@link java.util.concurrent.Executors#newSingleThreadExecutor(ThreadFactory)}
   */
  public static ExecutorService newMDCAwareSingleThreadExecutor(ThreadFactory threadFactory) {
    return new MDCAwareThreadPoolExecutor(1, 1,
            0L, TimeUnit.MILLISECONDS,
            new LinkedBlockingQueue(),
            threadFactory);
  }

  /**
   * Create a cached thread pool using a named thread factory
   */
  public static ExecutorService newMDCAwareCachedThreadPool(String name) {
    return newMDCAwareCachedThreadPool(new SolrjNamedThreadFactory(name));
  }

  /**
   * See {@link java.util.concurrent.Executors#newCachedThreadPool(ThreadFactory)}
   */
  public static ExecutorService newMDCAwareCachedThreadPool(ThreadFactory threadFactory) {
    return new MDCAwareThreadPoolExecutor(0, Integer.MAX_VALUE,
        60L, TimeUnit.SECONDS,
        new SynchronousQueue(),
        threadFactory);
  }

  @SuppressForbidden(reason = "class customizes ThreadPoolExecutor so it can be used instead")
  public static class MDCAwareThreadPoolExecutor extends ThreadPoolExecutor {

    private static final int MAX_THREAD_NAME_LEN = 512;

    public MDCAwareThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue workQueue, ThreadFactory threadFactory, RejectedExecutionHandler handler) {
      super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler);
    }

    public MDCAwareThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue workQueue) {
      super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
    }

    public MDCAwareThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue workQueue, ThreadFactory threadFactory) {
      super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory);
    }

    public MDCAwareThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue workQueue, RejectedExecutionHandler handler) {
      super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, handler);
    }

    @Override
    public void execute(final Runnable command) {
      final Map submitterContext = MDC.getCopyOfContextMap();
      StringBuilder contextString = new StringBuilder();
      if (submitterContext != null) {
        Collection values = submitterContext.values();

        for (String value : values) {
          contextString.append(value + " ");
        }
        if (contextString.length() > 1) {
          contextString.setLength(contextString.length() - 1);
        }
      }

      String ctxStr = contextString.toString().replace("/", "//");
      final String submitterContextStr = ctxStr.length() <= MAX_THREAD_NAME_LEN ? ctxStr : ctxStr.substring(0, MAX_THREAD_NAME_LEN);
      final Exception submitterStackTrace = new Exception("Submitter stack trace");
      final List providersCopy = providers;
      final ArrayList ctx = providersCopy.isEmpty() ? null : new ArrayList<>(providersCopy.size());
      if (ctx != null) {
        for (int i = 0; i < providers.size(); i++) {
          AtomicReference reference = new AtomicReference();
          ctx.add(reference);
          providersCopy.get(i).store(reference);
        }
      }
      super.execute(() -> {
        isServerPool.set(Boolean.TRUE);
        if (ctx != null) {
          for (int i = 0; i < providersCopy.size(); i++) providersCopy.get(i).set(ctx.get(i));
        }
        Map threadContext = MDC.getCopyOfContextMap();
        final Thread currentThread = Thread.currentThread();
        final String oldName = currentThread.getName();
        if (submitterContext != null && !submitterContext.isEmpty()) {
          MDC.setContextMap(submitterContext);
          currentThread.setName(oldName + "-processing-" + submitterContextStr);
        } else {
          MDC.clear();
        }
        try {
          command.run();
        } catch (Throwable t) {
          if (t instanceof OutOfMemoryError) {
            throw t;
          }
          log.error("Uncaught exception {} thrown by thread: {}", t, currentThread.getName(), submitterStackTrace);
          throw t;
        } finally {
          isServerPool.remove();
          if (threadContext != null && !threadContext.isEmpty()) {
            MDC.setContextMap(threadContext);
          } else {
            MDC.clear();
          }
          if (ctx != null) {
            for (int i = 0; i < providersCopy.size(); i++) providersCopy.get(i).clean(ctx.get(i));
          }
          currentThread.setName(oldName);
        }
      });
    }
  }

  private static final ThreadLocal isServerPool = new ThreadLocal<>();

  /// this tells whether a thread is owned/run by solr or not.
  public static boolean isSolrServerThread() {
    return Boolean.TRUE.equals(isServerPool.get());
  }

  public static void setServerThreadFlag(Boolean flag) {
    if (flag == null) isServerPool.remove();
    else isServerPool.set(flag);

  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy