All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hw.langchain.chains.base.Chain Maven / Gradle / Ivy

There is a newer version: 0.2.2
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.hw.langchain.chains.base;

import com.hw.langchain.schema.BaseMemory;

import java.util.*;

/**
 * Base interface that all chains should implement.
 *
 * @author HamaWhite
 */
public abstract class Chain {

    protected BaseMemory memory;

    public abstract String chainType();

    /**
     * Input keys this chain expects.
     */
    public abstract List inputKeys();

    /**
     * Output keys this chain expects.
     */
    public abstract List outputKeys();

    /**
     * Check that all inputs are present
     */
    private void validateInputs(Map inputs) {
        Set missingKeys = new HashSet<>(inputKeys());
        missingKeys.removeAll(inputs.keySet());
        if (!missingKeys.isEmpty()) {
            throw new IllegalArgumentException(String.format("Missing some input keys: %s", missingKeys));
        }
    }

    private void validateOutputs(Map outputs) {
        Set missingKeys = new HashSet<>(outputKeys());
        missingKeys.removeAll(outputs.keySet());
        if (!missingKeys.isEmpty()) {
            throw new IllegalArgumentException(String.format("Missing some output keys: %s", missingKeys));
        }
    }

    /**
     * Run the logic of this chain and return the output.
     */
    public abstract Map _call(Map inputs);

    /**
     * Run the logic of this chain and add to output if desired.
     *
     * @param input             single input if chain expects only one param.
     * @param returnOnlyOutputs boolean for whether to return only outputs in the response.
     *                          If True, only new keys generated by this chain will be returned.
     *                          If False, both input keys and new keys generated by this chain will be returned.
     *                          Defaults to False.
     */
    public Map call(String input, boolean returnOnlyOutputs) {
        Map inputs = prepInputs(input);
        return call(inputs, returnOnlyOutputs);
    }

    /**
     * Run the logic of this chain and add to output if desired.
     *
     * @param inputs            Dictionary of inputs.
     * @param returnOnlyOutputs boolean for whether to return only outputs in the response.
     *                          If True, only new keys generated by this chain will be returned.
     *                          If False, both input keys and new keys generated by this chain will be returned.
     *                          Defaults to False.
     */
    public Map call(Map inputs, boolean returnOnlyOutputs) {
        inputs = prepInputs(inputs);
        Map outputs = _call(inputs);
        return prepOutputs(inputs, outputs, returnOnlyOutputs);
    }

    /**
     * Validate and prep outputs.
     */
    private Map prepOutputs(Map inputs, Map outputs,
            boolean returnOnlyOutputs) {
        validateOutputs(outputs);
        if (memory != null) {
            memory.saveContext(inputs, outputs);
        }
        if (returnOnlyOutputs) {
            return outputs;
        } else {
            Map result = new HashMap<>();
            inputs.forEach((k, v) -> result.put(k, v.toString()));
            result.putAll(outputs);
            return result;
        }
    }

    /**
     * Validate and prep inputs.
     */
    private Map prepInputs(String input) {
        Set inputKeys = new HashSet<>(inputKeys());
        if (memory != null) {
            // If there are multiple input keys, but some get set by memory so that only one is not set,
            // we can still figure out which key it is.
            Set memoryVariables = new HashSet<>(memory.memoryVariables());
            inputKeys.removeAll(memoryVariables);
        }
        if (inputKeys.size() != 1) {
            throw new IllegalArgumentException(
                    String.format(
                            "A single string input was passed in, but this chain expects multiple inputs (%s). " +
                                    "When a chain expects multiple inputs, please call it by passing in a dictionary, "
                                    +
                                    "eg `chain(Map.of('foo', 1, 'bar', 2))`",
                            inputKeys));
        }
        return Map.of(new ArrayList<>(inputKeys).get(0), input);
    }

    /**
     * Validate and prep inputs.
     */
    private Map prepInputs(Map inputs) {
        Map newInputs = new HashMap<>(inputs);
        if (memory != null) {
            Map externalContext = memory.loadMemoryVariables(inputs);
            newInputs.putAll(externalContext);
        }
        validateInputs(newInputs);
        return newInputs;
    }

    /**
     * Run the chain as text in, text out
     */
    public String run(String args) {
        if (outputKeys().size() != 1) {
            throw new IllegalArgumentException(
                    "The `run` method is not supported when there is not exactly one output key. Got " + outputKeys()
                            + ".");
        }
        return call(args, false).get(outputKeys().get(0));
    }

    /**
     * Run the chain as multiple variables, text out.
     */
    public String run(Map args) {
        if (outputKeys().size() != 1) {
            throw new IllegalArgumentException(
                    "The `run` method is not supported when there is not exactly one output key. Got " + outputKeys()
                            + ".");
        }
        return call(args, false).get(outputKeys().get(0));
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy