All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fi.evolver.ai.spring.prompt.template.TemplateUtils Maven / Gradle / Ivy

package fi.evolver.ai.spring.prompt.template;


import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;

import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.knuddels.jtokkit.api.EncodingType;

import fi.evolver.ai.spring.Api;
import fi.evolver.ai.spring.JtokkitTokenizer;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.Tokenizer;
import fi.evolver.ai.spring.prompt.template.model.Section;
import fi.evolver.ai.spring.prompt.template.model.SectionProperty;
import fi.evolver.ai.spring.util.DurationUtils;
import freemarker.template.Configuration;
import freemarker.template.Template;

public class TemplateUtils {
	public static final String META_PROPERTY_MODEL = "model";
	private static final Logger LOG = LoggerFactory.getLogger(TemplateUtils.class);
	private static final Configuration FREEMARKER_CONFIGURATION;
	private static final Map TEMPLATE_CACHE = new ConcurrentHashMap<>();

	private static final Pattern REGEX_MODEL_WITH_CL100K_BASE = Pattern.compile("(?:gpt-3.5-turbo|gpt-4)(?:[-o].*)?"); // gpt-4o model uses new tokenizer o200k_base which isn't yet supported by jtokkit.

	public static final String SECTION_META = "META";
	public static final String SECTION_COMMENT = "COMMENT";
	public static final String SECTION_FUNCTION = "FUNCTION";
	public static final String SECTION_SKILL = "SKILL";
	public static final String SECTION_ASSISTANT_MESSAGE = "ASSISTANT_MESSAGE";
	public static final String SECTION_SYSTEM_MESSAGE = "SYSTEM_MESSAGE";
	public static final String SECTION_USER_MESSAGE = "USER_MESSAGE";
	public static final String SECTION_HISTORY = "HISTORY";
	public static final String SECTION_PROMPT = "PROMPT";

	public static final HistoryTag TAG_HISTORY = new HistoryTag();


	static {
		disableFreemarkerLogging();
		FREEMARKER_CONFIGURATION = new Configuration(Configuration.VERSION_2_3_31);
		FREEMARKER_CONFIGURATION.setSharedVariable("history", TemplateUtils.TAG_HISTORY);
	}


	public static Template getTemplate(Section section) throws IOException {
		Optional templateResource = section.getProperty("template");
		return getTemplate(templateResource.isPresent() ? readResource(templateResource.get()) : section.content());
	}

	private static Template getTemplate(String template) throws IOException {
		String hash = DigestUtils.sha1Hex(template);
		Template result = TEMPLATE_CACHE.get(hash);
		if (result == null) {
			result = createTemplate(hash, template);
			TEMPLATE_CACHE.put(hash, result);
		}
		return result;
	}

	private static Template createTemplate(String hash, String template) throws IOException {
		return new Template(hash, template, FREEMARKER_CONFIGURATION);
	}

	private static String readResource(String resource) throws IOException {
		try (Reader reader = new InputStreamReader(TemplateUtils.class.getResourceAsStream(resource), StandardCharsets.UTF_8)) {
			return IOUtils.toString(reader);
		}
	}

	@SuppressWarnings("deprecation")
	private static void disableFreemarkerLogging() {
		try {
			freemarker.log.Logger.selectLoggerLibrary(freemarker.log.Logger.LIBRARY_NONE);
		}
		catch (ClassNotFoundException e) {
			LOG.warn("Could not disable Freemarker log spam");
		}
	}


	/**
	 * Parse the meta details from the given sections
	 *
	 * NOTE: This method will throw an exception if the model is not defined in the meta section
	 *
	 * @param sections
	 * @return The parsed meta details
	 * @param 
	 */
	public static  MetaDetails parseMetaDetails(List
sections) { return parseMetaDetails(sections, null); } /** * Parse the meta details from the given sections * * NOTE: model has to be defined in the meta section or the defaultModel has to be provided * * @param sections * @param defaultModel used as a fallback if the model is not defined in the meta section * @return The parsed meta details * @param */ public static MetaDetails parseMetaDetails(List
sections, Model defaultModel) { Map properties = getMetaProperties(sections); String modelName = properties.remove("model"); if (modelName == null && defaultModel != null) modelName = defaultModel.name(); if (modelName == null) throw new PromptTemplateException("META", "missing the required 'model' property"); Optional timeout = Optional.ofNullable(properties.remove("timeout")).map(Object::toString).map(DurationUtils::parseDurationWithUnit); Model model = new Model<>( modelName, inferTokenLimit(modelName, properties.remove("tokenLimit")), inferTokenizer(modelName, properties.remove("tokenizer"))); return new MetaDetails<>(model, timeout, properties); } public static int inferTokenLimit(String model, String tokenLimit) { if (tokenLimit != null) return Integer.parseInt(tokenLimit); if (model.startsWith("gpt-4-turbo") || model.startsWith("gpt-4-1106") || model.startsWith("gpt-4o")) return 128_000; if (model.startsWith("gpt-4")) return 8_192; if (model.startsWith("gpt-3.5-turbo")) { String[] parts = model.split("-"); String version = parts.length >= 4 ? parts[3] : null; if (version != null && ("16k".equals(version) || (version.matches("\\d{4}") && "0125".compareTo(version) >= 0))) return 16_385; else return 4_096; } if (model.startsWith("text-embedding-ada-002")) return 8_192; throw new PromptTemplateException("META", "missing property token_limit, could not infer the value", model); } public static Tokenizer inferTokenizer(String model, String tokenizer) { if (tokenizer != null) return JtokkitTokenizer.of(EncodingType.fromName(tokenizer).orElseThrow(() -> new PromptTemplateException("META", "unsupported tokenizer '%s'", tokenizer))); if (REGEX_MODEL_WITH_CL100K_BASE.matcher(model).matches()) return Tokenizer.CL100K_BASE; throw new PromptTemplateException("META", "missing property tokenizer, could not infer the value", model); } private static Map getMetaProperties(List
sections) { Map results = new LinkedHashMap<>(); sections.stream() .filter(s -> SECTION_META.equals(s.type())) .map(Section::properties) .forEach(results::putAll); return results; } public static List
parseTemplate(Reader reader) { try (TemplateLineStream lines = new TemplateLineStream(reader)) { return parseTemplate(lines); } catch (IOException e) { throw new UncheckedIOException("Failed reading prompt template", e); } } private static List
parseTemplate(TemplateLineStream lines) throws IOException { List
results = new ArrayList<>(); while (lines.hasSectionHeader()) results.add(parseSection(lines)); return results; } private static Section parseSection(TemplateLineStream lines) throws IOException { String type = lines.expectSectionHeader(); Map properties = new LinkedHashMap<>(); while (lines.hasProperty()) { SectionProperty property = lines.expectProperty(); properties.put(property.key(), property.value()); } StringBuilder content = new StringBuilder(); while (lines.hasNext() && !lines.hasSectionHeader()) content.append(lines.next()).append("\n"); return new Section(type, properties, content.toString()); } public static record MetaDetails( Model model, Optional timeout, Map properties) { } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy