All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo Maven / Gradle / Ivy

There is a newer version: 0.21.0
Show newest version
package io.quarkiverse.langchain4j.runtime.aiservice;

import java.lang.reflect.Type;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Supplier;

import org.eclipse.microprofile.config.ConfigProvider;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.service.tool.ToolExecutor;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil;
import io.quarkiverse.langchain4j.runtime.config.GuardrailsConfig;
import io.quarkiverse.langchain4j.runtime.types.TypeSignatureParser;
import io.quarkus.arc.impl.LazyValue;
import io.quarkus.runtime.annotations.RecordableConstructor;

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public final class AiServiceMethodCreateInfo {
    private final String interfaceName;
    private final String methodName;
    private final Optional systemMessageInfo;
    private final UserMessageInfo userMessageInfo;
    private final Optional memoryIdParamPosition;
    private final boolean requiresModeration;
    private final String returnTypeSignature; // transient so bytecode recording ignores it
    private final transient LazyValue returnTypeVal; // transient so bytecode recording ignores it
    private final Optional metricsTimedInfo;
    private final Optional metricsCountedInfo;
    private final Optional spanInfo;
    // support @Toolbox
    private final List toolClassNames;
    private final ResponseSchemaInfo responseSchemaInfo;

    // support for guardrails
    private final List outputGuardrailsClassNames;
    private final List inputGuardrailsClassNames;

    // these are populated when the AiService method is first called which can happen on any thread
    private transient final List toolSpecifications = new CopyOnWriteArrayList<>();
    private transient final Map toolExecutors = new ConcurrentHashMap<>();

    // Don't cache the instances, because of scope issues (some will need to be re-queried)
    private transient final List> outputGuardrails = new CopyOnWriteArrayList<>();
    private transient final List> inputGuardrails = new CopyOnWriteArrayList<>();

    private final LazyValue guardrailsMaxRetry;

    @RecordableConstructor
    public AiServiceMethodCreateInfo(String interfaceName, String methodName,
            Optional systemMessageInfo,
            UserMessageInfo userMessageInfo,
            Optional memoryIdParamPosition,
            boolean requiresModeration,
            String returnTypeSignature,
            Optional metricsTimedInfo,
            Optional metricsCountedInfo,
            Optional spanInfo,
            ResponseSchemaInfo responseSchemaInfo,
            List toolClassNames,
            List inputGuardrailsClassNames,
            List outputGuardrailsClassNames) {
        this.interfaceName = interfaceName;
        this.methodName = methodName;
        this.systemMessageInfo = systemMessageInfo;
        this.userMessageInfo = userMessageInfo;
        this.memoryIdParamPosition = memoryIdParamPosition;
        this.requiresModeration = requiresModeration;
        this.returnTypeSignature = returnTypeSignature;
        this.returnTypeVal = new LazyValue<>(new Supplier<>() {
            @Override
            public Type get() {
                return TypeSignatureParser.parse(returnTypeSignature);
            }
        });
        this.metricsTimedInfo = metricsTimedInfo;
        this.metricsCountedInfo = metricsCountedInfo;
        this.spanInfo = spanInfo;
        this.responseSchemaInfo = responseSchemaInfo;
        this.toolClassNames = toolClassNames;
        this.inputGuardrailsClassNames = inputGuardrailsClassNames;
        this.outputGuardrailsClassNames = outputGuardrailsClassNames;
        // Use a lazy value to get the value at runtime.
        this.guardrailsMaxRetry = new LazyValue(new Supplier() {
            @Override
            public Integer get() {
                return ConfigProvider.getConfig().getOptionalValue("quarkus.langchain4j.guardrails.max-retries", Integer.class)
                        .orElse(GuardrailsConfig.MAX_RETRIES_DEFAULT);
            }
        });
    }

    public String getInterfaceName() {
        return interfaceName;
    }

    public String getMethodName() {
        return methodName;
    }

    public Optional getSystemMessageInfo() {
        return systemMessageInfo;
    }

    public UserMessageInfo getUserMessageInfo() {
        return userMessageInfo;
    }

    public Optional getMemoryIdParamPosition() {
        return memoryIdParamPosition;
    }

    public boolean isRequiresModeration() {
        return requiresModeration;
    }

    public String getReturnTypeSignature() {
        return returnTypeSignature;
    }

    public Type getReturnType() {
        return returnTypeVal.get();
    }

    public Optional getMetricsTimedInfo() {
        return metricsTimedInfo;
    }

    public Optional getMetricsCountedInfo() {
        return metricsCountedInfo;
    }

    public Optional getSpanInfo() {
        return spanInfo;
    }

    public ResponseSchemaInfo getResponseSchemaInfo() {
        return responseSchemaInfo;
    }

    public List getToolClassNames() {
        return toolClassNames;
    }

    public List getToolSpecifications() {
        return toolSpecifications;
    }

    public Map getToolExecutors() {
        return toolExecutors;
    }

    public List getOutputGuardrailsClassNames() {
        return outputGuardrailsClassNames;
    }

    public List> getOutputGuardrailsClasses() {
        return outputGuardrails;
    }

    public int getGuardrailsMaxRetry() {
        return guardrailsMaxRetry.get();
    }

    public List getInputGuardrailsClassNames() {
        return inputGuardrailsClassNames;
    }

    public List> getInputGuardrailsClasses() {
        return inputGuardrails;
    }

    public record UserMessageInfo(Optional template,
            Optional paramPosition,
            Optional userNameParamPosition,
            Optional imageParamPosition) {

        public static UserMessageInfo fromMethodParam(int paramPosition, Optional userNameParamPosition,
                Optional imageParamPosition) {
            return new UserMessageInfo(Optional.empty(), Optional.of(paramPosition),
                    userNameParamPosition, imageParamPosition);
        }

        public static UserMessageInfo fromTemplate(TemplateInfo templateInfo, Optional userNameParamPosition,
                Optional imageUrlParamPosition) {
            return new UserMessageInfo(Optional.of(templateInfo), Optional.empty(), userNameParamPosition,
                    imageUrlParamPosition);
        }
    }

    /**
     * @param methodParamPosition this is used to determine the position of the parameter that holds the template, and it is
     *        never set if 'text' is set
     */
    public record TemplateInfo(Optional text, Map nameToParamPosition,
            Optional methodParamPosition) {

        public static TemplateInfo fromText(String text, Map nameToParamPosition) {
            return new TemplateInfo(Optional.of(text), nameToParamPosition, Optional.empty());
        }

        public static TemplateInfo fromMethodParam(Integer methodParamPosition,
                Map nameToParamPosition) {
            return new TemplateInfo(Optional.empty(), nameToParamPosition, Optional.of(methodParamPosition));
        }
    }

    public record MetricsTimedInfo(String name,
            boolean longTask,
            String[] extraTags,
            double[] percentiles,
            boolean histogram, String description) {

        public static class Builder {
            private final String name;
            private boolean longTask = false;
            private String[] extraTags = {};
            private double[] percentiles = {};
            private boolean histogram = false;
            private String description = "";

            public Builder(String name) {
                this.name = name;
            }

            public Builder setLongTask(boolean longTask) {
                this.longTask = longTask;
                return this;
            }

            public Builder setExtraTags(String[] extraTags) {
                this.extraTags = extraTags;
                return this;
            }

            public Builder setPercentiles(double[] percentiles) {
                this.percentiles = percentiles;
                return this;
            }

            public Builder setHistogram(boolean histogram) {
                this.histogram = histogram;
                return this;
            }

            public Builder setDescription(String description) {
                this.description = description;
                return this;
            }

            public MetricsTimedInfo build() {
                return new MetricsTimedInfo(name, longTask, extraTags, percentiles, histogram,
                        description);
            }
        }
    }

    public record MetricsCountedInfo(String name,
            String[] extraTags,
            boolean recordFailuresOnly,
            String description) {

        public static class Builder {
            private final String name;
            private String[] extraTags = {};
            private boolean recordFailuresOnly = false;
            private String description = "";

            public Builder(String name) {
                this.name = name;
            }

            public Builder setExtraTags(String[] extraTags) {
                this.extraTags = extraTags;
                return this;
            }

            public Builder setRecordFailuresOnly(boolean recordFailuresOnly) {
                this.recordFailuresOnly = recordFailuresOnly;
                return this;
            }

            public Builder setDescription(String description) {
                this.description = description;
                return this;
            }

            public MetricsCountedInfo build() {
                return new MetricsCountedInfo(name, extraTags, recordFailuresOnly, description);
            }
        }
    }

    public record SpanInfo(String name) {
    }

    public record ResponseSchemaInfo(boolean enabled, boolean isInSystemMessage, Optional isInUserMessage,
            String outputFormatInstructions) {

        public static ResponseSchemaInfo of(boolean enabled, Optional systemMessageInfo,
                Optional userMessageInfo,
                String outputFormatInstructions) {

            boolean systemMessage = systemMessageInfo.flatMap(TemplateInfo::text)
                    .map(text -> text.contains(ResponseSchemaUtil.placeholder()))
                    .orElse(false);

            Optional userMessage = Optional.empty();
            if (userMessageInfo.isPresent() && userMessageInfo.get().text.isPresent()) {
                userMessage = Optional.of(userMessageInfo.get().text.get().contains(ResponseSchemaUtil.placeholder()));
            }

            return new ResponseSchemaInfo(enabled, systemMessage, userMessage, outputFormatInstructions);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy