All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.netty.handler.codec.compression.ZstdDecoder Maven / Gradle / Ivy

/*
 * Copyright 2024 The Netty Project
 *
 * The Netty Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package io.netty.handler.codec.compression;

import com.github.luben.zstd.ZstdInputStreamNoFinalizer;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;

import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;

/**
 * Decompresses a compressed block {@link ByteBuf} using the Zstandard algorithm.
 * See Zstandard.
 */
public final class ZstdDecoder extends ByteToMessageDecoder {
    // Don't use static here as we want to still allow to load the classes.
    {
        try {
            Zstd.ensureAvailability();
        } catch (Throwable throwable) {
            throw new ExceptionInInitializerError(throwable);
        }
    }

    private final MutableByteBufInputStream inputStream = new MutableByteBufInputStream();
    private ZstdInputStreamNoFinalizer zstdIs;

    private State currentState = State.DECOMPRESS_DATA;

    /**
     * Current state of stream.
     */
    private enum State {
        DECOMPRESS_DATA,
        CORRUPTED
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception {
        try {
            if (currentState == State.CORRUPTED) {
                in.skipBytes(in.readableBytes());
                return;
            }
            final int compressedLength = in.readableBytes();

            inputStream.current = in;

            ByteBuf outBuffer = null;
            try {
                int w;
                do {
                    if (outBuffer == null) {
                        // Let's start with the compressedLength * 2 as often we will not have everything
                        // we need in the in buffer and don't want to reserve too much memory.
                        outBuffer = ctx.alloc().heapBuffer(compressedLength * 2);
                    }
                    do {
                        w = outBuffer.writeBytes(zstdIs, outBuffer.writableBytes());
                    } while (w != -1 && outBuffer.isWritable());
                    if (outBuffer.isReadable()) {
                        out.add(outBuffer);
                        outBuffer = null;
                    }
                } while (w != -1);
            } finally {
                if (outBuffer != null) {
                    outBuffer.release();
                }
            }
        } catch (Exception e) {
            currentState = State.CORRUPTED;
            throw new DecompressionException(e);
        } finally {
            inputStream.current = null;
        }
    }

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
        super.handlerAdded(ctx);
        zstdIs = new ZstdInputStreamNoFinalizer(inputStream);
        zstdIs.setContinuous(true);
    }

    @Override
    protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
        try {
            closeSilently(zstdIs);
        } finally {
            super.handlerRemoved0(ctx);
        }
    }

    private static void closeSilently(Closeable closeable) {
        if (closeable != null) {
            try {
                closeable.close();
            } catch (IOException ignore) {
                // ignore
            }
        }
    }

    private static final class MutableByteBufInputStream extends InputStream {
        ByteBuf current;

        @Override
        public int read() {
            if (current == null || !current.isReadable()) {
                return -1;
            }
            return current.readByte() & 0xff;
        }

        @Override
        public int read(byte[] b, int off, int len) {
            int available = available();
            if (available == 0) {
                return -1;
            }

            len = Math.min(available, len);
            current.readBytes(b, off, len);
            return len;
        }

        @Override
        public int available() {
            return current == null ? 0 : current.readableBytes();
        }
    }
}