All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fi.evolver.ai.spring.skill.SkillService Maven / Gradle / Ivy

package fi.evolver.ai.spring.skill;

import java.time.Duration;
import java.util.*;
import java.util.concurrent.*;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import fi.evolver.ai.spring.chat.ChatApi;
import fi.evolver.ai.spring.chat.ChatResponse;
import fi.evolver.ai.spring.chat.FunctionCall;
import fi.evolver.ai.spring.chat.function.FunctionSpec;
import fi.evolver.ai.spring.chat.prompt.ChatPromptTemplateParser;
import fi.evolver.ai.spring.chat.prompt.Message;
import fi.evolver.ai.spring.chat.prompt.MessageContent;
import fi.evolver.ai.spring.chat.prompt.MessageContent.ToolResultContent;
import fi.evolver.ai.spring.skill.mock.LlmMockSkillConfigurationRepository;
import fi.evolver.ai.spring.skill.mock.LlmSkill;
import fi.evolver.ai.spring.skill.mock.entity.LlmMockSkillConfiguration;
import fi.evolver.utils.ContextUtils;
import io.swagger.v3.core.util.Json;

@Component
public class SkillService {
	private static final Logger LOG = LoggerFactory.getLogger(SkillService.class);

	private static final Duration TIMEOUT_DEFAULT = Duration.ofSeconds(30);
	private final Executor executor = createExecutor();
	private final Map> skills;

	private final LlmMockSkillConfigurationRepository llmMockSkillConfigurationRepository;
	private final ChatApi chatApi;


	@Autowired
	public SkillService(
			List> skills,
			LlmMockSkillConfigurationRepository llmMockSkillConfigurationRepository,
			ChatApi chatApi) {

		this.skills = skills.stream().collect(Collectors.toMap(Skill::getName, Function.identity()));
		ChatPromptTemplateParser.initSkills(skills);
		this.llmMockSkillConfigurationRepository = llmMockSkillConfigurationRepository;
		this.chatApi = chatApi;
	}


	protected Duration getDefaultTimeout() {
		return TIMEOUT_DEFAULT;
	}

	protected Executor createExecutor() {
		return ContextUtils.makeContextAware(Executors.newCachedThreadPool(r -> {
			Thread t = Executors.defaultThreadFactory().newThread(r);
		    t.setDaemon(true);
		    return t;
		}));
	}


	public List apply(ChatResponse response) throws SkillException {
		return apply(response, (a) -> {});
	}


	public List apply(ChatResponse response, SkillExecutionListener listener) throws SkillException {
		List> skillRuns = new ArrayList<>();

		for (FunctionCall functionCall: response.getFunctionCalls()) {
			SkillPrompt skillPrompt = response.getPrompt().getSkillPrompt(functionCall.getFunctionName());
			if (skillPrompt == null)
				continue;

			Skill skill = skills.get(skillPrompt.skillName());
			if (skill == null)
				throw new SkillException("Unknown skill: %s".formatted(functionCall.getFunctionName()));

			Skill implementation = chooseImplentation(skill);
			SkillRun run = new SkillRun<>(skillPrompt, implementation, functionCall, listener);
			run.execute(executor);
			skillRuns.add(run);
		}

		return skillRuns.stream()
				.map(SkillRun::getResult)
				.toList();
	}


	private  Skill chooseImplentation(Skill skill) {
		Optional mockConfig = llmMockSkillConfigurationRepository.findBySkillName(skill.getName());
		return mockConfig.filter(LlmMockSkillConfiguration::isEnabled).map(c -> mock(skill, c)).orElse(skill);
	}


	private  Skill mock(Skill skill, LlmMockSkillConfiguration config) {
		return new LlmSkill<>(
				skill.getParameterType(),
				skill.getResultType(),
				chatApi,
				config.getModel(),
				createSystemPrompt(skill, config));
	}


	private String createSystemPrompt(Skill skill, LlmMockSkillConfiguration config) {
		if (config.getPrompt().isPresent())
			return config.getPrompt().get();

		FunctionSpec parameterSpec = skill.getFunctionSpec();
		StringBuilder builder = new StringBuilder();
		builder.append("You are a mock implementation of a function named '").append(parameterSpec.getFunctionName()).append("'. ").append("Always respond to the given parameters in the requested format with convincing mock data.");
		parameterSpec.getTitle().ifPresent(t -> builder.append("\nFunction title: '").append(t).append("'"));
		parameterSpec.getDescription().ifPresent(d -> builder.append("\nFunction description: '").append(d).append("'"));
		return builder.toString();
	}


	private class SkillRun {
		private final Skill skill;
		private final SkillPrompt skillPrompt;
		private final FunctionCall functionCall;
		private final SkillExecutionListener listener;
		private volatile SkillResult result;
		private Future future;


		public SkillRun(SkillPrompt skillPrompt, Skill skill, FunctionCall functionCall, SkillExecutionListener listener) {
			this.skillPrompt = skillPrompt;
			this.skill = skill;
			this.functionCall = functionCall;
			this.listener = listener;
		}


		public void execute(Executor executor) {
			Duration timeout = skillPrompt.timeout().orElse(getDefaultTimeout());
			future = CompletableFuture.runAsync(this::run, executor)
					.orTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS);
		}


		private void run() {
			try {
				listener.onStart(skillPrompt);
				T parameters = functionCall.toResult(skill.getFunctionSpec());
				R value = skill.apply(parameters);
				done(value);
			}
			catch (RuntimeException e) {
				fail(e);
			}
		}


		public SkillResult getResult() {
			if (future == null)
				throw new IllegalStateException("Call execute(Executor) first!");

			try {
				future.get();
			}
			catch (InterruptedException | ExecutionException e) {
				fail(e);
			}

			return result;
		}



		private synchronized void done(Object value) {
			if (result != null)
				return;

			MessageContent content;
			if (value instanceof MessageContent mc)
				content = mc;
			else
				content = MessageContent.text(Json.pretty(value));

			result = new SkillResult(
					Optional.empty(),
					skillPrompt,
					value,
					MessageContent.toolResult(functionCall.getToolCallId(), content));
			listener.onComplete(skillPrompt, result);
		}


		public synchronized void fail(Exception e) {
			if (result != null) {
				LOG.warn("Running {} failed after a successfull completion: ignoring", skill.getName(), e);
				return;
			}

			Throwable cause = e;
			if (cause instanceof ExecutionException && cause.getCause() != null)
				cause = cause.getCause();

			SkillException error = null;
			String llmError = "unexpected failure";
			if (cause instanceof SkillException se) {
				llmError = cause.getMessage();
				error = se;
			}
			else if (cause instanceof TimeoutException) {
				llmError = "tool call timed out";
			}

			if (error == null)
				error = new SkillException(llmError, cause);

			listener.onError(skillPrompt, cause);
			result = new SkillResult(
					Optional.of(error),
					skillPrompt,
					null,
					MessageContent.toolResult(functionCall.getToolCallId(), llmError));
		}

	}


	public record SkillResult(
			Optional error,
			SkillPrompt prompt,
			Object value,
			ToolResultContent toolResult) {


		/**
		 * Return the skill's result if applying the skill succeeded.
		 *
		 * @return The skill's result.
		 * @throws SkillException if the skill failed.
		 */
		public Object value() {
			if (error.isPresent())
				throw error.get();
			return value;
		}

		public boolean success() {
			return error.isEmpty();
		}


		/**
		 * Create an user message of the skill results.
		 *
		 * @param skillResults The results to add to the message.
		 * @param moreContent Any extra contents to add into the message.
		 * @return The created message.
		 */
		public static Message toMessage(Collection skillResults, MessageContent... moreContent) {
			List contents = new ArrayList<>(skillResults.size() + moreContent.length);
			for (SkillResult skillResult: skillResults)
				contents.add(skillResult.toolResult());
			contents.addAll(Arrays.asList(moreContent));
			return Message.user(contents);
		}
	}


	public interface SkillExecutionListener {

		/**
		 * Called when the execution starts.
		 *
		 * @param skillPrompt The prompt for the skill to be executed.
		 */
		void onStart(SkillPrompt skillPrompt);

		/**
		 * Called when the execution has finished.
		 *
		 * @param skillPrompt The prompt for the executed skill.
		 * @param value The result value of the skill.
		 */
		default void onComplete(SkillPrompt skillPrompt, SkillResult result) {
		}

		/**
		 * Called on error.
		 *
		 * @param throwable The throwable detailing the issue.
		 */
		default void onError(SkillPrompt skillPrompt, Throwable throwable) {
		}

	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy