ai.vespa.airlift.zstd.Huffman Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of airlift-zstd Show documentation
Show all versions of airlift-zstd Show documentation
Fork of https://github.com/airlift/aircompressor (zstd only).
This module is temporary until we get an official release that includes the
ZstdInputStream API (which is already implemented by two different people
but neither PR shows any progress).
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.vespa.airlift.zstd;
import java.util.Arrays;
import static ai.vespa.airlift.zstd.BitInputStream.isEndOfStream;
import static ai.vespa.airlift.zstd.BitInputStream.peekBitsFast;
import static ai.vespa.airlift.zstd.Constants.SIZE_OF_INT;
import static ai.vespa.airlift.zstd.Constants.SIZE_OF_SHORT;
import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
import static ai.vespa.airlift.zstd.Util.isPowerOf2;
import static ai.vespa.airlift.zstd.Util.verify;
class Huffman
{
public static final int MAX_SYMBOL = 255;
public static final int MAX_SYMBOL_COUNT = MAX_SYMBOL + 1;
public static final int MAX_TABLE_LOG = 12;
public static final int MIN_TABLE_LOG = 5;
public static final int MAX_FSE_TABLE_LOG = 6;
// stats
private final byte[] weights = new byte[MAX_SYMBOL + 1];
private final int[] ranks = new int[MAX_TABLE_LOG + 1];
// table
private int tableLog = -1;
private final byte[] symbols = new byte[1 << MAX_TABLE_LOG];
private final byte[] numbersOfBits = new byte[1 << MAX_TABLE_LOG];
private final FseTableReader reader = new FseTableReader();
private final FiniteStateEntropy.Table fseTable = new FiniteStateEntropy.Table(MAX_FSE_TABLE_LOG);
public boolean isLoaded()
{
return tableLog != -1;
}
public int readTable(final Object inputBase, final long inputAddress, final int size)
{
Arrays.fill(ranks, 0);
long input = inputAddress;
// read table header
verify(size > 0, input, "Not enough input bytes");
int inputSize = UNSAFE.getByte(inputBase, input++) & 0xFF;
int outputSize;
if (inputSize >= 128) {
outputSize = inputSize - 127;
inputSize = ((outputSize + 1) / 2);
verify(inputSize + 1 <= size, input, "Not enough input bytes");
verify(outputSize <= MAX_SYMBOL + 1, input, "Input is corrupted");
for (int i = 0; i < outputSize; i += 2) {
int value = UNSAFE.getByte(inputBase, input + i / 2) & 0xFF;
weights[i] = (byte) (value >>> 4);
weights[i + 1] = (byte) (value & 0b1111);
}
}
else {
verify(inputSize + 1 <= size, input, "Not enough input bytes");
long inputLimit = input + inputSize;
input += reader.readFseTable(fseTable, inputBase, input, inputLimit, FiniteStateEntropy.MAX_SYMBOL, MAX_FSE_TABLE_LOG);
outputSize = FiniteStateEntropy.decompress(fseTable, inputBase, input, inputLimit, weights);
}
int totalWeight = 0;
for (int i = 0; i < outputSize; i++) {
ranks[weights[i]]++;
totalWeight += (1 << weights[i]) >> 1; // TODO same as 1 << (weights[n] - 1)?
}
verify(totalWeight != 0, input, "Input is corrupted");
tableLog = Util.highestBit(totalWeight) + 1;
verify(tableLog <= MAX_TABLE_LOG, input, "Input is corrupted");
int total = 1 << tableLog;
int rest = total - totalWeight;
verify(isPowerOf2(rest), input, "Input is corrupted");
int lastWeight = Util.highestBit(rest) + 1;
weights[outputSize] = (byte) lastWeight;
ranks[lastWeight]++;
int numberOfSymbols = outputSize + 1;
// populate table
int nextRankStart = 0;
for (int i = 1; i < tableLog + 1; ++i) {
int current = nextRankStart;
nextRankStart += ranks[i] << (i - 1);
ranks[i] = current;
}
for (int n = 0; n < numberOfSymbols; n++) {
int weight = weights[n];
int length = (1 << weight) >> 1; // TODO: 1 << (weight - 1) ??
byte symbol = (byte) n;
byte numberOfBits = (byte) (tableLog + 1 - weight);
for (int i = ranks[weight]; i < ranks[weight] + length; i++) {
symbols[i] = symbol;
numbersOfBits[i] = numberOfBits;
}
ranks[weight] += length;
}
verify(ranks[1] >= 2 && (ranks[1] & 1) == 0, input, "Input is corrupted");
return inputSize + 1;
}
public void decodeSingleStream(final Object inputBase, final long inputAddress, final long inputLimit, final Object outputBase, final long outputAddress, final long outputLimit)
{
BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, inputAddress, inputLimit);
initializer.initialize();
long bits = initializer.getBits();
int bitsConsumed = initializer.getBitsConsumed();
long currentAddress = initializer.getCurrentAddress();
int tableLog = this.tableLog;
byte[] numbersOfBits = this.numbersOfBits;
byte[] symbols = this.symbols;
// 4 symbols at a time
long output = outputAddress;
long fastOutputLimit = outputLimit - 4;
while (output < fastOutputLimit) {
BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits, bitsConsumed);
boolean done = loader.load();
bits = loader.getBits();
bitsConsumed = loader.getBitsConsumed();
currentAddress = loader.getCurrentAddress();
if (done) {
break;
}
bitsConsumed = decodeSymbol(outputBase, output, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
bitsConsumed = decodeSymbol(outputBase, output + 1, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
bitsConsumed = decodeSymbol(outputBase, output + 2, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
bitsConsumed = decodeSymbol(outputBase, output + 3, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
output += SIZE_OF_INT;
}
decodeTail(inputBase, inputAddress, currentAddress, bitsConsumed, bits, outputBase, output, outputLimit);
}
public void decode4Streams(final Object inputBase, final long inputAddress, final long inputLimit, final Object outputBase, final long outputAddress, final long outputLimit)
{
verify(inputLimit - inputAddress >= 10, inputAddress, "Input is corrupted"); // jump table + 1 byte per stream
long start1 = inputAddress + 3 * SIZE_OF_SHORT; // for the shorts we read below
long start2 = start1 + (UNSAFE.getShort(inputBase, inputAddress) & 0xFFFF);
long start3 = start2 + (UNSAFE.getShort(inputBase, inputAddress + 2) & 0xFFFF);
long start4 = start3 + (UNSAFE.getShort(inputBase, inputAddress + 4) & 0xFFFF);
BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, start1, start2);
initializer.initialize();
int stream1bitsConsumed = initializer.getBitsConsumed();
long stream1currentAddress = initializer.getCurrentAddress();
long stream1bits = initializer.getBits();
initializer = new BitInputStream.Initializer(inputBase, start2, start3);
initializer.initialize();
int stream2bitsConsumed = initializer.getBitsConsumed();
long stream2currentAddress = initializer.getCurrentAddress();
long stream2bits = initializer.getBits();
initializer = new BitInputStream.Initializer(inputBase, start3, start4);
initializer.initialize();
int stream3bitsConsumed = initializer.getBitsConsumed();
long stream3currentAddress = initializer.getCurrentAddress();
long stream3bits = initializer.getBits();
initializer = new BitInputStream.Initializer(inputBase, start4, inputLimit);
initializer.initialize();
int stream4bitsConsumed = initializer.getBitsConsumed();
long stream4currentAddress = initializer.getCurrentAddress();
long stream4bits = initializer.getBits();
int segmentSize = (int) ((outputLimit - outputAddress + 3) / 4);
long outputStart2 = outputAddress + segmentSize;
long outputStart3 = outputStart2 + segmentSize;
long outputStart4 = outputStart3 + segmentSize;
long output1 = outputAddress;
long output2 = outputStart2;
long output3 = outputStart3;
long output4 = outputStart4;
long fastOutputLimit = outputLimit - 7;
int tableLog = this.tableLog;
byte[] numbersOfBits = this.numbersOfBits;
byte[] symbols = this.symbols;
while (output4 < fastOutputLimit) {
stream1bitsConsumed = decodeSymbol(outputBase, output1, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
stream2bitsConsumed = decodeSymbol(outputBase, output2, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
stream3bitsConsumed = decodeSymbol(outputBase, output3, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
stream4bitsConsumed = decodeSymbol(outputBase, output4, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);
stream1bitsConsumed = decodeSymbol(outputBase, output1 + 1, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
stream2bitsConsumed = decodeSymbol(outputBase, output2 + 1, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
stream3bitsConsumed = decodeSymbol(outputBase, output3 + 1, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
stream4bitsConsumed = decodeSymbol(outputBase, output4 + 1, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);
stream1bitsConsumed = decodeSymbol(outputBase, output1 + 2, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
stream2bitsConsumed = decodeSymbol(outputBase, output2 + 2, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
stream3bitsConsumed = decodeSymbol(outputBase, output3 + 2, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
stream4bitsConsumed = decodeSymbol(outputBase, output4 + 2, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);
stream1bitsConsumed = decodeSymbol(outputBase, output1 + 3, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
stream2bitsConsumed = decodeSymbol(outputBase, output2 + 3, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
stream3bitsConsumed = decodeSymbol(outputBase, output3 + 3, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
stream4bitsConsumed = decodeSymbol(outputBase, output4 + 3, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);
output1 += SIZE_OF_INT;
output2 += SIZE_OF_INT;
output3 += SIZE_OF_INT;
output4 += SIZE_OF_INT;
BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, start1, stream1currentAddress, stream1bits, stream1bitsConsumed);
boolean done = loader.load();
stream1bitsConsumed = loader.getBitsConsumed();
stream1bits = loader.getBits();
stream1currentAddress = loader.getCurrentAddress();
if (done) {
break;
}
loader = new BitInputStream.Loader(inputBase, start2, stream2currentAddress, stream2bits, stream2bitsConsumed);
done = loader.load();
stream2bitsConsumed = loader.getBitsConsumed();
stream2bits = loader.getBits();
stream2currentAddress = loader.getCurrentAddress();
if (done) {
break;
}
loader = new BitInputStream.Loader(inputBase, start3, stream3currentAddress, stream3bits, stream3bitsConsumed);
done = loader.load();
stream3bitsConsumed = loader.getBitsConsumed();
stream3bits = loader.getBits();
stream3currentAddress = loader.getCurrentAddress();
if (done) {
break;
}
loader = new BitInputStream.Loader(inputBase, start4, stream4currentAddress, stream4bits, stream4bitsConsumed);
done = loader.load();
stream4bitsConsumed = loader.getBitsConsumed();
stream4bits = loader.getBits();
stream4currentAddress = loader.getCurrentAddress();
if (done) {
break;
}
}
verify(output1 <= outputStart2 && output2 <= outputStart3 && output3 <= outputStart4, inputAddress, "Input is corrupted");
/// finish streams one by one
decodeTail(inputBase, start1, stream1currentAddress, stream1bitsConsumed, stream1bits, outputBase, output1, outputStart2);
decodeTail(inputBase, start2, stream2currentAddress, stream2bitsConsumed, stream2bits, outputBase, output2, outputStart3);
decodeTail(inputBase, start3, stream3currentAddress, stream3bitsConsumed, stream3bits, outputBase, output3, outputStart4);
decodeTail(inputBase, start4, stream4currentAddress, stream4bitsConsumed, stream4bits, outputBase, output4, outputLimit);
}
private void decodeTail(final Object inputBase, final long startAddress, long currentAddress, int bitsConsumed, long bits, final Object outputBase, long outputAddress, final long outputLimit)
{
int tableLog = this.tableLog;
byte[] numbersOfBits = this.numbersOfBits;
byte[] symbols = this.symbols;
// closer to the end
while (outputAddress < outputLimit) {
BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, startAddress, currentAddress, bits, bitsConsumed);
boolean done = loader.load();
bitsConsumed = loader.getBitsConsumed();
bits = loader.getBits();
currentAddress = loader.getCurrentAddress();
if (done) {
break;
}
bitsConsumed = decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
}
// not more data in bit stream, so no need to reload
while (outputAddress < outputLimit) {
bitsConsumed = decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
}
verify(isEndOfStream(startAddress, currentAddress, bitsConsumed), startAddress, "Bit stream is not fully consumed");
}
private static int decodeSymbol(Object outputBase, long outputAddress, long bitContainer, int bitsConsumed, int tableLog, byte[] numbersOfBits, byte[] symbols)
{
int value = (int) peekBitsFast(bitsConsumed, bitContainer, tableLog);
UNSAFE.putByte(outputBase, outputAddress, symbols[value]);
return bitsConsumed + numbersOfBits[value];
}
}