All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hw.langchain.llms.openai.BaseOpenAI Maven / Gradle / Ivy

There is a newer version: 0.2.2
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.hw.langchain.llms.openai;

import com.hw.langchain.llms.base.BaseLLM;
import com.hw.langchain.schema.BaseMessage;
import com.hw.langchain.schema.Generation;
import com.hw.langchain.schema.LLMResult;
import com.hw.openai.OpenAiClient;
import com.hw.openai.entity.completions.Choice;
import com.hw.openai.entity.completions.Completion;
import com.hw.openai.entity.completions.CompletionResp;

import lombok.Builder;
import lombok.experimental.SuperBuilder;

import java.util.*;

import static com.google.common.base.Preconditions.checkArgument;

/**
 * Wrapper around OpenAI large language models.
 *
 * @author HamaWhite
 */
@SuperBuilder
public class BaseOpenAI extends BaseLLM {

    protected OpenAiClient client;

    /**
     * Model name to use.
     */
    @Builder.Default
    protected String model = "text-davinci-003";

    /**
     * What sampling temperature to use.
     */
    @Builder.Default
    protected float temperature = 0.7f;

    /**
     * The maximum number of tokens to generate in the completion.
     * -1 returns as many tokens as possible given the prompt and the model's maximal context size.
     */
    @Builder.Default
    protected int maxTokens = 256;

    /**
     * Total probability mass of tokens to consider at each step.
     */
    @Builder.Default
    protected float topP = 1f;

    /**
     * Penalizes repeated tokens according to frequency.
     */
    protected float frequencyPenalty;

    /**
     * Penalizes repeated tokens.
     */
    protected float presencePenalty;

    /**
     * How many completions to generate for each prompt.
     */
    @Builder.Default
    protected int n = 1;

    /**
     * Generates best_of completions server-side and returns the "best".
     */
    @Builder.Default
    protected int bestOf = 1;

    /**
     * API key for OpenAI.
     */
    protected String openaiApiKey;

    /**
     * Base URL for OpenAI API.
     */
    protected String openaiApiBase;

    /**
     * Organization ID for OpenAI.
     */
    protected String openaiOrganization;

    /**
     * Support explicit proxy for OpenAI
     */
    protected String openaiProxy;

    /**
     * Batch size to use when passing multiple documents to generate.
     */
    @Builder.Default
    protected int batchSize = 20;

    /**
     * Timeout for requests to OpenAI completion API. Default is 10 seconds.
     */
    @Builder.Default
    protected long requestTimeout = 10;

    /**
     * Adjust the probability of specific tokens being generated.
     */
    protected Map logitBias;

    /**
     * Maximum number of retries to make when generating.
     */
    @Builder.Default
    protected int maxRetries = 6;

    /**
     * Whether to stream the results or not.
     */
    protected boolean stream;

    /**
     * Set of special tokens that are allowed.
     */
    protected Set allowedSpecial;

    /**
     * Set of special tokens that are not allowed.
     */
    protected Set disallowedSpecial;

    @Override
    public String predict(String text, List stop) {
        return null;
    }

    @Override
    public BaseMessage predictMessages(List messages, List stop) {
        return null;
    }

    @Override
    public String llmType() {
        return "openai";
    }

    /**
     * Call out to OpenAI's endpoint with k unique prompts.
     *
     * @param prompts The prompts to pass into the model.
     * @param stop    list of stop words to use when generating.
     * @return The full LLM output.
     */
    @Override
    protected LLMResult _generate(List prompts, List stop) {
        List choices = new ArrayList<>();
        List> subPrompts = getSubPrompts(prompts);
        Completion completion = Completion.builder()
                .model(model)
                .temperature(temperature)
                .maxTokens(maxTokens)
                .topP(topP)
                .frequencyPenalty(frequencyPenalty)
                .presencePenalty(presencePenalty)
                .n(n)
                .logitBias(logitBias)
                .stop(stop)
                .build();

        for (var prompt : subPrompts) {
            completion.setPrompt(prompt);
            CompletionResp response = client.create(completion);
            choices.addAll(response.getChoices());
        }

        return createLLMResult(choices, prompts, Map.of());
    }

    /**
     * Create the LLMResult from the choices and prompts.
     */
    private LLMResult createLLMResult(List choices, List prompts, Map tokenUsage) {
        List> generations = new ArrayList<>();
        for (int i = 0; i < prompts.size(); i++) {
            List subChoices = choices.subList(i * n, (i + 1) * n);
            List generationList = new ArrayList<>();
            for (Choice choice : subChoices) {
                Map generationInfo = new HashMap<>(2);
                generationInfo.put("finishReason", choice.getFinishReason());
                generationInfo.put("logprobs", choice.getLogprobs());

                Generation generation = Generation.builder()
                        .text(choice.getText())
                        .generationInfo(generationInfo)
                        .build();
                generationList.add(generation);
            }
            generations.add(generationList);
        }

        Map llmOutput = new HashMap<>(2);
        llmOutput.put("token_usage", tokenUsage);
        llmOutput.put("model_name", model);

        return new LLMResult(generations, llmOutput);
    }

    /**
     * Get the sub prompts for llm call.
     */
    private List> getSubPrompts(List prompts) {
        if (maxTokens == -1) {
            checkArgument(prompts.size() == 1, "maxTokens set to -1 not supported for multiple inputs.");
            maxTokens = maxTokensForPrompt(prompts.get(0));
        }
        List> subPrompts = new ArrayList<>();
        for (int i = 0; i < prompts.size(); i += this.batchSize) {
            subPrompts.add(prompts.subList(i, Math.min(i + this.batchSize, prompts.size())));
        }
        return subPrompts;
    }

    /**
     * Calculate the maximum number of tokens possible to generate for a prompt.
     *
     * @param prompt The prompt to pass into the model.
     * @return The maximum number of tokens to generate for a prompt.
     */
    private int maxTokensForPrompt(String prompt) {
        // TODO
        return 0;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy