All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.testng.internal.thread.graph.GraphThreadPoolExecutor Maven / Gradle / Ivy

package org.testng.internal.thread.graph;

import org.testng.TestNGException;
import org.testng.collections.Maps;
import org.testng.internal.DynamicGraph;
import org.testng.internal.DynamicGraph.Status;
import org.testng.internal.RuntimeBehavior;
import org.testng.internal.thread.TestNGThreadFactory;
import org.testng.log4testng.Logger;

import javax.annotation.Nonnull;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * An Executor that launches tasks per batches. It takes a {@code DynamicGraph} of tasks to be run
 * and a {@code IThreadWorkerFactory} to initialize/create {@code Runnable} wrappers around those
 * tasks
 */
public class GraphThreadPoolExecutor extends ThreadPoolExecutor {

  private final DynamicGraph m_graph;
  private final IThreadWorkerFactory m_factory;
  private final Map> mapping = Maps.newConcurrentMap();
  private final Map upstream = Maps.newConcurrentMap();

  public GraphThreadPoolExecutor(
      String name,
      DynamicGraph graph,
      IThreadWorkerFactory factory,
      int corePoolSize,
      int maximumPoolSize,
      long keepAliveTime,
      TimeUnit unit,
      BlockingQueue workQueue) {
    super(
        corePoolSize,
        maximumPoolSize,
        keepAliveTime,
        unit,
        workQueue,
        new TestNGThreadFactory(name));
    m_graph = graph;
    m_factory = factory;

    if (m_graph.getFreeNodes().isEmpty()) {
      throw new TestNGException("The graph of methods contains a cycle:" + graph);
    }
  }

  public void run() {
    synchronized (m_graph) {
      List freeNodes = m_graph.getFreeNodes();
      runNodes(freeNodes);
    }
  }

  /** Create one worker per node and execute them. */
  private void runNodes(List freeNodes) {
    List> workers = m_factory.createWorkers(freeNodes);
    mapNodeToWorker(workers, freeNodes);

    for (int ix = 0; ix < workers.size(); ix++) {
      IWorker worker = workers.get(ix);
      mapNodeToParent(freeNodes, ix);
      setStatus(worker, Status.RUNNING);
      try {
        execute(worker);
      } catch (Exception ex) {
        Logger.getLogger(GraphThreadPoolExecutor.class).error(ex.getMessage(), ex);
      }
    }
  }

  @Override
  @SuppressWarnings("unchecked")
  public void afterExecute(Runnable r, Throwable t) {
    synchronized (m_graph) {
      setStatus((IWorker) r, computeStatus(r));
      if (m_graph.getNodeCount() == m_graph.getNodeCountWithStatus(Status.FINISHED)) {
        shutdown();
      } else {
        List freeNodes = m_graph.getFreeNodes();
        handleThreadAffinity(freeNodes);
        runNodes(freeNodes);
      }
    }
  }

  private void setStatus(IWorker worker, Status status) {
    synchronized (m_graph) {
      for (T m : worker.getTasks()) {
        m_graph.setStatus(m, status);
      }
    }
  }

  @SuppressWarnings("unchecked")
  private Status computeStatus(Runnable r) {
    IWorker worker = (IWorker) r;
    Status status = Status.FINISHED;
    if (RuntimeBehavior.enforceThreadAffinity() && !worker.completed()) {
      status = Status.READY;
    }
    return status;
  }

  private void mapNodeToWorker(List> runnables, List freeNodes) {
    if (!RuntimeBehavior.enforceThreadAffinity()) {
      return;
    }
    int index = 0;
    for (IWorker runnable : runnables) {
      T freeNode = freeNodes.get(index++);
      IWorker w = mapping.get(freeNode);
      if (w != null) {
        long current = w.getThreadIdToRunOn();
        runnable.setThreadIdToRunOn(current);
      }
      mapping.put(freeNode, runnable);
    }
  }

  private void mapNodeToParent(List freeNodes, int ix) {
    if (!RuntimeBehavior.enforceThreadAffinity()) {
      return;
    }
    T key = freeNodes.get(ix);
    List nodes = m_graph.getDependenciesFor(key);
    nodes.forEach(eachNode -> upstream.put(eachNode, key));
  }

  private void handleThreadAffinity(List freeNodes) {
    if (!RuntimeBehavior.enforceThreadAffinity()) {
      return;
    }
    for (T node : freeNodes) {
      IWorker w = mapping.get(upstream.get(node));
      long threadId = w.getCurrentThreadId();
      mapping.put(node, new PhoneyWorker(threadId));
    }
  }

  private class PhoneyWorker implements IWorker {
    private long threadId;

    public PhoneyWorker(long threadId) {
      this.threadId = threadId;
    }

    @Override
    public List getTasks() {
      return null;
    }

    @Override
    public long getTimeOut() {
      return 0;
    }

    @Override
    public int getPriority() {
      return 0;
    }

    @Override
    public int compareTo(@Nonnull IWorker o) {
      return 0;
    }

    @Override
    public void run() {}

    @Override
    public long getThreadIdToRunOn() {
      return threadId;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy