All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.filesystem.s3.S3OutputStream Maven / Gradle / Ivy

There is a newer version: 466
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.filesystem.s3;

import io.trino.filesystem.encryption.EncryptionKey;
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.memory.context.LocalMemoryContext;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CompletedPart;
import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.ObjectCannedACL;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.RequestPayer;
import software.amazon.awssdk.services.s3.model.S3Exception;
import software.amazon.awssdk.services.s3.model.UploadPartRequest;
import software.amazon.awssdk.services.s3.model.UploadPartResponse;

import java.io.IOException;
import java.io.InterruptedIOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.file.FileAlreadyExistsException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;

import static com.google.common.base.Verify.verify;
import static io.trino.filesystem.s3.S3FileSystemConfig.ObjectCannedAcl.getCannedAcl;
import static io.trino.filesystem.s3.S3FileSystemConfig.S3SseType.NONE;
import static io.trino.filesystem.s3.S3SseCUtils.encoded;
import static io.trino.filesystem.s3.S3SseCUtils.md5Checksum;
import static io.trino.filesystem.s3.S3SseRequestConfigurator.setEncryptionSettings;
import static java.lang.Math.clamp;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.lang.System.arraycopy;
import static java.net.HttpURLConnection.HTTP_PRECON_FAILED;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.supplyAsync;

final class S3OutputStream
        extends OutputStream
{
    private final List parts = new ArrayList<>();
    private final LocalMemoryContext memoryContext;
    private final Executor uploadExecutor;
    private final S3Client client;
    private final S3Location location;
    private final S3Context context;
    private final int partSize;
    private final RequestPayer requestPayer;
    private final ObjectCannedACL cannedAcl;
    private final boolean exclusiveCreate;
    private final Optional key;

    private int currentPartNumber;
    private byte[] buffer = new byte[0];
    private int bufferSize;
    private int initialBufferSize = 64;

    private boolean closed;
    private boolean failed;
    private boolean multipartUploadStarted;
    private Future inProgressUploadFuture;

    // Mutated by background thread which does the multipart upload.
    // Read by both main thread and background thread.
    // Visibility is ensured by calling get() on inProgressUploadFuture.
    private Optional uploadId = Optional.empty();

    public S3OutputStream(AggregatedMemoryContext memoryContext, Executor uploadExecutor, S3Client client, S3Context context, S3Location location, boolean exclusiveCreate, Optional key)
    {
        this.memoryContext = memoryContext.newLocalMemoryContext(S3OutputStream.class.getSimpleName());
        this.uploadExecutor = requireNonNull(uploadExecutor, "uploadExecutor is null");
        this.client = requireNonNull(client, "client is null");
        this.location = requireNonNull(location, "location is null");
        this.exclusiveCreate = exclusiveCreate;
        this.context = requireNonNull(context, "context is null");
        this.partSize = context.partSize();
        this.requestPayer = context.requestPayer();
        this.cannedAcl = getCannedAcl(context.cannedAcl());
        this.key = requireNonNull(key, "key is null");

        verify(key.isEmpty() || context.s3SseContext().sseType() == NONE, "Encryption key cannot be used with SSE configuration");
    }

    @SuppressWarnings("NumericCastThatLosesPrecision")
    @Override
    public void write(int b)
            throws IOException
    {
        ensureOpen();
        ensureCapacity(1);
        buffer[bufferSize] = (byte) b;
        bufferSize++;
        flushBuffer(false);
    }

    @Override
    public void write(byte[] bytes, int offset, int length)
            throws IOException
    {
        ensureOpen();

        while (length > 0) {
            ensureCapacity(length);

            int copied = min(buffer.length - bufferSize, length);
            arraycopy(bytes, offset, buffer, bufferSize, copied);
            bufferSize += copied;

            flushBuffer(false);

            offset += copied;
            length -= copied;
        }
    }

    @Override
    public void flush()
            throws IOException
    {
        ensureOpen();
        flushBuffer(false);
    }

    @Override
    public void close()
            throws IOException
    {
        if (closed) {
            return;
        }
        closed = true;

        if (failed) {
            try {
                abortUpload();
                return;
            }
            catch (SdkException e) {
                throw new IOException(e);
            }
        }

        try {
            flushBuffer(true);
            memoryContext.close();
            waitForPreviousUploadFinish();
        }
        catch (IOException | RuntimeException e) {
            abortUploadSuppressed(e);
            throw e;
        }

        try {
            uploadId.ifPresent(this::finishUpload);
        }
        catch (SdkException e) {
            abortUploadSuppressed(e);
            throw new IOException(e);
        }
    }

    private void ensureOpen()
            throws IOException
    {
        if (closed) {
            throw new IOException("Output stream closed: " + location);
        }
    }

    private void ensureCapacity(int extra)
    {
        int capacity = min(partSize, bufferSize + extra);
        if (buffer.length < capacity) {
            int target = max(buffer.length, initialBufferSize);
            if (target < capacity) {
                target += target / 2; // increase 50%
                target = clamp(target, capacity, partSize);
            }
            buffer = Arrays.copyOf(buffer, target);
            memoryContext.setBytes(buffer.length);
        }
    }

    private void flushBuffer(boolean finished)
            throws IOException
    {
        // skip multipart upload if there would only be one part
        if (finished && !multipartUploadStarted) {
            PutObjectRequest request = PutObjectRequest.builder()
                    .overrideConfiguration(context::applyCredentialProviderOverride)
                    .acl(cannedAcl)
                    .requestPayer(requestPayer)
                    .bucket(location.bucket())
                    .key(location.key())
                    .contentLength((long) bufferSize)
                    .applyMutation(builder -> {
                        if (exclusiveCreate) {
                            builder.ifNoneMatch("*");
                        }
                        key.ifPresent(encryption -> {
                            builder.sseCustomerKey(encoded(encryption));
                            builder.sseCustomerAlgorithm(encryption.algorithm());
                            builder.sseCustomerKeyMD5(md5Checksum(encryption));
                        });
                        setEncryptionSettings(builder, context.s3SseContext());
                    })
                    .build();

            ByteBuffer bytes = ByteBuffer.wrap(buffer, 0, bufferSize);

            try {
                client.putObject(request, RequestBody.fromByteBuffer(bytes));
                return;
            }
            catch (S3Exception e) {
                failed = true;
                // when `location` already exists, the operation will fail with `412 Precondition Failed`
                if (e.statusCode() == HTTP_PRECON_FAILED) {
                    throw new FileAlreadyExistsException(location.toString());
                }
                throw new IOException("Put failed for bucket [%s] key [%s]: %s".formatted(location.bucket(), location.key(), e), e);
            }
            catch (SdkException e) {
                failed = true;
                throw new IOException("Put failed for bucket [%s] key [%s]: %s".formatted(location.bucket(), location.key(), e), e);
            }
        }

        // the multipart upload API only allows the last part to be smaller than 5MB
        if ((bufferSize == partSize) || (finished && (bufferSize > 0))) {
            byte[] data = buffer;
            int length = bufferSize;

            if (finished) {
                this.buffer = null;
            }
            else {
                this.buffer = new byte[0];
                this.initialBufferSize = partSize;
                bufferSize = 0;
            }
            memoryContext.setBytes(0);

            try {
                waitForPreviousUploadFinish();
            }
            catch (IOException e) {
                failed = true;
                abortUploadSuppressed(e);
                throw e;
            }
            multipartUploadStarted = true;
            inProgressUploadFuture = supplyAsync(() -> uploadPage(data, length), uploadExecutor);
        }
    }

    private void waitForPreviousUploadFinish()
            throws IOException
    {
        if (inProgressUploadFuture == null) {
            return;
        }

        try {
            inProgressUploadFuture.get();
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new InterruptedIOException();
        }
        catch (ExecutionException e) {
            throw new IOException("Streaming upload failed", e);
        }
    }

    private CompletedPart uploadPage(byte[] data, int length)
    {
        if (uploadId.isEmpty()) {
            CreateMultipartUploadRequest request = CreateMultipartUploadRequest.builder()
                    .overrideConfiguration(context::applyCredentialProviderOverride)
                    .acl(cannedAcl)
                    .requestPayer(requestPayer)
                    .bucket(location.bucket())
                    .key(location.key())
                    .applyMutation(builder ->
                        key.ifPresentOrElse(
                                encryption ->
                                    builder.sseCustomerKey(encoded(encryption))
                                            .sseCustomerAlgorithm(encryption.algorithm())
                                            .sseCustomerKeyMD5(md5Checksum(encryption)),
                                    () -> setEncryptionSettings(builder, context.s3SseContext())))
                    .build();

            uploadId = Optional.of(client.createMultipartUpload(request).uploadId());
        }

        currentPartNumber++;
        UploadPartRequest request = UploadPartRequest.builder()
                .overrideConfiguration(context::applyCredentialProviderOverride)
                .requestPayer(requestPayer)
                .bucket(location.bucket())
                .key(location.key())
                .contentLength((long) length)
                .uploadId(uploadId.get())
                .partNumber(currentPartNumber)
                .applyMutation(builder ->
                    key.ifPresentOrElse(
                            encryption ->
                                builder.sseCustomerKey(encoded(encryption))
                                        .sseCustomerAlgorithm(encryption.algorithm())
                                        .sseCustomerKeyMD5(md5Checksum(encryption)),
                            () -> setEncryptionSettings(builder, context.s3SseContext())))
                .build();

        ByteBuffer bytes = ByteBuffer.wrap(data, 0, length);

        UploadPartResponse response = client.uploadPart(request, RequestBody.fromByteBuffer(bytes));

        CompletedPart part = CompletedPart.builder()
                .partNumber(currentPartNumber)
                .eTag(response.eTag())
                .build();

        parts.add(part);
        return part;
    }

    private void finishUpload(String uploadId)
    {
        CompleteMultipartUploadRequest request = CompleteMultipartUploadRequest.builder()
                .overrideConfiguration(context::applyCredentialProviderOverride)
                .requestPayer(requestPayer)
                .bucket(location.bucket())
                .key(location.key())
                .uploadId(uploadId)
                .multipartUpload(x -> x.parts(parts))
                .applyMutation(builder -> {
                    key.ifPresentOrElse(
                            encryption ->
                                    builder.sseCustomerKey(encoded(encryption))
                                            .sseCustomerAlgorithm(encryption.algorithm())
                                            .sseCustomerKeyMD5(md5Checksum(encryption)),
                            () -> setEncryptionSettings(builder, context.s3SseContext()));
                    if (exclusiveCreate) {
                        builder.ifNoneMatch("*");
                    }
                })
                .build();

        client.completeMultipartUpload(request);
    }

    private void abortUpload()
    {
        uploadId.map(id -> AbortMultipartUploadRequest.builder()
                        .overrideConfiguration(context::applyCredentialProviderOverride)
                        .requestPayer(requestPayer)
                        .bucket(location.bucket())
                        .key(location.key())
                        .uploadId(id)
                        .build())
                .ifPresent(client::abortMultipartUpload);
    }

    @SuppressWarnings("ObjectEquality")
    private void abortUploadSuppressed(Throwable throwable)
    {
        try {
            abortUpload();
        }
        catch (Throwable t) {
            if (throwable != t) {
                throwable.addSuppressed(t);
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy