All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.clickzetta.platform.test.CZTestBase Maven / Gradle / Ivy

There is a newer version: 2.0.0
Show newest version
package com.clickzetta.platform.test;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import cz.proto.ingestion.Ingestion;
import io.grpc.BindableService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

@VisibleForTesting
public abstract class CZTestBase implements Serializable {

  protected Logger LOG = LoggerFactory.getLogger(this.getClass());

  private final AtomicReference serverException = new AtomicReference<>();

  protected enum ServerName {
    ROUTER,
    CONTROLLER,
    WORKER,
  }

  // router.
  private final List routerServers = new ArrayList<>();
  private final List routerThreads = new ArrayList<>();

  // controller.
  private final List controllerServers = new ArrayList<>();
  private final List controllerThreads = new ArrayList<>();

  // workers.
  private final List workerServers = new ArrayList<>();
  private final List workerThreads = new ArrayList<>();

  public abstract Ingestion.HostPortTuple getRouterTuple();

  public abstract Ingestion.HostPortTuple getControllerTuple();

  public abstract List getWorkerTuples();

  public abstract List getRouterServices();

  public abstract List getControllerServices();

  public abstract List getWorkerServices();

  public abstract boolean routerModeEnable();

  public final String getHostPortTuple = "getHostPortTuple";
  public final String getBindableService = "getBindableService";

  /**
   * cache module with all hostPort tuple & services.
   */
  private final Map cacheMap = new ConcurrentHashMap<>();

  private Object cacheComputeIfAbsent(ServerName serverName, String suffix, Function applyFunc) {
    String cacheKey = serverName + "_" + suffix;
    return cacheMap.computeIfAbsent(cacheKey, applyFunc);
  }

  public void removeTargetInCache(ServerName serverName, String suffix) {
    String cacheKey = serverName + "_" + suffix;
    cacheMap.remove(cacheKey);
  }

  protected  T getHostPortTupleWithCache(ServerName serverName) {
    Function applyFunc;
    switch (serverName) {
      case ROUTER:
        applyFunc = s -> getRouterTuple();
        break;
      case CONTROLLER:
        applyFunc = s -> getControllerTuple();
        break;
      case WORKER:
        applyFunc = s -> getWorkerTuples();
        break;
      default:
        throw new UnsupportedOperationException("not support get hostPort tuple with serverName: " + serverName);
    }
    return (T) cacheComputeIfAbsent(serverName, getHostPortTuple, applyFunc);
  }

  public List getBindableServiceWithCache(ServerName serverName) {
    Function applyFunc;
    switch (serverName) {
      case ROUTER:
        applyFunc = s -> getRouterServices();
        break;
      case CONTROLLER:
        applyFunc = s -> getControllerServices();
        break;
      case WORKER:
        applyFunc = s -> getWorkerServices();
        break;
      default:
        throw new UnsupportedOperationException("not support get bindable service with serverName: " + serverName);
    }
    return (List) cacheComputeIfAbsent(serverName, getBindableService, applyFunc);
  }

  public synchronized int safeStartupService(ServerName serverName) throws IOException {
    validException();
    int index = 0;
    switch (serverName) {
      case ROUTER:
        startMockServer(serverName, routerServers, routerThreads);
        index = routerThreads.size() - 1;
        break;
      case CONTROLLER:
        startMockServer(serverName, controllerServers, controllerThreads);
        index = controllerThreads.size() - 1;
        break;
      case WORKER:
        startMockServer(serverName, workerServers, workerThreads);
        index = workerThreads.size() - 1;
        break;
      default:
        throw new UnsupportedOperationException("not support safeStartupService with serverName: " + serverName);
    }
    validException();
    return index;
  }

  public synchronized void safeShutdownService(ServerName serverName, int index) throws IOException {
    List servers;
    List threads;
    switch (serverName) {
      case ROUTER:
        Preconditions.checkArgument(index < routerThreads.size());
        servers = routerServers;
        threads = routerThreads;
        break;
      case CONTROLLER:
        Preconditions.checkArgument(index < controllerThreads.size());
        servers = controllerServers;
        threads = controllerThreads;
        break;
      case WORKER:
        Preconditions.checkArgument(index < workerThreads.size());
        servers = workerServers;
        threads = workerThreads;
        break;
      default:
        throw new UnsupportedOperationException("not support safeShutdownService with serverName: " + serverName);
    }
    MockServer mockServer = servers.remove(index);
    Thread thread = threads.remove(index);
    stopMockServer(serverName,
        new ArrayList() {{
          add(mockServer);
        }},
        new ArrayList() {{
          add(thread);
        }});
  }

  private void validException() throws IOException {
    if (serverException.get() != null) {
      throw serverException.get();
    }
  }

  public synchronized void initCluster() throws IOException {
    Preconditions.checkArgument(getBindableServiceWithCache(ServerName.CONTROLLER).size() > 0);
    Preconditions.checkArgument(((List) getHostPortTupleWithCache(ServerName.WORKER)).size() > 0);
    // first start router if needed.
    if (routerModeEnable()) {
      validException();
      startMockServer(ServerName.ROUTER, routerServers, routerThreads);
    }
    // second start controller.
    {
      validException();
      startMockServer(ServerName.CONTROLLER, controllerServers, controllerThreads);
    }
    // final start worker.
    {
      validException();
      startMockServer(ServerName.WORKER, workerServers, workerThreads);
    }
    validException();
    LOG.info("start cz cluster success.");
  }

  public synchronized void destroyCluster() throws IOException {
    // stop worker first.
    stopMockServer(ServerName.WORKER, workerServers, workerThreads);
    // stop controller.
    stopMockServer(ServerName.CONTROLLER, controllerServers, controllerThreads);
    // stop router.
    stopMockServer(ServerName.ROUTER, routerServers, routerThreads);
    validException();
    LOG.info("stop cz cluster success.");
  }

  private synchronized void startMockServer(ServerName serverName,
                                            List collectServers,
                                            List collectThreads) throws IOException {
    Object object = getHostPortTupleWithCache(serverName);
    List bindableServices = getBindableServiceWithCache(serverName);
    List hostPortTuples = new ArrayList<>();
    if (object instanceof Ingestion.HostPortTuple) {
      hostPortTuples.add((Ingestion.HostPortTuple) object);
    } else {
      hostPortTuples.addAll((List) object);
    }
    for (int index = 0; index < hostPortTuples.size(); index++) {
      Ingestion.HostPortTuple hostPortTuple = hostPortTuples.get(index);
      Thread thread = new Thread(() -> {
        try {
          MockServer mockServer = new MockServer(hostPortTuple.getPort(), bindableServices);
          collectServers.add(mockServer);
          mockServer.start();
        } catch (IOException e) {
          serverException.compareAndSet(null, e);
          LOG.error("start {} thread {} failed.", serverName, ClassUtil.getListClassName(getControllerServices()));
        }
      });
      thread.setName(String.format("%s-thread-%s", serverName, index));
      thread.start();
      collectThreads.add(thread);
      sleep(1 * 1000);
    }
  }

  private synchronized void stopMockServer(ServerName serverName,
                                           List collectServers,
                                           List collectThreads) throws IOException {
    for (int i = collectServers.size() - 1; i >= 0; i--) {
      MockServer mockServer = collectServers.get(i);
      try {
        mockServer.stop();
      } catch (IOException ioe) {
        serverException.compareAndSet(null, ioe);
        LOG.error("stop {} server {} failed.", serverName, ClassUtil.getListClassName(mockServer.getServiceBases()));
      }
    }
    for (int i = collectThreads.size() - 1; i >= 0; i--) {
      Thread thread = collectThreads.get(i);
      if (thread != null) {
        thread.stop();
      }
    }
  }

  private void sleep(long millis) {
    try {
      Thread.sleep(millis);
    } catch (InterruptedException ite) {
      throw new RuntimeException(ite);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy