All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.semantickernel.services.OrderedAIServiceSelector Maven / Gradle / Ivy

// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.services;

import com.microsoft.semantickernel.implementation.Verify;
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
import com.microsoft.semantickernel.semanticfunctions.KernelFunction;
import com.microsoft.semantickernel.semanticfunctions.KernelFunctionArguments;
import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService;
import com.microsoft.semantickernel.services.textcompletion.TextGenerationService;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Implementation of {@link com.microsoft.semantickernel.services.AIServiceSelector} that selects
 * the AI service based on the order of the execution settings. Uses the service id or model id to
 * select the preferred service provider and then returns the service and associated execution
 * settings.
 */
public class OrderedAIServiceSelector extends BaseAIServiceSelector {

    private static final Logger LOGGER = LoggerFactory.getLogger(OrderedAIServiceSelector.class);

    /**
     * Initializes a new instance of the {@link OrderedAIServiceSelector} class with an empty
     * collection of services.
     */
    public OrderedAIServiceSelector() {
        super(new AIServiceCollection());
    }

    /**
     * Initializes a new instance of the {@link OrderedAIServiceSelector} class with the specified
     * services.
     *
     * @param services The services to select from.
     */
    public OrderedAIServiceSelector(AIServiceCollection services) {
        super(services);
    }

    @SuppressWarnings("unchecked")
    @Nullable
    private static  AIServiceSelection castServiceSelection(
        AIServiceSelection selection) {
        try {
            // unchecked cast
            return (AIServiceSelection) selection;
        } catch (ClassCastException e) {
            LOGGER.debug("{}", e.getMessage());
            return null;
        }
    }

    @Nullable
    private static Map settingsFromFunctionSettings(
        @Nullable KernelFunction function) {
        if (function != null) {
            return function.getExecutionSettings();
        }
        return null;
    }

    @Nullable
    @Override
    public  AIServiceSelection trySelectAIService(
        Class serviceType,
        @Nullable KernelFunction function,
        @Nullable KernelFunctionArguments arguments,
        Map, AIService> services) {

        // Allow the execution settings from the kernel arguments to take precedence
        Map executionSettings = settingsFromFunctionSettings(
            function);

        if (executionSettings == null || executionSettings.isEmpty()) {
            AIService service = getAnyService(serviceType);
            if (service != null) {
                return castServiceSelection(new AIServiceSelection<>(service, null));
            }
        } else {
            AIServiceSelection selection = executionSettings
                .entrySet()
                .stream()
                .map(keyValue -> {

                    PromptExecutionSettings settings = keyValue.getValue();
                    String serviceId = keyValue.getKey();

                    if (!Verify.isNullOrEmpty(serviceId)) {
                        AIService service = getService(serviceId);
                        if (service != null) {
                            return castServiceSelection(
                                new AIServiceSelection<>(service, settings));
                        }
                    }

                    return null;
                })
                .filter(Objects::nonNull)
                .findFirst()
                .orElseGet(() -> null);

            if (selection != null) {
                return castServiceSelection(selection);
            }

            selection = executionSettings
                .entrySet()
                .stream()
                .map(keyValue -> {
                    PromptExecutionSettings settings = keyValue.getValue();

                    if (!Verify.isNullOrEmpty(settings.getModelId())) {
                        AIService service = getServiceByModelId(settings.getModelId());
                        if (service != null) {
                            return castServiceSelection(
                                new AIServiceSelection<>(service, settings));
                        }
                    }

                    return null;
                })
                .filter(Objects::nonNull)
                .findFirst()
                .orElseGet(() -> null);

            if (selection != null) {
                return castServiceSelection(selection);
            }
        }

        AIService defaultService = getService(PromptExecutionSettings.DEFAULT_SERVICE_ID);

        if (defaultService != null && serviceType.isAssignableFrom(defaultService.getClass())) {
            return castServiceSelection(new AIServiceSelection<>(defaultService, null));
        }

        AIService service = getAnyService(serviceType);
        PromptExecutionSettings settings = null;

        if (executionSettings != null && !executionSettings.isEmpty()) {
            if (executionSettings.containsKey(PromptExecutionSettings.DEFAULT_SERVICE_ID)) {
                settings = executionSettings.get(PromptExecutionSettings.DEFAULT_SERVICE_ID);
            } else {
                settings = executionSettings.values().stream().findFirst().orElseGet(() -> null);
            }
        }

        if (service != null) {
            return castServiceSelection(new AIServiceSelection<>(service, settings));
        }

        LOGGER.warn("No service found meeting requirements");
        return null;
    }

    private AIService getServiceByModelId(String modelId) {
        return services
            .values()
            .stream()
            .filter(s -> modelId.equalsIgnoreCase(s.getModelId()))
            .findFirst()
            .orElseGet(() -> null);
    }

    /**
     * Gets the service with the specified service id.
     *
     * @param serviceId The service id.
     * @return The service with the specified service id, or {@code null} if no such service exists.
     * @see AIService#getServiceId()
     */
    @Nullable
    public AIService getService(String serviceId) {
        return services
            .values()
            .stream()
            .filter(s -> serviceId.equalsIgnoreCase(s.getServiceId()))
            .findFirst()
            .orElseGet(() -> null);
    }

    /**
     * Gets the service of the specified type.
     *
     * @param clazz The type of service to get.
     * @param    The type of service to get.
     * @return The service of the specified type, or {@code null} if no such service exists.
     */
    @Nullable
    @SuppressWarnings("unchecked")
    public  T getService(Class clazz) {
        T service = (T) services.get(clazz);

        if (service == null) {
            service = (T) services
                .entrySet()
                .stream()
                .filter(it -> clazz.isAssignableFrom(it.getKey()))
                .map(it -> it.getValue())
                .findFirst()
                .orElseGet(() -> null);
        }

        if (service == null &&
            (clazz.equals(TextGenerationService.class) ||
                clazz.equals(ChatCompletionService.class))) {
            LOGGER.warn(
                "Requested a non-existent service type of {}. Consider requesting a TextAIService instead.",
                clazz.getName());
        }

        return service;
    }

    /**
     * Gets the first service of the specified type, or that is an instance of the specified type.
     *
     * @param serviceType The type of service to get.
     * @return The first service of the specified type, or {@code null} if no such service exists.
     */
    @Nullable
    AIService getAnyService(Class serviceType) {
        List matchingServices = getServices(serviceType);
        if (matchingServices.isEmpty()) {
            return null;
        }
        return matchingServices.get(0);
    }

    private List getServices(Class serviceType) {
        return services
            .entrySet()
            .stream()
            .filter(it -> serviceType.isAssignableFrom(it.getKey()))
            .map(Map.Entry::getValue)
            .collect(Collectors.toList());
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy