All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.dingtalk.baymax.framework.sdk.mercury.chain.Chain Maven / Gradle / Ivy

There is a newer version: 1.0.2
Show newest version
package com.dingtalk.baymax.framework.sdk.mercury.chain;

import com.alibaba.fastjson.annotation.JSONField;
import com.dingtalk.baymax.framework.sdk.mercury.annotation.ChainField;
import com.dingtalk.baymax.framework.sdk.mercury.constant.Constants;
import com.dingtalk.baymax.framework.sdk.mercury.domain.BaseModel;
import com.dingtalk.baymax.framework.sdk.mercury.interactive.BaseInteractive;
import com.dingtalk.baymax.framework.sdk.mercury.memory.BaseMemory;
import com.google.common.collect.ImmutableMap;
import org.jdeferred2.Promise;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import static com.dingtalk.baymax.framework.sdk.mercury.constant.Constants.CONV_USER_OUT_MSG_KEY;
import static com.dingtalk.baymax.framework.sdk.mercury.constant.Constants.SKILL_OUTPUT_PROMISE_KEY;

/**
 * 所有链都应该实现的基类
 *
 * @author xiaoxuan.lp
 */
public abstract class Chain extends BaseModel {

    private static final Logger log = LoggerFactory.getLogger(Chain.class);

    @ChainField(label = "AI卡片 | AI Card", name = "interactive", optional = true)
    protected BaseInteractive interactive;

    /**
     * 内存
     */
    @ChainField(label = "记忆 | Memory", name = "memory")
    private BaseMemory memory;

    @ChainField(label = "是否自动加载记忆", name = "autoLoadMemory")
    private Boolean autoLoadMemory;

    /**
     * 是否打印
     */
    private boolean verbose;

    @Override
    public void init() {
        super.init();
        if (null != memory) {
            memory.init();
        }
    }

    @Override
    public void setup(Map context) {
        super.setup(context);

        if (null != memory) {
            memory.setup(context);
        }
    }

    @JSONField(serialize = false)
    public abstract List getInputKeys();

    @JSONField(serialize = false)
    public abstract List getOutputKeys();

    /**
     * 直接请求,不带记忆功能
     *
     * @param inputs
     * @return
     */
    protected abstract Map call(Map inputs);

    /**
     * 异步方式直接请求,不带记忆功能
     *
     * @param inputs
     * @return
     */
    protected abstract CompletableFuture> callAsync(Map inputs);

    /**
     * 带记忆功能
     *
     * @param inputs
     * @return
     */
    public Map run(Map inputs) {
        return run(inputs, false);
    }


    /**
     * 带记忆功能
     *
     * @param inputs
     * @param returnOnlyOutputs
     * @return
     */
    public Map run(Map inputs, boolean returnOnlyOutputs) {
        Map fullInputs = prepInputs(inputs);
        Map outputs = call(fullInputs);
        Map finalOutputs = prepOutputs(inputs, outputs, returnOnlyOutputs);
        return finalOutputs;
    }

    /**
     * 带记忆功能
     *
     * @param inputs
     * @return
     */
    public CompletableFuture> runAsync(Map inputs) {
        return runAsync(inputs, false);
    }

    /**
     * 带记忆功能
     *
     * @param inputs
     * @param returnOnlyOutputs
     * @return
     */
    public CompletableFuture> runAsync(Map inputs, boolean returnOnlyOutputs) {
        return CompletableFuture.supplyAsync(() -> run(inputs, returnOnlyOutputs));
    }

    private Map prepInputs(Map inputs) {
        if (memory != null && (null == autoLoadMemory || autoLoadMemory)) {
            Map external_context = memory.loadMemoryVariables(inputs);
            if (null != external_context) {
                inputs.putAll(external_context);
            }
        }
        return inputs;
    }

    private Map prepOutputs(Map inputs,
                                            Map outputs,
                                            boolean returnOnlyOutputs) {
        if (outputs instanceof ImmutableMap) {
            outputs = new HashMap<>(outputs);
        }
        if (memory != null && (null == autoLoadMemory || autoLoadMemory)) {
            final Map finalOutputs = outputs;
            Object p = outputs.get(SKILL_OUTPUT_PROMISE_KEY);
            if (p instanceof Promise) {
                final Promise promise = (Promise) p;

//                    final boolean isSanduo = Optional.ofNullable(inputs.get(Constants.CONTEXT_KEY))
//                            .map(Map.class::cast)
//                            .map(ctx -> ctx.get("botUid"))
//                            .map(Long.valueOf(572309721L)::equals).orElse(false);
//                    if (!isSanduo) {
//                        // 钉三多特殊逻辑, 不等待打字机完整输出
//                        promise.waitSafely();
//                    }
                promise.done(o -> {
                    finalOutputs.merge(CONV_USER_OUT_MSG_KEY, o, (oldVal, newVal) -> {
                        log.warn("MergeStrategies:overwrite, key: {}, oldVal: {}, newVal: {}", CONV_USER_OUT_MSG_KEY, oldVal, newVal);
                        return newVal;
                    });
                    finalOutputs.remove(SKILL_OUTPUT_PROMISE_KEY);
                    memory.saveContext(inputs, finalOutputs);
                }).fail(e -> {
                    throw new RuntimeException(e);
                });
            }
            memory.saveContext(inputs, outputs);
            String input = (String) inputs.get(Constants.CONV_USER_IN_MSG_KEY);
            if (BaseMemory.MAGIC_CMD_CLEAR_MEMORY.equals(input)) {
                memory.clear(inputs);
                outputs = ImmutableMap.of(Constants.CONV_USER_OUT_MSG_KEY, BaseMemory.MAGIC_RESPONSE_CLEAR_MEMORY);
            }

        }
        if (returnOnlyOutputs) {
            return outputs;
        } else {
            Map map = new HashMap<>();
            map.putAll(inputs);
            map.putAll(outputs);
            return map;
        }
    }


    public BaseInteractive getInteractive() {
        return interactive;
    }

    public void setInteractive(BaseInteractive interactive) {
        this.interactive = interactive;
    }

    public BaseMemory getMemory() {
        return memory;
    }

    public void setMemory(BaseMemory memory) {
        this.memory = memory;
    }

    public Boolean getAutoLoadMemory() {
        return autoLoadMemory;
    }

    public void setAutoLoadMemory(Boolean autoLoadMemory) {
        this.autoLoadMemory = autoLoadMemory;
    }

    public boolean isVerbose() {
        return verbose;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy