com.dingtalk.baymax.framework.sdk.mercury.model.BaseLLM Maven / Gradle / Ivy
package com.dingtalk.baymax.framework.sdk.mercury.model;
import com.dingtalk.baymax.framework.sdk.mercury.domain.AIMessage;
import com.dingtalk.baymax.framework.sdk.mercury.domain.BaseMessage;
import com.dingtalk.baymax.framework.sdk.mercury.prompt.PromptValue;
import com.dingtalk.baymax.framework.sdk.mercury.util.BaseMessageUtils;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
/**
* LLM包装器应接受提示并返回一个字符串
*
* @author xiaoxuan.lp
*
* 将generate,generateAsync下移至BaseLLM
* 添加predictMessage方法
*
* @author xiaoyan.wjw
*/
public abstract class BaseLLM extends BaseLanguageModel {
/**
* 获取提示值列表并返回LLMResult
*
* @param prompts
* @param stops
* @return
*/
public abstract LLMResult generate(List prompts, List stops);
public void generateStream(List prompts, List stops, StreamObserver streamObserver){
}
/**
* 获取提示值列表并返回LLMResult
*
* @param prompts
* @param stops
* @return
*/
public abstract CompletableFuture generateAsync(List prompts, List stops);
@Override
public LLMResult generatePrompt(List prompts, List stops) {
List promptStrings = prompts.stream().map(PromptValue::toString).collect(Collectors.toList());;
return generate(promptStrings, stops);
}
@Override
public void generatePrompt(List prompts, List stops, StreamObserver streamObserver) {
List promptStrings = prompts.stream().map(PromptValue::toString).collect(Collectors.toList());;
generateStream(promptStrings, stops, streamObserver);
}
@Override
public CompletableFuture generatePromptAsync(List prompts, List stops) {
List promptStrings = prompts.stream().map(PromptValue::toString).collect(Collectors.toList());
return generateAsync(promptStrings, stops);
}
@Override
public String predict(String text) {
return predict(text, null);
}
@Override
public String predict(String text, List stops) {
LLMResult llmResult = generate(Collections.singletonList(text), stops);
Optional.ofNullable(llmResult.getStreamOutputPromise()).ifPresent(p -> {
try {
p.waitSafely(60000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
});
return llmResult.getGenerations().get(0).get(0).getText();
}
@Override
public CompletableFuture predictAsync(String text) {
return predictAsync(text, null);
}
@Override
public CompletableFuture predictAsync(String text, List stops) {
return CompletableFuture.supplyAsync(() -> predict(text, stops));
}
@Override
public BaseMessage predictMessage(List messages) {
return predictMessage(messages, null);
}
@Override
public BaseMessage predictMessage(List messages, List stops) {
String text = BaseMessageUtils.getBufferString(messages, "Human", "AI");
String content = predict(text, stops);
AIMessage message = new AIMessage();
message.setContent(content);
return message;
}
}