All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.facebook.presto.spark.util.PrestoSparkUtils Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.spark.util;

import com.facebook.airlift.json.Codec;
import com.facebook.presto.common.block.BlockEncodingManager;
import com.facebook.presto.spark.PrestoSparkServiceWaitTimeMetrics;
import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage;
import com.facebook.presto.spi.page.PageCompressor;
import com.facebook.presto.spi.page.PageDecompressor;
import com.facebook.presto.spi.page.PagesSerde;
import com.facebook.presto.spi.page.SerializedPage;
import com.github.luben.zstd.Zstd;
import com.github.luben.zstd.ZstdInputStreamNoFinalizer;
import com.github.luben.zstd.ZstdOutputStreamNoFinalizer;
import com.google.common.collect.AbstractIterator;
import com.google.common.util.concurrent.UncheckedExecutionException;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.Duration;
import org.apache.spark.SparkException;
import org.apache.spark.api.java.JavaFutureAction;
import scala.reflect.ClassTag;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.zip.DeflaterInputStream;
import java.util.zip.InflaterOutputStream;

import static com.facebook.presto.common.block.BlockUtil.compactArray;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Throwables.propagateIfPossible;
import static com.google.common.io.ByteStreams.toByteArray;
import static java.lang.Math.toIntExact;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

public class PrestoSparkUtils
{
    private static final int COMPRESSION_LEVEL = 3; // default level
    private PrestoSparkUtils() {}

    public static PrestoSparkSerializedPage toPrestoSparkSerializedPage(SerializedPage serializedPage)
    {
        Slice slice = serializedPage.getSlice();
        checkArgument(slice.hasByteArray(), "slice is expected to be based on a byte array");
        return new PrestoSparkSerializedPage(
                compactArray(slice.byteArray(), slice.byteArrayOffset(), slice.length()),
                serializedPage.getPositionCount(),
                serializedPage.getUncompressedSizeInBytes(),
                serializedPage.getPageCodecMarkers(),
                serializedPage.getChecksum());
    }

    public static SerializedPage toSerializedPage(PrestoSparkSerializedPage prestoSparkSerializedPage)
    {
        return new SerializedPage(
                Slices.wrappedBuffer(prestoSparkSerializedPage.getBytes()),
                prestoSparkSerializedPage.getPageCodecMarkers(),
                toIntExact(prestoSparkSerializedPage.getPositionCount()),
                prestoSparkSerializedPage.getUncompressedSizeInBytes(),
                prestoSparkSerializedPage.getChecksum());
    }

    public static byte[] compress(byte[] bytes)
    {
        try (DeflaterInputStream decompressor = new DeflaterInputStream(new ByteArrayInputStream(bytes))) {
            return toByteArray(decompressor);
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    public static byte[] decompress(byte[] bytes)
    {
        ByteArrayOutputStream output = new ByteArrayOutputStream();
        try (InflaterOutputStream compressor = new InflaterOutputStream(output)) {
            compressor.write(bytes);
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
        return output.toByteArray();
    }

    public static  byte[] serializeZstdCompressed(Codec codec, T instance)
    {
        try (ByteArrayOutputStream output = new ByteArrayOutputStream();
                ZstdOutputStreamNoFinalizer zstdOutput = new ZstdOutputStreamNoFinalizer(output)) {
            codec.writeBytes(zstdOutput, instance);
            zstdOutput.close();
            output.close();
            return output.toByteArray();
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    public static  T deserializeZstdCompressed(Codec codec, byte[] bytes)
    {
        try (InputStream input = new ByteArrayInputStream(bytes);
                ZstdInputStreamNoFinalizer zstdInput = new ZstdInputStreamNoFinalizer(input)) {
            return codec.readBytes(zstdInput);
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    public static PagesSerde createPagesSerde(BlockEncodingManager blockEncodingManager)
    {
        return new PagesSerde(
                blockEncodingManager,
                Optional.of(createPageCompressor()),
                Optional.of(createPageDecompressor()),
                Optional.empty());
    }

    private static PageCompressor createPageCompressor()
    {
        // based on ZstdJniCompressor
        return new PageCompressor()
        {
            @Override
            public int maxCompressedLength(int uncompressedSize)
            {
                return toIntExact(Zstd.compressBound(uncompressedSize));
            }

            @Override
            public int compress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength)
            {
                long size = Zstd.compressByteArray(output, outputOffset, maxOutputLength, input, inputOffset, inputLength, COMPRESSION_LEVEL);
                if (Zstd.isError(size)) {
                    throw new RuntimeException(Zstd.getErrorName(size));
                }
                return toIntExact(size);
            }

            @Override
            public void compress(ByteBuffer input, ByteBuffer output)
            {
                if (input.isDirect() || output.isDirect() || !input.hasArray() || !output.hasArray()) {
                    throw new IllegalArgumentException("Non-direct byte buffer backed by byte array required");
                }
                int inputOffset = input.arrayOffset() + input.position();
                int outputOffset = output.arrayOffset() + output.position();

                int written = compress(input.array(), inputOffset, input.remaining(), output.array(), outputOffset, output.remaining());
                output.position(output.position() + written);
            }
        };
    }

    private static PageDecompressor createPageDecompressor()
    {
        // based on ZstdJniDecompressor
        return new PageDecompressor()
        {
            @Override
            public int decompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength)
            {
                long size = Zstd.decompressByteArray(output, 0, maxOutputLength, input, inputOffset, inputLength);
                if (Zstd.isError(size)) {
                    String errorName = Zstd.getErrorName(size);
                    throw new RuntimeException("Zstd JNI decompressor failed with " + errorName);
                }
                return toIntExact(size);
            }

            @Override
            public void decompress(ByteBuffer input, ByteBuffer output)
            {
                if (input.isDirect() || output.isDirect() || !input.hasArray() || !output.hasArray()) {
                    throw new IllegalArgumentException("Non-direct byte buffer backed by byte array required");
                }
                int inputOffset = input.arrayOffset() + input.position();
                int outputOffset = output.arrayOffset() + output.position();

                int written = decompress(input.array(), inputOffset, input.remaining(), output.array(), outputOffset, output.remaining());
                output.position(output.position() + written);
            }
        };
    }

    public static long computeNextTimeout(long queryCompletionDeadline)
    {
        return queryCompletionDeadline - System.currentTimeMillis();
    }

    public static  T getActionResultWithTimeout(JavaFutureAction action, long timeout, TimeUnit timeUnit, Set waitTimeMetrics)
            throws SparkException, TimeoutException
    {
        long deadline = System.currentTimeMillis() + timeUnit.toMillis(timeout);
        try {
            while (true) {
                long totalWaitTime = waitTimeMetrics.stream().map(PrestoSparkServiceWaitTimeMetrics::getWaitTime).mapToLong(Duration::toMillis).sum();
                long nextTimeoutInMillis = (deadline + totalWaitTime) - System.currentTimeMillis();
                if (nextTimeoutInMillis <= 0) {
                    throw new TimeoutException();
                }
                try {
                    return action.get(nextTimeoutInMillis, MILLISECONDS);
                }
                catch (TimeoutException e) {
                }
            }
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
        catch (ExecutionException e) {
            propagateIfPossible(e.getCause(), SparkException.class);
            propagateIfPossible(e.getCause(), RuntimeException.class);

            // this should never happen
            throw new UncheckedExecutionException(e.getCause());
        }
        finally {
            if (!action.isDone()) {
                action.cancel(true);
            }
        }
    }

    public static  ClassTag classTag(Class clazz)
    {
        return scala.reflect.ClassTag$.MODULE$.apply(clazz);
    }

    public static  Iterator getNullifyingIterator(List list)
    {
        return new AbstractIterator()
        {
            private int index;

            @Override
            protected T computeNext()
            {
                if (index >= list.size()) {
                    return endOfData();
                }
                T element = list.get(index);
                list.set(index, null);
                index++;
                return element;
            }
        };
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy