com.microsoft.semantickernel.semanticfunctions.KernelFunctionFromPrompt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of semantickernel-api Show documentation
Show all versions of semantickernel-api Show documentation
Defines the public interface for the Semantic Kernel
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.semanticfunctions;
import com.azure.core.exception.HttpResponseException;
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.contextvariables.ContextVariable;
import com.microsoft.semantickernel.contextvariables.ContextVariableType;
import com.microsoft.semantickernel.hooks.FunctionInvokedEvent;
import com.microsoft.semantickernel.hooks.FunctionInvokingEvent;
import com.microsoft.semantickernel.hooks.KernelHooks;
import com.microsoft.semantickernel.hooks.PromptRenderedEvent;
import com.microsoft.semantickernel.hooks.PromptRenderingEvent;
import com.microsoft.semantickernel.localization.SemanticKernelResources;
import com.microsoft.semantickernel.orchestration.FunctionResult;
import com.microsoft.semantickernel.orchestration.InvocationContext;
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
import com.microsoft.semantickernel.services.AIService;
import com.microsoft.semantickernel.services.AIServiceSelection;
import com.microsoft.semantickernel.services.TextAIService;
import com.microsoft.semantickernel.services.chatcompletion.AuthorRole;
import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService;
import com.microsoft.semantickernel.services.textcompletion.TextGenerationService;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
/**
* A {@link KernelFunction} implementation that is created from a prompt template.
*
* @param the type of the return value of the function
*/
public class KernelFunctionFromPrompt extends KernelFunction {
private static final Logger LOGGER = LoggerFactory.getLogger(KernelFunctionFromPrompt.class);
private final PromptTemplate template;
/**
* Creates a new instance of {@link KernelFunctionFromPrompt}.
*
* @param template the prompt template to use for the function
* @param promptConfig the configuration for the prompt
* @param executionSettings the execution settings to use when invoking the function
*/
protected KernelFunctionFromPrompt(
PromptTemplate template,
PromptTemplateConfig promptConfig,
@Nullable Map executionSettings) {
super(
new KernelFunctionMetadata<>(
null,
getName(promptConfig),
promptConfig.getDescription(),
promptConfig.getKernelParametersMetadata(),
promptConfig.getKernelReturnParameterMetadata()),
executionSettings != null ? executionSettings : promptConfig.getExecutionSettings());
this.template = template;
}
private static String getName(PromptTemplateConfig promptConfig) {
if (promptConfig.getName() == null) {
return UUID.randomUUID().toString();
} else {
return promptConfig.getName();
}
}
/**
* Creates a new instance of {@link Builder}.
*
* @param The type of the return value of the function
* @return a new instance of {@link Builder}
*/
public static Builder builder() {
return new Builder<>();
}
/**
* Creates a new instance of {@link Builder}.
*
* @param returnType The type of the return value of the function
* @param The type of the return value of the function
* @return a new instance of {@link Builder}
*/
public static Builder builder(Class returnType) {
return new Builder<>();
}
private Flux> invokeInternalAsync(
Kernel kernel,
@Nullable KernelFunctionArguments argumentsIn,
@Nullable ContextVariableType contextVariableType,
@Nullable InvocationContext invocationContext) {
InvocationContext context = invocationContext != null ? invocationContext
: InvocationContext.builder().build();
// must be effectively final for lambda
KernelHooks kernelHooks = KernelHooks.merge(
kernel.getGlobalKernelHooks(),
context.getKernelHooks());
PromptRenderingEvent preRenderingHookResult = kernelHooks
.executeHooks(new PromptRenderingEvent(this, argumentsIn));
KernelFunctionArguments arguments = preRenderingHookResult.getArguments();
// TODO: put in method, add catch for classcastexception, fallback to noopconverter
ContextVariableType variableType = contextVariableType != null
? contextVariableType
: context.getContextVariableTypes().getVariableTypeForClass(
(Class) this.getMetadata().getOutputVariableType().getType());
return this.template
.renderAsync(kernel, arguments, context)
.flatMapMany(prompt -> {
PromptRenderedEvent promptHookResult = kernelHooks
.executeHooks(new PromptRenderedEvent(this, arguments, prompt));
prompt = promptHookResult.getPrompt();
KernelFunctionArguments args = promptHookResult.getArguments();
LOGGER.info(SemanticKernelResources.getString("rendered.prompt"), prompt);
FunctionInvokingEvent updateArguments = kernelHooks
.executeHooks(new FunctionInvokingEvent(this, args));
args = updateArguments.getArguments();
AIServiceSelection> aiServiceSelection = kernel
.getServiceSelector()
.trySelectAIService(
TextAIService.class,
this,
args);
AIService client = aiServiceSelection != null ? aiServiceSelection.getService()
: null;
if (aiServiceSelection == null) {
throw new IllegalStateException(
"Failed to initialise aiService, could not find any TextAIService implementations");
}
Flux> result;
// settings from prompt or use default
PromptExecutionSettings executionSettings = aiServiceSelection.getSettings();
if (client instanceof ChatCompletionService) {
InvocationContext contextWithExecutionSettings = context;
if (context.getPromptExecutionSettings() == null) {
contextWithExecutionSettings = InvocationContext.copy(context)
.withPromptExecutionSettings(executionSettings)
.build();
}
result = ((ChatCompletionService) client)
.getChatMessageContentsAsync(
prompt,
kernel,
contextWithExecutionSettings)
.flatMapMany(Flux::fromIterable)
.concatMap(chatMessageContent -> {
if (chatMessageContent.getAuthorRole() == AuthorRole.ASSISTANT) {
T value = variableType
.getConverter()
.fromObject(chatMessageContent);
if (value == null) {
value = variableType
.getConverter()
.fromPromptString(
chatMessageContent.getContent());
}
if (value == null) {
return Flux.empty();
}
return Flux.just(
new FunctionResult<>(
new ContextVariable<>(variableType, value),
chatMessageContent.getMetadata(),
chatMessageContent));
}
return Flux.empty();
})
.map(it -> {
return new FunctionResult<>(
new ContextVariable<>(
variableType,
it.getResult() != null
? variableType.of(it.getResult()).getValue()
: null),
it.getMetadata(),
it.getUnconvertedResult());
});
} else if (client instanceof TextGenerationService) {
result = ((TextGenerationService) client)
.getTextContentsAsync(
prompt,
executionSettings,
kernel)
.flatMapMany(Flux::fromIterable)
.concatMap(textContent -> {
T value = variableType
.getConverter()
.fromObject(textContent);
if (value == null) {
value = variableType
.getConverter()
.fromPromptString(textContent.getContent());
}
return Flux.just(
new FunctionResult<>(
new ContextVariable<>(
variableType,
value),
textContent.getMetadata(),
textContent));
});
} else {
return Flux.error(new IllegalStateException("Unknown service type"));
}
return result
.map(it -> {
FunctionInvokedEvent updatedResult = kernelHooks
.executeHooks(
new FunctionInvokedEvent<>(
this,
arguments,
it));
return updatedResult.getResult();
});
})
.doOnError(
ex -> {
LOGGER.warn(
SemanticKernelResources.getString(
"something.went.wrong.while.rendering.the.semantic.function.or.while.executing.the.text.completion.function.error"),
getPluginName(),
getName(),
ex.getMessage());
// Common message when you attempt to send text completion
// requests to a chat completion model:
// "logprobs, best_of and echo parameters are not
// available on gpt-35-turbo model"
if (ex instanceof HttpResponseException
&& ((HttpResponseException) ex).getResponse().getStatusCode() == 400
&& ex.getMessage() != null
&& ex.getMessage().contains("parameters are not available on")) {
LOGGER.warn(
SemanticKernelResources.getString(
"this.error.indicates.that.you.have.attempted.to.use.a.chat.completion.model"));
}
});
}
@Override
public Mono> invokeAsync(
Kernel kernel,
@Nullable KernelFunctionArguments arguments,
@Nullable ContextVariableType variableType,
@Nullable InvocationContext invocationContext) {
return invokeInternalAsync(kernel, arguments, variableType, invocationContext)
.takeLast(1).single();
}
/**
* A builder for creating a {@link KernelFunction} from a prompt template.
*
* @param the type of the return value of the function
*/
public static final class Builder implements FromPromptBuilder {
@Nullable
private PromptTemplate promptTemplate;
@Nullable
private String name;
@Nullable
private Map executionSettings = null;
@Nullable
private String description;
@Nullable
private List inputVariables;
@Nullable
private String template;
private String templateFormat = PromptTemplateConfig.SEMANTIC_KERNEL_TEMPLATE_FORMAT;
@Nullable
private OutputVariable> outputVariable;
@Nullable
private PromptTemplateFactory promptTemplateFactory;
@Nullable
private PromptTemplateConfig promptTemplateConfig;
@Override
public FromPromptBuilder withName(@Nullable String name) {
this.name = name;
return this;
}
@Override
public FromPromptBuilder withInputParameters(
@Nullable List inputVariables) {
if (inputVariables != null) {
this.inputVariables = new ArrayList<>(inputVariables);
} else {
this.inputVariables = null;
}
return this;
}
@Override
public FromPromptBuilder withPromptTemplate(@Nullable PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}
@Override
public FromPromptBuilder withExecutionSettings(
@Nullable Map executionSettings) {
if (this.executionSettings == null) {
this.executionSettings = new HashMap<>();
}
if (executionSettings != null) {
this.executionSettings.putAll(executionSettings);
}
return this;
}
@Override
public FromPromptBuilder withDefaultExecutionSettings(
@Nullable PromptExecutionSettings executionSettings) {
if (executionSettings == null) {
return this;
}
if (this.executionSettings == null) {
this.executionSettings = new HashMap<>();
}
this.executionSettings.put(PromptExecutionSettings.DEFAULT_SERVICE_ID,
executionSettings);
if (executionSettings.getServiceId() != null) {
this.executionSettings.put(executionSettings.getServiceId(), executionSettings);
}
return this;
}
@Override
public FromPromptBuilder withDescription(@Nullable String description) {
this.description = description;
return this;
}
@Override
public FromPromptBuilder withTemplate(@Nullable String template) {
this.template = template;
return this;
}
@Override
public FromPromptBuilder withTemplateFormat(String templateFormat) {
this.templateFormat = templateFormat;
return this;
}
@Override
public FromPromptBuilder withOutputVariable(
@Nullable OutputVariable outputVariable) {
this.outputVariable = outputVariable;
return (FromPromptBuilder) this;
}
@Override
public FromPromptBuilder withOutputVariable(@Nullable String description, String type) {
return this.withOutputVariable(new OutputVariable(type, description));
}
@Override
public FromPromptBuilder withPromptTemplateFactory(
@Nullable PromptTemplateFactory promptTemplateFactory) {
this.promptTemplateFactory = promptTemplateFactory;
return this;
}
@Override
public FromPromptBuilder withPromptTemplateConfig(
@Nullable PromptTemplateConfig promptTemplateConfig) {
this.promptTemplateConfig = promptTemplateConfig;
return this;
}
@Override
public KernelFunction build() {
if (templateFormat == null) {
templateFormat = PromptTemplateConfig.SEMANTIC_KERNEL_TEMPLATE_FORMAT;
}
if (name == null) {
name = UUID.randomUUID().toString();
}
if (promptTemplateFactory == null) {
promptTemplateFactory = new KernelPromptTemplateFactory();
}
if (promptTemplateConfig != null) {
if (promptTemplate == null) {
promptTemplate = promptTemplateFactory.tryCreate(promptTemplateConfig);
}
return new KernelFunctionFromPrompt<>(
promptTemplate,
promptTemplateConfig,
executionSettings);
}
PromptTemplateConfig config = new PromptTemplateConfig(
name,
template,
templateFormat,
Collections.emptySet(),
description,
inputVariables,
outputVariable,
executionSettings);
PromptTemplate temp;
if (promptTemplate != null) {
temp = promptTemplate;
} else {
temp = new KernelPromptTemplateFactory().tryCreate(config);
}
return new KernelFunctionFromPrompt<>(temp, config, executionSettings);
}
}
}