All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.quarkiverse.langchain4j.huggingface.QuarkusHuggingFaceChatModel Maven / Gradle / Ivy

There is a newer version: 0.23.0.CR1
Show newest version
package io.quarkiverse.langchain4j.huggingface;

import static java.util.stream.Collectors.joining;

import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.OptionalInt;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.huggingface.client.HuggingFaceClient;
import dev.langchain4j.model.huggingface.client.Options;
import dev.langchain4j.model.huggingface.client.Parameters;
import dev.langchain4j.model.huggingface.client.TextGenerationRequest;
import dev.langchain4j.model.huggingface.client.TextGenerationResponse;
import dev.langchain4j.model.huggingface.spi.HuggingFaceClientFactory;
import dev.langchain4j.model.output.Response;
import io.quarkiverse.langchain4j.huggingface.runtime.config.ChatModelConfig;

/**
 * This is a Quarkus specific version of the HuggingFace model.
 * 

* TODO: remove this in the future when the stock {@link dev.langchain4j.model.huggingface.HuggingFaceChatModel} * has been updated to fit our needs (i.e. allowing {@code returnFullText} to be null and making {code accessToken} optional) */ public class QuarkusHuggingFaceChatModel implements ChatLanguageModel { public static final QuarkusHuggingFaceClientFactory CLIENT_FACTORY = new QuarkusHuggingFaceClientFactory(); private final HuggingFaceClient client; private final Double temperature; private final Integer maxNewTokens; private final Boolean returnFullText; private final Boolean waitForModel; private final Optional doSample; private final OptionalDouble topP; private final OptionalInt topK; private final OptionalDouble repetitionPenalty; private QuarkusHuggingFaceChatModel(Builder builder) { this.client = CLIENT_FACTORY.create(builder, new HuggingFaceClientFactory.Input() { @Override public String apiKey() { return builder.accessToken; } @Override public String modelId() { throw new UnsupportedOperationException("Should not be called"); } @Override public Duration timeout() { return builder.timeout; } }, builder.url); this.temperature = builder.temperature; this.maxNewTokens = builder.maxNewTokens; this.returnFullText = builder.returnFullText; this.waitForModel = builder.waitForModel; this.doSample = builder.doSample; this.topP = builder.topP; this.topK = builder.topK; this.repetitionPenalty = builder.repetitionPenalty; } public static Builder builder() { return new Builder(); } @Override public Response generate(List messages) { Parameters.Builder builder = Parameters.builder() .temperature(temperature) .maxNewTokens(maxNewTokens) .returnFullText(returnFullText); doSample.ifPresent(builder::doSample); topK.ifPresent(builder::topK); topP.ifPresent(builder::topP); repetitionPenalty.ifPresent(builder::repetitionPenalty); Parameters parameters = builder .build(); TextGenerationRequest request = TextGenerationRequest.builder() .inputs(messages.stream() .map(ChatMessage::text) .collect(joining("\n"))) .parameters(parameters) .options(Options.builder() .waitForModel(waitForModel) .build()) .build(); TextGenerationResponse textGenerationResponse = client.chat(request); return Response.from(AiMessage.from(textGenerationResponse.getGeneratedText())); } @Override public Response generate(List messages, List toolSpecifications) { throw new IllegalArgumentException("Tools are currently not supported for HuggingFace models"); } @Override public Response generate(List messages, ToolSpecification toolSpecification) { throw new IllegalArgumentException("Tools are currently not supported for HuggingFace models"); } public static final class Builder { private String accessToken; private Duration timeout = Duration.ofSeconds(15); private Double temperature; private Integer maxNewTokens; private Boolean returnFullText; private Boolean waitForModel = true; private URI url = URI.create(ChatModelConfig.DEFAULT_INFERENCE_ENDPOINT); private Optional doSample; private OptionalInt topK; private OptionalDouble topP; private OptionalDouble repetitionPenalty; public boolean logResponses; public boolean logRequests; public Builder accessToken(String accessToken) { this.accessToken = accessToken; return this; } public Builder url(URL url) { try { this.url = url.toURI(); } catch (URISyntaxException e) { throw new RuntimeException(e); } return this; } public Builder timeout(Duration timeout) { this.timeout = timeout; return this; } public Builder temperature(Double temperature) { this.temperature = temperature; return this; } public Builder maxNewTokens(Integer maxNewTokens) { this.maxNewTokens = maxNewTokens; return this; } public Builder returnFullText(Boolean returnFullText) { this.returnFullText = returnFullText; return this; } public Builder waitForModel(Boolean waitForModel) { this.waitForModel = waitForModel; return this; } public Builder doSample(Optional doSample) { this.doSample = doSample; return this; } public Builder topK(OptionalInt topK) { this.topK = topK; return this; } public Builder topP(OptionalDouble topP) { this.topP = topP; return this; } public Builder repetitionPenalty(OptionalDouble repetitionPenalty) { this.repetitionPenalty = repetitionPenalty; return this; } public QuarkusHuggingFaceChatModel build() { return new QuarkusHuggingFaceChatModel(this); } public Builder logRequests(boolean logRequests) { this.logRequests = logRequests; return this; } public Builder logResponses(boolean logResponses) { this.logResponses = logResponses; return this; } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy