All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.server.protocol.spooling.SpoolingManagerBridge Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.server.protocol.spooling;

import com.google.inject.Inject;
import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
import io.trino.spi.TrinoException;
import io.trino.spi.protocol.SpooledLocation;
import io.trino.spi.protocol.SpooledSegmentHandle;
import io.trino.spi.protocol.SpoolingContext;
import io.trino.spi.protocol.SpoolingManager;

import javax.crypto.Cipher;
import javax.crypto.SecretKey;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import java.util.Optional;

import static io.airlift.slice.Slices.utf8Slice;
import static io.airlift.slice.Slices.wrappedBuffer;
import static io.trino.spi.StandardErrorCode.CONFIGURATION_INVALID;
import static io.trino.spi.protocol.SpooledLocation.CoordinatorLocation;
import static io.trino.spi.protocol.SpooledLocation.DirectLocation;
import static io.trino.spi.protocol.SpooledLocation.coordinatorLocation;
import static java.util.Base64.getUrlDecoder;
import static java.util.Base64.getUrlEncoder;
import static java.util.Objects.requireNonNull;
import static javax.crypto.Cipher.DECRYPT_MODE;
import static javax.crypto.Cipher.ENCRYPT_MODE;

public class SpoolingManagerBridge
        implements SpoolingManager
{
    private final SpoolingManagerRegistry registry;
    private final DataSize initialSegmentSize;
    private final DataSize maximumSegmentSize;
    private final boolean inlineSegments;
    private final SecretKey secretKey;
    private final boolean directStorageAccess;
    private final boolean directStorageFallback;

    @Inject
    public SpoolingManagerBridge(SpoolingConfig spoolingConfig, SpoolingManagerRegistry registry)
    {
        this.registry = requireNonNull(registry, "registry is null");
        requireNonNull(spoolingConfig, "spoolingConfig is null");
        this.initialSegmentSize = spoolingConfig.getInitialSegmentSize();
        this.maximumSegmentSize = spoolingConfig.getMaximumSegmentSize();
        this.inlineSegments = spoolingConfig.isInlineSegments();
        this.directStorageAccess = spoolingConfig.isDirectStorageAccess();
        this.directStorageFallback = spoolingConfig.isDirectStorageFallback();
        this.secretKey = spoolingConfig.getSharedEncryptionKey()
                .orElseThrow(() -> new IllegalArgumentException("protocol.spooling.shared-secret-key is not set"));
    }

    @Override
    public long maximumSegmentSize()
    {
        return maximumSegmentSize.toBytes();
    }

    @Override
    public long initialSegmentSize()
    {
        return initialSegmentSize.toBytes();
    }

    @Override
    public boolean allowSegmentInlining()
    {
        return inlineSegments && delegate().allowSegmentInlining();
    }

    @Override
    public SpooledSegmentHandle create(SpoolingContext context)
    {
        return delegate().create(context);
    }

    @Override
    public OutputStream createOutputStream(SpooledSegmentHandle handle)
            throws IOException
    {
        return delegate().createOutputStream(handle);
    }

    @Override
    public InputStream openInputStream(SpooledSegmentHandle handle)
            throws IOException
    {
        return delegate().openInputStream(handle);
    }

    @Override
    public void acknowledge(SpooledSegmentHandle handle)
            throws IOException
    {
        delegate().acknowledge(handle);
    }

    @Override
    public SpooledLocation location(SpooledSegmentHandle handle)
    {
        return switch (delegate().location(handle)) {
            case DirectLocation directLocation -> directLocation;
            case CoordinatorLocation coordinatorLocation ->
                    coordinatorLocation(toUri(secretKey, coordinatorLocation.identifier()), coordinatorLocation.headers());
        };
    }

    @Override
    public Optional directLocation(SpooledSegmentHandle handle)
            throws IOException
    {
        if (!directStorageAccess) {
            // Disabled - client fetches data through the coordinator
            return Optional.empty();
        }

        try {
            return delegate().directLocation(handle);
        }
        catch (UnsupportedOperationException e) {
            throw new TrinoException(CONFIGURATION_INVALID, "Direct storage access is enabled but not supported by " + delegate().getClass().getSimpleName(), e);
        }
        catch (IOException e) {
            if (directStorageFallback) {
                return Optional.empty();
            }
            throw e;
        }
    }

    @Override
    public SpooledSegmentHandle handle(SpooledLocation location)
    {
        switch (location) {
            case DirectLocation _ -> throw new IllegalArgumentException("Cannot convert direct location to handle");
            case CoordinatorLocation coordinatorLocation -> {
                return delegate()
                        .handle(coordinatorLocation(fromUri(secretKey, coordinatorLocation.identifier()), coordinatorLocation.headers()));
            }
        }
    }

    private SpoolingManager delegate()
    {
        return registry
                .getSpoolingManager()
                .orElseThrow(() -> new IllegalStateException("Spooling manager is not loaded"));
    }

    private static Slice toUri(SecretKey secretKey, Slice input)
    {
        try {
            Cipher cipher = Cipher.getInstance("AES");
            cipher.init(ENCRYPT_MODE, secretKey);
            return utf8Slice(getUrlEncoder().encodeToString(cipher.doFinal(input.getBytes())));
        }
        catch (GeneralSecurityException e) {
            throw new RuntimeException("Could not encode segment identifier to URI", e);
        }
    }

    private static Slice fromUri(SecretKey secretKey, Slice input)
    {
        try {
            Cipher cipher = Cipher.getInstance("AES");
            cipher.init(DECRYPT_MODE, secretKey);
            return wrappedBuffer(cipher.doFinal(getUrlDecoder().decode(input.getBytes())));
        }
        catch (GeneralSecurityException e) {
            throw new RuntimeException("Could not decode segment identifier from URI", e);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy