All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wildfly.extension.ai.chat.OpenAIChatModelProviderServiceConfigurator Maven / Gradle / Ivy

There is a newer version: 0.1.0
Show newest version
/*
 * Copyright The WildFly Authors
 * SPDX-License-Identifier: Apache-2.0
 */
package org.wildfly.extension.ai.chat;

import static org.wildfly.extension.ai.AIAttributeDefinitions.API_KEY;
import static org.wildfly.extension.ai.AIAttributeDefinitions.BASE_URL;
import static org.wildfly.extension.ai.AIAttributeDefinitions.CONNECT_TIMEOUT;
import static org.wildfly.extension.ai.AIAttributeDefinitions.LOG_REQUESTS;
import static org.wildfly.extension.ai.AIAttributeDefinitions.LOG_RESPONSES;
import static org.wildfly.extension.ai.AIAttributeDefinitions.MAX_TOKEN;
import static org.wildfly.extension.ai.AIAttributeDefinitions.MODEL_NAME;
import static org.wildfly.extension.ai.AIAttributeDefinitions.RESPONSE_FORMAT;
import static org.wildfly.extension.ai.AIAttributeDefinitions.TEMPERATURE;
import static org.wildfly.extension.ai.AIAttributeDefinitions.TOP_P;
import static org.wildfly.extension.ai.chat.OpenAIChatLanguageModelProviderRegistrar.FREQUENCY_PENALTY;
import static org.wildfly.extension.ai.chat.OpenAIChatLanguageModelProviderRegistrar.ORGANIZATION_ID;
import static org.wildfly.extension.ai.chat.OpenAIChatLanguageModelProviderRegistrar.PRESENCE_PENALTY;
import static org.wildfly.extension.ai.chat.OpenAIChatLanguageModelProviderRegistrar.SEED;

import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import java.time.Duration;
import java.util.function.Supplier;
import org.jboss.as.controller.OperationContext;
import org.jboss.as.controller.OperationFailedException;
import org.jboss.dmr.ModelNode;
import org.wildfly.extension.ai.AIAttributeDefinitions;
import org.wildfly.service.capture.ValueRegistry;
import org.wildfly.subsystem.service.ResourceServiceInstaller;

/**
 * Configures an aggregate ChatModel provider service.
 */
public class OpenAIChatModelProviderServiceConfigurator extends AbstractChatModelProviderServiceConfigurator {

    public OpenAIChatModelProviderServiceConfigurator(ValueRegistry registry) {
        super(registry);
    }

    @Override
    public ResourceServiceInstaller configure(OperationContext context, ModelNode model) throws OperationFailedException {
        String baseUrl = BASE_URL.resolveModelAttribute(context, model).asString();
        Long connectTimeOut = CONNECT_TIMEOUT.resolveModelAttribute(context, model).asLong();
        Double frequencyPenalty = FREQUENCY_PENALTY.resolveModelAttribute(context, model).asDoubleOrNull();
        String key = API_KEY.resolveModelAttribute(context, model).asString();
        String organizationId = ORGANIZATION_ID.resolveModelAttribute(context, model).asString();
        String modelName = MODEL_NAME.resolveModelAttribute(context, model).asString();
        Integer maxToken = MAX_TOKEN.resolveModelAttribute(context, model).asIntOrNull();
        Double presencePenalty = PRESENCE_PENALTY.resolveModelAttribute(context, model).asDoubleOrNull();
        Boolean logRequests = LOG_REQUESTS.resolveModelAttribute(context, model).asBooleanOrNull();
        Boolean logResponses = LOG_RESPONSES.resolveModelAttribute(context, model).asBooleanOrNull();
        Integer seed = SEED.resolveModelAttribute(context, model).asIntOrNull();
        Double temperature = TEMPERATURE.resolveModelAttribute(context, model).asDoubleOrNull();
        Double topP = TOP_P.resolveModelAttribute(context, model).asDoubleOrNull();
        boolean isJson = AIAttributeDefinitions.ResponseFormat.isJson(RESPONSE_FORMAT.resolveModelAttribute(context, model).asStringOrNull());
        Supplier factory = new Supplier<>() {
            @Override
            public ChatLanguageModel get() {
                OpenAiChatModel.OpenAiChatModelBuilder builder = OpenAiChatModel.builder()
                        .apiKey(key)
                        .baseUrl(baseUrl)
                        .frequencyPenalty(frequencyPenalty)
                        .logRequests(logRequests)
                        .logResponses(logResponses)
                        .maxRetries(5)
                        .maxTokens(maxToken)
                        .modelName(modelName)
                        .organizationId(organizationId)
                        .presencePenalty(presencePenalty)
                        .seed(seed)
                        .temperature(temperature)
                        .timeout(Duration.ofMillis(connectTimeOut))
                        .topP(topP);
                if (isJson) {
                    builder.responseFormat("json_object");
                }
                return builder.build();
            }
        };
        return installService(context.getCurrentAddressValue(), factory);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy