
com.microsoft.semantickernel.aiservices.huggingface.models.chat.HuggingFaceChatCompletionService.ignore Maven / Gradle / Ivy
package com.microsoft.semantickernel.aiservices.huggingface.services;
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.aiservices.huggingface.HuggingFaceClient;
import com.microsoft.semantickernel.aiservices.huggingface.models.ChatCompletionRequest;
import com.microsoft.semantickernel.aiservices.huggingface.models.HuggingFaceXMLPromptParser;
import com.microsoft.semantickernel.aiservices.huggingface.models.HuggingFaceXMLPromptParser.HuggingFaceParsedPrompt;
import com.microsoft.semantickernel.orchestration.InvocationContext;
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
import com.microsoft.semantickernel.services.chatcompletion.AuthorRole;
import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService;
import com.microsoft.semantickernel.services.chatcompletion.ChatHistory;
import com.microsoft.semantickernel.services.chatcompletion.ChatMessageContent;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import reactor.core.publisher.Mono;
// TODO Support TGI
public class HuggingFaceChatCompletionService implements ChatCompletionService {
private final String modelId;
private final String serviceId;
private final HuggingFaceClient client;
public HuggingFaceChatCompletionService(
String modelId,
String serviceId,
HuggingFaceClient client) {
this.modelId = modelId;
this.serviceId = serviceId;
this.client = client;
}
public Mono>> getChatMessageContentsAsync(
ChatHistory chatHistory,
@Nullable Kernel kernel,
@Nullable HuggingFacePromptExecutionSettings executionSettings) {
String model = modelId;
if (executionSettings.getModelId() != null && !executionSettings.getModelId().isEmpty()) {
model = executionSettings.getModelId();
}
ChatCompletionRequest request = new ChatCompletionRequest(
model,
false,
chatHistory
.getMessages()
.stream()
.map(
message -> {
return new ChatCompletionRequest.ChatMessage(
message.getAuthorRole().name(),
message.getContent(),
null,
null
);
}
)
.collect(Collectors.toList()),
executionSettings.getLogprobs(),
executionSettings.getTopLogProbs(),
executionSettings.getMaxTokens(),
new Float(executionSettings.getPresencePenalty()),
executionSettings.getStopSequences(),
executionSettings.getSeed(),
new Float(executionSettings.getTemperature()),
new Float(executionSettings.getTopP())
);
return client
.getChatMessageContentsAsync(modelId, request)
.map(result -> {
return Collections.singletonList(new ChatMessageContent<>(
AuthorRole.SYSTEM,
result)
);
});
}
@Override
public Mono>> getChatMessageContentsAsync(
ChatHistory chatHistory,
@Nullable Kernel kernel,
@Nullable InvocationContext invocationContext) {
HuggingFacePromptExecutionSettings executionSettings;
if (invocationContext != null && invocationContext.getPromptExecutionSettings() != null) {
executionSettings = HuggingFacePromptExecutionSettings.fromExecutionSettings(
invocationContext.getPromptExecutionSettings());
} else {
executionSettings = new HuggingFacePromptExecutionSettings(
PromptExecutionSettings.builder().build());
}
return getChatMessageContentsAsync(chatHistory, kernel, executionSettings);
}
@Override
public Mono>> getChatMessageContentsAsync(
String prompt,
@Nullable Kernel kernel,
@Nullable InvocationContext invocationContext) {
HuggingFaceParsedPrompt parsed = HuggingFaceXMLPromptParser.parse(prompt);
ChatHistory history = new ChatHistory();
parsed.getChatRequestMessages()
.forEach(message -> {
history.addMessage(AuthorRole.valueOf(message.role.toUpperCase(Locale.ROOT)),
message.content);
});
return getChatMessageContentsAsync(history, kernel, invocationContext);
}
@Nullable
@Override
public String getModelId() {
return modelId;
}
@Nullable
@Override
public String getServiceId() {
return serviceId;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
@Nullable
private String modelId;
@Nullable
private HuggingFaceClient client;
@Nullable
private String serviceId;
/**
* Sets the model ID for the service
*
* @param modelId The model ID
* @return The builder
*/
public Builder withModelId(String modelId) {
this.modelId = modelId;
return this;
}
/**
* Sets the service ID for the service
*
* @param serviceId The service ID
* @return The builder
*/
public Builder withServiceId(String serviceId) {
this.serviceId = serviceId;
return this;
}
public Builder withHuggingFaceClient(HuggingFaceClient client) {
this.client = client;
return this;
}
public ChatCompletionService build() {
return new HuggingFaceChatCompletionService(
this.modelId,
this.serviceId,
this.client);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy