All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wildfly.extension.ai.deployment.AIDependencyProcessor Maven / Gradle / Ivy

There is a newer version: 0.1.0
Show newest version
/*
 * Copyright The WildFly Authors
 * SPDX-License-Identifier: Apache-2.0
 */
package org.wildfly.extension.ai.deployment;

import static org.wildfly.extension.ai.AILogger.ROOT_LOGGER;
import static org.wildfly.extension.ai.Capabilities.CHAT_MODEL_PROVIDER_CAPABILITY;
import static org.wildfly.extension.ai.Capabilities.EMBEDDING_MODEL_PROVIDER_CAPABILITY;
import static org.wildfly.extension.ai.Capabilities.EMBEDDING_STORE_PROVIDER_CAPABILITY;

import jakarta.inject.Named;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.jboss.as.server.deployment.Attachments;
import org.jboss.as.server.deployment.DeploymentPhaseContext;
import org.jboss.as.server.deployment.DeploymentUnit;
import org.jboss.as.server.deployment.DeploymentUnitProcessingException;
import org.jboss.as.server.deployment.DeploymentUnitProcessor;
import org.jboss.as.server.deployment.annotation.CompositeIndex;
import org.jboss.as.server.deployment.module.ModuleDependency;
import org.jboss.as.server.deployment.module.ModuleSpecification;
import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.DotName;
import org.jboss.jandex.FieldInfo;
import org.jboss.jandex.Type;
import org.jboss.modules.Module;
import org.jboss.modules.ModuleLoader;
import org.wildfly.extension.ai.Capabilities;

/**
 *
 * @author Emmanuel Hugonnet (c) 2024 Red Hat, Inc.
 */
public class AIDependencyProcessor implements DeploymentUnitProcessor {

    public static final String[] OPTIONAL_MODULES = {
        "dev.langchain4j.openai",
        "dev.langchain4j.ollama",
        "dev.langchain4j.weaviate",
        "dev.langchain4j.web-search-engines"
    };

    public static final String[] EXPORTED_MODULES = {
        "dev.langchain4j",
        "io.smallrye.llm",
        "org.wildfly.extension.ai.injection"
    };

    @Override
    public void deploy(DeploymentPhaseContext deploymentPhaseContext) throws DeploymentUnitProcessingException {
        DeploymentUnit deploymentUnit = deploymentPhaseContext.getDeploymentUnit();
        ModuleSpecification moduleSpecification = deploymentUnit.getAttachment(Attachments.MODULE_SPECIFICATION);
        ModuleLoader moduleLoader = Module.getBootModuleLoader();
        for (String module : OPTIONAL_MODULES) {
            moduleSpecification.addSystemDependency(new ModuleDependency(moduleLoader, module, true, false, true, false));
        }
        for (String module : EXPORTED_MODULES) {
            ModuleDependency modDep = new ModuleDependency(moduleLoader, module, false, true, true, false);
            modDep.addImportFilter(s -> s.equals("META-INF"), true);
            moduleSpecification.addSystemDependency(modDep);
        }
        final CompositeIndex index = deploymentUnit.getAttachment(Attachments.COMPOSITE_ANNOTATION_INDEX);
        if (index == null) {
            throw ROOT_LOGGER.unableToResolveAnnotationIndex(deploymentUnit);
        }
        List annotations = index.getAnnotations(DotName.createSimple(Named.class));
        if (annotations == null || annotations.isEmpty()) {
            return;
        }
        Set requiredChatModels = new HashSet<>();
        Set requiredEmbeddingModels = new HashSet<>();
        Set requiredEmbeddingStores = new HashSet<>();
        Set requiredContentRetrievers = new HashSet<>();
        for (AnnotationInstance annotation : annotations) {
            if (annotation.target().kind() == AnnotationTarget.Kind.FIELD) {
                FieldInfo field = annotation.target().asField();
                if (field.type().kind() == Type.Kind.CLASS) {
                    try {
                        Class fieldClass = Class.forName(field.type().asClassType().name().toString());
                        if (dev.langchain4j.model.chat.ChatLanguageModel.class.isAssignableFrom(fieldClass)) {
                            ROOT_LOGGER.debug("We need the ChatLanguageModel in the class " + field.declaringClass());
                            String chatLanguageModelName = annotation.value().asString();
                            ROOT_LOGGER.debug("We need the ChatLanguageModel called " + chatLanguageModelName);
                            requiredChatModels.add(chatLanguageModelName);
                        } else if (dev.langchain4j.model.embedding.EmbeddingModel.class.isAssignableFrom(fieldClass)) {
                            ROOT_LOGGER.debug("We need the EmbeddingModel in the class " + field.declaringClass());
                            String embeddingModelName = annotation.value().asString();
                            ROOT_LOGGER.debug("We need the EmbeddingModel called " + embeddingModelName);
                            requiredEmbeddingModels.add(embeddingModelName);
                        } else if (dev.langchain4j.store.embedding.EmbeddingStore.class.isAssignableFrom(fieldClass)) {
                            ROOT_LOGGER.debug("We need the EmbeddingStore in the class " + field.declaringClass());
                            String embeddingStoreName = annotation.value().asString();
                            ROOT_LOGGER.debug("We need the EmbeddingStore called " + embeddingStoreName);
                            requiredEmbeddingStores.add(embeddingStoreName);
                        }else if (dev.langchain4j.rag.content.retriever.ContentRetriever.class.isAssignableFrom(fieldClass)) {
                            ROOT_LOGGER.debug("We need the ContentRetriever in the class " + field.declaringClass());
                            String contentRetrieverName = annotation.value().asString();
                            ROOT_LOGGER.debug("We need the ContentRetriever called " + contentRetrieverName);
                            requiredContentRetrievers.add(contentRetrieverName);
                        }
                    } catch (ClassNotFoundException ex) {
                        ROOT_LOGGER.error("Coudln't get the class type for " + field.type().asClassType().name().toString() + " to be able to check what to inject", ex);
                    }
                }
            }
        }
        if (!requiredChatModels.isEmpty() || !requiredEmbeddingModels.isEmpty() || !requiredEmbeddingStores.isEmpty()) {
            if (!requiredChatModels.isEmpty()) {
                for (String chatLanguageModelName : requiredChatModels) {
                    deploymentUnit.addToAttachmentList(AIAttachements.CHAT_MODEL_KEYS, chatLanguageModelName);
                    deploymentPhaseContext.addDeploymentDependency(CHAT_MODEL_PROVIDER_CAPABILITY.getCapabilityServiceName(chatLanguageModelName), AIAttachements.CHAT_MODELS);
                }
            }
            if (!requiredEmbeddingModels.isEmpty()) {
                for (String embeddingModelName : requiredEmbeddingModels) {
                    deploymentUnit.addToAttachmentList(AIAttachements.EMBEDDING_MODEL_KEYS, embeddingModelName);
                    deploymentPhaseContext.addDeploymentDependency(EMBEDDING_MODEL_PROVIDER_CAPABILITY.getCapabilityServiceName(embeddingModelName), AIAttachements.EMBEDDING_MODELS);
                }
            }
            if (!requiredEmbeddingStores.isEmpty()) {
                for (String embeddingStoreName : requiredEmbeddingStores) {
                    deploymentUnit.addToAttachmentList(AIAttachements.EMBEDDING_STORE_KEYS, embeddingStoreName);
                    deploymentPhaseContext.addDeploymentDependency(EMBEDDING_STORE_PROVIDER_CAPABILITY.getCapabilityServiceName(embeddingStoreName), AIAttachements.EMBEDDING_STORES);
                }
            }
            if (!requiredContentRetrievers.isEmpty()) {
                for (String contentRetrieverName : requiredContentRetrievers) {
                    deploymentUnit.addToAttachmentList(AIAttachements.CONTENT_RETRIEVER_KEYS, contentRetrieverName);
                    deploymentPhaseContext.addDeploymentDependency(Capabilities.CONTENT_RETRIEVER_PROVIDER_CAPABILITY.getCapabilityServiceName(contentRetrieverName), AIAttachements.CONTENT_RETRIEVERS);
                }
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy