io.airlift.compress.zstd.FiniteStateEntropy Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.airlift.compress.zstd;
import static io.airlift.compress.zstd.BitStream.peekBits;
import static io.airlift.compress.zstd.FseTableReader.FSE_MAX_SYMBOL_VALUE;
import static io.airlift.compress.zstd.UnsafeUtil.UNSAFE;
import static io.airlift.compress.zstd.ZstdFrameDecompressor.SIZE_OF_INT;
import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET;
class FiniteStateEntropy
{
private static final int MAX_TABLE_LOG = 12;
private final FiniteStateEntropy.Table table;
private final FseTableReader reader = new FseTableReader();
public FiniteStateEntropy(int maxLog)
{
table = new FiniteStateEntropy.Table(maxLog);
}
public int decompress(final Object inputBase, final long inputAddress, final long inputLimit, byte[] weights)
{
long input = inputAddress;
input += reader.readFseTable(table, inputBase, input, inputLimit, FSE_MAX_SYMBOL_VALUE, MAX_TABLE_LOG);
final Object outputBase = weights;
final long outputAddress = ARRAY_BYTE_BASE_OFFSET;
final long outputLimit = outputAddress + weights.length;
long output = outputAddress;
// initialize bit stream
BitStream.Initializer initializer = new BitStream.Initializer(inputBase, input, inputLimit);
initializer.initialize();
int bitsConsumed = initializer.getBitsConsumed();
long currentAddress = initializer.getCurrentAddress();
long bits = initializer.getBits();
// initialize first FSE stream
int state1 = (int) peekBits(bitsConsumed, bits, table.log2Size);
bitsConsumed += table.log2Size;
BitStream.Loader loader = new BitStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
loader.load();
bits = loader.getBits();
bitsConsumed = loader.getBitsConsumed();
currentAddress = loader.getCurrentAddress();
// initialize second FSE stream
int state2 = (int) peekBits(bitsConsumed, bits, table.log2Size);
bitsConsumed += table.log2Size;
loader = new BitStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
loader.load();
bits = loader.getBits();
bitsConsumed = loader.getBitsConsumed();
currentAddress = loader.getCurrentAddress();
byte[] symbols = table.symbol;
byte[] numbersOfBits = table.numberOfBits;
int[] newStates = table.newState;
// decode 4 symbols per loop
while (output < outputLimit) {
int numberOfBits;
UNSAFE.putByte(outputBase, output, symbols[state1]);
numberOfBits = numbersOfBits[state1];
state1 = (int) (newStates[state1] + peekBits(bitsConsumed, bits, numberOfBits));
bitsConsumed += numberOfBits;
UNSAFE.putByte(outputBase, output + 1, symbols[state2]);
numberOfBits = numbersOfBits[state2];
state2 = (int) (newStates[state2] + peekBits(bitsConsumed, bits, numberOfBits));
bitsConsumed += numberOfBits;
UNSAFE.putByte(outputBase, output + 2, symbols[state1]);
numberOfBits = numbersOfBits[state1];
state1 = (int) (newStates[state1] + peekBits(bitsConsumed, bits, numberOfBits));
bitsConsumed += numberOfBits;
UNSAFE.putByte(outputBase, output + 3, symbols[state2]);
numberOfBits = numbersOfBits[state2];
state2 = (int) (newStates[state2] + peekBits(bitsConsumed, bits, numberOfBits));
bitsConsumed += numberOfBits;
output += SIZE_OF_INT;
loader = new BitStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
boolean done = loader.load();
bitsConsumed = loader.getBitsConsumed();
bits = loader.getBits();
currentAddress = loader.getCurrentAddress();
if (done) {
break;
}
}
while (true) {
UNSAFE.putByte(outputBase, output++, symbols[state1]);
int numberOfBits = numbersOfBits[state1];
state1 = (int) (newStates[state1] + peekBits(bitsConsumed, bits, numberOfBits));
bitsConsumed += numberOfBits;
loader = new BitStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
loader.load();
bitsConsumed = loader.getBitsConsumed();
bits = loader.getBits();
currentAddress = loader.getCurrentAddress();
if (loader.isOverflow()) {
UNSAFE.putByte(outputBase, output++, symbols[state2]);
break;
}
UNSAFE.putByte(outputBase, output++, symbols[state2]);
int numberOfBits1 = numbersOfBits[state2];
state2 = (int) (newStates[state2] + peekBits(bitsConsumed, bits, numberOfBits1));
bitsConsumed += numberOfBits1;
loader = new BitStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
loader.load();
bitsConsumed = loader.getBitsConsumed();
bits = loader.getBits();
currentAddress = loader.getCurrentAddress();
if (loader.isOverflow()) {
UNSAFE.putByte(outputBase, output++, symbols[state1]);
break;
}
}
return (int) (output - outputAddress);
}
public static final class Table
{
int log2Size;
final int[] newState;
final byte[] symbol;
final byte[] numberOfBits;
public Table(int log2Size)
{
int size = 1 << log2Size;
newState = new int[size];
symbol = new byte[size];
numberOfBits = new byte[size];
}
public Table(int log2Size, int[] newState, byte[] symbol, byte[] numberOfBits)
{
int size = 1 << log2Size;
if (newState.length != size || symbol.length != size || numberOfBits.length != size) {
throw new IllegalArgumentException("Expected arrays to match provided size");
}
this.log2Size = log2Size;
this.newState = newState;
this.symbol = symbol;
this.numberOfBits = numberOfBits;
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy