xyz.felh.openai.jtokkit.TokenEncoder Maven / Gradle / Ivy
The newest version!
package xyz.felh.openai.jtokkit;
import java.util.*;
import static java.lang.Integer.MAX_VALUE;
import static java.lang.Integer.parseInt;
import static java.util.Collections.emptyMap;
import static xyz.felh.openai.jtokkit.TokenEncoderLarge.calculateTokensLarge;
public final class TokenEncoder {
public static final String VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY = "VERY_LARGE_TOKENIZER_BYTE_THRESHOLD";
public static final int DUMMY_RANK = MAX_VALUE;
public static final int MAX_RANK = MAX_VALUE - 1;
private final Map[] encoders;
private final Map decoder;
private int VERY_LARGE_TOKENIZER_BYTE_THRESHOLD;
public TokenEncoder(Map encoder) {
if (!encoder.isEmpty()) {
VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = parseInt(System.getProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, "500"));
TreeMap> tempEncoders = new TreeMap<>();
encoder.forEach((k, v) -> {
ByteArrayWrapper key = new ByteArrayWrapper(k);
tempEncoders.computeIfAbsent(k.length, integer -> new HashMap<>()).put(key, v);
});
//noinspection unchecked
encoders = new Map[tempEncoders.lastKey() + 1];
tempEncoders.forEach((k, v) -> encoders[k] = v);
this.decoder = new HashMap<>(encoder.size());
encoder.forEach((k, v) -> decoder.put(v, k));
} else {
//noinspection unchecked
encoders = new Map[0]; // for testing
this.decoder = emptyMap();
}
}
private static int getMinRankIndex(List ranks) {
int minRankIndex = -1;
int minRank = MAX_RANK;
int i = 0;
int length = ranks.size() - 3;
for (; i < length - 2; i += 4) { // Unrolled loop
{
int r = ranks.get(i);
if (r < minRank) {
minRankIndex = i;
minRank = r;
}
}
{
int r = ranks.get(i + 1);
if (r < minRank) {
minRankIndex = i + 1;
minRank = r;
}
}
{
int r = ranks.get(i + 2);
if (r < minRank) {
minRankIndex = i + 2;
minRank = r;
}
}
{
int r = ranks.get(i + 3);
if (r < minRank) {
minRankIndex = i + 3;
minRank = r;
}
}
}
for (; i <= length; i++) {
int r = ranks.get(i);
if (r < minRank) {
minRankIndex = i;
minRank = r;
}
}
return minRankIndex;
}
private static int getNextIndex(List ranks, int nextIndex) {
while (nextIndex < ranks.size() && ranks.get(nextIndex) == DUMMY_RANK) {
nextIndex++;
}
return nextIndex;
}
private static int getPreviousIndex(List ranks, int previousIndex) {
while (previousIndex >= 0 && ranks.get(previousIndex) == DUMMY_RANK) {
previousIndex--;
}
return previousIndex;
}
int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] utf8Bytes, List out, ArrayList ranks) {
ByteArrayWrapper match = new ByteArrayWrapper(utf8Bytes);
int encoded = encode(match);
if (encoded != MAX_RANK) {
if (keepEncodings) {
out.add(encoded);
}
return 1;
} else {
int length = match.length();
if (length < VERY_LARGE_TOKENIZER_BYTE_THRESHOLD) {
return calculateTokensSmall(maxTokenCount, keepEncodings, out, ranks, match, length);
} else {
return calculateTokensLarge(this, maxTokenCount, keepEncodings, out, match, length);
}
}
}
private int calculateTokensSmall(int maxTokenCount, boolean keepEncodings, List out, ArrayList ranks, ByteArrayWrapper match, int length) {
assert length > 1 : "Already filtered out";
ranks.clear();
ranks.ensureCapacity(length + 1);
int validRanks = 0;
int minRankIndex = -1;
for (int i = 0, minRank = MAX_RANK; i < length + 1; i++) {
int encoded = encode(match, i, i + 2);
if (encoded != MAX_RANK) {
validRanks++;
if (encoded < minRank) {
minRankIndex = i;
minRank = encoded;
}
}
ranks.add(encoded);
}
int tokenCount = mergeBytesAndGetTokenCount(match, length, ranks, validRanks, minRankIndex);
if (keepEncodings) {
for (int start = 0, end = 1; end < ranks.size() && out.size() < maxTokenCount; end++) {
if (ranks.get(end) != DUMMY_RANK) {
int token = encode(match, start, end);
assert token != MAX_RANK : "Token should not be MAX_RANK";
out.add(token);
start = end;
}
}
}
return tokenCount;
}
int mergeBytesAndGetTokenCount(ByteArrayWrapper piece, int length, List ranks, int validRanks, int minRankIndex) {
assert getMinRankIndex(ranks) == minRankIndex;
while (validRanks > 0) {
assert minRankIndex >= 0;
int previousIndex = getPreviousIndex(ranks, minRankIndex - 1);
int nextIndex = getNextIndex(ranks, minRankIndex + 1);
int nextNextIndex = getNextIndex(ranks, nextIndex + 1);
int nextNextNextIndex = getNextIndex(ranks, nextNextIndex + 1);
if (previousIndex >= 0) {
assert ranks.get(previousIndex) != DUMMY_RANK;
int newRank = encode(piece, previousIndex, nextNextIndex);
int oldRank = ranks.set(previousIndex, newRank);
if ((newRank == MAX_RANK) != (oldRank == MAX_RANK)) {
validRanks -= (newRank == MAX_RANK) ? 1 : -1;
}
}
assert ranks.get(minRankIndex) != DUMMY_RANK;
int newRank = encode(piece, minRankIndex, nextNextNextIndex);
int oldRank = ranks.set(minRankIndex, newRank);
if ((newRank == MAX_RANK) != (oldRank == MAX_RANK)) {
validRanks--;
}
int oldDeletedRank = ranks.set(nextIndex, DUMMY_RANK);
if (oldDeletedRank != MAX_RANK) {
validRanks--;
}
length--;
minRankIndex = getMinRankIndex(ranks);
}
assert getMinRankIndex(ranks) < 0;
return length;
}
private int encode(ByteArrayWrapper payload) {
if (payload.length() < encoders.length) {
Map encoder = encoders[payload.length()];
if (encoder != null) {
Integer result = encoder.get(payload);
if (result != null) {
return result;
}
}
}
return MAX_RANK;
}
int encode(ByteArrayWrapper piece, int start, int end) {
if (end > piece.length()) {
return MAX_RANK;
} else if (end - start == piece.length()) {
return encode(piece);
} else {
return encode(piece.getBytesBetween(start, end));
}
}
public byte[] decodeToken(int token, SpecialEncoder specialEncoder) {
return decoder.computeIfAbsent(token, specialEncoder::decodeIfPresent);
}
}