All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.grpc.testing.integration.Http2Client Maven / Gradle / Ivy

There is a newer version: 1.66.0
Show newest version
/*
 * Copyright 2016, gRPC Authors All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.testing.integration;

import static java.util.concurrent.Executors.newFixedThreadPool;

import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.NegotiationType;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.integration.Messages.Payload;
import io.grpc.testing.integration.Messages.PayloadType;
import io.grpc.testing.integration.Messages.SimpleRequest;
import io.grpc.testing.integration.Messages.SimpleResponse;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * Client application for the {@link TestServiceGrpc.TestServiceImplBase} that runs through a series
 * of HTTP/2 interop tests. The tests are designed to simulate incorrect behavior on the part of the
 * server. Some of the test cases require server-side checks and do not have assertions within the
 * client code.
 */
public final class Http2Client {
  private static final Logger logger = Logger.getLogger(Http2Client.class.getName());

  /**
   * The main application allowing this client to be launched from the command line.
   */
  public static void main(String[] args) throws Exception {
    final Http2Client client = new Http2Client();
    client.parseArgs(args);
    client.setUp();

    Runtime.getRuntime().addShutdownHook(new Thread() {
      @Override
      public void run() {
        try {
          client.shutdown();
        } catch (Exception e) {
          logger.log(Level.SEVERE, e.getMessage(), e);
        }
      }
    });

    try {
      client.run();
    } finally {
      client.shutdown();
    }
  }

  private String serverHost = "localhost";
  private int serverPort = 8080;
  private String testCase = Http2TestCases.RST_AFTER_DATA.name();

  private Tester tester = new Tester();
  private ListeningExecutorService threadpool;

  protected ManagedChannel channel;
  protected TestServiceGrpc.TestServiceBlockingStub blockingStub;
  protected TestServiceGrpc.TestServiceStub asyncStub;

  private void parseArgs(String[] args) {
    boolean usage = false;
    for (String arg : args) {
      if (!arg.startsWith("--")) {
        System.err.println("All arguments must start with '--': " + arg);
        usage = true;
        break;
      }
      String[] parts = arg.substring(2).split("=", 2);
      String key = parts[0];
      if ("help".equals(key)) {
        usage = true;
        break;
      }
      if (parts.length != 2) {
        System.err.println("All arguments must be of the form --arg=value");
        usage = true;
        break;
      }
      String value = parts[1];
      if ("server_host".equals(key)) {
        serverHost = value;
      } else if ("server_port".equals(key)) {
        serverPort = Integer.parseInt(value);
      } else if ("test_case".equals(key)) {
        testCase = value;
      } else {
        System.err.println("Unknown argument: " + key);
        usage = true;
        break;
      }
    }
    if (usage) {
      Http2Client c = new Http2Client();
      System.out.println(
          "Usage: [ARGS...]"
              + "\n"
              + "\n  --server_host=HOST          Server to connect to. Default " + c.serverHost
              + "\n  --server_port=PORT          Port to connect to. Default " + c.serverPort
              + "\n  --test_case=TESTCASE        Test case to run. Default " + c.testCase
              + "\n    Valid options:"
              + validTestCasesHelpText()
      );
      System.exit(1);
    }
  }

  private void setUp() {
    channel = createChannel();
    blockingStub = TestServiceGrpc.newBlockingStub(channel);
    asyncStub = TestServiceGrpc.newStub(channel);
  }

  private void shutdown() {
    try {
      if (channel != null) {
        channel.shutdownNow();
        channel.awaitTermination(1, TimeUnit.SECONDS);
      }
    } catch (Exception ex) {
      throw new RuntimeException(ex);
    }

    try {
      if (threadpool != null) {
        threadpool.shutdownNow();
      }
    } catch (Exception ex) {
      throw new RuntimeException(ex);
    }
  }

  private void run() {
    logger.info("Running test " + testCase);
    try {
      runTest(Http2TestCases.fromString(testCase));
    } catch (RuntimeException ex) {
      throw ex;
    } catch (Exception ex) {
      throw new RuntimeException(ex);
    }
    logger.info("Test completed.");
  }

  private void runTest(Http2TestCases testCase) throws Exception {
    switch (testCase) {
      case RST_AFTER_HEADER:
        tester.rstAfterHeader();
        break;
      case RST_AFTER_DATA:
        tester.rstAfterData();
        break;
      case RST_DURING_DATA:
        tester.rstDuringData();
        break;
      case GOAWAY:
        tester.goAway();
        break;
      case PING:
        tester.ping();
        break;
      case MAX_STREAMS:
        tester.maxStreams();
        break;
      default:
        throw new IllegalArgumentException("Unknown test case: " + testCase);
    }
  }

  private class Tester {
    private final int timeoutSeconds = 180;

    private final int responseSize = 314159;
    private final int payloadSize = 271828;
    private final SimpleRequest simpleRequest = SimpleRequest.newBuilder()
        .setResponseSize(responseSize)
        .setResponseType(PayloadType.COMPRESSABLE)
        .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[payloadSize])))
        .build();
    final SimpleResponse goldenResponse = SimpleResponse.newBuilder()
        .setPayload(Payload.newBuilder()
            .setType(PayloadType.COMPRESSABLE)
            .setBody(ByteString.copyFrom(new byte[responseSize])))
        .build();

    private void rstAfterHeader() throws Exception {
      try {
        blockingStub.unaryCall(simpleRequest);
        throw new AssertionError("Expected call to fail");
      } catch (StatusRuntimeException ex) {
        assertRstStreamReceived(ex.getStatus());
      }
    }

    private void rstAfterData() throws Exception {
      // Use async stub to verify data is received.
      RstStreamObserver responseObserver = new RstStreamObserver();
      asyncStub.unaryCall(simpleRequest, responseObserver);
      if (!responseObserver.awaitCompletion(timeoutSeconds, TimeUnit.SECONDS)) {
        throw new AssertionError("Operation timed out");
      }
      if (responseObserver.getError() == null) {
        throw new AssertionError("Expected call to fail");
      }
      assertRstStreamReceived(Status.fromThrowable(responseObserver.getError()));
      if (responseObserver.getResponses().size() != 1) {
        throw new AssertionError("Expected one response");
      }
    }

    private void rstDuringData() throws Exception {
      // Use async stub to verify no data is received.
      RstStreamObserver responseObserver = new RstStreamObserver();
      asyncStub.unaryCall(simpleRequest, responseObserver);
      if (!responseObserver.awaitCompletion(timeoutSeconds, TimeUnit.SECONDS)) {
        throw new AssertionError("Operation timed out");
      }
      if (responseObserver.getError() == null) {
        throw new AssertionError("Expected call to fail");
      }
      assertRstStreamReceived(Status.fromThrowable(responseObserver.getError()));
      if (responseObserver.getResponses().size() != 0) {
        throw new AssertionError("Expected zero responses");
      }
    }

    private void goAway() throws Exception {
      assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
      TimeUnit.SECONDS.sleep(1);
      assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
    }

    private void ping() throws Exception {
      assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
    }

    private void maxStreams() throws Exception {
      final int numThreads = 10;

      // Preliminary call to ensure MAX_STREAMS setting is received by the client.
      assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);

      threadpool = MoreExecutors.listeningDecorator(newFixedThreadPool(numThreads));
      List> workerFutures = new ArrayList>();
      for (int i = 0; i < numThreads; i++) {
        workerFutures.add(threadpool.submit(new MaxStreamsWorker(i, simpleRequest)));
      }
      ListenableFuture f = Futures.allAsList(workerFutures);
      f.get(timeoutSeconds, TimeUnit.SECONDS);
    }

    private class RstStreamObserver implements StreamObserver {
      private final CountDownLatch latch = new CountDownLatch(1);
      private final List responses = new ArrayList();
      private Throwable error;

      @Override
      public void onNext(SimpleResponse value) {
        responses.add(value);
      }

      @Override
      public void onError(Throwable t) {
        error = t;
        latch.countDown();
      }

      @Override
      public void onCompleted() {
        latch.countDown();
      }

      public List getResponses() {
        return responses;
      }

      public Throwable getError() {
        return error;
      }

      public boolean awaitCompletion(long timeout, TimeUnit unit) throws Exception {
        return latch.await(timeout, unit);
      }
    }

    private class MaxStreamsWorker implements Runnable {
      int threadNum;
      SimpleRequest request;

      MaxStreamsWorker(int threadNum, SimpleRequest request) {
        this.threadNum = threadNum;
        this.request = request;
      }

      @Override
      public void run() {
        Thread.currentThread().setName("thread:" + threadNum);
        try {
          TestServiceGrpc.TestServiceBlockingStub blockingStub =
              TestServiceGrpc.newBlockingStub(channel);
          assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
        } catch (Exception e) {
          throw new RuntimeException(e);
        }
      }
    }

    private void assertRstStreamReceived(Status status) {
      if (!status.getCode().equals(Status.Code.UNAVAILABLE)) {
        throw new AssertionError("Wrong status code. Expected: " + Status.Code.UNAVAILABLE
            + " Received: " + status.getCode());
      }
      String http2ErrorPrefix = "HTTP/2 error code: NO_ERROR";
      if (status.getDescription() == null
          || !status.getDescription().startsWith(http2ErrorPrefix)) {
        throw new AssertionError("Wrong HTTP/2 error code. Expected: " + http2ErrorPrefix
            + " Received: " + status.getDescription());
      }
    }

    private void assertResponseEquals(SimpleResponse response, SimpleResponse goldenResponse) {
      if (!response.equals(goldenResponse)) {
        throw new AssertionError("Incorrect response received");
      }
    }
  }

  private ManagedChannel createChannel() {
    InetAddress address;
    try {
      address = InetAddress.getByName(serverHost);
    } catch (UnknownHostException ex) {
      throw new RuntimeException(ex);
    }
    return NettyChannelBuilder.forAddress(new InetSocketAddress(address, serverPort))
        .negotiationType(NegotiationType.PLAINTEXT)
        .build();
  }

  private static String validTestCasesHelpText() {
    StringBuilder builder = new StringBuilder();
    for (Http2TestCases testCase : Http2TestCases.values()) {
      String strTestcase = testCase.name().toLowerCase();
      builder.append("\n      ")
          .append(strTestcase)
          .append(": ")
          .append(testCase.description());
    }
    return builder.toString();
  }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy