io.grpc.testing.integration.Http2Client Maven / Gradle / Ivy
/*
* Copyright 2016, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.testing.integration;
import static java.util.concurrent.Executors.newFixedThreadPool;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.NegotiationType;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.integration.Messages.Payload;
import io.grpc.testing.integration.Messages.PayloadType;
import io.grpc.testing.integration.Messages.SimpleRequest;
import io.grpc.testing.integration.Messages.SimpleResponse;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* Client application for the {@link TestServiceGrpc.TestServiceImplBase} that runs through a series
* of HTTP/2 interop tests. The tests are designed to simulate incorrect behavior on the part of the
* server. Some of the test cases require server-side checks and do not have assertions within the
* client code.
*/
public final class Http2Client {
private static final Logger logger = Logger.getLogger(Http2Client.class.getName());
/**
* The main application allowing this client to be launched from the command line.
*/
public static void main(String[] args) throws Exception {
final Http2Client client = new Http2Client();
client.parseArgs(args);
client.setUp();
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override
public void run() {
try {
client.shutdown();
} catch (Exception e) {
logger.log(Level.SEVERE, e.getMessage(), e);
}
}
});
try {
client.run();
} finally {
client.shutdown();
}
}
private String serverHost = "localhost";
private int serverPort = 8080;
private String testCase = Http2TestCases.RST_AFTER_DATA.name();
private Tester tester = new Tester();
private ListeningExecutorService threadpool;
protected ManagedChannel channel;
protected TestServiceGrpc.TestServiceBlockingStub blockingStub;
protected TestServiceGrpc.TestServiceStub asyncStub;
private void parseArgs(String[] args) {
boolean usage = false;
for (String arg : args) {
if (!arg.startsWith("--")) {
System.err.println("All arguments must start with '--': " + arg);
usage = true;
break;
}
String[] parts = arg.substring(2).split("=", 2);
String key = parts[0];
if ("help".equals(key)) {
usage = true;
break;
}
if (parts.length != 2) {
System.err.println("All arguments must be of the form --arg=value");
usage = true;
break;
}
String value = parts[1];
if ("server_host".equals(key)) {
serverHost = value;
} else if ("server_port".equals(key)) {
serverPort = Integer.parseInt(value);
} else if ("test_case".equals(key)) {
testCase = value;
} else {
System.err.println("Unknown argument: " + key);
usage = true;
break;
}
}
if (usage) {
Http2Client c = new Http2Client();
System.out.println(
"Usage: [ARGS...]"
+ "\n"
+ "\n --server_host=HOST Server to connect to. Default " + c.serverHost
+ "\n --server_port=PORT Port to connect to. Default " + c.serverPort
+ "\n --test_case=TESTCASE Test case to run. Default " + c.testCase
+ "\n Valid options:"
+ validTestCasesHelpText()
);
System.exit(1);
}
}
private void setUp() {
channel = createChannel();
blockingStub = TestServiceGrpc.newBlockingStub(channel).withWaitForReady();
asyncStub = TestServiceGrpc.newStub(channel).withWaitForReady();
}
private void shutdown() {
try {
if (channel != null) {
channel.shutdownNow();
channel.awaitTermination(1, TimeUnit.SECONDS);
}
} catch (Exception ex) {
throw new RuntimeException(ex);
}
try {
if (threadpool != null) {
threadpool.shutdownNow();
}
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
private void run() {
logger.info("Running test " + testCase);
try {
runTest(Http2TestCases.fromString(testCase));
} catch (RuntimeException ex) {
throw ex;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
logger.info("Test completed.");
}
private void runTest(Http2TestCases testCase) throws Exception {
switch (testCase) {
case RST_AFTER_HEADER:
tester.rstAfterHeader();
break;
case RST_AFTER_DATA:
tester.rstAfterData();
break;
case RST_DURING_DATA:
tester.rstDuringData();
break;
case GOAWAY:
tester.goAway();
break;
case PING:
tester.ping();
break;
case MAX_STREAMS:
tester.maxStreams();
break;
default:
throw new IllegalArgumentException("Unknown test case: " + testCase);
}
}
private class Tester {
private final int timeoutSeconds = 5;
private final int responseSize = 314159;
private final int payloadSize = 271828;
private final SimpleRequest simpleRequest = SimpleRequest.newBuilder()
.setResponseSize(responseSize)
.setResponseType(PayloadType.COMPRESSABLE)
.setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[payloadSize])))
.build();
final SimpleResponse goldenResponse = SimpleResponse.newBuilder()
.setPayload(Payload.newBuilder()
.setType(PayloadType.COMPRESSABLE)
.setBody(ByteString.copyFrom(new byte[responseSize])))
.build();
private void rstAfterHeader() throws Exception {
try {
blockingStub.unaryCall(simpleRequest);
throw new AssertionError("Expected call to fail");
} catch (StatusRuntimeException ex) {
assertRstStreamReceived(ex.getStatus());
}
}
private void rstAfterData() throws Exception {
// Use async stub to verify data is received.
RstStreamObserver responseObserver = new RstStreamObserver();
asyncStub.unaryCall(simpleRequest, responseObserver);
if (!responseObserver.awaitCompletion(timeoutSeconds, TimeUnit.SECONDS)) {
throw new AssertionError("Operation timed out");
}
if (responseObserver.getResponses().size() != 1) {
throw new AssertionError("Expected one response");
}
if (responseObserver.getError() == null) {
throw new AssertionError("Expected call to fail");
}
assertRstStreamReceived(Status.fromThrowable(responseObserver.getError()));
}
private void rstDuringData() throws Exception {
// Use async stub to verify no data is received.
RstStreamObserver responseObserver = new RstStreamObserver();
asyncStub.unaryCall(simpleRequest, responseObserver);
if (!responseObserver.awaitCompletion(timeoutSeconds, TimeUnit.SECONDS)) {
throw new AssertionError("Operation timed out");
}
if (responseObserver.getResponses().size() != 0) {
throw new AssertionError("Expected zero responses");
}
if (responseObserver.getError() == null) {
throw new AssertionError("Expected call to fail");
}
assertRstStreamReceived(Status.fromThrowable(responseObserver.getError()));
}
private void goAway() throws Exception {
assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
TimeUnit.SECONDS.sleep(1);
assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
}
private void ping() throws Exception {
assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
}
private void maxStreams() throws Exception {
final int numThreads = 10;
// Preliminary call to ensure MAX_STREAMS setting is received by the client.
assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
threadpool = MoreExecutors.listeningDecorator(newFixedThreadPool(numThreads));
List> workerFutures = new ArrayList>();
for (int i = 0; i < numThreads; i++) {
workerFutures.add(threadpool.submit(new MaxStreamsWorker(i, simpleRequest)));
}
ListenableFuture> f = Futures.allAsList(workerFutures);
f.get(timeoutSeconds, TimeUnit.SECONDS);
}
private class RstStreamObserver implements StreamObserver {
private final CountDownLatch latch = new CountDownLatch(1);
private final List responses = new ArrayList();
private Throwable error;
@Override
public void onNext(SimpleResponse value) {
responses.add(value);
}
@Override
public void onError(Throwable t) {
error = t;
latch.countDown();
}
@Override
public void onCompleted() {
latch.countDown();
}
public List getResponses() {
return responses;
}
public Throwable getError() {
return error;
}
public boolean awaitCompletion(long timeout, TimeUnit unit) throws Exception {
return latch.await(timeout, unit);
}
}
private class MaxStreamsWorker implements Runnable {
int threadNum;
SimpleRequest request;
MaxStreamsWorker(int threadNum, SimpleRequest request) {
this.threadNum = threadNum;
this.request = request;
}
@Override
public void run() {
Thread.currentThread().setName("thread:" + threadNum);
try {
TestServiceGrpc.TestServiceBlockingStub blockingStub =
TestServiceGrpc.newBlockingStub(channel);
assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
private void assertRstStreamReceived(Status status) {
if (!status.getCode().equals(Status.Code.UNAVAILABLE)) {
throw new AssertionError("Wrong status code. Expected: " + Status.Code.UNAVAILABLE
+ " Received: " + status.getCode());
}
String http2ErrorPrefix = "HTTP/2 error code: NO_ERROR";
if (status.getDescription() == null
|| !status.getDescription().startsWith(http2ErrorPrefix)) {
throw new AssertionError("Wrong HTTP/2 error code. Expected: " + http2ErrorPrefix
+ " Received: " + status.getDescription());
}
}
private void assertResponseEquals(SimpleResponse response, SimpleResponse goldenResponse) {
if (!response.equals(goldenResponse)) {
throw new AssertionError("Incorrect response received");
}
}
}
private ManagedChannel createChannel() {
InetAddress address;
try {
address = InetAddress.getByName(serverHost);
} catch (UnknownHostException ex) {
throw new RuntimeException(ex);
}
return NettyChannelBuilder.forAddress(new InetSocketAddress(address, serverPort))
.negotiationType(NegotiationType.PLAINTEXT)
.build();
}
private static String validTestCasesHelpText() {
StringBuilder builder = new StringBuilder();
for (Http2TestCases testCase : Http2TestCases.values()) {
String strTestcase = testCase.name().toLowerCase();
builder.append("\n ")
.append(strTestcase)
.append(": ")
.append(testCase.description());
}
return builder.toString();
}
}