All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.cassandra.io.util.RewindableDataInputStreamPlus Maven / Gradle / Ivy

There is a newer version: 3.11.12.3
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.cassandra.io.util;

import java.io.Closeable;
import java.io.File;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.RandomAccessFile;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.apache.cassandra.utils.Throwables.maybeFail;
import static org.apache.cassandra.utils.Throwables.merge;

/**
 * Adds mark/reset functionality to another input stream by caching read bytes to a memory buffer and
 * spilling to disk if necessary.
 *
 * When the stream is marked via {@link this#mark()} or {@link this#mark(int)}, up to
 * maxMemBufferSize will be cached in memory (heap). If more than
 * maxMemBufferSize bytes are read while the stream is marked, the
 * following bytes are cached on the spillFile for up to maxDiskBufferSize.
 *
 * Please note that successive calls to {@link this#mark()} and {@link this#reset()} will write
 * sequentially to the same spillFile until maxDiskBufferSize is reached.
 * At this point, if less than maxDiskBufferSize bytes are currently cached on the
 * spillFile, the remaining bytes are written to the beginning of the file,
 * treating the spillFile as a circular buffer.
 *
 * If more than maxMemBufferSize + maxDiskBufferSize are cached while the stream is marked,
 * the following {@link this#reset()} invocation will throw a {@link IllegalStateException}.
 *
 */
public class RewindableDataInputStreamPlus extends FilterInputStream implements RewindableDataInput, Closeable
{
    private boolean marked = false;
    private boolean exhausted = false;
    private AtomicBoolean closed = new AtomicBoolean(false);

    protected int memAvailable = 0;
    protected int diskTailAvailable = 0;
    protected int diskHeadAvailable = 0;

    private final File spillFile;
    private final int initialMemBufferSize;
    private final int maxMemBufferSize;
    private final int maxDiskBufferSize;

    private volatile byte memBuffer[];
    private int memBufferSize;
    private RandomAccessFile spillBuffer;

    private final DataInputPlus dataReader;

    public RewindableDataInputStreamPlus(InputStream in, int initialMemBufferSize, int maxMemBufferSize,
                                         File spillFile, int maxDiskBufferSize)
    {
        super(in);
        dataReader = new DataInputStreamPlus(this);
        this.initialMemBufferSize = initialMemBufferSize;
        this.maxMemBufferSize = maxMemBufferSize;
        this.spillFile = spillFile;
        this.maxDiskBufferSize = maxDiskBufferSize;
    }

    /* RewindableDataInput methods */

    /**
     * Marks the current position of a stream to return to this position later via the {@link this#reset(DataPosition)} method.
     * @return An empty @link{DataPosition} object
     */
    public DataPosition mark()
    {
        mark(0);
        return new RewindableDataInputPlusMark();
    }

    /**
     * Rewinds to the previously marked position via the {@link this#mark()} method.
     * @param mark it's not possible to return to a custom position, so this parameter is ignored.
     * @throws IOException if an error ocurs while resetting
     */
    public void reset(DataPosition mark) throws IOException
    {
        reset();
    }

    public long bytesPastMark(DataPosition mark)
    {
        return maxMemBufferSize - memAvailable + (diskTailAvailable == -1? 0 : maxDiskBufferSize - diskHeadAvailable - diskTailAvailable);
    }


    protected static class RewindableDataInputPlusMark implements DataPosition
    {
    }

    /* InputStream methods */

    public boolean markSupported()
    {
        return true;
    }

    /**
     * Marks the current position of a stream to return to this position
     * later via the {@link this#reset()} method.
     * @param readlimit the maximum amount of bytes to cache
     */
    public synchronized void mark(int readlimit)
    {
        if (marked)
            throw new IllegalStateException("Cannot mark already marked stream.");

        if (memAvailable > 0 || diskHeadAvailable > 0 || diskTailAvailable > 0)
            throw new IllegalStateException("Can only mark stream after reading previously marked data.");

        marked = true;
        memAvailable = maxMemBufferSize;
        diskHeadAvailable = -1;
        diskTailAvailable = -1;
    }

    public synchronized void reset() throws IOException
    {
        if (!marked)
            throw new IOException("Must call mark() before calling reset().");

        if (exhausted)
            throw new IOException(String.format("Read more than capacity: %d bytes.", maxMemBufferSize + maxDiskBufferSize));

        memAvailable = maxMemBufferSize - memAvailable;
        memBufferSize = memAvailable;

        if (diskTailAvailable == -1)
        {
            diskHeadAvailable = 0;
            diskTailAvailable = 0;
        }
        else
        {
            int initialPos = diskTailAvailable > 0 ? 0 : (int)getIfNotClosed(spillBuffer).getFilePointer();
            int diskMarkpos = initialPos + diskHeadAvailable;
            getIfNotClosed(spillBuffer).seek(diskMarkpos);

            diskHeadAvailable = diskMarkpos - diskHeadAvailable;
            diskTailAvailable = (maxDiskBufferSize - diskTailAvailable) - diskMarkpos;
        }

        marked = false;
    }

    public int available() throws IOException
    {

        return super.available() + (marked? 0 : memAvailable + diskHeadAvailable + diskTailAvailable);
    }

    public int read() throws IOException
    {
        int read = readOne();
        if (read == -1)
            return read;

        if (marked)
        {
            //mark exhausted
            if (isExhausted(1))
            {
                exhausted = true;
                return read;
            }

            writeOne(read);
        }

        return read;
    }

    public int read(byte[] b, int off, int len) throws IOException
    {
        int readBytes = readMulti(b, off, len);
        if (readBytes == -1)
            return readBytes;

        if (marked)
        {
            //check we have space on buffer
            if (isExhausted(readBytes))
            {
                exhausted = true;
                return readBytes;
            }

            writeMulti(b, off, readBytes);
        }

        return readBytes;
    }

    private void maybeCreateDiskBuffer() throws IOException
    {
        if (spillBuffer == null)
        {
            if (!spillFile.getParentFile().exists())
                spillFile.getParentFile().mkdirs();
            spillFile.createNewFile();

            this.spillBuffer = new RandomAccessFile(spillFile, "rw");
        }
    }


    private int readOne() throws IOException
    {
        if (!marked)
        {
            if (memAvailable > 0)
            {
                int pos = memBufferSize - memAvailable;
                memAvailable--;
                return getIfNotClosed(memBuffer)[pos] & 0xff;
            }

            if (diskTailAvailable > 0 || diskHeadAvailable > 0)
            {
                int read = getIfNotClosed(spillBuffer).read();
                if (diskTailAvailable > 0)
                    diskTailAvailable--;
                else if (diskHeadAvailable > 0)
                    diskHeadAvailable++;
                if (diskTailAvailable == 0)
                    spillBuffer.seek(0);
                return read;
            }
        }

        return getIfNotClosed(in).read();
    }

    private boolean isExhausted(int readBytes)
    {
        return exhausted || readBytes > memAvailable + (long)(diskTailAvailable == -1? maxDiskBufferSize : diskTailAvailable + diskHeadAvailable);
    }

    private int readMulti(byte[] b, int off, int len) throws IOException
    {
        int readBytes = 0;
        if (!marked)
        {
            if (memAvailable > 0)
            {
                readBytes += memAvailable < len ? memAvailable : len;
                int pos = memBufferSize - memAvailable;
                System.arraycopy(memBuffer, pos, b, off, readBytes);
                memAvailable -= readBytes;
                off += readBytes;
                len -= readBytes;
            }
            if (len > 0 && diskTailAvailable > 0)
            {
                int readFromTail = diskTailAvailable < len? diskTailAvailable : len;
                getIfNotClosed(spillBuffer).read(b, off, readFromTail);
                readBytes += readFromTail;
                diskTailAvailable -= readFromTail;
                off += readFromTail;
                len -= readFromTail;
                if (diskTailAvailable == 0)
                    spillBuffer.seek(0);
            }
            if (len > 0 && diskHeadAvailable > 0)
            {
                int readFromHead = diskHeadAvailable < len? diskHeadAvailable : len;
                getIfNotClosed(spillBuffer).read(b, off, readFromHead);
                readBytes += readFromHead;
                diskHeadAvailable -= readFromHead;
                off += readFromHead;
                len -= readFromHead;
            }
        }

        if (len > 0)
            readBytes += getIfNotClosed(in).read(b, off, len);

        return readBytes;
    }

    private void writeMulti(byte[] b, int off, int len) throws IOException
    {
        if (memAvailable > 0)
        {
            if (memBuffer == null)
                memBuffer = new byte[initialMemBufferSize];
            int pos = maxMemBufferSize - memAvailable;
            int memWritten = memAvailable < len? memAvailable : len;
            if (pos + memWritten >= getIfNotClosed(memBuffer).length)
                growMemBuffer(pos, memWritten);
            System.arraycopy(b, off, memBuffer, pos, memWritten);
            off += memWritten;
            len -= memWritten;
            memAvailable -= memWritten;
        }

        if (len > 0)
        {
            if (diskTailAvailable == -1)
            {
                maybeCreateDiskBuffer();
                diskHeadAvailable = (int)spillBuffer.getFilePointer();
                diskTailAvailable = maxDiskBufferSize - diskHeadAvailable;
            }

            if (len > 0 && diskTailAvailable > 0)
            {
                int diskTailWritten = diskTailAvailable < len? diskTailAvailable : len;
                getIfNotClosed(spillBuffer).write(b, off, diskTailWritten);
                off += diskTailWritten;
                len -= diskTailWritten;
                diskTailAvailable -= diskTailWritten;
                if (diskTailAvailable == 0)
                    spillBuffer.seek(0);
            }

            if (len > 0 && diskTailAvailable > 0)
            {
                int diskHeadWritten = diskHeadAvailable < len? diskHeadAvailable : len;
                getIfNotClosed(spillBuffer).write(b, off, diskHeadWritten);
            }
        }
    }

    private void writeOne(int value) throws IOException
    {
        if (memAvailable > 0)
        {
            if (memBuffer == null)
                memBuffer = new byte[initialMemBufferSize];
            int pos = maxMemBufferSize - memAvailable;
            if (pos == getIfNotClosed(memBuffer).length)
                growMemBuffer(pos, 1);
            getIfNotClosed(memBuffer)[pos] = (byte)value;
            memAvailable--;
            return;
        }

        if (diskTailAvailable == -1)
        {
            maybeCreateDiskBuffer();
            diskHeadAvailable = (int)spillBuffer.getFilePointer();
            diskTailAvailable = maxDiskBufferSize - diskHeadAvailable;
        }

        if (diskTailAvailable > 0 || diskHeadAvailable > 0)
        {
            getIfNotClosed(spillBuffer).write(value);
            if (diskTailAvailable > 0)
                diskTailAvailable--;
            else if (diskHeadAvailable > 0)
                diskHeadAvailable--;
            if (diskTailAvailable == 0)
                spillBuffer.seek(0);
            return;
        }
    }

    public int read(byte[] b) throws IOException
    {
        return read(b, 0, b.length);
    }

    private void growMemBuffer(int pos, int writeSize)
    {
        int newSize = Math.min(2 * (pos + writeSize), maxMemBufferSize);
        byte newBuffer[] = new byte[newSize];
        System.arraycopy(memBuffer, 0, newBuffer, 0, (int)pos);
        memBuffer = newBuffer;
    }

    public long skip(long n) throws IOException
    {
        long skipped = 0;

        if (marked)
        {
            //if marked, we need to cache skipped bytes
            while (n-- > 0 && read() != -1)
            {
                skipped++;
            }
            return skipped;
        }

        if (memAvailable > 0)
        {
            skipped += memAvailable < n ? memAvailable : n;
            memAvailable -= skipped;
            n -= skipped;
        }
        if (n > 0 && diskTailAvailable > 0)
        {
            int skipFromTail = diskTailAvailable < n? diskTailAvailable : (int)n;
            getIfNotClosed(spillBuffer).skipBytes(skipFromTail);
            diskTailAvailable -= skipFromTail;
            skipped += skipFromTail;
            n -= skipFromTail;
            if (diskTailAvailable == 0)
                spillBuffer.seek(0);
        }
        if (n > 0 && diskHeadAvailable > 0)
        {
            int skipFromHead = diskHeadAvailable < n? diskHeadAvailable : (int)n;
            getIfNotClosed(spillBuffer).skipBytes(skipFromHead);
            diskHeadAvailable -= skipFromHead;
            skipped += skipFromHead;
            n -= skipFromHead;
        }

        if (n > 0)
            skipped += getIfNotClosed(in).skip(n);

        return skipped;
    }

    private  T getIfNotClosed(T in) throws IOException {
        if (closed.get())
            throw new IOException("Stream closed");
        return in;
    }

    public void close() throws IOException
    {
        close(true);
    }

    public void close(boolean closeUnderlying) throws IOException
    {
        if (closed.compareAndSet(false, true))
        {
            Throwable fail = null;
            if (closeUnderlying)
            {
                try
                {
                    super.close();
                }
                catch (IOException e)
                {
                    fail = merge(fail, e);
                }
            }
            try
            {
                if (spillBuffer != null)
                {
                    this.spillBuffer.close();
                    this.spillBuffer = null;
                }
            } catch (IOException e)
            {
                fail = merge(fail, e);
            }
            try {
                if (spillFile.exists())
                {
                    spillFile.delete();
                }
            }
            catch (Throwable e)
            {
                fail = merge(fail, e);
            }
            maybeFail(fail, IOException.class);
        }
    }

    /* DataInputPlus methods */

    public void readFully(byte[] b) throws IOException
    {
        dataReader.readFully(b);
    }

    public void readFully(byte[] b, int off, int len) throws IOException
    {
        dataReader.readFully(b, off, len);
    }

    public int skipBytes(int n) throws IOException
    {
        return dataReader.skipBytes(n);
    }

    public boolean readBoolean() throws IOException
    {
        return dataReader.readBoolean();
    }

    public byte readByte() throws IOException
    {
        return dataReader.readByte();
    }

    public int readUnsignedByte() throws IOException
    {
        return dataReader.readUnsignedByte();
    }

    public short readShort() throws IOException
    {
        return dataReader.readShort();
    }

    public int readUnsignedShort() throws IOException
    {
        return dataReader.readUnsignedShort();
    }

    public char readChar() throws IOException
    {
        return dataReader.readChar();
    }

    public int readInt() throws IOException
    {
        return dataReader.readInt();
    }

    public long readLong() throws IOException
    {
        return dataReader.readLong();
    }

    public float readFloat() throws IOException
    {
        return dataReader.readFloat();
    }

    public double readDouble() throws IOException
    {
        return dataReader.readDouble();
    }

    public String readLine() throws IOException
    {
        return dataReader.readLine();
    }

    public String readUTF() throws IOException
    {
        return dataReader.readUTF();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy