All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.quarkiverse.langchain4j.huggingface.runtime.HuggingFaceRecorder Maven / Gradle / Ivy

There is a newer version: 0.23.0.CR1
Show newest version
package io.quarkiverse.langchain4j.huggingface.runtime;

import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault;

import java.net.URL;
import java.time.Duration;
import java.util.function.Supplier;

import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.DisabledChatLanguageModel;
import dev.langchain4j.model.embedding.DisabledEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import io.quarkiverse.langchain4j.huggingface.QuarkusHuggingFaceChatModel;
import io.quarkiverse.langchain4j.huggingface.QuarkusHuggingFaceEmbeddingModel;
import io.quarkiverse.langchain4j.huggingface.runtime.config.ChatModelConfig;
import io.quarkiverse.langchain4j.huggingface.runtime.config.EmbeddingModelConfig;
import io.quarkiverse.langchain4j.huggingface.runtime.config.LangChain4jHuggingFaceConfig;
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
import io.quarkus.runtime.annotations.Recorder;
import io.smallrye.config.ConfigValidationException;

@Recorder
public class HuggingFaceRecorder {

    private static final String DUMMY_KEY = "dummy";
    private static final String HUGGING_FACE_URL_MARKER = "api-inference.huggingface.co";

    public Supplier chatModel(LangChain4jHuggingFaceConfig runtimeConfig, String configName) {
        LangChain4jHuggingFaceConfig.HuggingFaceConfig huggingFaceConfig = correspondingHuggingFaceConfig(runtimeConfig,
                configName);

        if (huggingFaceConfig.enableIntegration()) {
            String apiKey = huggingFaceConfig.apiKey();
            ChatModelConfig chatModelConfig = huggingFaceConfig.chatModel();
            URL url = chatModelConfig.inferenceEndpointUrl();

            if (DUMMY_KEY.equals(apiKey) && url.toExternalForm().contains(HUGGING_FACE_URL_MARKER)) { // when using the default base URL an API key is required
                throw new ConfigValidationException(createApiKeyConfigProblem(configName));
            }

            var builder = QuarkusHuggingFaceChatModel.builder()
                    .url(url)
                    .timeout(huggingFaceConfig.timeout().orElse(Duration.ofSeconds(10)))
                    .temperature(chatModelConfig.temperature())
                    .waitForModel(chatModelConfig.waitForModel())
                    .doSample(chatModelConfig.doSample())
                    .topP(chatModelConfig.topP())
                    .topK(chatModelConfig.topK())
                    .repetitionPenalty(chatModelConfig.repetitionPenalty())
                    .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), huggingFaceConfig.logRequests()))
                    .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), huggingFaceConfig.logResponses()))
                    .returnFullText(chatModelConfig.returnFullText());

            if (!DUMMY_KEY.equals(apiKey)) {
                builder.accessToken(apiKey);
            }

            if (chatModelConfig.maxNewTokens().isPresent()) {
                builder.maxNewTokens(chatModelConfig.maxNewTokens().get());
            }

            return new Supplier<>() {
                @Override
                public ChatLanguageModel get() {
                    return builder.build();
                }
            };
        } else {
            return new Supplier<>() {
                @Override
                public ChatLanguageModel get() {
                    return new DisabledChatLanguageModel();
                }
            };
        }
    }

    public Supplier embeddingModel(LangChain4jHuggingFaceConfig runtimeConfig, String configName) {
        LangChain4jHuggingFaceConfig.HuggingFaceConfig huggingFaceConfig = correspondingHuggingFaceConfig(runtimeConfig,
                configName);

        if (huggingFaceConfig.enableIntegration()) {
            String apiKey = huggingFaceConfig.apiKey();
            EmbeddingModelConfig embeddingModelConfig = huggingFaceConfig.embeddingModel();
            URL url = embeddingModelConfig.inferenceEndpointUrl();

            if (DUMMY_KEY.equals(apiKey) && url.toExternalForm().contains(HUGGING_FACE_URL_MARKER)) { // when using the default base URL an API key is required
                throw new ConfigValidationException(createApiKeyConfigProblem(configName));
            }

            var builder = QuarkusHuggingFaceEmbeddingModel.builder()
                    .url(url)
                    .timeout(huggingFaceConfig.timeout().orElse(Duration.ofSeconds(10)))
                    .waitForModel(embeddingModelConfig.waitForModel());

            if (!DUMMY_KEY.equals(apiKey)) {
                builder.accessToken(apiKey);
            }

            return new Supplier<>() {
                @Override
                public EmbeddingModel get() {
                    return builder.build();
                }
            };
        } else {
            return new Supplier<>() {
                @Override
                public EmbeddingModel get() {
                    return new DisabledEmbeddingModel();
                }
            };
        }
    }

    private LangChain4jHuggingFaceConfig.HuggingFaceConfig correspondingHuggingFaceConfig(
            LangChain4jHuggingFaceConfig runtimeConfig, String configName) {
        LangChain4jHuggingFaceConfig.HuggingFaceConfig huggingFaceConfig;
        if (NamedConfigUtil.isDefault(configName)) {
            huggingFaceConfig = runtimeConfig.defaultConfig();
        } else {
            huggingFaceConfig = runtimeConfig.namedConfig().get(configName);
        }
        return huggingFaceConfig;
    }

    private ConfigValidationException.Problem[] createApiKeyConfigProblem(String configName) {
        return createConfigProblems("api-key", configName);
    }

    private ConfigValidationException.Problem[] createConfigProblems(String key, String configName) {
        return new ConfigValidationException.Problem[] { createConfigProblem(key, configName) };
    }

    private static ConfigValidationException.Problem createConfigProblem(String key, String configName) {
        return new ConfigValidationException.Problem(String.format(
                "SRCFG00014: The config property quarkus.langchain4j.huggingface%s%s is required but it could not be found in any config source",
                NamedConfigUtil.isDefault(configName) ? "." : ("." + configName + "."), key));
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy