fi.evolver.ai.spring.skill.SkillService Maven / Gradle / Ivy
package fi.evolver.ai.spring.skill;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import fi.evolver.ai.spring.chat.ChatApi;
import fi.evolver.ai.spring.chat.ChatResponse;
import fi.evolver.ai.spring.chat.FunctionCall;
import fi.evolver.ai.spring.chat.function.FunctionSpec;
import fi.evolver.ai.spring.chat.prompt.ChatPromptTemplateParser;
import fi.evolver.ai.spring.chat.prompt.Message;
import fi.evolver.ai.spring.chat.prompt.MessageContent;
import fi.evolver.ai.spring.chat.prompt.MessageContent.ToolResultContent;
import fi.evolver.ai.spring.skill.mock.LlmMockSkillConfigurationRepository;
import fi.evolver.ai.spring.skill.mock.LlmSkill;
import fi.evolver.ai.spring.skill.mock.entity.LlmMockSkillConfiguration;
import fi.evolver.utils.ContextUtils;
import io.swagger.v3.core.util.Json;
@Component
public class SkillService {
private static final Logger LOG = LoggerFactory.getLogger(SkillService.class);
private static final Duration TIMEOUT_DEFAULT = Duration.ofSeconds(30);
private final Executor executor = createExecutor();
private final Map> skills;
private final LlmMockSkillConfigurationRepository llmMockSkillConfigurationRepository;
private final ChatApi chatApi;
@Autowired
public SkillService(
List> skills,
LlmMockSkillConfigurationRepository llmMockSkillConfigurationRepository,
ChatApi chatApi) {
this.skills = skills.stream().collect(Collectors.toMap(Skill::getName, Function.identity()));
ChatPromptTemplateParser.initSkills(skills);
this.llmMockSkillConfigurationRepository = llmMockSkillConfigurationRepository;
this.chatApi = chatApi;
}
protected Duration getDefaultTimeout() {
return TIMEOUT_DEFAULT;
}
protected Executor createExecutor() {
return ContextUtils.makeContextAware(Executors.newCachedThreadPool(r -> {
Thread t = Executors.defaultThreadFactory().newThread(r);
t.setDaemon(true);
return t;
}));
}
public List apply(ChatResponse response) throws SkillException {
return apply(response, (a) -> {});
}
public List apply(ChatResponse response, SkillExecutionListener listener) throws SkillException {
List> skillRuns = new ArrayList<>();
for (FunctionCall functionCall: response.getFunctionCalls()) {
SkillPrompt skillPrompt = response.getPrompt().getSkillPrompt(functionCall.getFunctionName());
if (skillPrompt == null)
continue;
Skill, ?> skill = skills.get(skillPrompt.skillName());
if (skill == null)
throw new SkillException("Unknown skill: %s".formatted(functionCall.getFunctionName()));
Skill, ?> implementation = chooseImplentation(skill);
SkillRun, ?> run = new SkillRun<>(skillPrompt, implementation, functionCall, listener);
run.execute(executor);
skillRuns.add(run);
}
return skillRuns.stream()
.map(SkillRun::getResult)
.toList();
}
private Skill chooseImplentation(Skill skill) {
Optional mockConfig = llmMockSkillConfigurationRepository.findBySkillName(skill.getName());
return mockConfig.filter(LlmMockSkillConfiguration::isEnabled).map(c -> mock(skill, c)).orElse(skill);
}
private Skill mock(Skill skill, LlmMockSkillConfiguration config) {
return new LlmSkill<>(
skill.getParameterType(),
skill.getResultType(),
chatApi,
config.getModel(),
createSystemPrompt(skill, config));
}
private String createSystemPrompt(Skill, ?> skill, LlmMockSkillConfiguration config) {
if (config.getPrompt().isPresent())
return config.getPrompt().get();
FunctionSpec> parameterSpec = skill.getFunctionSpec();
StringBuilder builder = new StringBuilder();
builder.append("You are a mock implementation of a function named '").append(parameterSpec.getFunctionName()).append("'. ").append("Always respond to the given parameters in the requested format with convincing mock data.");
parameterSpec.getTitle().ifPresent(t -> builder.append("\nFunction title: '").append(t).append("'"));
parameterSpec.getDescription().ifPresent(d -> builder.append("\nFunction description: '").append(d).append("'"));
return builder.toString();
}
private class SkillRun {
private final Skill skill;
private final SkillPrompt skillPrompt;
private final FunctionCall functionCall;
private final SkillExecutionListener listener;
private volatile SkillResult result;
private Future future;
public SkillRun(SkillPrompt skillPrompt, Skill skill, FunctionCall functionCall, SkillExecutionListener listener) {
this.skillPrompt = skillPrompt;
this.skill = skill;
this.functionCall = functionCall;
this.listener = listener;
}
public void execute(Executor executor) {
Duration timeout = skillPrompt.timeout().orElse(getDefaultTimeout());
future = CompletableFuture.runAsync(this::run, executor)
.orTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS);
}
private void run() {
try {
listener.onStart(skillPrompt);
T parameters = functionCall.toResult(skill.getFunctionSpec());
R value = skill.apply(parameters);
done(value);
}
catch (RuntimeException e) {
fail(e);
}
}
public SkillResult getResult() {
if (future == null)
throw new IllegalStateException("Call execute(Executor) first!");
try {
future.get();
}
catch (InterruptedException | ExecutionException e) {
fail(e);
}
return result;
}
private synchronized void done(Object value) {
if (result != null)
return;
MessageContent content;
if (value instanceof MessageContent mc)
content = mc;
else
content = MessageContent.text(Json.pretty(value));
result = new SkillResult(
Optional.empty(),
skillPrompt,
value,
MessageContent.toolResult(functionCall.getToolCallId(), content));
listener.onComplete(skillPrompt, result);
}
public synchronized void fail(Exception e) {
if (result != null) {
LOG.warn("Running {} failed after a successfull completion: ignoring", skill.getName(), e);
return;
}
Throwable cause = e;
if (cause instanceof ExecutionException && cause.getCause() != null)
cause = cause.getCause();
SkillException error = null;
String llmError = "unexpected failure";
if (cause instanceof SkillException se) {
llmError = cause.getMessage();
error = se;
}
else if (cause instanceof TimeoutException) {
llmError = "tool call timed out";
}
if (error == null)
error = new SkillException(llmError, cause);
listener.onError(skillPrompt, cause);
result = new SkillResult(
Optional.of(error),
skillPrompt,
null,
MessageContent.toolResult(functionCall.getToolCallId(), llmError));
}
}
public record SkillResult(
Optional error,
SkillPrompt prompt,
Object value,
ToolResultContent toolResult) {
/**
* Return the skill's result if applying the skill succeeded.
*
* @return The skill's result.
* @throws SkillException if the skill failed.
*/
public Object value() {
if (error.isPresent())
throw error.get();
return value;
}
public boolean success() {
return error.isEmpty();
}
/**
* Create an user message of the skill results.
*
* @param skillResults The results to add to the message.
* @param moreContent Any extra contents to add into the message.
* @return The created message.
*/
public static Message toMessage(Collection extends SkillResult> skillResults, MessageContent... moreContent) {
List contents = new ArrayList<>(skillResults.size() + moreContent.length);
for (SkillResult skillResult: skillResults)
contents.add(skillResult.toolResult());
contents.addAll(Arrays.asList(moreContent));
return Message.user(contents);
}
}
public interface SkillExecutionListener {
/**
* Called when the execution starts.
*
* @param skillPrompt The prompt for the skill to be executed.
*/
void onStart(SkillPrompt skillPrompt);
/**
* Called when the execution has finished.
*
* @param skillPrompt The prompt for the executed skill.
* @param value The result value of the skill.
*/
default void onComplete(SkillPrompt skillPrompt, SkillResult result) {
}
/**
* Called on error.
*
* @param throwable The throwable detailing the issue.
*/
default void onError(SkillPrompt skillPrompt, Throwable throwable) {
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy