fi.evolver.ai.spring.util.TokenUtils Maven / Gradle / Ivy
package fi.evolver.ai.spring.util;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.Tokenizer;
import fi.evolver.ai.spring.chat.function.FunctionSpec;
import fi.evolver.ai.spring.chat.prompt.ChatPrompt;
import fi.evolver.ai.spring.chat.prompt.Message;
/**
* General utilities for handling tokens.
*/
public class TokenUtils {
private static final int PER_FUNCTION_OVERHEAD = 50;
/**
* Calculate the number of tokens in a given text based on a specific model.
*
* @param text The text to tokenize
* @param model The model whose specific tokenization is used
* @return The number of tokens.
*/
public static int calculateTokens(String text, Model> model) {
return model.tokenizer().countTokens(text);
}
/**
* Calculate the number of tokens in a given text based on the default (GPT 3.5 turbo) model tokenization.
*
* @param text The text to tokenize
* @return The number of tokens.
*/
public static int calculateTokens(String text) {
return Tokenizer.CL100K_BASE.countTokens(text);
}
/**
* Calculate the number of tokens that a function call consumes based on a specific model.
*
* @param functionSpec The function specification
* @param model The model whose specific tokenization is used
* @return The number of tokens
*/
public static int calculateTokens(FunctionSpec> functionSpec, Model> model) {
int result = calculateTokens(functionSpec.toJsonSchema(), model) + PER_FUNCTION_OVERHEAD;
result += functionSpec.getTitle().map(t -> calculateTokens(t, model)).orElse(0);
result += functionSpec.getDescription().map(d -> calculateTokens(d, model)).orElse(0);
return result;
}
/**
* Calculate the number of tokens that a whole prompt consumes, i.e. sum of the token counts of the messages
* and functions of the prompt.
*
* @param prompt The function specification
* @return The number of tokens that the prompt consumes when sent to the engine.
*/
public static int calculateTokens(ChatPrompt prompt) {
int counter = prompt.messages().stream()
.map(Message::getContent)
.mapToInt(m -> calculateTokens(m, prompt.model()))
.sum();
counter += prompt.functions().stream()
.mapToInt(f -> calculateTokens(f, prompt.model()))
.sum();
return counter;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy