io.quarkiverse.langchain4j.deployment.AiServicesProcessor Maven / Gradle / Ivy
package io.quarkiverse.langchain4j.deployment;
import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static io.quarkiverse.langchain4j.deployment.ExceptionUtil.illegalConfigurationForMethod;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.BEAN_IF_EXISTS_RETRIEVAL_AUGMENTOR_SUPPLIER;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.MEMORY_ID;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVAL_AUGMENTOR_SUPPLIER;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVER;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.SEED_MEMORY;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.V;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import jakarta.annotation.PreDestroy;
import jakarta.enterprise.context.Dependent;
import jakarta.inject.Inject;
import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.AnnotationValue;
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.ClassType;
import org.jboss.jandex.DotName;
import org.jboss.jandex.IndexView;
import org.jboss.jandex.MethodInfo;
import org.jboss.jandex.MethodParameterInfo;
import org.jboss.jandex.ParameterizedType;
import org.jboss.jandex.Type;
import org.jboss.logging.Logger;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.analysis.AnalyzerException;
import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.service.Moderate;
import dev.langchain4j.service.output.ServiceOutputParser;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.ToolBox;
import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
import io.quarkiverse.langchain4j.runtime.QuarkusServiceOutputParser;
import io.quarkiverse.langchain4j.runtime.RequestScopeStateDefaultMemoryIdProvider;
import io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo.ResponseSchemaInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
import io.quarkiverse.langchain4j.runtime.aiservice.ChatMemoryRemovable;
import io.quarkiverse.langchain4j.runtime.aiservice.ChatMemorySeeder;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsCountedWrapper;
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsTimedWrapper;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkiverse.langchain4j.runtime.aiservice.SpanWrapper;
import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.InstanceHandle;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.CustomScopeAnnotationsBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.builder.item.MultiBuildItem;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
import io.quarkus.deployment.GeneratedClassGizmoAdaptor;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.deployment.builditem.GeneratedClassBuildItem;
import io.quarkus.deployment.builditem.HotDeploymentWatchedFileBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ServiceProviderBuildItem;
import io.quarkus.deployment.metrics.MetricsCapabilityBuildItem;
import io.quarkus.gizmo.ClassCreator;
import io.quarkus.gizmo.ClassOutput;
import io.quarkus.gizmo.FieldDescriptor;
import io.quarkus.gizmo.Gizmo;
import io.quarkus.gizmo.MethodCreator;
import io.quarkus.gizmo.MethodDescriptor;
import io.quarkus.gizmo.ResultHandle;
import io.quarkus.runtime.metrics.MetricsFactory;
import io.smallrye.mutiny.Multi;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class AiServicesProcessor {
private static final Logger log = Logger.getLogger(AiServicesProcessor.class);
public static final DotName MICROMETER_TIMED = DotName.createSimple("io.micrometer.core.annotation.Timed");
public static final DotName MICROMETER_COUNTED = DotName.createSimple("io.micrometer.core.annotation.Counted");
private static final String DEFAULT_DELIMITER = "\n";
private static final Predicate IS_METHOD_PARAMETER_ANNOTATION = ai -> ai.target()
.kind() == AnnotationTarget.Kind.METHOD_PARAMETER;
private static final Function METHOD_PARAMETER_POSITION_FUNCTION = ai -> Integer
.valueOf(ai.target()
.asMethodParameter().position());
public static final MethodDescriptor OBJECT_CONSTRUCTOR = MethodDescriptor.ofConstructor(Object.class);
private static final MethodDescriptor RECORDER_METHOD_CREATE_INFO = MethodDescriptor.ofMethod(AiServicesRecorder.class,
"getAiServiceMethodCreateInfo", AiServiceMethodCreateInfo.class, String.class, String.class);
private static final MethodDescriptor SUPPORT_IMPLEMENT = MethodDescriptor.ofMethod(
AiServiceMethodImplementationSupport.class,
"implement", Object.class, AiServiceMethodImplementationSupport.Input.class);
private static final MethodDescriptor QUARKUS_AI_SERVICES_CONTEXT_CLOSE = MethodDescriptor.ofMethod(
QuarkusAiServiceContext.class, "close", void.class);
private static final MethodDescriptor QUARKUS_AI_SERVICES_CONTEXT_REMOVE_CHAT_MEMORY_IDS = MethodDescriptor.ofMethod(
QuarkusAiServiceContext.class, "removeChatMemoryIds", void.class, Object[].class);
public static final MethodDescriptor CHAT_MEMORY_SEEDER_CONTEXT_METHOD_NAME = MethodDescriptor
.ofMethod(ChatMemorySeeder.Context.class, "methodName", String.class);
private static final String METRICS_DEFAULT_NAME = "langchain4j.aiservices";
private static final Class[] EMPTY_CLASS_ARRAY = new Class[0];
private static final String[] EMPTY_STRING_ARRAY = new String[0];
private static final ResultHandle[] EMPTY_RESULT_HANDLES_ARRAY = new ResultHandle[0];
private static final ServiceOutputParser SERVICE_OUTPUT_PARSER = new QuarkusServiceOutputParser(); // TODO: this might need to be improved
@BuildStep
public void nativeSupport(CombinedIndexBuildItem indexBuildItem,
List aiServicesMethodBuildItems,
BuildProducer reflectiveClassProducer,
BuildProducer serviceProviderProducer) {
IndexView index = indexBuildItem.getIndex();
Collection instances = index.getAnnotations(LangChain4jDotNames.DESCRIPTION);
Set classesUsingDescription = new HashSet<>();
for (AnnotationInstance instance : instances) {
if (instance.target().kind() != AnnotationTarget.Kind.FIELD) {
continue;
}
classesUsingDescription.add(instance.target().asField().declaringClass());
}
if (!classesUsingDescription.isEmpty()) {
reflectiveClassProducer.produce(ReflectiveClassBuildItem
.builder(classesUsingDescription.stream().map(i -> i.name().toString()).toArray(String[]::new)).fields(true)
.build());
}
Set returnTypesToRegister = new HashSet<>();
for (AiServicesMethodBuildItem aiServicesMethodBuildItem : aiServicesMethodBuildItems) {
Type type = aiServicesMethodBuildItem.methodInfo.returnType();
if (type.kind() == Type.Kind.PRIMITIVE) {
continue;
}
DotName returnTypeName = type.name();
if (returnTypeName.toString().startsWith("java.")) {
continue;
}
returnTypesToRegister.add(returnTypeName);
}
if (!returnTypesToRegister.isEmpty()) {
reflectiveClassProducer.produce(ReflectiveClassBuildItem
.builder(returnTypesToRegister.stream().map(DotName::toString).toArray(String[]::new))
.constructors(false)
.build());
}
serviceProviderProducer.produce(new ServiceProviderBuildItem(DefaultMemoryIdProvider.class.getName(),
RequestScopeStateDefaultMemoryIdProvider.class.getName()));
}
@BuildStep
public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
CustomScopeAnnotationsBuildItem customScopes,
BuildProducer requestChatModelBeanProducer,
BuildProducer requestModerationModelBeanProducer,
BuildProducer declarativeAiServiceProducer,
BuildProducer reflectiveClassProducer,
BuildProducer generatedClassProducer) {
IndexView index = indexBuildItem.getIndex();
Set chatModelNames = new HashSet<>();
Set moderationModelNames = new HashSet<>();
ClassOutput generatedClassOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true);
for (AnnotationInstance instance : index.getAnnotations(LangChain4jDotNames.REGISTER_AI_SERVICES)) {
if (instance.target().kind() != AnnotationTarget.Kind.CLASS) {
continue; // should never happen
}
ClassInfo declarativeAiServiceClassInfo = instance.target().asClass();
DotName chatLanguageModelSupplierClassDotName = null;
AnnotationValue chatLanguageModelSupplierValue = instance.value("chatLanguageModelSupplier");
if (chatLanguageModelSupplierValue != null) {
chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierValue.asClass().name();
if (chatLanguageModelSupplierClassDotName.equals(LangChain4jDotNames.BEAN_CHAT_MODEL_SUPPLIER)) { // this is the case where the
// default was set, so we just
// ignore it
chatLanguageModelSupplierClassDotName = null;
} else {
validateSupplierAndRegisterForReflection(chatLanguageModelSupplierClassDotName, index,
reflectiveClassProducer);
}
}
String chatModelName = NamedConfigUtil.DEFAULT_NAME;
if (chatLanguageModelSupplierClassDotName == null) {
AnnotationValue modelNameValue = instance.value("modelName");
if (modelNameValue != null) {
String modelNameValueStr = modelNameValue.asString();
if ((modelNameValueStr != null) && !modelNameValueStr.isEmpty()) {
chatModelName = modelNameValueStr;
}
}
chatModelNames.add(chatModelName);
}
List toolDotNames = Collections.emptyList();
AnnotationValue toolsInstance = instance.value("tools");
if (toolsInstance != null) {
toolDotNames = Arrays.stream(toolsInstance.asClassArray()).map(Type::name)
.collect(Collectors.toList());
}
// the default value depends on whether tools exists or not - if they do, then we require a ChatMemoryProvider bean
DotName chatMemoryProviderSupplierClassDotName = LangChain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER;
AnnotationValue chatMemoryProviderSupplierValue = instance.value("chatMemoryProviderSupplier");
if (chatMemoryProviderSupplierValue != null) {
chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierValue.asClass().name();
if (chatMemoryProviderSupplierClassDotName.equals(
LangChain4jDotNames.NO_CHAT_MEMORY_PROVIDER_SUPPLIER)) {
chatMemoryProviderSupplierClassDotName = null;
} else if (!chatMemoryProviderSupplierClassDotName
.equals(LangChain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER)) {
validateSupplierAndRegisterForReflection(chatMemoryProviderSupplierClassDotName, index,
reflectiveClassProducer);
}
}
DotName retrieverClassDotName = null;
AnnotationValue retrieverValue = instance.value("retriever");
if (retrieverValue != null) {
retrieverClassDotName = retrieverValue.asClass().name();
if (NO_RETRIEVER.equals(retrieverClassDotName)) {
retrieverClassDotName = null;
}
}
boolean customRetrievalAugmentorSupplierClassIsABean = false;
DotName retrievalAugmentorSupplierClassName = BEAN_IF_EXISTS_RETRIEVAL_AUGMENTOR_SUPPLIER;
AnnotationValue retrievalAugmentorSupplierValue = instance.value("retrievalAugmentor");
if (retrievalAugmentorSupplierValue != null && !BEAN_IF_EXISTS_RETRIEVAL_AUGMENTOR_SUPPLIER
.equals(retrievalAugmentorSupplierValue.asClass().name())) {
if (NO_RETRIEVAL_AUGMENTOR_SUPPLIER.equals(retrievalAugmentorSupplierValue.asClass().name())) {
retrievalAugmentorSupplierClassName = null;
} else {
retrievalAugmentorSupplierClassName = retrievalAugmentorSupplierValue.asClass().name();
// if the supplier is not a CDI bean, make sure can build an instance
BuiltinScope declaredScope = BuiltinScope
.from(index.getClassByName(retrievalAugmentorSupplierClassName));
if (declaredScope != null) {
customRetrievalAugmentorSupplierClassIsABean = true;
} else {
validateSupplierAndRegisterForReflection(retrievalAugmentorSupplierClassName, index,
reflectiveClassProducer);
}
}
}
if (retrieverClassDotName != null && retrievalAugmentorSupplierClassName != null) {
if (!retrievalAugmentorSupplierClassName.equals(BEAN_IF_EXISTS_RETRIEVAL_AUGMENTOR_SUPPLIER)) {
throw new IllegalConfigurationException("Both 'retriever' and 'retrievalAugmentor' are set for "
+ declarativeAiServiceClassInfo.name().toString()
+ ". Only one of them can be set.");
}
}
DotName auditServiceSupplierClassName = LangChain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER;
AnnotationValue auditServiceSupplierValue = instance.value("auditServiceSupplier");
if (auditServiceSupplierValue != null) {
auditServiceSupplierClassName = auditServiceSupplierValue.asClass().name();
validateSupplierAndRegisterForReflection(auditServiceSupplierClassName, index, reflectiveClassProducer);
}
DotName moderationModelSupplierClassName = LangChain4jDotNames.BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER;
AnnotationValue moderationModelSupplierValue = instance.value("moderationModelSupplier");
if (moderationModelSupplierValue != null) {
moderationModelSupplierClassName = moderationModelSupplierValue.asClass().name();
validateSupplierAndRegisterForReflection(moderationModelSupplierClassName, index, reflectiveClassProducer);
}
// determine whether the method is annotated with @Moderate
String moderationModelName = NamedConfigUtil.DEFAULT_NAME;
for (MethodInfo method : declarativeAiServiceClassInfo.methods()) {
if (method.hasAnnotation(LangChain4jDotNames.MODERATE)) {
if (moderationModelSupplierClassName.equals(LangChain4jDotNames.BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER)) {
AnnotationValue modelNameValue = instance.value("modelName");
if (modelNameValue != null) {
String modelNameValueStr = modelNameValue.asString();
if ((modelNameValueStr != null) && !modelNameValueStr.isEmpty()) {
moderationModelName = modelNameValueStr;
}
}
moderationModelNames.add(moderationModelName);
}
break;
}
}
DotName cdiScope = BuiltinScope.REQUEST.getInfo().getDotName();
Optional scopeAnnotation = customScopes.getScope(declarativeAiServiceClassInfo.annotations());
if (scopeAnnotation.isPresent()) {
cdiScope = scopeAnnotation.get().name();
}
declarativeAiServiceProducer.produce(
new DeclarativeAiServiceBuildItem(
declarativeAiServiceClassInfo,
chatLanguageModelSupplierClassDotName,
toolDotNames,
chatMemoryProviderSupplierClassDotName,
retrieverClassDotName,
retrievalAugmentorSupplierClassName,
customRetrievalAugmentorSupplierClassIsABean,
auditServiceSupplierClassName,
moderationModelSupplierClassName,
determineChatMemorySeeder(declarativeAiServiceClassInfo, generatedClassOutput),
cdiScope,
chatModelName,
moderationModelName));
}
for (String chatModelName : chatModelNames) {
requestChatModelBeanProducer.produce(new RequestChatModelBeanBuildItem(chatModelName));
}
for (String moderationModelName : moderationModelNames) {
requestModerationModelBeanProducer.produce(new RequestModerationModelBeanBuildItem(moderationModelName));
}
}
private void validateSupplierAndRegisterForReflection(DotName supplierDotName, IndexView index,
BuildProducer producer) {
ClassInfo classInfo = index.getClassByName(supplierDotName);
if (classInfo == null) {
log.warn("'" + supplierDotName.toString() + "' cannot be indexed"); // TODO: maybe this should be an error
return;
}
if (!classInfo.hasNoArgsConstructor()) {
throw new IllegalConfigurationException(
"Class '" + supplierDotName.toString() + "' which must contain a no-args constructor.");
}
producer.produce(ReflectiveClassBuildItem.builder(supplierDotName.toString()).constructors(true).build());
}
@BuildStep
@Record(ExecutionTime.STATIC_INIT)
public void handleDeclarativeServices(AiServicesRecorder recorder,
List declarativeAiServiceItems,
List selectedChatModelProvider,
BuildProducer syntheticBeanProducer,
BuildProducer unremoveableProducer) {
boolean needsChatModelBean = false;
boolean needsStreamingChatModelBean = false;
boolean needsChatMemoryProviderBean = false;
boolean needsRetrieverBean = false;
boolean needsRetrievalAugmentorBean = false;
boolean needsAuditServiceBean = false;
boolean needsModerationModelBean = false;
Set allToolNames = new HashSet<>();
for (DeclarativeAiServiceBuildItem bi : declarativeAiServiceItems) {
ClassInfo declarativeAiServiceClassInfo = bi.getServiceClassInfo();
String serviceClassName = declarativeAiServiceClassInfo.name().toString();
String chatLanguageModelSupplierClassName = (bi.getLanguageModelSupplierClassDotName() != null
? bi.getLanguageModelSupplierClassDotName().toString()
: null);
List toolClassNames = bi.getToolDotNames().stream().map(DotName::toString).collect(Collectors.toList());
String chatMemoryProviderSupplierClassName = bi.getChatMemoryProviderSupplierClassDotName() != null
? bi.getChatMemoryProviderSupplierClassDotName().toString()
: null;
String retrieverClassName = bi.getRetrieverClassDotName() != null
? bi.getRetrieverClassDotName().toString()
: null;
String retrievalAugmentorSupplierClassName = bi.getRetrievalAugmentorSupplierClassDotName() != null
? bi.getRetrievalAugmentorSupplierClassDotName().toString()
: null;
String auditServiceClassSupplierName = bi.getAuditServiceClassSupplierDotName() != null
? bi.getAuditServiceClassSupplierDotName().toString()
: null;
String moderationModelSupplierClassName = (bi.getModerationModelSupplierDotName() != null
? bi.getModerationModelSupplierDotName().toString()
: null);
String chatMemorySeederClassName = (bi.getChatMemorySeederClassDotName() != null
? bi.getChatMemorySeederClassDotName().toString()
: null);
// determine whether the method returns Multi
boolean injectStreamingChatModelBean = false;
for (MethodInfo method : declarativeAiServiceClassInfo.methods()) {
if (!LangChain4jDotNames.MULTI.equals(method.returnType().name())) {
continue;
}
boolean isMultiString = false;
if (method.returnType().kind() == Type.Kind.PARAMETERIZED_TYPE) {
Type multiType = method.returnType().asParameterizedType().arguments().get(0);
if (LangChain4jDotNames.STRING.equals(multiType.name())) {
isMultiString = true;
}
}
if (!isMultiString) {
throw illegalConfiguration("Only Multi is supported as a Multi return type. Offending method is '"
+ method.declaringClass().name().toString() + "#" + method.name() + "'");
}
injectStreamingChatModelBean = true;
}
boolean injectModerationModelBean = false;
for (MethodInfo method : declarativeAiServiceClassInfo.methods()) {
if (method.hasAnnotation(Moderate.class)) {
injectModerationModelBean = true;
break;
}
}
String chatModelName = bi.getChatModelName();
String moderationModelName = bi.getModerationModelName();
SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
.configure(QuarkusAiServiceContext.class)
.forceApplicationClass()
.createWith(recorder.createDeclarativeAiService(
new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName,
toolClassNames, chatMemoryProviderSupplierClassName, retrieverClassName,
retrievalAugmentorSupplierClassName,
auditServiceClassSupplierName,
moderationModelSupplierClassName,
chatMemorySeederClassName,
chatModelName,
moderationModelName,
injectStreamingChatModelBean,
injectModerationModelBean)))
.setRuntimeInit()
.addQualifier()
.annotation(LangChain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER).addValue("value", serviceClassName)
.done()
.scope(Dependent.class);
if ((chatLanguageModelSupplierClassName == null) && !selectedChatModelProvider.isEmpty()) {
if (NamedConfigUtil.isDefault(chatModelName)) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.CHAT_MODEL));
if (injectStreamingChatModelBean) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.STREAMING_CHAT_MODEL));
needsStreamingChatModelBean = true;
}
} else {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.CHAT_MODEL),
AnnotationInstance.builder(ModelName.class).add("value", chatModelName).build());
if (injectStreamingChatModelBean) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.STREAMING_CHAT_MODEL),
AnnotationInstance.builder(ModelName.class).add("value", chatModelName).build());
needsStreamingChatModelBean = true;
}
}
needsChatModelBean = true;
}
if (!toolClassNames.isEmpty()) {
for (String toolClassName : toolClassNames) {
DotName dotName = DotName.createSimple(toolClassName);
configurator.addInjectionPoint(ClassType.create(dotName));
allToolNames.add(dotName);
}
}
if (LangChain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.CHAT_MEMORY_PROVIDER));
needsChatMemoryProviderBean = true;
}
if (retrieverClassName != null) {
configurator.addInjectionPoint(ClassType.create(retrieverClassName));
needsRetrieverBean = true;
}
if (LangChain4jDotNames.BEAN_IF_EXISTS_RETRIEVAL_AUGMENTOR_SUPPLIER.toString()
.equals(retrievalAugmentorSupplierClassName)) {
// Use a CDI bean of type `RetrievalAugmentor` if one exists, otherwise
// don't use an augmentor.
configurator.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(LangChain4jDotNames.RETRIEVAL_AUGMENTOR) }, null));
needsRetrievalAugmentorBean = true;
} else {
if (retrievalAugmentorSupplierClassName != null) {
// Use the provided `Supplier`. If
// the provided supplier, is a CDI bean, use it as such
// and declare an injection point for it here. If it's
// not a CDI bean, the recorder will call its no-arg
// constructor to obtain an instance.
if (bi.isCustomRetrievalAugmentorSupplierClassIsABean()) {
configurator.addInjectionPoint(ClassType.create(retrievalAugmentorSupplierClassName));
unremoveableProducer
.produce(UnremovableBeanBuildItem.beanClassNames(retrievalAugmentorSupplierClassName));
}
}
}
if (LangChain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER.toString().equals(auditServiceClassSupplierName)) {
configurator.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(LangChain4jDotNames.AUDIT_SERVICE) }, null));
needsAuditServiceBean = true;
}
if (LangChain4jDotNames.BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER.toString()
.equals(moderationModelSupplierClassName) && injectModerationModelBean) {
if (NamedConfigUtil.isDefault(moderationModelName)) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.MODERATION_MODEL));
} else {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.MODERATION_MODEL),
AnnotationInstance.builder(ModelName.class).add("value", moderationModelName).build());
}
needsModerationModelBean = true;
}
syntheticBeanProducer.produce(configurator.done());
}
if (needsChatModelBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.CHAT_MODEL));
}
if (needsStreamingChatModelBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.STREAMING_CHAT_MODEL));
}
if (needsChatMemoryProviderBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.CHAT_MEMORY_PROVIDER));
}
if (needsRetrieverBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.RETRIEVER));
}
if (needsRetrievalAugmentorBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.RETRIEVAL_AUGMENTOR));
}
if (needsAuditServiceBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.AUDIT_SERVICE));
}
if (needsModerationModelBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.MODERATION_MODEL));
}
if (!allToolNames.isEmpty()) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolNames));
}
}
@BuildStep
public void watchResourceFiles(CombinedIndexBuildItem indexBuildItem,
BuildProducer producer) {
IndexView index = indexBuildItem.getIndex();
List instances = new ArrayList<>();
instances.addAll(index.getAnnotations(LangChain4jDotNames.SYSTEM_MESSAGE));
instances.addAll(index.getAnnotations(LangChain4jDotNames.USER_MESSAGE));
for (AnnotationInstance instance : instances) {
AnnotationValue fromResource = instance.value("fromResource");
if (fromResource != null) {
producer.produce(new HotDeploymentWatchedFileBuildItem(fromResource.asString()));
}
}
}
@BuildStep
@Record(ExecutionTime.STATIC_INIT)
public void handleAiServices(
LangChain4jBuildConfig config,
AiServicesRecorder recorder,
CombinedIndexBuildItem indexBuildItem,
List declarativeAiServiceItems,
BuildProducer generatedClassProducer,
BuildProducer generatedBeanProducer,
BuildProducer reflectiveClassProducer,
BuildProducer aiServicesMethodProducer,
BuildProducer additionalBeanProducer,
BuildProducer unremovableBeanProducer,
Optional metricsCapability,
Capabilities capabilities) {
IndexView index = indexBuildItem.getIndex();
List aiServicesAnalysisResults = new ArrayList<>();
for (ClassInfo classInfo : index.getKnownUsers(LangChain4jDotNames.AI_SERVICES)) {
String className = classInfo.name().toString();
if (className.startsWith("io.quarkiverse.langchain4j") || className.startsWith("dev.langchain4j")) { // TODO: this can be made smarter if
// needed
continue;
}
try (InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(
className.replace('.', '/') + ".class")) {
if (is == null) {
return;
}
var cn = new ClassNode(Gizmo.ASM_API_VERSION);
var cr = new ClassReader(is);
cr.accept(cn, 0);
for (MethodNode method : cn.methods) {
aiServicesAnalysisResults.addAll(AiServicesUseAnalyzer.analyze(cn, method).entries);
}
} catch (IOException e) {
throw new UncheckedIOException("Reading bytecode of class '" + className + "' failed", e);
} catch (AnalyzerException e) {
log.debug("Unable to analyze bytecode of class '" + className + "'", e);
}
}
Map nameToUsed = aiServicesAnalysisResults.stream()
.collect(Collectors.toMap(e -> e.createdClassName, e -> e.chatMemoryProviderUsed, (u1, u2) -> u1 || u2));
for (var entry : nameToUsed.entrySet()) {
String className = entry.getKey();
ClassInfo classInfo = index.getClassByName(className);
if (classInfo == null) {
continue;
}
if (!classInfo.annotations(LangChain4jDotNames.MEMORY_ID).isEmpty() && !entry.getValue()) {
log.warn("Class '" + className
+ "' is used in AiServices and while it leverages @MemoryId, a ChatMemoryProvider has not been configured. This will likely result in an exception being thrown when the service is used.");
}
}
Set detectedForCreate = new HashSet<>(nameToUsed.keySet());
addCreatedAware(index, detectedForCreate);
addIfacesWithMessageAnns(index, detectedForCreate);
Set registeredAiServiceClassNames = declarativeAiServiceItems.stream()
.map(bi -> bi.getServiceClassInfo().name().toString()).collect(
Collectors.toUnmodifiableSet());
detectedForCreate.addAll(registeredAiServiceClassNames);
Set ifacesForCreate = new HashSet<>();
for (String className : detectedForCreate) {
ClassInfo classInfo = index.getClassByName(className);
if (classInfo == null) {
log.warn("'" + className
+ "' used for creating an AiService was not found in the Quarkus index. Attempting to create "
+ "an AiService using this class will fail");
continue;
}
if (!classInfo.isInterface()) {
log.warn("'" + className
+ "' used for creating an AiService is not an interface. Attempting to create an AiService "
+ "using this class will fail");
}
ifacesForCreate.add(classInfo);
}
var addMicrometerMetrics = metricsCapability.isPresent()
&& metricsCapability.get().metricsSupported(MetricsFactory.MICROMETER);
if (addMicrometerMetrics) {
additionalBeanProducer.produce(AdditionalBeanBuildItem.builder().addBeanClass(MetricsTimedWrapper.class).build());
additionalBeanProducer.produce(AdditionalBeanBuildItem.builder().addBeanClass(MetricsCountedWrapper.class).build());
}
var addOpenTelemetrySpan = capabilities.isPresent(Capability.OPENTELEMETRY_TRACER);
if (addOpenTelemetrySpan) {
additionalBeanProducer.produce(AdditionalBeanBuildItem.builder().addBeanClass(SpanWrapper.class).build());
}
Map perClassMetadata = new HashMap<>();
if (!ifacesForCreate.isEmpty()) {
ClassOutput generatedClassOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true);
ClassOutput generatedBeanOutput = new GeneratedBeanGizmoAdaptor(generatedBeanProducer);
for (ClassInfo iface : ifacesForCreate) {
List allMethods = new ArrayList<>(iface.methods());
JandexUtil.getAllSuperinterfaces(iface, index).forEach(ci -> allMethods.addAll(ci.methods()));
List methodsToImplement = new ArrayList<>();
Map perMethodMetadata = new HashMap<>();
for (MethodInfo method : allMethods) {
short modifiers = method.flags();
if (Modifier.isStatic(modifiers) || Modifier.isPrivate(modifiers) || JandexUtil.isDefault(
modifiers)) {
continue;
}
if (methodsToImplement.stream().anyMatch(m -> MethodUtil.methodSignaturesMatch(m, method))) {
continue;
}
methodsToImplement.add(method);
}
String ifaceName = iface.name().toString();
String implClassName = ifaceName + "$$QuarkusImpl";
boolean isRegisteredService = registeredAiServiceClassNames.contains(ifaceName);
ClassCreator.Builder classCreatorBuilder = ClassCreator.builder()
.classOutput(isRegisteredService ? generatedBeanOutput : generatedClassOutput)
.className(implClassName)
.interfaces(ifaceName, ChatMemoryRemovable.class.getName());
if (isRegisteredService) {
classCreatorBuilder.interfaces(AutoCloseable.class);
}
try (ClassCreator classCreator = classCreatorBuilder.build()) {
if (isRegisteredService) {
// we need to make this a bean, so we need to add the proper scope annotation
DotName scopeInfo = declarativeAiServiceItems.stream()
.filter(bi -> bi.getServiceClassInfo().equals(iface))
.findFirst().orElseThrow(() -> new IllegalStateException(
"Unable to determine the CDI scope of " + iface))
.getCdiScope();
classCreator.addAnnotation(scopeInfo.toString());
}
FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class)
.setModifiers(Modifier.PRIVATE | Modifier.FINAL)
.getFieldDescriptor();
for (MethodInfo methodInfo : methodsToImplement) {
// The implementation essentially gets the context and delegates to
// MethodImplementationSupport#implement
String methodId = createMethodId(methodInfo);
AiServiceMethodCreateInfo methodCreateInfo = gatherMethodMetadata(methodInfo, index,
addMicrometerMetrics,
addOpenTelemetrySpan,
config.responseSchema());
if (!methodCreateInfo.getToolClassNames().isEmpty()) {
unremovableBeanProducer.produce(UnremovableBeanBuildItem
.beanClassNames(methodCreateInfo.getToolClassNames().toArray(EMPTY_STRING_ARRAY)));
}
perMethodMetadata.put(methodId, methodCreateInfo);
{
MethodCreator ctor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V",
QuarkusAiServiceContext.class);
ctor.setModifiers(Modifier.PUBLIC);
ctor.addAnnotation(Inject.class);
ctor.getParameterAnnotations(0)
.addAnnotation(LangChain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER.toString())
.add("value", ifaceName);
ctor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, ctor.getThis());
ctor.writeInstanceField(contextField, ctor.getThis(),
ctor.getMethodParam(0));
ctor.returnValue(null);
}
{
MethodCreator noArgsCtor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V");
noArgsCtor.setModifiers(Modifier.PUBLIC);
noArgsCtor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, noArgsCtor.getThis());
noArgsCtor.writeInstanceField(contextField, noArgsCtor.getThis(), noArgsCtor.loadNull());
noArgsCtor.returnValue(null);
}
{ // actual method we need to implement
MethodCreator mc = classCreator.getMethodCreator(MethodDescriptor.of(methodInfo));
// copy annotations
for (AnnotationInstance annotationInstance : methodInfo.declaredAnnotations()) {
// TODO: we need to review this
if (annotationInstance.name().toString()
.startsWith("org.eclipse.microprofile.faulttolerance")
|| annotationInstance.name().toString()
.startsWith("io.smallrye.faulttolerance.api")) {
mc.addAnnotation(annotationInstance);
}
}
ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
ResultHandle methodCreateInfoHandle = mc.invokeStaticMethod(RECORDER_METHOD_CREATE_INFO,
mc.load(ifaceName),
mc.load(methodId));
ResultHandle paramsHandle = mc.newArray(Object.class, methodInfo.parametersCount());
for (int i = 0; i < methodInfo.parametersCount(); i++) {
mc.writeArrayValue(paramsHandle, i, mc.getMethodParam(i));
}
ResultHandle supportHandle = getFromCDI(mc, AiServiceMethodImplementationSupport.class.getName());
ResultHandle inputHandle = mc.newInstance(
MethodDescriptor.ofConstructor(AiServiceMethodImplementationSupport.Input.class,
QuarkusAiServiceContext.class, AiServiceMethodCreateInfo.class,
Object[].class),
contextHandle, methodCreateInfoHandle, paramsHandle);
ResultHandle resultHandle = mc.invokeVirtualMethod(SUPPORT_IMPLEMENT, supportHandle, inputHandle);
mc.returnValue(resultHandle);
aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo));
}
}
if (isRegisteredService) {
MethodCreator mc = classCreator.getMethodCreator(
MethodDescriptor.ofMethod(implClassName, "close", void.class));
mc.addAnnotation(PreDestroy.class);
ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
mc.invokeVirtualMethod(QUARKUS_AI_SERVICES_CONTEXT_CLOSE, contextHandle);
mc.returnVoid();
}
{
MethodCreator mc = classCreator.getMethodCreator(
MethodDescriptor.ofMethod(implClassName, "remove", void.class, Object[].class));
ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
mc.invokeVirtualMethod(QUARKUS_AI_SERVICES_CONTEXT_REMOVE_CHAT_MEMORY_IDS, contextHandle,
mc.getMethodParam(0));
mc.returnVoid();
}
}
perClassMetadata.put(ifaceName, new AiServiceClassCreateInfo(perMethodMetadata, implClassName));
// make the constructor accessible reflectively since that is how we create the instance
reflectiveClassProducer.produce(ReflectiveClassBuildItem.builder(implClassName).build());
}
}
recorder.setMetadata(perClassMetadata);
}
private ResultHandle getFromCDI(MethodCreator mc, String className) {
ResultHandle containerHandle = mc
.invokeStaticMethod(MethodDescriptor.ofMethod(Arc.class, "container", ArcContainer.class));
ResultHandle instanceHandle = mc.invokeInterfaceMethod(
MethodDescriptor.ofMethod(ArcContainer.class, "instance", InstanceHandle.class, Class.class,
Annotation[].class),
containerHandle, mc.loadClassFromTCCL(className),
mc.newArray(Annotation.class, 0));
return mc.invokeInterfaceMethod(MethodDescriptor.ofMethod(InstanceHandle.class, "get", Object.class), instanceHandle);
}
private String createMethodId(MethodInfo methodInfo) {
return methodInfo.name() + '('
+ Arrays.toString(methodInfo.parameters().stream().map(mp -> mp.type().name().toString()).toArray()) + ')';
}
private void addIfacesWithMessageAnns(IndexView index, Set detectedForCreate) {
List annotations = List.of(LangChain4jDotNames.SYSTEM_MESSAGE, LangChain4jDotNames.USER_MESSAGE,
LangChain4jDotNames.MODERATE);
for (DotName annotation : annotations) {
Collection instances = index.getAnnotations(annotation);
for (AnnotationInstance instance : instances) {
AnnotationTarget target = instance.target();
AnnotationTarget.Kind kind = target.kind();
if (kind == AnnotationTarget.Kind.METHOD) {
ClassInfo declaringClass = target.asMethod().declaringClass();
if (declaringClass.isInterface()) {
detectedForCreate.add(declaringClass.name().toString());
}
} else if (kind == AnnotationTarget.Kind.CLASS) {
ClassInfo declaringClass = target.asClass();
if (declaringClass.isInterface()) {
detectedForCreate.add(declaringClass.name().toString());
}
}
}
}
}
private static void addCreatedAware(IndexView index, Set detectedForCreate) {
Collection instances = index.getAnnotations(LangChain4jDotNames.CREATED_AWARE);
for (var instance : instances) {
if (instance.target().kind() != AnnotationTarget.Kind.CLASS) {
continue;
}
detectedForCreate.add(instance.target().asClass().name().toString());
}
}
private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, IndexView index, boolean addMicrometerMetrics,
boolean addOpenTelemetrySpans, boolean generateResponseSchema) {
validateReturnType(method);
boolean requiresModeration = method.hasAnnotation(LangChain4jDotNames.MODERATE);
java.lang.reflect.Type returnType = javaLangReturnType(method);
List params = method.parameters();
// TODO give user ability to provide custom OutputParser
String outputFormatInstructions = "";
if (generateResponseSchema && !returnType.equals(Multi.class))
outputFormatInstructions = SERVICE_OUTPUT_PARSER.outputFormatInstructions(returnType);
List templateParams = gatherTemplateParamInfo(params);
Optional systemMessageInfo = gatherSystemMessageInfo(method, templateParams);
AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo = gatherUserMessageInfo(method, templateParams);
AiServiceMethodCreateInfo.ResponseSchemaInfo responseSchemaInfo = ResponseSchemaInfo.of(generateResponseSchema,
systemMessageInfo,
userMessageInfo.template(), outputFormatInstructions);
if (!generateResponseSchema && responseSchemaInfo.isInSystemMessage())
throw new RuntimeException(
"The %s placeholder cannot be used if the property quarkus.langchain4j.response-schema is set to false. Found in: %s"
.formatted(ResponseSchemaUtil.placeholder(), method.declaringClass()));
if (!generateResponseSchema && responseSchemaInfo.isInUserMessage().isPresent()
&& responseSchemaInfo.isInUserMessage().get())
throw new RuntimeException(
"The %s placeholder cannot be used if the property quarkus.langchain4j.response-schema is set to false. Found in: %s"
.formatted(ResponseSchemaUtil.placeholder(), method.declaringClass()));
Optional memoryIdParamPosition = gatherMemoryIdParamName(method);
Optional metricsTimedInfo = gatherMetricsTimedInfo(method,
addMicrometerMetrics);
Optional metricsCountedInfo = gatherMetricsCountedInfo(method,
addMicrometerMetrics);
Optional spanInfo = gatherSpanInfo(method, addOpenTelemetrySpans);
List methodToolClassNames = gatherMethodToolClassNames(method);
return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
userMessageInfo, memoryIdParamPosition, requiresModeration,
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames);
}
private void validateReturnType(MethodInfo method) {
Type returnType = method.returnType();
Type.Kind returnTypeKind = returnType.kind();
if (returnTypeKind == Type.Kind.VOID) {
throw illegalConfiguration("Return type of method '%s' cannot be void", method);
}
if ((returnTypeKind != Type.Kind.CLASS) && (returnTypeKind != Type.Kind.PARAMETERIZED_TYPE)) {
throw illegalConfiguration("Unsupported type of method '%s", method);
}
}
private java.lang.reflect.Type javaLangReturnType(MethodInfo method) {
try {
Class> declaringClass = Class.forName(method.declaringClass().name().toString(), false,
Thread.currentThread().getContextClassLoader());
List> methodParamTypes = new ArrayList<>(method.parametersCount());
for (Type methodParamType : method.parameterTypes()) {
methodParamTypes.add(JandexUtil.load(methodParamType, Thread.currentThread().getContextClassLoader()));
}
return declaringClass.getDeclaredMethod(method.name(), methodParamTypes.toArray(EMPTY_CLASS_ARRAY))
.getGenericReturnType();
} catch (ClassNotFoundException | NoSuchMethodException e) {
throw new IllegalStateException(e);
}
}
private String returnTypeSignature(Type returnType, TypeArgMapper typeArgMapper) {
return AsmUtil.getSignature(returnType, typeArgMapper);
}
private List gatherTemplateParamInfo(List params) {
if (params.isEmpty()) {
return Collections.emptyList();
}
List templateParams = new ArrayList<>();
for (MethodParameterInfo param : params) {
List effectiveParamAnnotations = effectiveParamAnnotations(param);
if (effectiveParamAnnotations.isEmpty() // if a parameter has no annotations it is considered a template variable
|| effectiveParamAnnotations.stream().map(AnnotationInstance::name).anyMatch(MEMORY_ID::equals) // we allow @MemoryId parameters to be
// part of the template
) {
templateParams.add(new TemplateParameterInfo(param.position(), param.name()));
} else {
AnnotationInstance vInstance = param.annotation(V);
if (vInstance != null) {
AnnotationValue value = vInstance.value();
if (value != null) {
templateParams.add(new TemplateParameterInfo(param.position(), value.asString()));
}
}
}
}
if ((templateParams.size() == 1) && (params.size() == 1)) {
// the special 'it' param is supported when the method only has one parameter
templateParams.add(new TemplateParameterInfo(0, "it"));
}
if (!templateParams.isEmpty() && templateParams.stream().map(TemplateParameterInfo::name).allMatch(Objects::isNull)) {
log.warn(
"The application has been compiled without the '-parameters' being set flag on javac. Make sure your build tool is configured to pass this flag to javac, otherwise Quarkus LangChain4j is unlikely to work properly without it.");
}
return templateParams;
}
private List effectiveParamAnnotations(MethodParameterInfo param) {
return param.annotations().stream().filter(ai -> {
String name = ai.name().toString();
if (name.startsWith("kotlin") || name.startsWith("jakarta.validation.constraints")) {
return false;
}
if (name.endsWith("NotNull")) {
return false;
}
if (name.startsWith("io.opentelemetry")) {
return false;
}
return true;
}).collect(Collectors.toList());
}
private Optional gatherSystemMessageInfo(MethodInfo method,
List templateParams) {
AnnotationInstance instance = method.annotation(LangChain4jDotNames.SYSTEM_MESSAGE);
if (instance == null) { // try and see if the class is annotated with @SystemMessage
instance = method.declaringClass().declaredAnnotation(LangChain4jDotNames.SYSTEM_MESSAGE);
}
if (instance != null) {
String systemMessageTemplate = getTemplateFromAnnotationInstance(instance);
if (systemMessageTemplate.isEmpty()) {
throw illegalConfigurationForMethod("@SystemMessage's template parameter cannot be empty", method);
}
// TODO: we should probably add a lot more template validation here
return Optional.of(
AiServiceMethodCreateInfo.TemplateInfo.fromText(
systemMessageTemplate,
TemplateParameterInfo.toNameToArgsPositionMap(templateParams)));
}
return Optional.empty();
}
private Optional gatherMemoryIdParamName(MethodInfo method) {
return method.annotations(LangChain4jDotNames.MEMORY_ID).stream().filter(IS_METHOD_PARAMETER_ANNOTATION)
.map(METHOD_PARAMETER_POSITION_FUNCTION)
.findFirst();
}
private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodInfo method,
List templateParams) {
Optional userNameParamName = method.annotations(LangChain4jDotNames.USER_NAME).stream().filter(
IS_METHOD_PARAMETER_ANNOTATION).map(METHOD_PARAMETER_POSITION_FUNCTION).findFirst();
AnnotationInstance userMessageInstance = method.declaredAnnotation(LangChain4jDotNames.USER_MESSAGE);
if (userMessageInstance != null) {
String userMessageTemplate = getTemplateFromAnnotationInstance(userMessageInstance);
if (userMessageTemplate.contains("{{it}}")) {
if (method.parametersCount() != 1) {
throw illegalConfigurationForMethod(
"Error: The {{it}} placeholder is present but the method does not have exactly one parameter. " +
"Please ensure that methods using the {{it}} placeholder have exactly one parameter",
method);
}
}
// TODO: we should probably add a lot more template validation here
return AiServiceMethodCreateInfo.UserMessageInfo.fromTemplate(
AiServiceMethodCreateInfo.TemplateInfo.fromText(userMessageTemplate,
TemplateParameterInfo.toNameToArgsPositionMap(templateParams)),
userNameParamName);
} else {
Optional userMessageOnMethodParam = method.annotations(LangChain4jDotNames.USER_MESSAGE)
.stream()
.filter(IS_METHOD_PARAMETER_ANNOTATION).findFirst();
if (userMessageOnMethodParam.isPresent()) {
if (DotNames.STRING.equals(userMessageOnMethodParam.get().target().asMethodParameter().type().name())
&& !templateParams.isEmpty()) {
return AiServiceMethodCreateInfo.UserMessageInfo.fromTemplate(
AiServiceMethodCreateInfo.TemplateInfo.fromMethodParam(
Short.valueOf(userMessageOnMethodParam.get().target().asMethodParameter().position())
.intValue(),
TemplateParameterInfo.toNameToArgsPositionMap(templateParams)),
userNameParamName);
} else {
return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(
userMessageOnMethodParam.get().target().asMethodParameter().position(),
userNameParamName);
}
} else {
if (method.parametersCount() == 0) {
throw illegalConfigurationForMethod("Method should have at least one argument", method);
}
if (method.parametersCount() == 1) {
return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(0, userNameParamName);
}
throw illegalConfigurationForMethod(
"For methods with multiple parameters, each parameter must be annotated with @V (or match an template parameter by name), @UserMessage, @UserName or @MemoryId",
method);
}
}
}
/**
* Meant to be called with instances of {@link dev.langchain4j.service.SystemMessage} or
* {@link dev.langchain4j.service.UserMessage}
*
* @return the String value of the template or an empty string if not specified
*/
private String getTemplateFromAnnotationInstance(AnnotationInstance instance) {
AnnotationValue fromResourceValue = instance.value("fromResource");
if (fromResourceValue != null) {
String fromResource = fromResourceValue.asString();
if (!fromResource.startsWith("/")) {
fromResource = "/" + fromResource;
}
try (InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(fromResource)) {
if (is != null) {
return new String(is.readAllBytes());
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
} else {
AnnotationValue valueValue = instance.value();
if (valueValue != null) {
AnnotationValue delimiterValue = instance.value("delimiter");
String delimiter = delimiterValue != null ? delimiterValue.asString() : DEFAULT_DELIMITER;
return String.join(delimiter, valueValue.asStringArray());
}
}
return "";
}
private Optional gatherMetricsTimedInfo(MethodInfo method,
boolean addMicrometerMetrics) {
if (!addMicrometerMetrics) {
return Optional.empty();
}
String name = METRICS_DEFAULT_NAME;
List tags = defaultMetricsTags(method);
AnnotationInstance timedInstance = method.annotation(MICROMETER_TIMED);
if (timedInstance == null) {
timedInstance = method.declaringClass().declaredAnnotation(MICROMETER_TIMED);
}
if (timedInstance == null) {
// we default to having all AiServices being timed
return Optional.of(new AiServiceMethodCreateInfo.MetricsTimedInfo.Builder(name)
.setExtraTags(tags.toArray(EMPTY_STRING_ARRAY)).build());
}
AnnotationValue nameValue = timedInstance.value();
if (nameValue != null) {
String nameStr = nameValue.asString();
if (nameStr != null && !nameStr.isEmpty()) {
name = nameStr;
}
}
var builder = new AiServiceMethodCreateInfo.MetricsTimedInfo.Builder(name);
AnnotationValue extraTagsValue = timedInstance.value("extraTags");
if (extraTagsValue != null) {
tags.addAll(Arrays.asList(extraTagsValue.asStringArray()));
}
builder.setExtraTags(tags.toArray(EMPTY_STRING_ARRAY));
AnnotationValue longTaskValue = timedInstance.value("longTask");
if (longTaskValue != null) {
builder.setLongTask(longTaskValue.asBoolean());
}
AnnotationValue percentilesValue = timedInstance.value("percentiles");
if (percentilesValue != null) {
builder.setPercentiles(percentilesValue.asDoubleArray());
}
AnnotationValue histogramValue = timedInstance.value("histogram");
if (histogramValue != null) {
builder.setHistogram(histogramValue.asBoolean());
}
AnnotationValue descriptionValue = timedInstance.value("description");
if (descriptionValue != null) {
builder.setDescription(descriptionValue.asString());
}
return Optional.of(builder.build());
}
private Optional gatherMetricsCountedInfo(MethodInfo method,
boolean addMicrometerMetrics) {
if (!addMicrometerMetrics) {
return Optional.empty();
}
String name = METRICS_DEFAULT_NAME;
List tags = defaultMetricsTags(method);
AnnotationInstance timedInstance = method.annotation(MICROMETER_COUNTED);
if (timedInstance == null) {
timedInstance = method.declaringClass().declaredAnnotation(MICROMETER_COUNTED);
}
if (timedInstance == null) {
// we default to having all AiServices being timed
return Optional.of(new AiServiceMethodCreateInfo.MetricsCountedInfo.Builder(name)
.setExtraTags(tags.toArray(EMPTY_STRING_ARRAY)).build());
}
AnnotationValue nameValue = timedInstance.value();
if (nameValue != null) {
String nameStr = nameValue.asString();
if (nameStr != null && !nameStr.isEmpty()) {
name = nameStr;
}
}
var builder = new AiServiceMethodCreateInfo.MetricsCountedInfo.Builder(name);
AnnotationValue extraTagsValue = timedInstance.value("extraTags");
if (extraTagsValue != null) {
tags.addAll(Arrays.asList(extraTagsValue.asStringArray()));
}
builder.setExtraTags(tags.toArray(EMPTY_STRING_ARRAY));
AnnotationValue recordFailuresOnlyValue = timedInstance.value("recordFailuresOnly");
if (recordFailuresOnlyValue != null) {
builder.setRecordFailuresOnly(recordFailuresOnlyValue.asBoolean());
}
AnnotationValue descriptionValue = timedInstance.value("description");
if (descriptionValue != null) {
builder.setDescription(descriptionValue.asString());
}
return Optional.of(builder.build());
}
private List defaultMetricsTags(MethodInfo method) {
List tags = new ArrayList<>(4);
tags.add("aiservice");
tags.add(method.declaringClass().name().withoutPackagePrefix());
tags.add("method");
tags.add(method.name());
return tags;
}
private Optional gatherSpanInfo(MethodInfo method,
boolean addOpenTelemetrySpans) {
if (!addOpenTelemetrySpans) {
return Optional.empty();
}
String name = defaultAiServiceSpanName(method);
// TODO: add more
return Optional.of(new AiServiceMethodCreateInfo.SpanInfo(name));
}
private List gatherMethodToolClassNames(MethodInfo method) {
AnnotationInstance toolBoxInstance = method.declaredAnnotation(ToolBox.class);
if (toolBoxInstance == null) {
return Collections.emptyList();
}
AnnotationValue toolBoxValue = toolBoxInstance.value();
if (toolBoxValue == null) {
return Collections.emptyList();
}
Type[] toolClasses = toolBoxValue.asClassArray();
if (toolClasses.length == 0) {
return Collections.emptyList();
}
return Arrays.stream(toolClasses).map(t -> t.name().toString()).collect(Collectors.toList());
}
private DotName determineChatMemorySeeder(ClassInfo iface, ClassOutput classOutput) {
List annotations = iface.annotations(SEED_MEMORY);
if (annotations.isEmpty()) {
return null;
}
if (annotations.size() > 1) {
throw new IllegalConfigurationException(
"Only a single @SeedMemory annotation is allowed per AiService. Offending class is '" + iface.name() + "'");
}
AnnotationInstance seedMemoryInstance = annotations.get(0);
AnnotationTarget seedMemoryTarget = seedMemoryInstance.target();
if (seedMemoryTarget.kind() != AnnotationTarget.Kind.METHOD) {
throw new IllegalConfigurationException(
"The @SeedMemory annotation can only be placed on methods. Offending target is '" + seedMemoryTarget + "'");
}
return DotName.createSimple(generateAiServiceChatMemorySeeder(iface, seedMemoryTarget.asMethod(), classOutput));
}
/**
* Generates a class that looks like the following:
*
*
* {@code
* public class SomeAiService$$QuarkusChatMemorySeeder implements ChatMemorySeeder {
*
* @Override
* public List seed(Context context) {
* return SomeAiService.someMethod(context.methodName());
* }
* }
* }
*
*/
private String generateAiServiceChatMemorySeeder(ClassInfo iface, MethodInfo seedMemoryTargetMethod,
ClassOutput classOutput) {
if (!Modifier.isStatic(seedMemoryTargetMethod.flags())) {
throw new IllegalConfigurationException(
"The @SeedMemory annotation can only be placed on static methods. Offending method is '"
+ seedMemoryTargetMethod.declaringClass().name() + "#" + seedMemoryTargetMethod.name() + "'");
}
boolean hasListChatMessageReturnType = false;
if (seedMemoryTargetMethod.returnType().kind() == Type.Kind.PARAMETERIZED_TYPE) {
ParameterizedType parameterizedType = seedMemoryTargetMethod.returnType().asParameterizedType();
if (DotNames.LIST.equals(parameterizedType.name()) && (parameterizedType.arguments().size() == 1)) {
hasListChatMessageReturnType = LangChain4jDotNames.CHAT_MESSAGE
.equals(parameterizedType.arguments().get(0).name());
}
}
if (!hasListChatMessageReturnType) {
throw new IllegalConfigurationException(
"The @SeedMemory annotation can only be placed on methods that return List. Offending method is '"
+ seedMemoryTargetMethod.declaringClass().name() + "#" + seedMemoryTargetMethod.name() + "'");
}
String implClassName = iface.name() + "$$QuarkusChatMemorySeeder";
ClassCreator.Builder classCreatorBuilder = ClassCreator.builder()
.classOutput(classOutput)
.className(implClassName)
.interfaces(ChatMemorySeeder.class.getName());
try (ClassCreator classCreator = classCreatorBuilder.build()) {
MethodCreator methodCreator = classCreator.getMethodCreator("seed", List.class, ChatMemorySeeder.Context.class);
LinkedHashMap seedMemoryTargetMethodParams = new LinkedHashMap<>();
for (Type paramType : seedMemoryTargetMethod.parameterTypes()) {
ResultHandle targetMethodParamHandle;
if (paramType.name().equals(DotNames.STRING)) {
targetMethodParamHandle = methodCreator.invokeVirtualMethod(CHAT_MEMORY_SEEDER_CONTEXT_METHOD_NAME,
methodCreator.getMethodParam(0));
} else {
throw new IllegalConfigurationException(
"The @SeedMemory annotation can only be placed on methods can only take parameters of type 'String' (or no parameters at all). Offending method is '"
+ seedMemoryTargetMethod.declaringClass().name() + "#" + seedMemoryTargetMethod.name()
+ "'");
}
seedMemoryTargetMethodParams.put(paramType.name().toString(), targetMethodParamHandle);
}
if (seedMemoryTargetMethodParams.isEmpty()) {
ResultHandle resultHandle = methodCreator.invokeStaticInterfaceMethod(
MethodDescriptor.ofMethod(
seedMemoryTargetMethod.declaringClass().name().toString(),
seedMemoryTargetMethod.name(),
seedMemoryTargetMethod.returnType().name().toString()));
methodCreator.returnValue(resultHandle);
} else {
ResultHandle resultHandle = methodCreator.invokeStaticInterfaceMethod(
MethodDescriptor.ofMethod(
seedMemoryTargetMethod.declaringClass().name().toString(),
seedMemoryTargetMethod.name(),
seedMemoryTargetMethod.returnType().name().toString(),
seedMemoryTargetMethodParams.keySet().toArray(EMPTY_STRING_ARRAY)),
seedMemoryTargetMethodParams.values().toArray(EMPTY_RESULT_HANDLES_ARRAY));
methodCreator.returnValue(resultHandle);
}
}
return implClassName;
}
private String defaultAiServiceSpanName(MethodInfo method) {
return "langchain4j.aiservices." + method.declaringClass().name().withoutPackagePrefix() + "." + method.name();
}
private record TemplateParameterInfo(int position, String name) {
static Map toNameToArgsPositionMap(List list) {
return list.stream()
.collect(Collectors.toMap(TemplateParameterInfo::name, TemplateParameterInfo::position));
}
}
public static final class AiServicesMethodBuildItem extends MultiBuildItem {
private final MethodInfo methodInfo;
public AiServicesMethodBuildItem(MethodInfo methodInfo) {
this.methodInfo = methodInfo;
}
public MethodInfo getMethodInfo() {
return methodInfo;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy