All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.grpc.testing.integration.XdsTestClient Maven / Gradle / Ivy

There is a newer version: 1.68.0
Show newest version
/*
 * Copyright 2020 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.testing.integration;

import com.google.common.base.CaseFormat;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableScheduledFuture;
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.InsecureServerCredentials;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.protobuf.services.ProtoReflectionService;
import io.grpc.services.AdminInterface;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.integration.Messages.ClientConfigureRequest;
import io.grpc.testing.integration.Messages.ClientConfigureRequest.RpcType;
import io.grpc.testing.integration.Messages.ClientConfigureResponse;
import io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsRequest;
import io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsResponse;
import io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsResponse.MethodStats;
import io.grpc.testing.integration.Messages.LoadBalancerStatsRequest;
import io.grpc.testing.integration.Messages.LoadBalancerStatsResponse;
import io.grpc.testing.integration.Messages.SimpleRequest;
import io.grpc.testing.integration.Messages.SimpleResponse;
import io.grpc.xds.XdsChannelCredentials;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
import javax.annotation.concurrent.ThreadSafe;

/** Client for xDS interop tests. */
public final class XdsTestClient {
  private static Logger logger = Logger.getLogger(XdsTestClient.class.getName());

  private final Set watchers = new HashSet<>();
  private final Object lock = new Object();
  private final List channels = new ArrayList<>();
  private final StatsAccumulator statsAccumulator = new StatsAccumulator();

  private int numChannels = 1;
  private boolean printResponse = false;
  private int qps = 1;
  private volatile List rpcConfigs;
  private int rpcTimeoutSec = 20;
  private boolean secureMode = false;
  private String server = "localhost:8080";
  private int statsPort = 8081;
  private Server statsServer;
  private long currentRequestId;
  private ListeningScheduledExecutorService exec;

  /**
   * The main application allowing this client to be launched from the command line.
   */
  public static void main(String[] args) {
    final XdsTestClient client = new XdsTestClient();
    client.parseArgs(args);
    Runtime.getRuntime()
        .addShutdownHook(
            new Thread() {
              @Override
              @SuppressWarnings("CatchAndPrintStackTrace")
              public void run() {
                try {
                  client.stop();
                } catch (Exception e) {
                  e.printStackTrace();
                }
              }
            });
    client.run();
  }

  private void parseArgs(String[] args) {
    boolean usage = false;
    List rpcTypes = ImmutableList.of(RpcType.UNARY_CALL);
    EnumMap metadata = new EnumMap<>(RpcType.class);
    for (String arg : args) {
      if (!arg.startsWith("--")) {
        System.err.println("All arguments must start with '--': " + arg);
        usage = true;
        break;
      }
      String[] parts = arg.substring(2).split("=", 2);
      String key = parts[0];
      if ("help".equals(key)) {
        usage = true;
        break;
      }
      if (parts.length != 2) {
        System.err.println("All arguments must be of the form --arg=value");
        usage = true;
        break;
      }
      String value = parts[1];
      if ("metadata".equals(key)) {
        metadata = parseMetadata(value);
      } else if ("num_channels".equals(key)) {
        numChannels = Integer.valueOf(value);
      } else if ("print_response".equals(key)) {
        printResponse = Boolean.valueOf(value);
      } else if ("qps".equals(key)) {
        qps = Integer.valueOf(value);
      } else if ("rpc".equals(key)) {
        rpcTypes = parseRpcs(value);
      } else if ("rpc_timeout_sec".equals(key)) {
        rpcTimeoutSec = Integer.valueOf(value);
      } else if ("server".equals(key)) {
        server = value;
      } else if ("stats_port".equals(key)) {
        statsPort = Integer.valueOf(value);
      } else if ("secure_mode".equals(key)) {
        secureMode = Boolean.valueOf(value);
      } else {
        System.err.println("Unknown argument: " + key);
        usage = true;
        break;
      }
    }
    List configs = new ArrayList<>();
    for (RpcType type : rpcTypes) {
      Metadata md = new Metadata();
      if (metadata.containsKey(type)) {
        md = metadata.get(type);
      }
      configs.add(new RpcConfig(type, md, rpcTimeoutSec));
    }
    rpcConfigs = Collections.unmodifiableList(configs);

    if (usage) {
      XdsTestClient c = new XdsTestClient();
      System.err.println(
          "Usage: [ARGS...]"
              + "\n"
              + "\n  --num_channels=INT     Default: "
              + c.numChannels
              + "\n  --print_response=BOOL  Write RPC response to stdout. Default: "
              + c.printResponse
              + "\n  --qps=INT              Qps per channel, for each type of RPC. Default: "
              + c.qps
              + "\n  --rpc=STR              Types of RPCs to make, ',' separated string. RPCs can "
              + "be EmptyCall or UnaryCall. Default: UnaryCall"
              + "\n[deprecated] Use XdsUpdateClientConfigureService"
              + "\n  --metadata=STR         The metadata to send with each RPC, in the format "
              + "EmptyCall:key1:value1,UnaryCall:key2:value2."
              + "\n[deprecated] Use XdsUpdateClientConfigureService"
              + "\n  --rpc_timeout_sec=INT  Per RPC timeout seconds. Default: "
              + c.rpcTimeoutSec
              + "\n  --server=host:port     Address of server. Default: "
              + c.server
              + "\n  --secure_mode=BOOLEAN  Use true to enable XdsCredentials. Default: "
              + c.secureMode
              + "\n  --stats_port=INT       Port to expose peer distribution stats service. "
              + "Default: "
              + c.statsPort);
      System.exit(1);
    }
  }

  private static List parseRpcs(String rpcArg) {
    List rpcs = new ArrayList<>();
    for (String rpc : Splitter.on(',').split(rpcArg)) {
      rpcs.add(parseRpc(rpc));
    }
    return rpcs;
  }

  private static EnumMap parseMetadata(String metadataArg) {
    EnumMap rpcMetadata = new EnumMap<>(RpcType.class);
    for (String metadata : Splitter.on(',').omitEmptyStrings().split(metadataArg)) {
      List parts = Splitter.on(':').splitToList(metadata);
      if (parts.size() != 3) {
        throw new IllegalArgumentException("Invalid metadata: '" + metadata + "'");
      }
      RpcType rpc = parseRpc(parts.get(0));
      String key = parts.get(1);
      String value = parts.get(2);
      Metadata md = new Metadata();
      md.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value);
      if (rpcMetadata.containsKey(rpc)) {
        rpcMetadata.get(rpc).merge(md);
      } else {
        rpcMetadata.put(rpc, md);
      }
    }
    return rpcMetadata;
  }

  private static RpcType parseRpc(String rpc) {
    if ("EmptyCall".equals(rpc)) {
      return RpcType.EMPTY_CALL;
    } else if ("UnaryCall".equals(rpc)) {
      return RpcType.UNARY_CALL;
    } else {
      throw new IllegalArgumentException("Unknown RPC: '" + rpc + "'");
    }
  }

  private void run() {
    statsServer =
        Grpc.newServerBuilderForPort(statsPort, InsecureServerCredentials.create())
            .addService(new XdsStatsImpl())
            .addService(new ConfigureUpdateServiceImpl())
            .addService(ProtoReflectionService.newInstance())
            .addServices(AdminInterface.getStandardServices())
            .build();
    try {
      statsServer.start();
      for (int i = 0; i < numChannels; i++) {
        channels.add(
            Grpc.newChannelBuilder(
                    server,
                    secureMode
                        ? XdsChannelCredentials.create(InsecureChannelCredentials.create())
                        : InsecureChannelCredentials.create())
                .enableRetry()
                .build());
      }
      exec = MoreExecutors.listeningDecorator(Executors.newSingleThreadScheduledExecutor());
      runQps();
    } catch (Throwable t) {
      logger.log(Level.SEVERE, "Error running client", t);
      System.exit(1);
    }
  }

  private void stop() throws InterruptedException {
    if (statsServer != null) {
      statsServer.shutdownNow();
      if (!statsServer.awaitTermination(5, TimeUnit.SECONDS)) {
        System.err.println("Timed out waiting for server shutdown");
      }
    }
    for (ManagedChannel channel : channels) {
      channel.shutdownNow();
    }
    if (exec != null) {
      exec.shutdownNow();
    }
  }


  private void runQps() throws InterruptedException, ExecutionException {
    final SettableFuture failure = SettableFuture.create();
    final class PeriodicRpc implements Runnable {

      @Override
      public void run() {
        List configs = rpcConfigs;
        for (RpcConfig cfg : configs) {
          makeRpc(cfg);
        }
      }

      private void makeRpc(final RpcConfig config) {
        final long requestId;
        final Set savedWatchers = new HashSet<>();
        synchronized (lock) {
          currentRequestId += 1;
          requestId = currentRequestId;
          savedWatchers.addAll(watchers);
        }

        ManagedChannel channel = channels.get((int) (requestId % channels.size()));
        TestServiceGrpc.TestServiceStub stub = TestServiceGrpc.newStub(channel);
        final AtomicReference> clientCallRef = new AtomicReference<>();
        final AtomicReference hostnameRef = new AtomicReference<>();
        stub =
            stub.withDeadlineAfter(config.timeoutSec, TimeUnit.SECONDS)
                .withInterceptors(
                    new ClientInterceptor() {
                      @Override
                      public  ClientCall interceptCall(
                          MethodDescriptor method,
                          CallOptions callOptions,
                          Channel next) {
                        ClientCall call = next.newCall(method, callOptions);
                        clientCallRef.set(call);
                        return new SimpleForwardingClientCall(call) {
                          @Override
                          public void start(Listener responseListener, Metadata headers) {
                            headers.merge(config.metadata);
                            super.start(
                                new SimpleForwardingClientCallListener(responseListener) {
                                  @Override
                                  public void onHeaders(Metadata headers) {
                                    hostnameRef.set(headers.get(XdsTestServer.HOSTNAME_KEY));
                                    super.onHeaders(headers);
                                  }
                                },
                                headers);
                          }
                        };
                      }
                    });

        if (config.rpcType == RpcType.EMPTY_CALL) {
          stub.emptyCall(
              EmptyProtos.Empty.getDefaultInstance(),
              new StreamObserver() {
                @Override
                public void onCompleted() {
                  handleRpcCompleted(requestId, config.rpcType, hostnameRef.get(), savedWatchers);
                }

                @Override
                public void onError(Throwable t) {
                  handleRpcError(requestId, config.rpcType, Status.fromThrowable(t),
                      savedWatchers);
                }

                @Override
                public void onNext(EmptyProtos.Empty response) {}
              });
        } else if (config.rpcType == RpcType.UNARY_CALL) {
          SimpleRequest request = SimpleRequest.newBuilder().setFillServerId(true).build();
          stub.unaryCall(
              request,
              new StreamObserver() {
                @Override
                public void onCompleted() {
                  handleRpcCompleted(requestId, config.rpcType, hostnameRef.get(), savedWatchers);
                }

                @Override
                public void onError(Throwable t) {
                  if (printResponse) {
                    logger.log(Level.WARNING, "Rpc failed", t);
                  }
                  handleRpcError(requestId, config.rpcType, Status.fromThrowable(t),
                      savedWatchers);
                }

                @Override
                public void onNext(SimpleResponse response) {
                  // TODO(ericgribkoff) Currently some test environments cannot access the stats RPC
                  // service and rely on parsing stdout.
                  if (printResponse) {
                    System.out.println(
                        "Greeting: Hello world, this is "
                            + response.getHostname()
                            + ", from "
                            + clientCallRef
                                .get()
                                .getAttributes()
                                .get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
                  }
                  // Use the hostname from the response if not present in the metadata.
                  // TODO(ericgribkoff) Delete when server is deployed that sets metadata value.
                  if (hostnameRef.get() == null) {
                    hostnameRef.set(response.getHostname());
                  }
                }
              });
        } else {
          throw new AssertionError("Unknown RPC type: " + config.rpcType);
        }
        statsAccumulator.recordRpcStarted(config.rpcType);
      }

      private void handleRpcCompleted(long requestId, RpcType rpcType, String hostname,
          Set watchers) {
        statsAccumulator.recordRpcFinished(rpcType, Status.OK);
        notifyWatchers(watchers, rpcType, requestId, hostname);
      }

      private void handleRpcError(long requestId, RpcType rpcType, Status status,
          Set watchers) {
        statsAccumulator.recordRpcFinished(rpcType, status);
        notifyWatchers(watchers, rpcType, requestId, null);
      }
    }

    long nanosPerQuery = TimeUnit.SECONDS.toNanos(1) / qps;
    ListenableScheduledFuture future =
        exec.scheduleAtFixedRate(new PeriodicRpc(), 0, nanosPerQuery, TimeUnit.NANOSECONDS);
    Futures.addCallback(
        future,
        new FutureCallback() {

          @Override
          public void onFailure(Throwable t) {
            failure.setException(t);
          }

          @Override
          public void onSuccess(Object o) {}
        },
        MoreExecutors.directExecutor());

    failure.get();
  }

  private void notifyWatchers(
      Set watchers, RpcType rpcType, long requestId, String hostname) {
    for (XdsStatsWatcher watcher : watchers) {
      watcher.rpcCompleted(rpcType, requestId, hostname);
    }
  }

  private final class ConfigureUpdateServiceImpl extends
      XdsUpdateClientConfigureServiceGrpc.XdsUpdateClientConfigureServiceImplBase {
    @Override
    public void configure(ClientConfigureRequest request,
        StreamObserver responseObserver) {
      EnumMap newMetadata = new EnumMap<>(RpcType.class);
      for (ClientConfigureRequest.Metadata metadata : request.getMetadataList()) {
        Metadata md = newMetadata.get(metadata.getType());
        if (md == null) {
          md = new Metadata();
        }
        md.put(Metadata.Key.of(metadata.getKey(), Metadata.ASCII_STRING_MARSHALLER),
            metadata.getValue());
        newMetadata.put(metadata.getType(), md);
      }
      List configs = new ArrayList<>();
      for (RpcType type : request.getTypesList()) {
        Metadata md = newMetadata.containsKey(type) ? newMetadata.get(type) : new Metadata();
        int timeout = request.getTimeoutSec() != 0 ? request.getTimeoutSec() : rpcTimeoutSec;
        configs.add(new RpcConfig(type, md, timeout));
      }
      rpcConfigs = Collections.unmodifiableList(configs);
      responseObserver.onNext(ClientConfigureResponse.getDefaultInstance());
      responseObserver.onCompleted();
    }
  }

  private class XdsStatsImpl extends LoadBalancerStatsServiceGrpc.LoadBalancerStatsServiceImplBase {
    @Override
    public void getClientStats(
        LoadBalancerStatsRequest req, StreamObserver responseObserver) {
      XdsStatsWatcher watcher;
      synchronized (lock) {
        long startId = currentRequestId + 1;
        long endId = startId + req.getNumRpcs();
        watcher = new XdsStatsWatcher(startId, endId);
        watchers.add(watcher);
      }
      LoadBalancerStatsResponse response = watcher.waitForRpcStats(req.getTimeoutSec());
      synchronized (lock) {
        watchers.remove(watcher);
      }
      responseObserver.onNext(response);
      responseObserver.onCompleted();
    }

    @Override
    public void getClientAccumulatedStats(LoadBalancerAccumulatedStatsRequest request,
        StreamObserver responseObserver) {
      responseObserver.onNext(statsAccumulator.getRpcStats());
      responseObserver.onCompleted();
    }
  }

  /** Configuration applies to the specific type of RPCs. */
  private static final class RpcConfig {
    private final RpcType rpcType;
    private final Metadata metadata;
    private final int timeoutSec;

    private RpcConfig(RpcType rpcType, Metadata metadata, int timeoutSec) {
      this.rpcType = rpcType;
      this.metadata = metadata;
      this.timeoutSec = timeoutSec;
    }
  }

  /** Stats recorder for test RPCs. */
  @ThreadSafe
  private static final class StatsAccumulator {
    private final Map rpcsStartedByMethod = new HashMap<>();
    // TODO(chengyuanzhang): delete the following two after corresponding fields deleted in proto.
    private final Map rpcsFailedByMethod = new HashMap<>();
    private final Map rpcsSucceededByMethod = new HashMap<>();
    private final Map> rpcStatusByMethod = new HashMap<>();

    private synchronized void recordRpcStarted(RpcType rpcType) {
      String method = getRpcTypeString(rpcType);
      int count = rpcsStartedByMethod.containsKey(method) ? rpcsStartedByMethod.get(method) : 0;
      rpcsStartedByMethod.put(method, count + 1);
    }

    private synchronized void recordRpcFinished(RpcType rpcType, Status status) {
      String method = getRpcTypeString(rpcType);
      if (status.isOk()) {
        int count =
            rpcsSucceededByMethod.containsKey(method) ? rpcsSucceededByMethod.get(method) : 0;
        rpcsSucceededByMethod.put(method, count + 1);
      } else {
        int count = rpcsFailedByMethod.containsKey(method) ? rpcsFailedByMethod.get(method) : 0;
        rpcsFailedByMethod.put(method, count + 1);
      }
      int statusCode = status.getCode().value();
      Map statusCounts = rpcStatusByMethod.get(method);
      if (statusCounts == null) {
        statusCounts = new HashMap<>();
        rpcStatusByMethod.put(method, statusCounts);
      }
      int count = statusCounts.containsKey(statusCode) ? statusCounts.get(statusCode) : 0;
      statusCounts.put(statusCode, count + 1);
    }

    @SuppressWarnings("deprecation")
    private synchronized LoadBalancerAccumulatedStatsResponse getRpcStats() {
      LoadBalancerAccumulatedStatsResponse.Builder builder =
          LoadBalancerAccumulatedStatsResponse.newBuilder();
      builder.putAllNumRpcsStartedByMethod(rpcsStartedByMethod);
      builder.putAllNumRpcsSucceededByMethod(rpcsSucceededByMethod);
      builder.putAllNumRpcsFailedByMethod(rpcsFailedByMethod);

      for (String method : rpcsStartedByMethod.keySet()) {
        MethodStats.Builder methodStatsBuilder = MethodStats.newBuilder();
        methodStatsBuilder.setRpcsStarted(rpcsStartedByMethod.get(method));
        if (rpcStatusByMethod.containsKey(method)) {
          methodStatsBuilder.putAllResult(rpcStatusByMethod.get(method));
        }
        builder.putStatsPerMethod(method, methodStatsBuilder.build());
      }
      return builder.build();
    }

    // e.g., RpcType.UNARY_CALL -> "UNARY_CALL"
    private static String getRpcTypeString(RpcType rpcType) {
      return rpcType.name();
    }
  }

  /** Records the remote peer distribution for a given range of RPCs. */
  private static class XdsStatsWatcher {
    private final CountDownLatch latch;
    private final long startId;
    private final long endId;
    private final Map rpcsByPeer = new HashMap<>();
    private final EnumMap> rpcsByTypeAndPeer =
        new EnumMap<>(RpcType.class);
    private final Object lock = new Object();
    private int rpcsFailed;

    private XdsStatsWatcher(long startId, long endId) {
      latch = new CountDownLatch(Ints.checkedCast(endId - startId));
      this.startId = startId;
      this.endId = endId;
    }

    void rpcCompleted(RpcType rpcType, long requestId, @Nullable String hostname) {
      synchronized (lock) {
        if (startId <= requestId && requestId < endId) {
          if (hostname != null) {
            if (rpcsByPeer.containsKey(hostname)) {
              rpcsByPeer.put(hostname, rpcsByPeer.get(hostname) + 1);
            } else {
              rpcsByPeer.put(hostname, 1);
            }
            if (rpcsByTypeAndPeer.containsKey(rpcType)) {
              if (rpcsByTypeAndPeer.get(rpcType).containsKey(hostname)) {
                rpcsByTypeAndPeer
                    .get(rpcType)
                    .put(hostname, rpcsByTypeAndPeer.get(rpcType).get(hostname) + 1);
              } else {
                rpcsByTypeAndPeer.get(rpcType).put(hostname, 1);
              }
            } else {
              Map rpcMap = new HashMap<>();
              rpcMap.put(hostname, 1);
              rpcsByTypeAndPeer.put(rpcType, rpcMap);
            }
          } else {
            rpcsFailed += 1;
          }
          latch.countDown();
        }
      }
    }

    LoadBalancerStatsResponse waitForRpcStats(long timeoutSeconds) {
      try {
        boolean success = latch.await(timeoutSeconds, TimeUnit.SECONDS);
        if (!success) {
          logger.log(Level.INFO, "Await timed out, returning partial stats");
        }
      } catch (InterruptedException e) {
        logger.log(Level.INFO, "Await interrupted, returning partial stats", e);
        Thread.currentThread().interrupt();
      }
      LoadBalancerStatsResponse.Builder builder = LoadBalancerStatsResponse.newBuilder();
      synchronized (lock) {
        builder.putAllRpcsByPeer(rpcsByPeer);
        for (Map.Entry> entry : rpcsByTypeAndPeer.entrySet()) {
          LoadBalancerStatsResponse.RpcsByPeer.Builder rpcs =
              LoadBalancerStatsResponse.RpcsByPeer.newBuilder();
          rpcs.putAllRpcsByPeer(entry.getValue());
          builder.putRpcsByMethod(getRpcTypeString(entry.getKey()), rpcs.build());
        }
        builder.setNumFailures(rpcsFailed);
      }
      return builder.build();
    }

    // e.g., RpcType.UNARY_CALL -> "UnaryCall"
    private static String getRpcTypeString(RpcType rpcType) {
      return CaseFormat.UPPER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, rpcType.name());
    }
  }
}