All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cz.o2.proxima.direct.io.s3.S3Client Maven / Gradle / Ivy

/*
 * Copyright 2017-2023 O2 Czech Republic, a.s.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package cz.o2.proxima.direct.io.s3;

import com.amazonaws.ClientConfiguration;
import com.amazonaws.Protocol;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3Client;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import com.amazonaws.services.s3.model.CompleteMultipartUploadRequest;
import com.amazonaws.services.s3.model.GetObjectRequest;
import com.amazonaws.services.s3.model.InitiateMultipartUploadRequest;
import com.amazonaws.services.s3.model.PartETag;
import com.amazonaws.services.s3.model.S3Object;
import com.amazonaws.services.s3.model.SSECustomerKey;
import com.amazonaws.services.s3.model.UploadPartRequest;
import com.amazonaws.services.s3.model.UploadPartResult;
import cz.o2.proxima.core.functional.BiConsumer;
import cz.o2.proxima.core.functional.UnaryFunction;
import cz.o2.proxima.core.storage.UriUtil;
import cz.o2.proxima.direct.io.blob.RetryStrategy;
import cz.o2.proxima.internal.com.google.common.annotations.VisibleForTesting;
import cz.o2.proxima.internal.com.google.common.base.Preconditions;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

@Slf4j
class S3Client implements Serializable {

  private static final long serialVersionUID = 1L;

  /** Part size in multi-part upload (5MB). */
  private static final int UPLOAD_PART_SIZE = 5 * 1024 * 1024;

  @VisibleForTesting
  static class AmazonS3Factory {
    private static final Map> UPDATERS =
        new HashMap<>();

    static {
      UPDATERS.put(
          "path-style-access",
          (value, builder) -> builder.setPathStyleAccessEnabled(Boolean.valueOf(value.toString())));
      UPDATERS.put(
          "endpoint",
          (value, builder) ->
              builder.setEndpointConfiguration(
                  new EndpointConfiguration(
                      value.toString(), endpoint(builder).getSigningRegion())));
      UPDATERS.put(
          "signing-region",
          (value, builder) ->
              builder.setEndpointConfiguration(
                  new EndpointConfiguration(
                      endpoint(builder).getServiceEndpoint(), value.toString())));
      UPDATERS.put(
          "ssl-enabled",
          (value, builder) -> {
            if (!Boolean.getBoolean(value.toString())) {
              clientConfiguration(builder).setProtocol(Protocol.HTTP);
            }
          });
      UPDATERS.put("region", (value, builder) -> builder.setRegion(value.toString()));

      UPDATERS.put(
          "max-connections",
          (value, builder) -> clientConfiguration(builder).setMaxConnections((int) value));

      UPDATERS.put(
          "connection-timeout-ms",
          (value, builder) -> clientConfiguration(builder).setConnectionTimeout((int) value));
    }

    private static ClientConfiguration clientConfiguration(AmazonS3ClientBuilder builder) {
      return Optional.ofNullable(builder.getClientConfiguration())
          .orElse(new ClientConfiguration());
    }

    static EndpointConfiguration endpoint(AmazonS3ClientBuilder builder) {
      return Optional.ofNullable(builder.getEndpoint()).orElse(new EndpointConfiguration("", ""));
    }

    private final Map cfg;

    AmazonS3Factory(Map cfg) {
      this.cfg = cfg;
    }

    AmazonS3 build() {
      validate();
      AmazonS3ClientBuilder builder = AmazonS3Client.builder();
      UPDATERS.forEach(
          (name, updater) ->
              Optional.ofNullable(cfg.get(name))
                  .ifPresent(value -> updater.accept(value, builder)));
      String accessKey = getOpt(cfg, "access-key", Object::toString, "");
      String secretKey = getOpt(cfg, "secret-key", Object::toString, "");
      builder.setCredentials(
          new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)));
      return builder.build();
    }

    private void validate() {
      String accessKey = getOpt(cfg, "access-key", Object::toString, "");
      String secretKey = getOpt(cfg, "secret-key", Object::toString, "");
      Preconditions.checkArgument(!accessKey.isEmpty(), "access-key must not be empty");
      Preconditions.checkArgument(!secretKey.isEmpty(), "secret-key must not be empty");
      @Nullable String base64SseKey = getOpt(cfg, "ssec-base64-key", Object::toString, null);
      boolean sslEnabled = getOpt(cfg, "ssl-enabled", Boolean::parseBoolean, false);
      @Nullable URI endpoint = getOpt(cfg, "endpoint", URI::create, null);
      if (!sslEnabled && endpoint != null) {
        sslEnabled =
            Optional.ofNullable(endpoint.getScheme())
                .map(e -> e.equalsIgnoreCase("https"))
                .orElse(false);
      }

      if (base64SseKey != null) {
        // SSE-C encryption require SSL
        Preconditions.checkArgument(sslEnabled, "SSL is required when sse-c is enabled.");
      }
    }
  }

  @Getter private final String bucket;
  @Getter private final String path;
  @Getter private final RetryStrategy retry;
  private final Map cfg;
  @Nullable private transient AmazonS3 client;
  @Nullable @Getter private transient SSECustomerKey sseCustomerKey;

  S3Client(URI uri, Map cfg) {
    this.bucket = uri.getAuthority();
    this.path = toPath(uri);
    int initialRetryDelay = getOpt(cfg, "initial-retry-delay-ms", Integer::valueOf, 5000);
    int maxRetryDelay = getOpt(cfg, "max-retry-delay-ms", Integer::valueOf, (2 << 10) * 5000);
    this.retry = new RetryStrategy(initialRetryDelay, maxRetryDelay);
    @Nullable String base64SseKey = getOpt(cfg, "ssec-base64-key", Object::toString, null);
    if (base64SseKey != null) {
      sseCustomerKey = new SSECustomerKey(base64SseKey);
    }
    this.cfg = cfg;
    new AmazonS3Factory(cfg).validate();
  }

  // normalize path to not start and to end with slash
  private static String toPath(URI uri) {
    return UriUtil.getPathNormalized(uri) + "/";
  }

  static  T getOpt(
      Map cfg, String name, UnaryFunction map, T defval) {
    return Optional.ofNullable(cfg.get(name)).map(Object::toString).map(map::apply).orElse(defval);
  }

  @VisibleForTesting
  AmazonS3 client() {
    if (client == null) {
      client = new AmazonS3Factory(cfg).build();
    }
    return client;
  }

  public S3Object getObject(String blobName) {
    GetObjectRequest request = new GetObjectRequest(getBucket(), blobName);
    if (sseCustomerKey != null) {
      request.setSSECustomerKey(sseCustomerKey);
    }
    return client().getObject(request);
  }

  public void deleteObject(String key) {
    client().deleteObject(getBucket(), key);
  }

  /**
   * Put object to s3 using multi-part upload.
   *
   * @param blobName Name of the blob we want to write.
   * @return Output stream that we can write data into.
   */
  public OutputStream putObject(String blobName) {
    Preconditions.checkState(!client().doesObjectExist(bucket, blobName), "Object already exists.");
    final String currentBucket = getBucket();
    InitiateMultipartUploadRequest request =
        new InitiateMultipartUploadRequest(currentBucket, blobName);
    if (sseCustomerKey != null) {
      request.setSSECustomerKey(sseCustomerKey);
    }
    final String uploadId = client().initiateMultipartUpload(request).getUploadId();
    final List eTags = new ArrayList<>();
    final byte[] partBuffer = new byte[UPLOAD_PART_SIZE];
    return new OutputStream() {

      /** Signalizes whether this output stream is closed. */
      private boolean closed = false;

      /** Number of un-flushed bytes in current part buffer. */
      private int currentBytes = 0;

      /** Part number of current part in multi-part upload. Indexing from 1. */
      private int partNumber = 1;

      @Override
      public void write(int b) throws IOException {
        Preconditions.checkState(!closed, "Output stream already closed.");
        // Number of bytes written is also position of next write.
        partBuffer[currentBytes] = (byte) b;
        currentBytes++;
        if (currentBytes >= UPLOAD_PART_SIZE) {
          flush();
        }
      }

      @Override
      public void flush() throws IOException {
        Preconditions.checkState(!closed, "Output stream already closed.");
        if (currentBytes > 0) {
          try (final InputStream is = new ByteArrayInputStream(partBuffer, 0, currentBytes)) {
            final UploadPartRequest uploadPartRequest =
                new UploadPartRequest()
                    .withBucketName(currentBucket)
                    .withKey(blobName)
                    .withUploadId(uploadId)
                    .withPartNumber(partNumber)
                    .withInputStream(is)
                    .withPartSize(currentBytes);
            if (sseCustomerKey != null) {
              uploadPartRequest.setSSECustomerKey(sseCustomerKey);
            }

            final UploadPartResult uploadPartResult = client().uploadPart(uploadPartRequest);
            eTags.add(uploadPartResult.getPartETag());
            partNumber++;
          }
        }
        currentBytes = 0;
      }

      @Override
      public void close() throws IOException {
        if (!closed) {
          flush();
          client()
              .completeMultipartUpload(
                  new CompleteMultipartUploadRequest(currentBucket, blobName, uploadId, eTags));
          closed = true;
        }
      }
    };
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy