All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fi.evolver.ai.spring.chat.prompt.ChatPromptTemplateParser Maven / Gradle / Ivy

package fi.evolver.ai.spring.chat.prompt;

import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.StringWriter;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.chat.ChatApi;
import fi.evolver.ai.spring.chat.function.FunctionSpec;
import fi.evolver.ai.spring.chat.prompt.ChatPrompt.Builder;
import fi.evolver.ai.spring.chat.prompt.Message.Role;
import fi.evolver.ai.spring.prompt.template.HistoryTag;
import fi.evolver.ai.spring.prompt.template.PromptTemplateException;
import fi.evolver.ai.spring.prompt.template.TemplateUtils;
import fi.evolver.ai.spring.prompt.template.TemplateUtils.MetaDetails;
import fi.evolver.ai.spring.prompt.template.model.Section;
import fi.evolver.ai.spring.skill.Skill;
import fi.evolver.ai.spring.skill.SkillPrompt;
import fi.evolver.ai.spring.util.DurationUtils;
import freemarker.core.InvalidReferenceException;
import freemarker.template.Template;
import freemarker.template.TemplateException;


public class ChatPromptTemplateParser {
	private static Map> skillsByName;


	public static ChatPrompt parse(Reader reader, Map variables, List history) {
		List
sections = TemplateUtils.parseTemplate(reader); return buildPrompt(sections, variables, history, Optional.empty()); } public static ChatPrompt parse(String resource, Map variables, List history) { return parse( new InputStreamReader(ChatPromptTemplateParser.class.getResourceAsStream(resource), StandardCharsets.UTF_8), variables, history); } public static ChatPrompt parse(Reader reader, Map variables, List history, Optional> model) { List
sections = TemplateUtils.parseTemplate(reader); return buildPrompt(sections, variables, history, model); } public static ChatPrompt parse(String resource, Map variables, List history, Optional> model) { return parse( new InputStreamReader(ChatPromptTemplateParser.class.getResourceAsStream(resource), StandardCharsets.UTF_8), variables, history, model); } private static ChatPrompt buildPrompt(List
sections, Map variables, List history, Optional> model) { Builder builder = createBuilder(sections, model); for (Section section : sections) { switch (section.type()) { case TemplateUtils.SECTION_META -> {} case TemplateUtils.SECTION_FUNCTION -> handleFunction(builder, section); case TemplateUtils.SECTION_SKILL -> handleSkill(builder, section); case TemplateUtils.SECTION_ASSISTANT_MESSAGE -> handleMessage(builder, section, Role.ASSISTANT, variables, history); case TemplateUtils.SECTION_SYSTEM_MESSAGE -> handleMessage(builder, section, Role.SYSTEM, variables, history); case TemplateUtils.SECTION_USER_MESSAGE -> handleMessage(builder, section, Role.USER, variables, history); case TemplateUtils.SECTION_HISTORY -> handleHistory(builder, section, history); case TemplateUtils.SECTION_COMMENT -> {} default -> throw new PromptTemplateException(section.type(), "unsupported template section type"); } } return builder.build(); } private static Builder createBuilder(List
sections, Optional> model) { MetaDetails metaDetails = TemplateUtils.parseMetaDetails(sections, model.orElse(null)); Model templateModel = metaDetails.model(); Model modelToUse = templateModel != null ? templateModel : model.orElse(null); Builder builder = ChatPrompt.builder(modelToUse); metaDetails.timeout().ifPresent(builder::setTimeout); metaDetails.properties().forEach(builder::setParameter); return builder; } private static void handleFunction(Builder builder, Section section) { if (!"java".equals(section.expectProperty("mode"))) throw new PromptTemplateException(section.type(), "only mode=java is supported for now"); String className = section.expectProperty("class"); boolean mandatory = "true".equals(section.getProperty("mandatory").orElse(null)); try { Class clazz = ChatPromptTemplateParser.class.getClassLoader().loadClass(className); FunctionSpec functionSpec = FunctionSpec.of(clazz); builder.add(functionSpec, mandatory); } catch (ClassNotFoundException e) { throw new PromptTemplateException(section.type(), "unknown class %s", className); } } private static void handleSkill(Builder builder, Section section) { String name = section.expectProperty("name"); Skill skill = skillsByName.get(name); if (skill == null) throw new PromptTemplateException(section.type(), "unknown skill %s", name); boolean mandatory = "true".equals(section.getProperty("mandatory").orElse(null)); Duration timeout = section.getProperty("timeout").map(DurationUtils::parseDurationWithUnit).orElse(null); SkillPrompt skillPrompt = SkillPrompt.builder(skill) .setMandatory(mandatory) .setTimeout(timeout) .setParameter("mandatory", skill) .build(); builder.add(skillPrompt); } private static void handleMessage(Builder builder, Section section, Role role, Map variables, List history) { try { StringWriter writer = new StringWriter(); Template template = TemplateUtils.getTemplate(section); try (HistoryTag.NonFailingAutoCloseable c = TemplateUtils.TAG_HISTORY.setHistory(history)) { template.process(variables, writer); } builder.add(new Message(role, writer.toString())); } catch (InvalidReferenceException e) { throw new PromptTemplateException(e, section.type(), "missing template parameter %s", e.getBlamedExpressionString()); } catch (IOException | TemplateException e) { throw new PromptTemplateException(e, section.type(), "failed templating message"); } } private static void handleHistory(Builder builder, Section section, List history) { int count = section.getProperty("count").map(Integer::parseInt).orElse(10000); int skipFirst = section.getProperty("skipFirst").map(Integer::parseInt).orElse(0); int skipLast = section.getProperty("skipLast").map(Integer::parseInt).orElse(0); String roleFilter = section.expectProperty("roles"); builder.addAll(HistoryTag.findMessages(history, roleFilter, count, skipFirst, skipLast)); } public static synchronized void initSkills(List> skills) { if (skillsByName != null) throw new IllegalStateException("Skills have already been initialized"); skillsByName = skills.stream().collect(Collectors.toMap( Skill::getName, Function.identity())); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy