All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.quarkiverse.langchain4j.llama3.copy.Tokenizer Maven / Gradle / Ivy

package io.quarkiverse.langchain4j.llama3.copy;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HexFormat;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * Byte Pair Encoding tokenizer.
 * 

* Based on minbpe, algorithmically follows along the * GPT 2 tokenizer */ public class Tokenizer { private final Pattern compiledPattern; private final Vocabulary vocabulary; private final Map, Integer> merges; private final Map specialTokens; public String regexPattern() { if (compiledPattern == null) { return null; } return compiledPattern.pattern(); } public Map getSpecialTokens() { return specialTokens; } public boolean isSpecialToken(int tokenIndex) { return specialTokens.containsValue(tokenIndex); } public Tokenizer(Vocabulary vocabulary, List> merges, String regexPattern, Map specialTokens) { this.vocabulary = vocabulary; this.compiledPattern = regexPattern != null ? Pattern.compile(regexPattern) : null; this.specialTokens = new HashMap<>(specialTokens); this.merges = new HashMap<>(); for (Pair pair : merges) { int firstIndex = pair.first(); int secondIndex = pair.second(); int mergeIndex = vocabulary.getIndex(vocabulary.get(firstIndex) + vocabulary.get(secondIndex)).orElseThrow(); this.merges.put(pair, mergeIndex); } } private int[] encodeImpl(String text) { return encode(text, Set.of()).stream().mapToInt(i -> i).toArray(); } /** * Unlike {@link #encodeOrdinary(String)}, this function handles special tokens. * allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens * if none_raise, then an error is raised if any special token is encountered in text * this is the default tiktoken behavior right now as well * any other behavior is either annoying, or a major footgun. */ List encode(String text, Set allowedSpecial) { // decode the user desire w.r.t. handling of special tokens Set special = allowedSpecial; assert getSpecialTokens().keySet().containsAll(special); if (special.isEmpty()) { // shortcut: if no special tokens, just use the ordinary encoding return encodeOrdinary(text); } // otherwise, we have to be careful with potential special tokens in text // we handle special tokens by splitting the text // based on the occurrence of any exact match with any of the special tokens // we can use re.split for this. note that surrounding the pattern with () // makes it into a capturing group, so the special tokens will be included String specialPattern = special .stream() .map(Pattern::quote) .collect(Collectors.joining("|", "(", ")")); String[] specialChunks = text.split(specialPattern); // now all the special characters are separated from the rest of the text // all chunks of text are encoded separately, then results are joined List ids = new ArrayList<>(); for (String part : specialChunks) { if (special.contains(part)) { // this is a special token, encode it separately as a special case ids.add(getSpecialTokens().get(part)); } else { // this is an ordinary sequence, encode it normally ids.addAll(encodeOrdinary(part)); } } return ids; } private static List findAll(Pattern pattern, String text) { List allMatches = new ArrayList<>(); Matcher matcher = pattern.matcher(text); while (matcher.find()) { allMatches.add(matcher.group()); } return allMatches; } /** * Encoding that ignores any special tokens. */ public List encodeOrdinary(String text) { // split text into chunks of text by categories defined in regex pattern List textChunks = findAll(compiledPattern, text); // all chunks of text are encoded separately, then results are joined List ids = new ArrayList<>(); for (String chunk : textChunks) { List chunkIds = encodeChunk(chunk); ids.addAll(chunkIds); } return ids; } private Map, Integer> getStats(List ids) { Map, Integer> map = new HashMap<>(); for (int i = 0; i + 1 < ids.size(); i++) { Pair key = new Pair<>(ids.get(i), ids.get(i + 1)); map.put(key, map.getOrDefault(key, 0) + 1); } return map; } private List encodeChunk(String chunk) { // return the token ids // let's begin. first, convert all bytes to integers in range 0..255 List ids = new ArrayList<>(); for (int b : chunk.toCharArray()) { int tokenIndex = this.vocabulary.getIndex(String.valueOf((char) b)).orElseThrow(); ids.add(tokenIndex); } while (ids.size() >= 2) { // find the pair with the lowest merge index Map, Integer> stats = getStats(ids); Pair pair = stats.keySet().stream() .min(Comparator.comparingInt(key -> this.merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow(); // subtle: if there are no more merges available, the key will // result in an inf for every single pair, and the min will be // just the first pair in the list, arbitrarily // we can detect this terminating case by a membership check if (!this.merges.containsKey(pair)) { break; // nothing else can be merged anymore } // otherwise let's merge the best pair (lowest merge index) int idx = this.merges.get(pair); ids = merge(ids, pair, idx); } return ids; } private static List merge(List ids, Pair pair, int idx) { List newids = new ArrayList<>(); int i = 0; while (i < ids.size()) { // if not at the very last position AND the pair matches, replace it if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) { newids.add(idx); i += 2; } else { newids.add(ids.get(i)); i += 1; } } return newids; } public String decodeImpl(List tokens) { StringBuilder sb = new StringBuilder(); for (int token : tokens) { String tokenString = vocabulary.get(token); sb.append(tokenString); } return sb.toString(); } /** * Returns list of utf-8 byte and a corresponding list of unicode strings. * The reversible bpe codes work on unicode strings. * This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. * When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. * This is a significant percentage of your normal, say, 32K bpe vocab. * To avoid that, we want lookup tables between utf-8 bytes and unicode strings. * And avoids mapping to whitespace/control characters the bpe code barfs on. */ private static Map bytesToUnicode() { List bs = new ArrayList<>(); IntStream.rangeClosed('!', '~').forEach(bs::add); IntStream.rangeClosed('¡', '¬').forEach(bs::add); IntStream.rangeClosed('®', 'ÿ').forEach(bs::add); List cs = new ArrayList<>(bs); int n = 0; for (int b = 0; b < 256; ++b) { if (!bs.contains(b)) { bs.add(b); cs.add(256 + n); n += 1; } } // return dict(zip(bs, cs)) return IntStream.range(0, bs.size()) .boxed() .collect(Collectors.toMap(bs::get, cs::get)); } static final Map BYTE_ENCODER = bytesToUnicode(); static final Map BYTE_DECODER = BYTE_ENCODER.entrySet() .stream() .collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); public int[] encode(String text) { StringBuilder sb = new StringBuilder(); byte[] bytes = text.getBytes(StandardCharsets.UTF_8); for (byte b : bytes) { sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); } return encodeImpl(sb.toString()); } public static String replaceControlCharacters(int[] codePoints) { // we don't want to print control characters // which distort the output (e.g. \n or much worse) // https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117 // http://www.unicode.org/reports/tr44/#GC_Values_Table\ StringBuilder chars = new StringBuilder(); for (int cp : codePoints) { if (Character.getType(cp) == Character.CONTROL && cp != '\n') { chars.append("\\u").append(HexFormat.of().toHexDigits(cp, 4)); // escape } else { chars.appendCodePoint(cp); // this character is ok } } return chars.toString(); } public static String replaceControlCharacters(String str) { return replaceControlCharacters(str.codePoints().toArray()); } public List encodeAsList(String text) { return Arrays.stream(encode(text)).boxed().toList(); } public String decode(List tokens) { String decoded = decodeImpl(tokens); int[] decodedBytesAsInts = decoded.codePoints().map(BYTE_DECODER::get).toArray(); byte[] rawBytes = new byte[decodedBytesAsInts.length]; for (int i = 0; i < decoded.length(); i++) { rawBytes[i] = (byte) decodedBytesAsInts[i]; } return new String(rawBytes, StandardCharsets.UTF_8); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy