io.quarkiverse.langchain4j.llama3.copy.Tokenizer Maven / Gradle / Ivy
package io.quarkiverse.langchain4j.llama3.copy;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HexFormat;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* Byte Pair Encoding tokenizer.
*
* Based on minbpe, algorithmically follows along the
* GPT 2 tokenizer
*/
public class Tokenizer {
private final Pattern compiledPattern;
private final Vocabulary vocabulary;
private final Map, Integer> merges;
private final Map specialTokens;
public String regexPattern() {
if (compiledPattern == null) {
return null;
}
return compiledPattern.pattern();
}
public Map getSpecialTokens() {
return specialTokens;
}
public boolean isSpecialToken(int tokenIndex) {
return specialTokens.containsValue(tokenIndex);
}
public Tokenizer(Vocabulary vocabulary, List> merges, String regexPattern,
Map specialTokens) {
this.vocabulary = vocabulary;
this.compiledPattern = regexPattern != null ? Pattern.compile(regexPattern) : null;
this.specialTokens = new HashMap<>(specialTokens);
this.merges = new HashMap<>();
for (Pair pair : merges) {
int firstIndex = pair.first();
int secondIndex = pair.second();
int mergeIndex = vocabulary.getIndex(vocabulary.get(firstIndex) + vocabulary.get(secondIndex)).orElseThrow();
this.merges.put(pair, mergeIndex);
}
}
private int[] encodeImpl(String text) {
return encode(text, Set.of()).stream().mapToInt(i -> i).toArray();
}
/**
* Unlike {@link #encodeOrdinary(String)}, this function handles special tokens.
* allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
* if none_raise, then an error is raised if any special token is encountered in text
* this is the default tiktoken behavior right now as well
* any other behavior is either annoying, or a major footgun.
*/
List encode(String text, Set allowedSpecial) {
// decode the user desire w.r.t. handling of special tokens
Set special = allowedSpecial;
assert getSpecialTokens().keySet().containsAll(special);
if (special.isEmpty()) {
// shortcut: if no special tokens, just use the ordinary encoding
return encodeOrdinary(text);
}
// otherwise, we have to be careful with potential special tokens in text
// we handle special tokens by splitting the text
// based on the occurrence of any exact match with any of the special tokens
// we can use re.split for this. note that surrounding the pattern with ()
// makes it into a capturing group, so the special tokens will be included
String specialPattern = special
.stream()
.map(Pattern::quote)
.collect(Collectors.joining("|", "(", ")"));
String[] specialChunks = text.split(specialPattern);
// now all the special characters are separated from the rest of the text
// all chunks of text are encoded separately, then results are joined
List ids = new ArrayList<>();
for (String part : specialChunks) {
if (special.contains(part)) {
// this is a special token, encode it separately as a special case
ids.add(getSpecialTokens().get(part));
} else {
// this is an ordinary sequence, encode it normally
ids.addAll(encodeOrdinary(part));
}
}
return ids;
}
private static List findAll(Pattern pattern, String text) {
List allMatches = new ArrayList<>();
Matcher matcher = pattern.matcher(text);
while (matcher.find()) {
allMatches.add(matcher.group());
}
return allMatches;
}
/**
* Encoding that ignores any special tokens.
*/
public List encodeOrdinary(String text) {
// split text into chunks of text by categories defined in regex pattern
List textChunks = findAll(compiledPattern, text);
// all chunks of text are encoded separately, then results are joined
List ids = new ArrayList<>();
for (String chunk : textChunks) {
List chunkIds = encodeChunk(chunk);
ids.addAll(chunkIds);
}
return ids;
}
private Map, Integer> getStats(List ids) {
Map, Integer> map = new HashMap<>();
for (int i = 0; i + 1 < ids.size(); i++) {
Pair key = new Pair<>(ids.get(i), ids.get(i + 1));
map.put(key, map.getOrDefault(key, 0) + 1);
}
return map;
}
private List encodeChunk(String chunk) {
// return the token ids
// let's begin. first, convert all bytes to integers in range 0..255
List ids = new ArrayList<>();
for (int b : chunk.toCharArray()) {
int tokenIndex = this.vocabulary.getIndex(String.valueOf((char) b)).orElseThrow();
ids.add(tokenIndex);
}
while (ids.size() >= 2) {
// find the pair with the lowest merge index
Map, Integer> stats = getStats(ids);
Pair pair = stats.keySet().stream()
.min(Comparator.comparingInt(key -> this.merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow();
// subtle: if there are no more merges available, the key will
// result in an inf for every single pair, and the min will be
// just the first pair in the list, arbitrarily
// we can detect this terminating case by a membership check
if (!this.merges.containsKey(pair)) {
break; // nothing else can be merged anymore
}
// otherwise let's merge the best pair (lowest merge index)
int idx = this.merges.get(pair);
ids = merge(ids, pair, idx);
}
return ids;
}
private static List merge(List ids, Pair pair, int idx) {
List newids = new ArrayList<>();
int i = 0;
while (i < ids.size()) {
// if not at the very last position AND the pair matches, replace it
if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) {
newids.add(idx);
i += 2;
} else {
newids.add(ids.get(i));
i += 1;
}
}
return newids;
}
public String decodeImpl(List tokens) {
StringBuilder sb = new StringBuilder();
for (int token : tokens) {
String tokenString = vocabulary.get(token);
sb.append(tokenString);
}
return sb.toString();
}
/**
* Returns list of utf-8 byte and a corresponding list of unicode strings.
* The reversible bpe codes work on unicode strings.
* This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
* When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
* This is a significant percentage of your normal, say, 32K bpe vocab.
* To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
* And avoids mapping to whitespace/control characters the bpe code barfs on.
*/
private static Map bytesToUnicode() {
List bs = new ArrayList<>();
IntStream.rangeClosed('!', '~').forEach(bs::add);
IntStream.rangeClosed('¡', '¬').forEach(bs::add);
IntStream.rangeClosed('®', 'ÿ').forEach(bs::add);
List cs = new ArrayList<>(bs);
int n = 0;
for (int b = 0; b < 256; ++b) {
if (!bs.contains(b)) {
bs.add(b);
cs.add(256 + n);
n += 1;
}
}
// return dict(zip(bs, cs))
return IntStream.range(0, bs.size())
.boxed()
.collect(Collectors.toMap(bs::get, cs::get));
}
static final Map BYTE_ENCODER = bytesToUnicode();
static final Map BYTE_DECODER = BYTE_ENCODER.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
public int[] encode(String text) {
StringBuilder sb = new StringBuilder();
byte[] bytes = text.getBytes(StandardCharsets.UTF_8);
for (byte b : bytes) {
sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b)));
}
return encodeImpl(sb.toString());
}
public static String replaceControlCharacters(int[] codePoints) {
// we don't want to print control characters
// which distort the output (e.g. \n or much worse)
// https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
// http://www.unicode.org/reports/tr44/#GC_Values_Table\
StringBuilder chars = new StringBuilder();
for (int cp : codePoints) {
if (Character.getType(cp) == Character.CONTROL && cp != '\n') {
chars.append("\\u").append(HexFormat.of().toHexDigits(cp, 4)); // escape
} else {
chars.appendCodePoint(cp); // this character is ok
}
}
return chars.toString();
}
public static String replaceControlCharacters(String str) {
return replaceControlCharacters(str.codePoints().toArray());
}
public List encodeAsList(String text) {
return Arrays.stream(encode(text)).boxed().toList();
}
public String decode(List tokens) {
String decoded = decodeImpl(tokens);
int[] decodedBytesAsInts = decoded.codePoints().map(BYTE_DECODER::get).toArray();
byte[] rawBytes = new byte[decodedBytesAsInts.length];
for (int i = 0; i < decoded.length(); i++) {
rawBytes[i] = (byte) decodedBytesAsInts[i];
}
return new String(rawBytes, StandardCharsets.UTF_8);
}
}