fi.evolver.ai.spring.chat.prompt.ChatPromptTemplateParser Maven / Gradle / Ivy
package fi.evolver.ai.spring.chat.prompt;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.StringWriter;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.chat.ChatApi;
import fi.evolver.ai.spring.chat.function.FunctionSpec;
import fi.evolver.ai.spring.chat.prompt.ChatPrompt.Builder;
import fi.evolver.ai.spring.chat.prompt.Message.Role;
import fi.evolver.ai.spring.prompt.template.HistoryTag;
import fi.evolver.ai.spring.prompt.template.PromptTemplateException;
import fi.evolver.ai.spring.prompt.template.TemplateUtils;
import fi.evolver.ai.spring.prompt.template.TemplateUtils.MetaDetails;
import fi.evolver.ai.spring.prompt.template.model.Section;
import fi.evolver.ai.spring.skill.Skill;
import fi.evolver.ai.spring.skill.SkillPrompt;
import fi.evolver.ai.spring.util.DurationUtils;
import freemarker.core.InvalidReferenceException;
import freemarker.template.Template;
import freemarker.template.TemplateException;
public class ChatPromptTemplateParser {
private static Map> skillsByName;
public static ChatPrompt parse(Reader reader, Map variables, List history) {
List sections = TemplateUtils.parseTemplate(reader);
return buildPrompt(sections, variables, history, Optional.empty());
}
public static ChatPrompt parse(String resource, Map variables, List history) {
return parse(
new InputStreamReader(ChatPromptTemplateParser.class.getResourceAsStream(resource), StandardCharsets.UTF_8),
variables,
history);
}
public static ChatPrompt parse(Reader reader, Map variables, List history, Optional> model) {
List sections = TemplateUtils.parseTemplate(reader);
return buildPrompt(sections, variables, history, model);
}
public static ChatPrompt parse(String resource, Map variables, List history, Optional> model) {
return parse(
new InputStreamReader(ChatPromptTemplateParser.class.getResourceAsStream(resource), StandardCharsets.UTF_8),
variables,
history,
model);
}
private static ChatPrompt buildPrompt(List sections, Map variables, List history, Optional> model) {
Builder builder = createBuilder(sections, model);
for (Section section : sections) {
switch (section.type()) {
case TemplateUtils.SECTION_META -> {}
case TemplateUtils.SECTION_FUNCTION -> handleFunction(builder, section);
case TemplateUtils.SECTION_SKILL -> handleSkill(builder, section);
case TemplateUtils.SECTION_ASSISTANT_MESSAGE -> handleMessage(builder, section, Role.ASSISTANT, variables, history);
case TemplateUtils.SECTION_SYSTEM_MESSAGE -> handleMessage(builder, section, Role.SYSTEM, variables, history);
case TemplateUtils.SECTION_USER_MESSAGE -> handleMessage(builder, section, Role.USER, variables, history);
case TemplateUtils.SECTION_HISTORY -> handleHistory(builder, section, history);
case TemplateUtils.SECTION_COMMENT -> {}
default -> throw new PromptTemplateException(section.type(), "unsupported template section type");
}
}
return builder.build();
}
private static Builder createBuilder(List sections, Optional> model) {
MetaDetails metaDetails = TemplateUtils.parseMetaDetails(sections, model.orElse(null));
Model templateModel = metaDetails.model();
Model modelToUse = templateModel != null ? templateModel : model.orElse(null);
Builder builder = ChatPrompt.builder(modelToUse);
metaDetails.timeout().ifPresent(builder::setTimeout);
metaDetails.properties().forEach(builder::setParameter);
return builder;
}
private static void handleFunction(Builder builder, Section section) {
if (!"java".equals(section.expectProperty("mode")))
throw new PromptTemplateException(section.type(), "only mode=java is supported for now");
String className = section.expectProperty("class");
boolean mandatory = "true".equals(section.getProperty("mandatory").orElse(null));
try {
Class> clazz = ChatPromptTemplateParser.class.getClassLoader().loadClass(className);
FunctionSpec> functionSpec = FunctionSpec.of(clazz);
builder.add(functionSpec, mandatory);
}
catch (ClassNotFoundException e) {
throw new PromptTemplateException(section.type(), "unknown class %s", className);
}
}
private static void handleSkill(Builder builder, Section section) {
String name = section.expectProperty("name");
Skill, ?> skill = skillsByName.get(name);
if (skill == null)
throw new PromptTemplateException(section.type(), "unknown skill %s", name);
boolean mandatory = "true".equals(section.getProperty("mandatory").orElse(null));
Duration timeout = section.getProperty("timeout").map(DurationUtils::parseDurationWithUnit).orElse(null);
SkillPrompt skillPrompt = SkillPrompt.builder(skill)
.setMandatory(mandatory)
.setTimeout(timeout)
.setParameter("mandatory", skill)
.build();
builder.add(skillPrompt);
}
private static void handleMessage(Builder builder, Section section, Role role, Map variables, List history) {
try {
StringWriter writer = new StringWriter();
Template template = TemplateUtils.getTemplate(section);
try (HistoryTag.NonFailingAutoCloseable c = TemplateUtils.TAG_HISTORY.setHistory(history)) {
template.process(variables, writer);
}
builder.add(new Message(role, writer.toString()));
}
catch (InvalidReferenceException e) {
throw new PromptTemplateException(e, section.type(), "missing template parameter %s", e.getBlamedExpressionString());
}
catch (IOException | TemplateException e) {
throw new PromptTemplateException(e, section.type(), "failed templating message");
}
}
private static void handleHistory(Builder builder, Section section, List history) {
int count = section.getProperty("count").map(Integer::parseInt).orElse(10000);
int skipFirst = section.getProperty("skipFirst").map(Integer::parseInt).orElse(0);
int skipLast = section.getProperty("skipLast").map(Integer::parseInt).orElse(0);
String roleFilter = section.expectProperty("roles");
builder.addAll(HistoryTag.findMessages(history, roleFilter, count, skipFirst, skipLast));
}
public static synchronized void initSkills(List> skills) {
if (skillsByName != null)
throw new IllegalStateException("Skills have already been initialized");
skillsByName = skills.stream().collect(Collectors.toMap(
Skill::getName,
Function.identity()));
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy