All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fi.evolver.ai.spring.config.ApiConfigurationService Maven / Gradle / Ivy

package fi.evolver.ai.spring.config;

import java.time.Duration;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import fi.evolver.ai.spring.Api;
import fi.evolver.ai.spring.ApiResponseException;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.config.LlmApiConfiguration.ApiConfig;
import fi.evolver.ai.spring.config.LlmApiConfiguration.ModelConfig;
import fi.evolver.ai.spring.config.LlmApiConfiguration.ProviderConfig;
import fi.evolver.utils.NullSafetyUtils;


@Service
public class ApiConfigurationService {
	private static final Map MODEL_PARAMETERS = new ConcurrentHashMap<>();

	private static final Model DEFAULT = new Model<>("default", 0, null);


	private final LlmApiConfiguration llmApiConfiguration;


	@Autowired
	public ApiConfigurationService(LlmApiConfiguration llmApiConfiguration) {
		this.llmApiConfiguration = llmApiConfiguration;
	}


	private ApiEndpointParameters getEndpointParameters(Class providerClass, Optional providerName, String apiType, Model model) {
		ApiConfigurationCacheKey key = new ApiConfigurationCacheKey(providerClass, providerName, apiType, model.name());
		return MODEL_PARAMETERS.computeIfAbsent(key, this::fetchModelConfig);
	}


	/**
	 * Get connectivity configuration for the given API and model.
	 *
	 * @param providerClass The class providing the API.
	 * @param providerName The API endpoint's provider's name.
	 * @param api The type of the API.
	 * @param model The model to be used.
	 * @return Configuration for the endpoint providing the requested API.
	 */
	public ApiEndpointParameters getEndpointParameters(Class providerClass, Optional providerName, Class api, Model model) {
		String apiType = api.getSimpleName().replaceFirst("Api$", "").replaceAll("(?<=[a-z])(?=[A-Z])", "_").toLowerCase();
		return getEndpointParameters(providerClass, providerName, apiType, model);
	}

	/**
	 * Get connectivity configuration for the given API.
	 *
	 * @param providerClass The class providing the API.
	 * @param providerName The API endpoint's provider's name.
	 * @param api The type of the API.
	 * @return Configuration for the endpoint providing the requested API.
	 */
	public ApiEndpointParameters getEndpointParameters(Class providerClass, Optional providerName, Class api) {
		return getEndpointParameters(providerClass, providerName, api, DEFAULT);
	}

	/**
	 * Get connectivity configuration for the given custom API type.
	 *
	 * @param providerClass The class providing the API.
	 * @param providerName The API endpoint's provider's name.
	 * @param apiType The custom API type.
	 * @return Configuration for the endpoint providing the requested API.
	 */
	public ApiEndpointParameters getEndpointParameters(Class providerClass, Optional providerName, String apiType) {
		return getEndpointParameters(providerClass, providerName, apiType, DEFAULT);
	}


	private ApiEndpointParameters fetchModelConfig(ApiConfigurationCacheKey key) {
		ProviderConfig providerConfig = getProviderConfig(key.providerClass(), key.apiType(), key.providerName);

		ApiConfig apiConfig = providerConfig.apis().get(key.apiType());
		Optional modelConfig = Optional.ofNullable(apiConfig.models().get(key.modelName));

		Map headers = NullSafetyUtils.denull(
				modelConfig.map(ModelConfig::headers).orElse(null),
				apiConfig.headers(),
				providerConfig.headers(),
				Map.of());

		String urlString = NullSafetyUtils.denull(
				modelConfig.map(ModelConfig::url).orElse(null),
				apiConfig.url(),
				providerConfig.url());

		Integer port = NullSafetyUtils.denull(
				modelConfig.map(ModelConfig::port).orElse(null),
				apiConfig.port(),
				providerConfig.port());

		Duration timeout = Duration.ofMillis(NullSafetyUtils.denull(
				modelConfig.map(ModelConfig::timeoutMs).orElse(null),
				apiConfig.timeoutMs(),
				providerConfig.timeoutMs(),
				30_000));

		if (urlString == null)
			throw new ApiResponseException("The API connection for %s has not been initialized correctly".formatted(key));

		return new ApiEndpointParameters(headers, urlString, Optional.ofNullable(port), timeout);
	}


	private ProviderConfig getProviderConfig(Class providerClass, String apiType, Optional providerName) {
		return llmApiConfiguration.providers().entrySet().stream()
				.filter(e -> providerName.map(e.getKey()::equals).orElse(true))
				.map(Map.Entry::getValue)
				.filter(e -> providerClass.getSimpleName().equals(e.providerClass()))
				.filter(p -> p.apis().containsKey(apiType))
				.findFirst()
				.orElseThrow(() -> new IllegalArgumentException("No matching provider for %s with %s found for api %s".formatted(
						providerClass.getSimpleName(),
						providerName.map("name %s"::formatted).orElse("any name"),
						apiType)));
	}


	private record ApiConfigurationCacheKey(
			Class providerClass,
			Optional providerName,
			String apiType,
			String modelName) {
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy