All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.quarkiverse.langchain4j.llama3.Llama3ChatModel Maven / Gradle / Ivy

The newest version!
package io.quarkiverse.langchain4j.llama3;

import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static io.quarkiverse.langchain4j.llama3.MessageMapper.toLlama3Message;
import static io.quarkiverse.langchain4j.llama3.copy.Llama3.selectSampler;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import org.jboss.logging.Logger;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import io.quarkiverse.langchain4j.llama3.copy.ChatFormat;
import io.quarkiverse.langchain4j.llama3.copy.Llama;
import io.quarkiverse.langchain4j.llama3.copy.Llama3;
import io.quarkiverse.langchain4j.llama3.copy.Sampler;

public class Llama3ChatModel implements ChatLanguageModel {

    private static final Logger log = Logger.getLogger(Llama3ChatModel.class);

    private final Path modelPath;
    private final Llama model;
    private final Float temperature;
    private final Integer maxTokens;
    private final Float topP;
    private final Integer seed;
    private final boolean logRequests;
    private final boolean logResponses;

    public Llama3ChatModel(Builder builder) {
        Llama3ModelRegistry llama3ModelRegistry = Llama3ModelRegistry.getOrCreate(builder.modelCachePath);
        try {
            modelPath = llama3ModelRegistry.downloadModel(builder.modelName, builder.quantization,
                    Optional.ofNullable(builder.authToken), Optional.empty());
            model = llama3ModelRegistry.loadModel(builder.modelName, builder.quantization, builder.maxTokens, true);
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        temperature = builder.temperature;
        maxTokens = builder.maxTokens;
        topP = builder.topP;
        seed = builder.seed;
        logRequests = builder.logRequests;
        logResponses = builder.logResponses;
    }

    @Override
    public Response generate(List messages) {

        if (logRequests) {
            log.info("Request: " + messages);
        }

        List llama3Messages = new ArrayList<>();
        for (ChatMessage message : messages) {
            llama3Messages.add(toLlama3Message(message));
        }

        Llama3.Options options = new Llama3.Options(
                modelPath,
                "", // unused
                "", // unused
                false,
                temperature,
                topP,
                seed,
                maxTokens,
                false, // stream
                false // echo
        );
        Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(),
                options.seed());
        InferenceResponse inferenceResponse = runInference(model, sampler, options, llama3Messages);

        var response = Response.from(aiMessage(inferenceResponse.text()),
                new TokenUsage(inferenceResponse.promptTokens(), inferenceResponse.responseTokens()));

        if (logResponses) {
            log.info("Response: " + response);
        }

        return response;
    }

    private InferenceResponse runInference(Llama model, Sampler sampler, Llama3.Options options,
            List messages) {
        Llama.State state = model.createNewState(Llama3.BATCH_SIZE);
        ChatFormat chatFormat = new ChatFormat(model.tokenizer());

        List promptTokens = new ArrayList<>(chatFormat.encodeDialogPrompt(true, messages));

        Set stopTokens = chatFormat.getStopTokens();
        List responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(),
                sampler, options.echo(), token -> {
                    if (options.stream()) {
                        if (!model.tokenizer().isSpecialToken(token)) {
                            System.out.print(model.tokenizer().decode(List.of(token)));
                        }
                    }
                });
        if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
            responseTokens.removeLast();
        }

        return new InferenceResponse(model.tokenizer().decode(responseTokens), promptTokens.size(), responseTokens.size());
    }

    record InferenceResponse(String text, int promptTokens, int responseTokens) {

    }

    public static Builder builder() {
        return new Builder();
    }

    @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
    public static class Builder {

        private Optional modelCachePath;
        private String modelName = Consts.DEFAULT_CHAT_MODEL_NAME;
        private String quantization = Consts.DEFAULT_CHAT_MODEL_QUANTIZATION;
        private String authToken;
        private Integer maxTokens = 4_000;
        private Float temperature = 0.7f;
        private Float topP = 0.95f;
        private Integer seed = 17;
        private boolean logRequests;
        private boolean logResponses;

        public Builder modelCachePath(Optional modelCachePath) {
            this.modelCachePath = modelCachePath;
            return this;
        }

        public Builder modelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public Builder quantization(String quantization) {
            this.quantization = quantization;
            return this;
        }

        public Builder authToken(String authToken) {
            this.authToken = authToken;
            return this;
        }

        public Builder temperature(Float temperature) {
            this.temperature = temperature;
            return this;
        }

        public Builder maxTokens(Integer maxTokens) {
            this.maxTokens = maxTokens;
            return this;
        }

        public Builder topP(Float topP) {
            this.topP = topP;
            return this;
        }

        public Builder seed(Integer seed) {
            this.seed = seed;
            return this;
        }

        public Builder logRequests(boolean logRequests) {
            this.logRequests = logRequests;
            return this;
        }

        public Builder logResponses(boolean logResponses) {
            this.logResponses = logResponses;
            return this;
        }

        public Llama3ChatModel build() {
            return new Llama3ChatModel(this);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy