All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo Maven / Gradle / Ivy

There is a newer version: 0.21.0
Show newest version
package io.quarkiverse.langchain4j.runtime.aiservice;

import java.lang.reflect.Type;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Supplier;

import org.eclipse.microprofile.config.ConfigProvider;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.service.tool.ToolExecutor;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator;
import io.quarkiverse.langchain4j.response.AiResponseAugmenter;
import io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil;
import io.quarkiverse.langchain4j.runtime.config.GuardrailsConfig;
import io.quarkiverse.langchain4j.runtime.types.TypeSignatureParser;
import io.quarkus.arc.impl.LazyValue;
import io.quarkus.runtime.annotations.RecordableConstructor;

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public final class AiServiceMethodCreateInfo {
    private final String interfaceName;
    private final String methodName;
    private final Optional systemMessageInfo;
    private final UserMessageInfo userMessageInfo;
    private final Optional memoryIdParamPosition;
    private final boolean requiresModeration;
    private final String returnTypeSignature; // transient so bytecode recording ignores it
    private final transient LazyValue returnTypeVal; // transient so bytecode recording ignores it
    private final Optional metricsTimedInfo;
    private final Optional metricsCountedInfo;
    private final Optional spanInfo;
    // support @Toolbox
    private final List toolClassNames;
    private final ResponseSchemaInfo responseSchemaInfo;

    // support for guardrails
    private final List outputGuardrailsClassNames;
    private final List inputGuardrailsClassNames;

    // support for response augmenter, potentially null
    private final String responseAugmenterClassName;

    // these are populated when the AiService method is first called which can happen on any thread
    private transient final List toolSpecifications = new CopyOnWriteArrayList<>();
    private transient final Map toolExecutors = new ConcurrentHashMap<>();

    // Don't cache the instances, because of scope issues (some will need to be re-queried)
    private transient final List> outputGuardrails = new CopyOnWriteArrayList<>();
    private transient final List> inputGuardrails = new CopyOnWriteArrayList<>();
    private transient Class> augmenter;

    private final String outputTokenAccumulatorClassName;
    private OutputTokenAccumulator accumulator;

    private final LazyValue guardrailsMaxRetry;
    private final boolean switchToWorkerThread;

    @RecordableConstructor
    public AiServiceMethodCreateInfo(String interfaceName, String methodName,
            Optional systemMessageInfo,
            UserMessageInfo userMessageInfo,
            Optional memoryIdParamPosition,
            boolean requiresModeration,
            String returnTypeSignature,
            Optional metricsTimedInfo,
            Optional metricsCountedInfo,
            Optional spanInfo,
            ResponseSchemaInfo responseSchemaInfo,
            List toolClassNames,
            boolean switchToWorkerThread,
            List inputGuardrailsClassNames,
            List outputGuardrailsClassNames,
            String outputTokenAccumulatorClassName,
            String responseAugmenterClassName) {
        this.interfaceName = interfaceName;
        this.methodName = methodName;
        this.systemMessageInfo = systemMessageInfo;
        this.userMessageInfo = userMessageInfo;
        this.memoryIdParamPosition = memoryIdParamPosition;
        this.requiresModeration = requiresModeration;
        this.returnTypeSignature = returnTypeSignature;
        this.returnTypeVal = new LazyValue<>(new Supplier<>() {
            @Override
            public Type get() {
                return TypeSignatureParser.parse(returnTypeSignature);
            }
        });
        this.metricsTimedInfo = metricsTimedInfo;
        this.metricsCountedInfo = metricsCountedInfo;
        this.spanInfo = spanInfo;
        this.responseSchemaInfo = responseSchemaInfo;
        this.toolClassNames = toolClassNames;
        this.inputGuardrailsClassNames = inputGuardrailsClassNames;
        this.outputGuardrailsClassNames = outputGuardrailsClassNames;
        this.outputTokenAccumulatorClassName = outputTokenAccumulatorClassName;
        // Use a lazy value to get the value at runtime.
        this.guardrailsMaxRetry = new LazyValue(new Supplier() {
            @Override
            public Integer get() {
                return ConfigProvider.getConfig().getOptionalValue("quarkus.langchain4j.guardrails.max-retries", Integer.class)
                        .orElse(GuardrailsConfig.MAX_RETRIES_DEFAULT);
            }
        });
        this.switchToWorkerThread = switchToWorkerThread;
        this.responseAugmenterClassName = responseAugmenterClassName;
    }

    public String getInterfaceName() {
        return interfaceName;
    }

    public String getMethodName() {
        return methodName;
    }

    public Optional getSystemMessageInfo() {
        return systemMessageInfo;
    }

    public UserMessageInfo getUserMessageInfo() {
        return userMessageInfo;
    }

    public Optional getMemoryIdParamPosition() {
        return memoryIdParamPosition;
    }

    public boolean isRequiresModeration() {
        return requiresModeration;
    }

    public String getReturnTypeSignature() {
        return returnTypeSignature;
    }

    public Type getReturnType() {
        return returnTypeVal.get();
    }

    public Optional getMetricsTimedInfo() {
        return metricsTimedInfo;
    }

    public Optional getMetricsCountedInfo() {
        return metricsCountedInfo;
    }

    public Optional getSpanInfo() {
        return spanInfo;
    }

    public ResponseSchemaInfo getResponseSchemaInfo() {
        return responseSchemaInfo;
    }

    public List getToolClassNames() {
        return toolClassNames;
    }

    public List getToolSpecifications() {
        return toolSpecifications;
    }

    public Map getToolExecutors() {
        return toolExecutors;
    }

    public List getOutputGuardrailsClassNames() {
        return outputGuardrailsClassNames;
    }

    public List> getOutputGuardrailsClasses() {
        return outputGuardrails;
    }

    public String getResponseAugmenterClassName() {
        return responseAugmenterClassName;
    }

    @SuppressWarnings("unchecked")
    public Class> getResponseAugmenter() {
        if (this.responseAugmenterClassName == null) {
            return null;
        }

        synchronized (this) {
            if (this.augmenter == null) { // Not loaded yet.
                try {
                    this.augmenter = (Class>) Class.forName(
                            getResponseAugmenterClassName(), true,
                            Thread.currentThread().getContextClassLoader());
                } catch (Exception e) {
                    throw new RuntimeException(
                            "Could not find " + AiResponseAugmenter.class.getSimpleName() + " implementation class: "
                                    + getResponseAugmenterClassName(),
                            e);
                }
            }
            return augmenter;
        }
    }

    public int getGuardrailsMaxRetry() {
        return guardrailsMaxRetry.get();
    }

    public List getInputGuardrailsClassNames() {
        return inputGuardrailsClassNames;
    }

    public List> getInputGuardrailsClasses() {
        return inputGuardrails;
    }

    public String getOutputTokenAccumulatorClassName() {
        return outputTokenAccumulatorClassName;
    }

    public void setOutputTokenAccumulator(OutputTokenAccumulator accumulator) {
        this.accumulator = accumulator;
    }

    public OutputTokenAccumulator getOutputTokenAccumulator() {
        return accumulator;
    }

    public String getUserMessageTemplate() {
        Optional userMessageTemplateOpt = this.getUserMessageInfo().template()
                .flatMap(AiServiceMethodCreateInfo.TemplateInfo::text);

        return userMessageTemplateOpt.orElse("");
    }

    public boolean isSwitchToWorkerThread() {
        return switchToWorkerThread;
    }

    public void setResponseAugmenter(Class> augmenter) {
        this.augmenter = augmenter;
    }

    public record UserMessageInfo(Optional template,
            Optional paramPosition,
            Optional userNameParamPosition,
            Optional imageParamPosition) {

        public static UserMessageInfo fromMethodParam(int paramPosition, Optional userNameParamPosition,
                Optional imageParamPosition) {
            return new UserMessageInfo(Optional.empty(), Optional.of(paramPosition),
                    userNameParamPosition, imageParamPosition);
        }

        public static UserMessageInfo fromTemplate(TemplateInfo templateInfo, Optional userNameParamPosition,
                Optional imageUrlParamPosition) {
            return new UserMessageInfo(Optional.of(templateInfo), Optional.empty(), userNameParamPosition,
                    imageUrlParamPosition);
        }
    }

    /**
     * @param methodParamPosition this is used to determine the position of the parameter that holds the template, and it is
     *        never set if 'text' is set
     */
    public record TemplateInfo(Optional text, Map nameToParamPosition,
            Optional methodParamPosition) {

        public static TemplateInfo fromText(String text, Map nameToParamPosition) {
            return new TemplateInfo(Optional.of(text), nameToParamPosition, Optional.empty());
        }

        public static TemplateInfo fromMethodParam(Integer methodParamPosition,
                Map nameToParamPosition) {
            return new TemplateInfo(Optional.empty(), nameToParamPosition, Optional.of(methodParamPosition));
        }
    }

    public record MetricsTimedInfo(String name,
            boolean longTask,
            String[] extraTags,
            double[] percentiles,
            boolean histogram, String description) {

        public static class Builder {
            private final String name;
            private boolean longTask = false;
            private String[] extraTags = {};
            private double[] percentiles = {};
            private boolean histogram = false;
            private String description = "";

            public Builder(String name) {
                this.name = name;
            }

            public Builder setLongTask(boolean longTask) {
                this.longTask = longTask;
                return this;
            }

            public Builder setExtraTags(String[] extraTags) {
                this.extraTags = extraTags;
                return this;
            }

            public Builder setPercentiles(double[] percentiles) {
                this.percentiles = percentiles;
                return this;
            }

            public Builder setHistogram(boolean histogram) {
                this.histogram = histogram;
                return this;
            }

            public Builder setDescription(String description) {
                this.description = description;
                return this;
            }

            public MetricsTimedInfo build() {
                return new MetricsTimedInfo(name, longTask, extraTags, percentiles, histogram,
                        description);
            }
        }
    }

    public record MetricsCountedInfo(String name,
            String[] extraTags,
            boolean recordFailuresOnly,
            String description) {

        public static class Builder {
            private final String name;
            private String[] extraTags = {};
            private boolean recordFailuresOnly = false;
            private String description = "";

            public Builder(String name) {
                this.name = name;
            }

            public Builder setExtraTags(String[] extraTags) {
                this.extraTags = extraTags;
                return this;
            }

            public Builder setRecordFailuresOnly(boolean recordFailuresOnly) {
                this.recordFailuresOnly = recordFailuresOnly;
                return this;
            }

            public Builder setDescription(String description) {
                this.description = description;
                return this;
            }

            public MetricsCountedInfo build() {
                return new MetricsCountedInfo(name, extraTags, recordFailuresOnly, description);
            }
        }
    }

    public record SpanInfo(String name) {
    }

    public record ResponseSchemaInfo(boolean enabled, boolean isInSystemMessage, Optional isInUserMessage,
            String outputFormatInstructions) {

        public static ResponseSchemaInfo of(boolean enabled, Optional systemMessageInfo,
                Optional userMessageInfo,
                String outputFormatInstructions) {

            boolean systemMessage = systemMessageInfo.flatMap(TemplateInfo::text)
                    .map(text -> text.contains(ResponseSchemaUtil.placeholder()))
                    .orElse(false);

            Optional userMessage = Optional.empty();
            if (userMessageInfo.isPresent() && userMessageInfo.get().text.isPresent()) {
                userMessage = Optional.of(userMessageInfo.get().text.get().contains(ResponseSchemaUtil.placeholder()));
            }

            return new ResponseSchemaInfo(enabled, systemMessage, userMessage, outputFormatInstructions);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy