All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.microsoft.semantickernel.aiservices.openai.textcompletion.OpenAITextGenerationService Maven / Gradle / Ivy
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.aiservices.openai.textcompletion;
import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.models.CompletionsOptions;
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.aiservices.openai.OpenAiService;
import com.microsoft.semantickernel.aiservices.openai.implementation.OpenAIRequestSettings;
import com.microsoft.semantickernel.exceptions.AIException;
import com.microsoft.semantickernel.exceptions.AIException.ErrorCodes;
import com.microsoft.semantickernel.orchestration.FunctionResultMetadata;
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
import com.microsoft.semantickernel.services.textcompletion.StreamingTextContent;
import com.microsoft.semantickernel.services.textcompletion.TextContent;
import com.microsoft.semantickernel.services.textcompletion.TextGenerationService;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.text.StringEscapeUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
/**
* An OpenAI implementation of a {@link TextGenerationService}.
*/
public class OpenAITextGenerationService extends OpenAiService implements TextGenerationService {
private static final Logger LOGGER = LoggerFactory.getLogger(OpenAITextGenerationService.class);
/**
* Creates a new {@link OpenAITextGenerationService}.
*
* @param client OpenAI client
* @param modelId OpenAI model id
* @param serviceId Service id
*/
protected OpenAITextGenerationService(
OpenAIAsyncClient client,
String modelId,
@Nullable String serviceId,
String deploymentName) {
super(client, serviceId, modelId, deploymentName);
}
/**
* Creates a builder for creating a {@link OpenAITextGenerationService}.
*
* @return A new {@link OpenAITextGenerationService} builder.
*/
public static Builder builder() {
return new Builder();
}
@Override
public Mono> getTextContentsAsync(
String prompt,
@Nullable PromptExecutionSettings executionSettings,
@Nullable Kernel kernel) {
return this.internalCompleteTextAsync(prompt, executionSettings);
}
@Override
public Flux getStreamingTextContentsAsync(
String prompt,
@Nullable PromptExecutionSettings executionSettings,
@Nullable Kernel kernel) {
return this
.internalCompleteTextAsync(prompt, executionSettings)
.flatMapMany(it -> Flux.fromStream(it.stream())
.map(StreamingTextContent::new));
}
protected Mono> internalCompleteTextAsync(
String text,
@Nullable PromptExecutionSettings requestSettings) {
CompletionsOptions completionsOptions = getCompletionsOptions(text, requestSettings);
return getClient()
.getCompletionsWithResponse(getDeploymentName(), completionsOptions,
OpenAIRequestSettings.getRequestOptions())
.flatMap(completionsResult -> {
if (completionsResult.getStatusCode() >= 400) {
return Mono.error(new AIException(ErrorCodes.SERVICE_ERROR,
"Request failed: " + completionsResult.getStatusCode()));
}
return Mono.just(completionsResult.getValue());
})
.map(completions -> {
FunctionResultMetadata metadata = FunctionResultMetadata.build(
completions.getId(),
completions.getUsage(),
completions.getCreatedAt());
return completions
.getChoices()
.stream()
.map(choice -> {
return new TextContent(
choice.getText(),
completionsOptions.getModel(),
metadata);
})
.collect(Collectors.toList());
});
}
private CompletionsOptions getCompletionsOptions(
String text,
@Nullable PromptExecutionSettings requestSettings) {
text = StringEscapeUtils.unescapeXml(text);
if (requestSettings == null) {
return new CompletionsOptions(Collections.singletonList(text))
.setMaxTokens(PromptExecutionSettings.DEFAULT_MAX_TOKENS);
}
if (requestSettings.getMaxTokens() < 1) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST, "Max tokens must be >0");
}
if (requestSettings.getResultsPerPrompt() < 1
|| requestSettings.getResultsPerPrompt() > MAX_RESULTS_PER_PROMPT) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
String.format("Results per prompt must be in range between 1 and %d, inclusive.",
MAX_RESULTS_PER_PROMPT));
}
CompletionsOptions options = new CompletionsOptions(Collections.singletonList(text))
.setMaxTokens(requestSettings.getMaxTokens())
.setTemperature(requestSettings.getTemperature())
.setTopP(requestSettings.getTopP())
.setFrequencyPenalty(requestSettings.getFrequencyPenalty())
.setPresencePenalty(requestSettings.getPresencePenalty())
.setModel(getModelId())
.setN(requestSettings.getResultsPerPrompt())
.setUser(requestSettings.getUser())
.setBestOf(requestSettings.getBestOf())
.setLogitBias(new HashMap<>());
return options;
}
/**
* Builder for a TextGenerationService
*/
public static class Builder extends TextGenerationService.Builder {
@Override
public TextGenerationService build() {
if (this.client == null) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
"OpenAI client must be provided");
}
if (this.modelId == null) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
"OpenAI model id must be provided");
}
if (deploymentName == null) {
LOGGER.debug("Deployment name is not provided, using model id as deployment name");
deploymentName = modelId;
}
return new OpenAITextGenerationService(
this.client,
this.modelId,
this.serviceId,
this.deploymentName);
}
}
}