All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.semantickernel.orchestration.PromptExecutionSettings Maven / Gradle / Ivy

There is a newer version: 1.3.0
Show newest version
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.orchestration;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.microsoft.semantickernel.exceptions.SKException;
import com.microsoft.semantickernel.orchestration.responseformat.JsonObjectResponseFormat;
import com.microsoft.semantickernel.orchestration.responseformat.JsonSchemaResponseFormat;
import com.microsoft.semantickernel.orchestration.responseformat.ResponseFormat;
import com.microsoft.semantickernel.orchestration.responseformat.TextResponseFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import javax.annotation.Nullable;

/**
 * Configuration settings for prompt execution.
 */
public class PromptExecutionSettings {

    /**
     * The default for {@link #getServiceId()} if a {@link Builder#withServiceId(String) service id}
     * is not provided. Defaults to {@code "default"}
     */
    public static final String DEFAULT_SERVICE_ID = "default";

    /**
     * The default for {@link #getMaxTokens()} if {@link Builder#withMaxTokens(int) max_tokens} is
     * not provided. Defaults to {@code 256}
     */
    public static final int DEFAULT_MAX_TOKENS = 256;

    /**
     * The default for {@link #getTemperature()} if
     * {@link Builder#withTemperature(double) temperature} is not provided. Defaults to {@code 1.0}
     */
    public static final double DEFAULT_TEMPERATURE = 1.0;

    /**
     * The default for {@link #getTopP()} if {@link Builder#withTopP(double) top_p} is not provided.
     * Defaults to {@code 1.0}
     */
    public static final double DEFAULT_TOP_P = 1.0;

    /**
     * The default for {@link #getPresencePenalty()} if
     * {@link Builder#withPresencePenalty(double) presence_penalty} is not provided. Defaults to
     * {@code 0.0}
     */
    public static final double DEFAULT_PRESENCE_PENALTY = 0.0;

    /**
     * The default for {@link #getFrequencyPenalty()} if
     * {@link Builder#withFrequencyPenalty(double)} frequency_penalty} is not provided. Defaults to
     * {@code 0.0}
     */
    public static final double DEFAULT_FREQUENCY_PENALTY = 0.0;

    /**
     * The default for {@link #getBestOf()} if {@link Builder#withBestOf(int) best_of} is not
     * provided. Defaults to {@code 1}
     */
    public static final int DEFAULT_BEST_OF = 1;

    /**
     * The default for {@link #getResultsPerPrompt()} if
     * {@link Builder#withResultsPerPrompt(int) results per prompt (n)} is not provided. Defaults to
     * {@code 1}
     */
    public static final int DEFAULT_RESULTS_PER_PROMPT = 1;

    // 
    // Keys used as both @JsonProperty names and keys to the Builder's value map.
    //
    private static final String SERVICE_ID = "service_id";
    private static final String MODEL_ID = "model_id";
    private static final String TEMPERATURE = "temperature";
    private static final String TOP_P = "top_p";
    private static final String PRESENCE_PENALTY = "presence_penalty";
    private static final String FREQUENCY_PENALTY = "frequency_penalty";
    private static final String MAX_TOKENS = "max_tokens";
    private static final String BEST_OF = "best_of";
    private static final String USER = "user";
    private static final String STOP_SEQUENCES = "stop_sequences";
    private static final String RESULTS_PER_PROMPT = "results_per_prompt";
    private static final String TOKEN_SELECTION_BIASES = "token_selection_biases";
    private static final String RESPONSE_FORMAT = "response_format";

    private final String serviceId;
    private final String modelId;
    private final double temperature;
    private final double topP;
    private final double presencePenalty;
    private final double frequencyPenalty;
    private final int maxTokens;
    private final int bestOf;
    private final int resultsPerPrompt;
    private final String user;
    private final List stopSequences;
    private final Map tokenSelectionBiases;
    private final ResponseFormat responseFormat;

    /**
     * Create a new instance of PromptExecutionSettings.
     *
     * @param serviceId            The id of the AI service to use for prompt execution.
     * @param modelId              The id of the model to use for prompt execution.
     * @param temperature          The temperature setting for prompt execution.
     * @param topP                 The topP setting for prompt execution.
     * @param presencePenalty      The presence penalty setting for prompt execution.
     * @param frequencyPenalty     The frequency penalty setting for prompt execution.
     * @param maxTokens            The maximum number of tokens to generate in the output.
     * @param resultsPerPrompt     The number of results to generate for each prompt.
     * @param bestOf               The best of setting for prompt execution.
     * @param user                 The user to associate with the prompt execution.
     * @param stopSequences        The stop sequences to use for prompt execution.
     * @param tokenSelectionBiases The token selection biases to use for prompt execution.
     * @param responseFormat       The response format to use for prompt execution
     *                             {@link ResponseFormat}, Defaults to TextResponseFormat.
     */
    @JsonCreator
    public PromptExecutionSettings(
        @JsonProperty(SERVICE_ID) String serviceId,
        @JsonProperty(MODEL_ID) String modelId,
        @JsonProperty(TEMPERATURE) Double temperature,
        @JsonProperty(TOP_P) Double topP,
        @JsonProperty(PRESENCE_PENALTY) Double presencePenalty,
        @JsonProperty(FREQUENCY_PENALTY) Double frequencyPenalty,
        @JsonProperty(MAX_TOKENS) Integer maxTokens,
        @JsonProperty(RESULTS_PER_PROMPT) Integer resultsPerPrompt,
        @JsonProperty(BEST_OF) Integer bestOf,
        @JsonProperty(USER) String user,
        @Nullable @JsonProperty(STOP_SEQUENCES) List stopSequences,
        @Nullable @JsonProperty(TOKEN_SELECTION_BIASES) Map tokenSelectionBiases,
        @Nullable @JsonProperty(RESPONSE_FORMAT) ResponseFormat responseFormat) {
        this.serviceId = serviceId != null ? serviceId : DEFAULT_SERVICE_ID;
        this.modelId = modelId != null ? modelId : "";
        this.temperature = clamp(temperature, 0d, 2d, DEFAULT_TEMPERATURE);
        this.topP = clamp(topP, 0d, 1d, DEFAULT_TOP_P);
        this.presencePenalty = clamp(presencePenalty, -2d, 2d, DEFAULT_PRESENCE_PENALTY);
        this.frequencyPenalty = clamp(frequencyPenalty, -2d, 2d, DEFAULT_FREQUENCY_PENALTY);
        this.maxTokens = clamp(maxTokens, 1, Integer.MAX_VALUE, DEFAULT_MAX_TOKENS);
        this.resultsPerPrompt = clamp(resultsPerPrompt, 1, Integer.MAX_VALUE,
            DEFAULT_RESULTS_PER_PROMPT);
        this.bestOf = clamp(bestOf, 1, Integer.MAX_VALUE, DEFAULT_BEST_OF);
        this.user = user != null ? user : "";
        this.stopSequences = stopSequences != null ? new ArrayList<>(stopSequences)
            : Collections.emptyList();
        this.tokenSelectionBiases = tokenSelectionBiases != null
            ? new HashMap<>(tokenSelectionBiases)
            : Collections.emptyMap();
        this.tokenSelectionBiases.replaceAll((k, v) -> clamp(v, -100, 100, 0));

        if (responseFormat == null) {
            this.responseFormat = new TextResponseFormat();
        } else {
            this.responseFormat = responseFormat;
        }
    }

    /**
     * Create a new builder for PromptExecutionSettings.
     *
     * @return A new builder for PromptExecutionSettings.
     */
    public static Builder builder() {
        return new Builder();
    }

    private static  T clamp(T value, T min, T max, T defaultValue) {
        if (value == null) {
            return defaultValue;
        }
        if (value.doubleValue() < min.doubleValue()) {
            return min;
        }
        if (value.doubleValue() > max.doubleValue()) {
            return max;
        }
        return value;
    }

    /**
     * Get the id of the AI service to use for prompt execution.
     *
     * @return The id of the AI service to use for prompt execution.
     */
    @JsonProperty(SERVICE_ID)
    public String getServiceId() {
        return serviceId;
    }

    /**
     * Get the id of the model to use for prompt execution.
     *
     * @return The id of the model to use for prompt execution.
     */
    @JsonProperty(MODEL_ID)
    public String getModelId() {
        return modelId;
    }

    /**
     * The temperature setting controls the randomness of the output. Lower values produce more
     * deterministic outputs, while higher values produce more random outputs.
     *
     * @return The temperature setting.
     */
    @JsonProperty(TEMPERATURE)
    public double getTemperature() {
        return Double.isNaN(temperature) ? DEFAULT_TEMPERATURE : temperature;
    }

    /**
     * The topP setting controls how many different words or phrases are considered to predict the
     * next token. The value is a probability threshold, and the model considers the most likely
     * tokens whose cumulative probability mass is greater than the threshold. For example, if the
     * value is 0.1, the model considers only the tokens that make up the top 10% of the cumulative
     * probability mass.
     *
     * @return The topP setting.
     */
    @JsonProperty(TOP_P)
    public double getTopP() {
        return topP;
    }

    /**
     * Presence penalty encourages the model to use a more or less diverse range of tokens in the
     * output. A higher value means that the model will try to use a greater variety of tokens in
     * the output.
     *
     * @return The presence penalty setting.
     */
    @JsonProperty(PRESENCE_PENALTY)
    public double getPresencePenalty() {
        return presencePenalty;
    }

    /**
     * Frequency penalty encourages the model to avoid repeating the same token in the output. A
     * higher value means that the model will be less likely to repeat a token.
     *
     * @return The frequency penalty setting.
     */
    @JsonProperty(FREQUENCY_PENALTY)
    public double getFrequencyPenalty() {
        return frequencyPenalty;
    }

    /**
     * The maximum number of tokens to generate in the output.
     *
     * @return The maximum number of tokens to generate in the output.
     */
    @JsonProperty(MAX_TOKENS)
    public int getMaxTokens() {
        return maxTokens;
    }

    /**
     * The number of results to generate for each prompt.
     *
     * @return The number of results to generate for each prompt.
     */
    @JsonProperty(RESULTS_PER_PROMPT)
    public int getResultsPerPrompt() {
        return resultsPerPrompt;
    }

    /**
     * The log probability threshold for a result to be considered.
     *
     * @return The log probability threshold for a result to be considered.
     */
    @JsonProperty(BEST_OF)
    public int getBestOf() {
        // TODO: not present in com.azure:azure-ai-openai
        return bestOf;
    }

    /**
     * The user to associate with the prompt execution.
     *
     * @return The user to associate with the prompt execution.
     */
    @JsonProperty(USER)
    public String getUser() {
        return user;
    }

    /**
     * The stop sequences to use for prompt execution.
     *
     * @return The stop sequences to use for prompt execution.
     */
    @JsonProperty(STOP_SEQUENCES)
    public List getStopSequences() {
        return Collections.unmodifiableList(stopSequences);
    }

    /**
     * The token selection biases to use for prompt execution. The key is the token id from the
     * tokenizer, and the value is the bias. A negative bias will make the model less likely to use
     * the token, and a positive bias will make the model more likely to use the token.
     *
     * @return The token selection biases to use for prompt execution.
     */
    @JsonProperty(TOKEN_SELECTION_BIASES)
    public Map getTokenSelectionBiases() {
        return Collections.unmodifiableMap(tokenSelectionBiases);
    }

    @Override
    public int hashCode() {
        return Objects.hash(serviceId, modelId, temperature, topP, presencePenalty,
            frequencyPenalty,
            maxTokens, bestOf, resultsPerPrompt, user, stopSequences, tokenSelectionBiases,
            responseFormat);
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!getClass().isInstance(obj)) {
            return false;
        }

        PromptExecutionSettings other = (PromptExecutionSettings) obj;

        if (!Objects.equals(serviceId, other.serviceId)) {
            return false;
        }
        if (!Objects.equals(modelId, other.modelId)) {
            return false;
        }
        if (Double.compare(temperature, other.temperature) != 0) {
            return false;
        }
        if (Double.compare(topP, other.topP) != 0) {
            return false;
        }
        if (Double.compare(presencePenalty, other.presencePenalty) != 0) {
            return false;
        }
        if (Double.compare(frequencyPenalty, other.frequencyPenalty) != 0) {
            return false;
        }
        if (maxTokens != other.maxTokens) {
            return false;
        }
        if (bestOf != other.bestOf) {
            return false;
        }
        if (resultsPerPrompt != other.resultsPerPrompt) {
            return false;
        }
        if (!Objects.equals(user, other.user)) {
            return false;
        }
        if (!Objects.equals(stopSequences, other.stopSequences)) {
            return false;
        }
        if (!Objects.equals(responseFormat, other.responseFormat)) {
            return false;
        }
        return Objects.equals(tokenSelectionBiases, other.tokenSelectionBiases);
    }

    /**
     * The response format to use for prompt execution. Currently this only applies to chat
     * completions.
     *
     * @return The response format to use for prompt execution.
     */
    @JsonProperty(RESPONSE_FORMAT)
    public ResponseFormat getResponseFormat() {
        return responseFormat;
    }

    /**
     * Builder for PromptExecutionSettings.
     */
    public static class Builder {

        Map settings = new HashMap<>();

        /**
         * Set the id of the AI service to use for prompt execution.
         *
         * @param serviceId The id of the AI service to use for prompt execution.
         * @return This builder.
         */
        public Builder withServiceId(String serviceId) {
            settings.put(SERVICE_ID, serviceId);
            return this;
        }

        /**
         * Set the id of the model to use for prompt execution.
         *
         * @param modelId The id of the model to use for prompt execution.
         * @return This builder.
         */
        public Builder withModelId(String modelId) {
            settings.put(MODEL_ID, modelId);
            return this;
        }

        /**
         * Set the temperature setting for prompt execution. The value is clamped to the range [0.0,
         * 2.0], and the default is 1.0.
         *
         * @param temperature The temperature setting for prompt execution.
         * @return This builder.
         */
        public Builder withTemperature(double temperature) {
            if (!Double.isNaN(temperature)) {
                settings.put(TEMPERATURE, temperature);
            }
            return this;
        }

        /**
         * Set the topP setting for prompt execution. The value is clamped to the range [0.0, 1.0],
         * and the default is 1.0.
         *
         * @param topP The topP setting for prompt execution.
         * @return This builder.
         */
        public Builder withTopP(double topP) {
            if (!Double.isNaN(topP)) {
                settings.put(TOP_P, topP);
            }
            return this;
        }

        /**
         * Set the presence penalty setting for prompt execution. The value is clamped to the range
         * [-2.0, 2.0], and the default is 0.0.
         *
         * @param presencePenalty The presence penalty setting for prompt execution.
         * @return This builder.
         */
        public Builder withPresencePenalty(double presencePenalty) {
            if (!Double.isNaN(presencePenalty)) {
                settings.put(PRESENCE_PENALTY, presencePenalty);
            }
            return this;
        }

        /**
         * Set the frequency penalty setting for prompt execution. The value is clamped to the range
         * [-2.0, 2.0], and the default is 0.0.
         *
         * @param frequencyPenalty The frequency penalty setting for prompt execution.
         * @return This builder.
         */
        public Builder withFrequencyPenalty(double frequencyPenalty) {
            if (!Double.isNaN(frequencyPenalty)) {
                settings.put(FREQUENCY_PENALTY, frequencyPenalty);
            }
            return this;
        }

        /**
         * Set the maximum number of tokens to generate in the output. The value is clamped to the
         * range [1, Integer.MAX_VALUE], and the default is 256.
         *
         * @param maxTokens The maximum number of tokens to generate in the output.
         * @return This builder.
         */
        public Builder withMaxTokens(int maxTokens) {
            settings.put(MAX_TOKENS, maxTokens);
            return this;
        }

        /**
         * Set the number of results to generate for each prompt. The value is clamped to the range
         * [1, Integer.MAX_VALUE], and the default is 1.
         *
         * @param resultsPerPrompt The number of results to generate for each prompt.
         * @return This builder.
         */
        public Builder withResultsPerPrompt(int resultsPerPrompt) {
            settings.put(RESULTS_PER_PROMPT, resultsPerPrompt);
            return this;
        }

        /**
         * Set the best of setting for prompt execution. The value is clamped to the range [1,
         * Integer.MAX_VALUE], and the default is 1.
         *
         * @param bestOf The best of setting for prompt execution.
         * @return This builder.
         */
        public Builder withBestOf(int bestOf) {
            settings.put(BEST_OF, bestOf);
            return this;
        }

        /**
         * Set the user to associate with the prompt execution.
         *
         * @param user The user to associate with the prompt execution.
         * @return This builder.
         */
        public Builder withUser(String user) {
            settings.put(USER, user);
            return this;
        }

        /**
         * Set the stop sequences to use for prompt execution.
         *
         * @param stopSequences The stop sequences to use for prompt execution.
         * @return This builder.
         */
        @SuppressWarnings("unchecked")
        public Builder withStopSequences(List stopSequences) {
            if (stopSequences != null) {
                ((List) settings.computeIfAbsent(STOP_SEQUENCES,
                    k -> new ArrayList<>())).addAll(stopSequences);
            }
            return this;
        }

        /**
         * Set the token selection biases to use for prompt execution. The bias values are clamped
         * to the range [-100, 100].
         *
         * @param tokenSelectionBiases The token selection biases to use for prompt execution.
         * @return This builder.
         */
        @SuppressWarnings("unchecked")
        public Builder withTokenSelectionBiases(Map tokenSelectionBiases) {
            if (tokenSelectionBiases != null) {
                ((Map) settings.computeIfAbsent(TOKEN_SELECTION_BIASES,
                    k -> new HashMap<>())).putAll(tokenSelectionBiases);
            }
            return this;
        }

        /**
         * Set the response format to use for prompt execution.
         *
         * @param responseFormat The response format to use for prompt execution.
         * @return This builder.
         */
        public Builder withResponseFormat(ResponseFormat responseFormat) {
            if (responseFormat != null) {
                settings.put(RESPONSE_FORMAT, responseFormat);
            }
            return this;
        }

        /**
         * Set the response format to use for prompt execution.
         *
         * @param responseFormat The response format to use for prompt execution.
         * @return This builder.
         */
        public Builder withResponseFormat(ResponseFormat.Type responseFormat) {
            switch (responseFormat) {
                case JSON_OBJECT:
                    settings.put(RESPONSE_FORMAT, new JsonObjectResponseFormat());
                    break;
                case TEXT:
                    settings.put(RESPONSE_FORMAT, new TextResponseFormat());
                    break;
                case JSON_SCHEMA:
                    throw new SKException(
                        "Cannot set JSON_SCHEMA response format without a schema, use withResponseFormat(ResponseFormat responseFormat)");
            }

            return this;
        }

        /**
         * Set the response format to use a json schema generated for the given class. The name of
         * the response format will be the name of the class.
         *
         * @param responseFormat The response format type.
         * @return This builder.
         */
        public Builder withJsonSchemaResponseFormat(Class responseFormat) {
            if (responseFormat != null) {
                settings.put(RESPONSE_FORMAT,
                    JsonSchemaResponseFormat.builder()
                        .setResponseFormat(responseFormat)
                        .setName(responseFormat.getSimpleName())
                        .build());
            }
            return this;
        }

        /**
         * Build the PromptExecutionSettings.
         *
         * @return A new PromptExecutionSettings from this builder.
         */
        @SuppressWarnings("unchecked")
        public PromptExecutionSettings build() {
            return new PromptExecutionSettings(
                (String) settings.getOrDefault(SERVICE_ID, DEFAULT_SERVICE_ID),
                (String) settings.getOrDefault(MODEL_ID, ""),
                (Double) settings.getOrDefault(TEMPERATURE, DEFAULT_TEMPERATURE),
                (Double) settings.getOrDefault(TOP_P, DEFAULT_TOP_P),
                (Double) settings.getOrDefault(PRESENCE_PENALTY, DEFAULT_PRESENCE_PENALTY),
                (Double) settings.getOrDefault(FREQUENCY_PENALTY, DEFAULT_FREQUENCY_PENALTY),
                (Integer) settings.getOrDefault(MAX_TOKENS, DEFAULT_MAX_TOKENS),
                (Integer) settings.getOrDefault(RESULTS_PER_PROMPT, DEFAULT_RESULTS_PER_PROMPT),
                (Integer) settings.getOrDefault(BEST_OF, DEFAULT_BEST_OF),
                (String) settings.getOrDefault(USER, ""),
                (List) settings.getOrDefault(STOP_SEQUENCES, Collections.emptyList()),
                (Map) settings.getOrDefault(TOKEN_SELECTION_BIASES,
                    Collections.emptyMap()),
                (ResponseFormat) settings.getOrDefault(RESPONSE_FORMAT, new TextResponseFormat()));
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy