fi.evolver.ai.spring.prompt.template.TemplateUtils Maven / Gradle / Ivy
package fi.evolver.ai.spring.prompt.template;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.knuddels.jtokkit.api.EncodingType;
import fi.evolver.ai.spring.Api;
import fi.evolver.ai.spring.JtokkitTokenizer;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.Tokenizer;
import fi.evolver.ai.spring.prompt.template.model.Section;
import fi.evolver.ai.spring.prompt.template.model.SectionProperty;
import fi.evolver.ai.spring.util.DurationUtils;
import freemarker.template.Configuration;
import freemarker.template.Template;
public class TemplateUtils {
public static final String META_PROPERTY_MODEL = "model";
private static final Logger LOG = LoggerFactory.getLogger(TemplateUtils.class);
private static final Configuration FREEMARKER_CONFIGURATION;
private static final Map TEMPLATE_CACHE = new ConcurrentHashMap<>();
private static final Pattern REGEX_MODEL_WITH_CL100K_BASE = Pattern.compile("(?:gpt-3.5-turbo|gpt-4)(?:[-o].*)?"); // gpt-4o model uses new tokenizer o200k_base which isn't yet supported by jtokkit.
public static final String SECTION_META = "META";
public static final String SECTION_COMMENT = "COMMENT";
public static final String SECTION_FUNCTION = "FUNCTION";
public static final String SECTION_SKILL = "SKILL";
public static final String SECTION_ASSISTANT_MESSAGE = "ASSISTANT_MESSAGE";
public static final String SECTION_SYSTEM_MESSAGE = "SYSTEM_MESSAGE";
public static final String SECTION_USER_MESSAGE = "USER_MESSAGE";
public static final String SECTION_HISTORY = "HISTORY";
public static final String SECTION_PROMPT = "PROMPT";
public static final HistoryTag TAG_HISTORY = new HistoryTag();
static {
disableFreemarkerLogging();
FREEMARKER_CONFIGURATION = new Configuration(Configuration.VERSION_2_3_31);
FREEMARKER_CONFIGURATION.setSharedVariable("history", TemplateUtils.TAG_HISTORY);
}
public static Template getTemplate(Section section) throws IOException {
Optional templateResource = section.getProperty("template");
return getTemplate(templateResource.isPresent() ? readResource(templateResource.get()) : section.content());
}
private static Template getTemplate(String template) throws IOException {
String hash = DigestUtils.sha1Hex(template);
Template result = TEMPLATE_CACHE.get(hash);
if (result == null) {
result = createTemplate(hash, template);
TEMPLATE_CACHE.put(hash, result);
}
return result;
}
private static Template createTemplate(String hash, String template) throws IOException {
return new Template(hash, template, FREEMARKER_CONFIGURATION);
}
private static String readResource(String resource) throws IOException {
try (Reader reader = new InputStreamReader(TemplateUtils.class.getResourceAsStream(resource), StandardCharsets.UTF_8)) {
return IOUtils.toString(reader);
}
}
@SuppressWarnings("deprecation")
private static void disableFreemarkerLogging() {
try {
freemarker.log.Logger.selectLoggerLibrary(freemarker.log.Logger.LIBRARY_NONE);
}
catch (ClassNotFoundException e) {
LOG.warn("Could not disable Freemarker log spam");
}
}
/**
* Parse the meta details from the given sections
*
* NOTE: This method will throw an exception if the model is not defined in the meta section
*
* @param sections
* @return The parsed meta details
* @param
*/
public static MetaDetails parseMetaDetails(List sections) {
return parseMetaDetails(sections, null);
}
/**
* Parse the meta details from the given sections
*
* NOTE: model has to be defined in the meta section or the defaultModel has to be provided
*
* @param sections
* @param defaultModel used as a fallback if the model is not defined in the meta section
* @return The parsed meta details
* @param
*/
public static MetaDetails parseMetaDetails(List sections, Model defaultModel) {
Map properties = getMetaProperties(sections);
String modelName = properties.remove("model");
if (modelName == null && defaultModel != null)
modelName = defaultModel.name();
if (modelName == null)
throw new PromptTemplateException("META", "missing the required 'model' property");
Optional timeout = Optional.ofNullable(properties.remove("timeout")).map(Object::toString).map(DurationUtils::parseDurationWithUnit);
Model model = new Model<>(
modelName,
inferTokenLimit(modelName, properties.remove("tokenLimit")),
inferTokenizer(modelName, properties.remove("tokenizer")));
return new MetaDetails<>(model, timeout, properties);
}
public static int inferTokenLimit(String model, String tokenLimit) {
if (tokenLimit != null)
return Integer.parseInt(tokenLimit);
if (model.startsWith("gpt-4-turbo") || model.startsWith("gpt-4-1106") || model.startsWith("gpt-4o"))
return 128_000;
if (model.startsWith("gpt-4"))
return 8_192;
if (model.startsWith("gpt-3.5-turbo")) {
String[] parts = model.split("-");
String version = parts.length >= 4 ? parts[3] : null;
if (version != null && ("16k".equals(version) || (version.matches("\\d{4}") && "0125".compareTo(version) >= 0)))
return 16_385;
else
return 4_096;
}
if (model.startsWith("text-embedding-ada-002"))
return 8_192;
throw new PromptTemplateException("META", "missing property token_limit, could not infer the value", model);
}
public static Tokenizer inferTokenizer(String model, String tokenizer) {
if (tokenizer != null)
return JtokkitTokenizer.of(EncodingType.fromName(tokenizer).orElseThrow(() -> new PromptTemplateException("META", "unsupported tokenizer '%s'", tokenizer)));
if (REGEX_MODEL_WITH_CL100K_BASE.matcher(model).matches())
return Tokenizer.CL100K_BASE;
throw new PromptTemplateException("META", "missing property tokenizer, could not infer the value", model);
}
private static Map getMetaProperties(List sections) {
Map results = new LinkedHashMap<>();
sections.stream()
.filter(s -> SECTION_META.equals(s.type()))
.map(Section::properties)
.forEach(results::putAll);
return results;
}
public static List parseTemplate(Reader reader) {
try (TemplateLineStream lines = new TemplateLineStream(reader)) {
return parseTemplate(lines);
}
catch (IOException e) {
throw new UncheckedIOException("Failed reading prompt template", e);
}
}
private static List parseTemplate(TemplateLineStream lines) throws IOException {
List results = new ArrayList<>();
while (lines.hasSectionHeader())
results.add(parseSection(lines));
return results;
}
private static Section parseSection(TemplateLineStream lines) throws IOException {
String type = lines.expectSectionHeader();
Map properties = new LinkedHashMap<>();
while (lines.hasProperty()) {
SectionProperty property = lines.expectProperty();
properties.put(property.key(), property.value());
}
StringBuilder content = new StringBuilder();
while (lines.hasNext() && !lines.hasSectionHeader())
content.append(lines.next()).append("\n");
return new Section(type, properties, content.toString());
}
public static record MetaDetails(
Model model,
Optional timeout,
Map properties) {
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy