All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.amazonaws.http.AmazonRxNettyHttpClient Maven / Gradle / Ivy

package com.amazonaws.http;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.net.URLEncoder;
import java.util.List;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import rx.Observable;
import rx.schedulers.Schedulers;

import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.HttpContentDecompressor;
import iep.io.reactivex.netty.RxNetty;
import iep.io.reactivex.netty.client.PooledConnectionReleasedEvent;
import iep.io.reactivex.netty.protocol.http.AbstractHttpConfigurator;
import iep.io.reactivex.netty.protocol.http.server.HttpServerRequest;
import iep.io.reactivex.netty.protocol.http.server.HttpServerResponse;
import iep.io.reactivex.netty.protocol.http.server.HttpResponseHeaders;
import iep.io.reactivex.netty.protocol.http.client.ClientRequestResponseConverter;
import iep.io.reactivex.netty.protocol.http.client.HttpClient;
import iep.io.reactivex.netty.protocol.http.client.HttpClient.HttpClientConfig;
import iep.io.reactivex.netty.protocol.http.client.HttpClientRequest;
import iep.io.reactivex.netty.protocol.http.client.HttpClientResponse;
import iep.io.reactivex.netty.protocol.http.client.HttpClientPipelineConfigurator;
import iep.io.reactivex.netty.protocol.http.client.HttpRequestHeaders;
import iep.io.reactivex.netty.pipeline.PipelineConfigurator;
import iep.io.reactivex.netty.pipeline.PipelineConfiguratorComposite;
import iep.io.reactivex.netty.pipeline.ssl.DefaultFactories;

import org.w3c.dom.Node;

import com.amazonaws.*;
import com.amazonaws.auth.*;
import com.amazonaws.event.*;
import com.amazonaws.handlers.*;
import com.amazonaws.http.*;
import com.amazonaws.internal.*;
import com.amazonaws.metrics.*;
import com.amazonaws.regions.*;
import com.amazonaws.transform.*;
import com.amazonaws.util.*;
import com.amazonaws.util.AWSRequestMetrics.Field;

import org.joda.time.format.ISODateTimeFormat;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;

import javax.xml.bind.DatatypeConverter;

import java.security.MessageDigest;


abstract public class AmazonRxNettyHttpClient extends AmazonWebServiceClient {

  private static final Map> CLIENTS =
    new ConcurrentHashMap>();

  protected String mkToken(String... tokens) {
    if (tokens.length == 1)
      return tokens[0];
    else if (Arrays.stream(tokens).anyMatch(t -> { return t != null; }))
      return Arrays.stream(tokens).reduce((s1, s2) -> s1 + "|" + s2).get();
    else
      return null;
  }

  static {
    SignerFactory.registerSigner("AWS4SignerType", RxAWS4Signer.class);
  }

  private AWSCredentialsProvider awsCredentialsProvider;

  public AmazonRxNettyHttpClient() {
    this(new DefaultAWSCredentialsProviderChain(), new ClientConfiguration());
  }

  public AmazonRxNettyHttpClient(AWSCredentialsProvider awsCredentialsProvider) {
    this(awsCredentialsProvider, new ClientConfiguration());
  }

  public AmazonRxNettyHttpClient(ClientConfiguration clientConfiguration) {
    this(new DefaultAWSCredentialsProviderChain(), clientConfiguration);
  }

  public AmazonRxNettyHttpClient(
    AWSCredentialsProvider awsCredentialsProvider,
    ClientConfiguration clientConfiguration
  ) {
    super(clientConfiguration);
    this.awsCredentialsProvider = awsCredentialsProvider;
    init();
  }

  abstract protected void init();

  private  Observable getBackoffStrategyDelay(Request request, int cnt, AmazonClientException error) {
      if (cnt == 0) return Observable.just(0L);
      else {
        long delay = clientConfiguration.getRetryPolicy().getBackoffStrategy().delayBeforeNextRetry(request.getOriginalRequest(), error, cnt);
        return Observable.timer(delay, TimeUnit.MILLISECONDS);
      }
    }

  protected  Observable invokeStax(
    Request request,
    Unmarshaller unmarshaller,
    List> errorUnmarshallers,
    ExecutionContext executionContext
  ) {
    StaxRxNettyResponseHandler responseHandler = new StaxRxNettyResponseHandler(unmarshaller);
    XmlRxNettyErrorResponseHandler errorResponseHandler = new XmlRxNettyErrorResponseHandler(errorUnmarshallers);

    return invoke(request, responseHandler, errorResponseHandler, executionContext);
  }

  protected  Observable invokeJson(
    Request request,
    Unmarshaller unmarshaller,
    List errorUnmarshallers,
    ExecutionContext executionContext
  ) {
    JsonRxNettyResponseHandler responseHandler = new JsonRxNettyResponseHandler(unmarshaller);
    JsonRxNettyErrorResponseHandler errorResponseHandler = new JsonRxNettyErrorResponseHandler(request.getServiceName(), errorUnmarshallers);

    return invoke(request, responseHandler, errorResponseHandler, executionContext);
  }

  protected  Observable invokeJsonV2(
    Request request,
    Unmarshaller unmarshaller,
    List errorUnmarshallers,
    ExecutionContext executionContext
  ) {
    JsonRxNettyResponseHandler responseHandler = new JsonRxNettyResponseHandler(unmarshaller);
    JsonRxNettyErrorResponseHandlerV2 errorResponseHandler = new JsonRxNettyErrorResponseHandlerV2(request.getServiceName(), errorUnmarshallers);

    return invoke(request, responseHandler, errorResponseHandler, executionContext);
  }

  protected  Observable invoke(
    Request request,
    RxNettyResponseHandler> responseHandler,
    RxNettyResponseHandler errorResponseHandler,
    ExecutionContext executionContext
  ) {
    return Observable.defer(() -> {
      final AtomicReference error =
        new AtomicReference(null);

      final AtomicInteger cnt = new AtomicInteger(0);

      final Map> originalParameters =
        new LinkedHashMap>(request.getParameters());
      final Map originalHeaders =
        new HashMap(request.getHeaders());
      // Always mark the input stream before execution.
      final InputStream originalContent = request.getContent();
      if (originalContent != null && originalContent.markSupported()) {
        AmazonWebServiceRequest awsreq = request.getOriginalRequest();
        final int readLimit = awsreq.getRequestClientOptions().getReadLimit();
        originalContent.mark(readLimit);
      }

      return Observable.using(
        () -> {
          if (cnt.get() == 0) return Boolean.valueOf(false);
          return Boolean.valueOf(true);
        },
        (isPrepared) -> {
          assert(cnt.get() == 0 || error.get() != null);
          if (
            cnt.get() == 0 ||
            (
              cnt.get() < clientConfiguration.getRetryPolicy().getMaxErrorRetry() &&
              clientConfiguration.getRetryPolicy().getRetryCondition().shouldRetry(
                request.getOriginalRequest(), error.get(), cnt.get()
              )
            )
          ) {
            return Observable.defer(() -> {
              if (isPrepared) return Observable.just(null);
              return  prepareRequest(request, executionContext);
            })
            .flatMap(v -> { return getBackoffStrategyDelay(request, cnt.get(), error.get()); })
            .flatMap(i -> {
              try {
                if (isPrepared) {
                  request.setParameters(originalParameters);
                  request.setHeaders(originalHeaders);
                  request.setContent(originalContent);
                }
                return invokeImpl(request, responseHandler, errorResponseHandler, executionContext);
              }
              catch (java.io.UnsupportedEncodingException e) {
                return Observable.error(e);
              }
            })
            .doOnNext(n -> {
              error.set(null);
            })
            .onErrorResumeNext(t -> {
              if (t instanceof AmazonClientException) error.set((AmazonClientException) t);
              else error.set(new AmazonClientException(t));
              return Observable.empty();
            });
          }
          else return Observable.error(error.get());
        },
        (isPrepared) -> {
          cnt.getAndIncrement();
        }
      )
      .repeat()
      .first();
    });
  }

  protected  Observable prepareRequest(
    Request request,
    ExecutionContext executionContext
  ) {
    return Observable.defer(() -> {
      request.setEndpoint(endpoint);
      request.setTimeOffset(timeOffset);
      request.addHeader("User-agent", "rx-" + clientConfiguration.getUserAgent());
      if (clientConfiguration.useGzip())
        request.addHeader("Accept-encoding", "gzip");
      AmazonWebServiceRequest originalRequest = request.getOriginalRequest();

      AWSCredentials credentials = request.getOriginalRequest().getRequestCredentials();
      if (credentials == null) {
        credentials = awsCredentialsProvider.getCredentials();
      }

      executionContext.setCredentials(credentials);

      for (RequestHandler2 requestHandler2 : requestHandler2s(executionContext)) {
        if (requestHandler2 instanceof CredentialsRequestHandler)
          ((CredentialsRequestHandler) requestHandler2).setCredentials(
            executionContext.getCredentials()
          );
        requestHandler2.beforeRequest(request);
      }

      ProgressListener listener = originalRequest.getGeneralProgressListener();

      String serviceName = request.getServiceName().substring(6).toLowerCase();
      if (serviceName.endsWith("v2"))
        serviceName = serviceName.substring(0, serviceName.length() - 2);
      if (serviceName.equals("cloudwatch"))
        serviceName = "monitoring";
      String hostName = endpoint.getHost();
      String regionName = AwsHostNameUtils.parseRegionName(hostName, serviceName);

      Signer signer = SignerFactory.getSigner(serviceName, regionName);
      signer.sign(request, credentials);
      return Observable.just(null);
    });
  }

  private List requestHandler2s(ExecutionContext executionContext) {
    List requestHandler2s = executionContext.getRequestHandler2s();
    return (requestHandler2s == null) ? java.util.Collections.emptyList() : requestHandler2s;
  }

  private void afterError(
    Request request,
    Response response,
    List requestHandler2s,
    AmazonClientException e
  ) {
    for (RequestHandler2 handler2 : requestHandler2s) {
      handler2.afterError(request, response, e);
    }
  }

  private  void afterResponse(
    Request request,
    List requestHandler2s,
    Response response
  ) {
    for (RequestHandler2 handler2 : requestHandler2s) {
      handler2.afterResponse(request, response);
    }
  }

  protected  Observable invokeImpl(
    Request request,
    RxNettyResponseHandler> responseHandler,
    RxNettyResponseHandler errorResponseHandler,
    ExecutionContext executionContext
  ) throws java.io.UnsupportedEncodingException {

    return Observable.defer(() -> {
      final List requestHandler2s = requestHandler2s(executionContext);
      StringBuffer sbPath = new StringBuffer();
      if (request.getResourcePath() != null) sbPath.append(request.getResourcePath());
      if (sbPath.length() == 0) sbPath.append("/");

      String content = null;
      if (request.getContent() != null)
        content = ((StringInputStream) request.getContent()).getString();

      String queryString = RxSdkHttpUtils.encodeParameters(request);
      if (RxSdkHttpUtils.usePayloadForQueryParameters(request)) {
        request.addHeader("Content-Type", "application/x-www-form-urlencoded; charset=utf-8");
        content = queryString;
      }
      else if (queryString != null) {
        sbPath.append("?").append(queryString);
      }

      HttpClientRequest rxRequest = HttpClientRequest.create(
        HttpMethod.valueOf(request.getHttpMethod().toString()),
        sbPath.toString()
      );
      HttpRequestHeaders rxHeaders = rxRequest.getHeaders();
      request.getHeaders().entrySet().stream().forEach(e -> {
        rxHeaders.set(e.getKey(), e.getValue());
      });

      if (content != null) rxRequest.withContent(content);

      return getClient(endpoint.getHost()).submit(rxRequest)
      .flatMap(response -> {
        if (response.getStatus().code() / 100 == 2) {
          try {
            return responseHandler.handle(response).map(r -> {
              Response awsResponse = new Response(r.getResult(), null);
              afterResponse(request, requestHandler2s, awsResponse);
              return awsResponse.getAwsResponse();
            });
          }
          catch (Exception e) {
            return Observable.error(e);
          }
        }
        else {
          try {
            return errorResponseHandler.handle(response).flatMap(e -> {
              e.setServiceName(request.getServiceName());
              afterError(request, null, requestHandler2s, e);
              return Observable.error(e);
            });
          }
          catch (Exception e) {
            return Observable.error(e);
          }
        }
      });
    })
    .onErrorResumeNext(t -> {
      if (t instanceof AmazonClientException) return Observable.error(t);
      else return  Observable.error(new AmazonClientException(t));
    });
  }

  private HttpClient getClient(String host) {
    Protocol protocol = clientConfiguration.getProtocol();
    String key = protocol + "|" + host;
    if (!CLIENTS.containsKey(key)) {
      boolean isSecure;
      int port;
      if (Protocol.HTTP.equals(protocol)) {
        isSecure = false;
        port = 80;
      }
      else if (Protocol.HTTPS.equals(protocol)) {
        isSecure = true;
        port = 443;
      }
      else {
        throw new IllegalStateException("unknown protocol: " + protocol);
      }

      HttpClientConfig config = new HttpClient.HttpClientConfig.Builder()
        .setFollowRedirect(true)
        .readTimeout(clientConfiguration.getSocketTimeout(), TimeUnit.MILLISECONDS)
        .responseSubscriptionTimeout(5000, TimeUnit.MILLISECONDS)
        .build();

      HttpClient client = RxNetty.newHttpClientBuilder(host, port)
        .withName(host + "." + port)
        .config(config)
        .channelOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, clientConfiguration.getConnectionTimeout())
        .withMaxConnections(clientConfiguration.getMaxConnections())
        .withIdleConnectionsTimeoutMillis(60000)
        .withSslEngineFactory((isSecure) ? DefaultFactories.trustAll() : null)
        .pipelineConfigurator(
          new PipelineConfiguratorComposite,HttpClientRequest>(
            new HttpClientPipelineConfigurator(),
            new HttpDecompressionConfigurator()
          )
        )
        .appendPipelineConfigurator(
          pipeline -> pipeline.addLast(new ActiveLifeTracker(clientConfiguration.getConnectionTTL()))
        )
        .disableAutoReleaseBuffers()
        .build();
      CLIENTS.putIfAbsent(key, client);
    }
    return CLIENTS.get(key);
  }

  public class HttpDecompressionConfigurator implements PipelineConfigurator {
    @Override
    public void configureNewPipeline(ChannelPipeline pipeline) {
      pipeline.addLast("deflater", new HttpContentDecompressor());
    }
  }

  private static class ActiveLifeTracker extends ChannelDuplexHandler {
    private long activationTime;
    private long ttl;

    public ActiveLifeTracker(long ttl) {
      this.ttl = ttl;
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
      activationTime = System.currentTimeMillis();
      super.channelActive(ctx);
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
      if (evt instanceof PooledConnectionReleasedEvent) {
        long timeActive = System.currentTimeMillis() - activationTime;
        if (ttl >= 0 && timeActive > ttl) {
          ctx.channel().attr(ClientRequestResponseConverter.DISCARD_CONNECTION).set(true);
        }
      }
      super.userEventTriggered(ctx, evt);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy