All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.dingtalk.baymax.framework.sdk.mercury.model.BaseLLM Maven / Gradle / Ivy

package com.dingtalk.baymax.framework.sdk.mercury.model;

import com.dingtalk.baymax.framework.sdk.mercury.domain.AIMessage;
import com.dingtalk.baymax.framework.sdk.mercury.domain.BaseMessage;
import com.dingtalk.baymax.framework.sdk.mercury.prompt.PromptValue;
import com.dingtalk.baymax.framework.sdk.mercury.util.BaseMessageUtils;

import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

/**
 * LLM包装器应接受提示并返回一个字符串
 *
 * @author xiaoxuan.lp
 *
 * 将generate,generateAsync下移至BaseLLM
 * 添加predictMessage方法
 *
 * @author xiaoyan.wjw
 */
public abstract class BaseLLM extends BaseLanguageModel {

    /**
     * 获取提示值列表并返回LLMResult
     *
     * @param prompts
     * @param stops
     * @return
     */
    public abstract LLMResult generate(List prompts, List stops);

    public void generateStream(List prompts, List stops, StreamObserver streamObserver){
    }

    /**
     * 获取提示值列表并返回LLMResult
     *
     * @param prompts
     * @param stops
     * @return
     */
    public abstract CompletableFuture generateAsync(List prompts, List stops);

    @Override
    public LLMResult generatePrompt(List prompts, List stops) {
        List promptStrings = prompts.stream().map(PromptValue::toString).collect(Collectors.toList());;
        return generate(promptStrings, stops);
    }

    @Override
    public void generatePrompt(List prompts, List stops, StreamObserver streamObserver) {
        List promptStrings = prompts.stream().map(PromptValue::toString).collect(Collectors.toList());;
        generateStream(promptStrings, stops, streamObserver);
    }

    @Override
    public CompletableFuture generatePromptAsync(List prompts, List stops) {
        List promptStrings = prompts.stream().map(PromptValue::toString).collect(Collectors.toList());
        return generateAsync(promptStrings, stops);
    }

    @Override
    public String predict(String text) {
        return predict(text, null);
    }

    @Override
    public String predict(String text, List stops) {
        LLMResult llmResult = generate(Collections.singletonList(text), stops);
        Optional.ofNullable(llmResult.getStreamOutputPromise()).ifPresent(p -> {
            try {
                p.waitSafely(60000);
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        });
        return llmResult.getGenerations().get(0).get(0).getText();
    }

    @Override
    public CompletableFuture predictAsync(String text) {
        return predictAsync(text, null);
    }

    @Override
    public CompletableFuture predictAsync(String text, List stops) {
        return CompletableFuture.supplyAsync(() -> predict(text, stops));
    }

    @Override
    public BaseMessage predictMessage(List messages) {
        return predictMessage(messages, null);
    }

    @Override
    public BaseMessage predictMessage(List messages, List stops) {
        String text = BaseMessageUtils.getBufferString(messages, "Human", "AI");
        String content = predict(text, stops);
        AIMessage message = new AIMessage();
        message.setContent(content);
        return message;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy