All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.luben.zstd.ZstdDictTrainer Maven / Gradle / Ivy

package com.github.luben.zstd;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;

public class ZstdDictTrainer {
    private final int allocatedSize;
    private final ByteBuffer trainingSamples;
    private final List sampleSizes;
    private final int dictSize;
    private long filledSize;

    public ZstdDictTrainer(int sampleSize, int dictSize) {
        trainingSamples = ByteBuffer.allocateDirect(sampleSize);
        sampleSizes =  new ArrayList();
        this.allocatedSize = sampleSize;
        this.dictSize = dictSize;
    }

    public synchronized boolean addSample(byte[] sample) {
        if (filledSize + sample.length > allocatedSize) {
            return false;
        }
        trainingSamples.put(sample);
        sampleSizes.add(sample.length);
        filledSize += sample.length;
        return true;
    }

    public ByteBuffer trainSamplesDirect() throws ZstdException {
        return trainSamplesDirect(false);
    }

    public synchronized ByteBuffer trainSamplesDirect(boolean legacy) throws ZstdException {
        ByteBuffer dictBuffer = ByteBuffer.allocateDirect(dictSize);
        long l = Zstd.trainFromBufferDirect(trainingSamples, copyToIntArray(sampleSizes), dictBuffer, legacy);
        if (Zstd.isError(l)) {
            dictBuffer.limit(0);
            throw new ZstdException(l);
        }
        dictBuffer.limit(Long.valueOf(l).intValue());
        return dictBuffer;
    }

    public byte[] trainSamples() throws ZstdException {
        return trainSamples(false);
    }

    public byte[] trainSamples(boolean legacy) throws ZstdException {
        ByteBuffer byteBuffer = trainSamplesDirect(legacy);
        byte[] bytes = new byte[byteBuffer.remaining()];
        byteBuffer.get(bytes);
        return bytes;
    }

    private int[] copyToIntArray(List list) {
        int[] ints = new int[list.size()];
        int idx = 0;
        for (Integer i: list) {
            ints[idx] = i;
            idx++;
        }
        return ints;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy