All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.semantickernel.aiservices.huggingface.models.chat.HuggingFaceChatCompletionService.ignore Maven / Gradle / Ivy

package com.microsoft.semantickernel.aiservices.huggingface.services;

import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.aiservices.huggingface.HuggingFaceClient;
import com.microsoft.semantickernel.aiservices.huggingface.models.ChatCompletionRequest;
import com.microsoft.semantickernel.aiservices.huggingface.models.HuggingFaceXMLPromptParser;
import com.microsoft.semantickernel.aiservices.huggingface.models.HuggingFaceXMLPromptParser.HuggingFaceParsedPrompt;
import com.microsoft.semantickernel.orchestration.InvocationContext;
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
import com.microsoft.semantickernel.services.chatcompletion.AuthorRole;
import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService;
import com.microsoft.semantickernel.services.chatcompletion.ChatHistory;
import com.microsoft.semantickernel.services.chatcompletion.ChatMessageContent;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import reactor.core.publisher.Mono;

// TODO Support TGI
public class HuggingFaceChatCompletionService implements ChatCompletionService {

    private final String modelId;
    private final String serviceId;
    private final HuggingFaceClient client;

    public HuggingFaceChatCompletionService(
        String modelId,
        String serviceId,
        HuggingFaceClient client) {
        this.modelId = modelId;
        this.serviceId = serviceId;
        this.client = client;
    }

    public Mono>> getChatMessageContentsAsync(
        ChatHistory chatHistory,
        @Nullable Kernel kernel,
        @Nullable HuggingFacePromptExecutionSettings executionSettings) {

        String model = modelId;
        if (executionSettings.getModelId() != null && !executionSettings.getModelId().isEmpty()) {
            model = executionSettings.getModelId();
        }

        ChatCompletionRequest request = new ChatCompletionRequest(
            model,
            false,
            chatHistory
                .getMessages()
                .stream()
                .map(
                    message -> {
                        return new ChatCompletionRequest.ChatMessage(
                            message.getAuthorRole().name(),
                            message.getContent(),
                            null,
                            null
                        );
                    }
                )
                .collect(Collectors.toList()),
            executionSettings.getLogprobs(),
            executionSettings.getTopLogProbs(),
            executionSettings.getMaxTokens(),
            new Float(executionSettings.getPresencePenalty()),
            executionSettings.getStopSequences(),
            executionSettings.getSeed(),
            new Float(executionSettings.getTemperature()),
            new Float(executionSettings.getTopP())
        );

        return client
            .getChatMessageContentsAsync(modelId, request)
            .map(result -> {
                return Collections.singletonList(new ChatMessageContent<>(
                    AuthorRole.SYSTEM,
                    result)
                );
            });
    }

    @Override
    public Mono>> getChatMessageContentsAsync(
        ChatHistory chatHistory,
        @Nullable Kernel kernel,
        @Nullable InvocationContext invocationContext) {

        HuggingFacePromptExecutionSettings executionSettings;
        if (invocationContext != null && invocationContext.getPromptExecutionSettings() != null) {
            executionSettings = HuggingFacePromptExecutionSettings.fromExecutionSettings(
                invocationContext.getPromptExecutionSettings());
        } else {
            executionSettings = new HuggingFacePromptExecutionSettings(
                PromptExecutionSettings.builder().build());
        }

        return getChatMessageContentsAsync(chatHistory, kernel, executionSettings);

    }

    @Override
    public Mono>> getChatMessageContentsAsync(
        String prompt,
        @Nullable Kernel kernel,
        @Nullable InvocationContext invocationContext) {
        HuggingFaceParsedPrompt parsed = HuggingFaceXMLPromptParser.parse(prompt);

        ChatHistory history = new ChatHistory();
        parsed.getChatRequestMessages()
            .forEach(message -> {
                history.addMessage(AuthorRole.valueOf(message.role.toUpperCase(Locale.ROOT)),
                    message.content);
            });

        return getChatMessageContentsAsync(history, kernel, invocationContext);

    }

    @Nullable
    @Override
    public String getModelId() {
        return modelId;
    }

    @Nullable
    @Override
    public String getServiceId() {
        return serviceId;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {

        @Nullable
        private String modelId;
        @Nullable
        private HuggingFaceClient client;
        @Nullable
        private String serviceId;

        /**
         * Sets the model ID for the service
         *
         * @param modelId The model ID
         * @return The builder
         */
        public Builder withModelId(String modelId) {
            this.modelId = modelId;
            return this;
        }

        /**
         * Sets the service ID for the service
         *
         * @param serviceId The service ID
         * @return The builder
         */
        public Builder withServiceId(String serviceId) {
            this.serviceId = serviceId;
            return this;
        }

        public Builder withHuggingFaceClient(HuggingFaceClient client) {
            this.client = client;
            return this;
        }

        public ChatCompletionService build() {
            return new HuggingFaceChatCompletionService(
                this.modelId,
                this.serviceId,
                this.client);
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy