All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.tjake.jlama.cli.commands.ChatCommand Maven / Gradle / Ivy

There is a newer version: 0.8.2
Show newest version
/*
 * Copyright 2024 T Jake Luciani
 *
 * The Jlama Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package com.github.tjake.jlama.cli.commands;

import static com.github.tjake.jlama.model.ModelSupport.loadModel;

import com.diogonunes.jcolor.AnsiFormat;
import com.diogonunes.jcolor.Attribute;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;

import java.io.PrintWriter;
import java.nio.file.Path;
import java.util.Optional;
import java.util.Scanner;
import java.util.UUID;
import java.util.function.BiConsumer;

import picocli.CommandLine.*;

@Command(name = "chat", description = "Interact with the specified model", abbreviateSynopsis = true)
public class ChatCommand extends ModelBaseCommand {
    private static final AnsiFormat chatText = new AnsiFormat(Attribute.CYAN_TEXT());
    private static final AnsiFormat statsColor = new AnsiFormat(Attribute.BLUE_TEXT());

    @Option(names = { "--system-prompt" }, paramLabel = "ARG", description = "Change the default system prompt for this model")
    String systemPrompt = null;

    @Override
    public void run() {
        Path modelPath = SimpleBaseCommand.getModel(
            modelName,
            modelDirectory,
            downloadSection.autoDownload,
            downloadSection.branch,
            downloadSection.authToken
        );
        AbstractModel m = loadModel(
            modelPath.toFile(),
            workingDirectory,
            advancedSection.workingMemoryType,
            advancedSection.workingQuantizationType,
            Optional.ofNullable(advancedSection.modelQuantization),
            Optional.ofNullable(advancedSection.threadCount)
        );

        if (m.promptSupport().isEmpty()) {
            System.err.println("This model does not support chat prompting");
            System.exit(1);
        }

        UUID session = UUID.randomUUID();
        PromptSupport promptSupport = m.promptSupport().get();
        PrintWriter out = System.console().writer();

        out.println("\nChatting with " + modelName + "...\n");
        out.flush();
        Scanner sc = new Scanner(System.in);
        boolean first = true;
        while (true) {
            out.print("\nYou: ");
            out.flush();
            String prompt = sc.nextLine();
            out.println();
            out.flush();
            if (prompt.isEmpty()) {
                break;
            }

            PromptSupport.Builder builder = promptSupport.builder();
            if (first && systemPrompt != null) {
                builder.addSystemMessage(systemPrompt);
            }
            builder.addUserMessage(prompt);
            PromptContext builtPrompt = builder.build();

            Generator.Response r = m.generate(session, builtPrompt, temperature, Integer.MAX_VALUE, makeOutHandler());

            out.println(
                "\n\n"
                    + statsColor.format(
                        Math.round(r.promptTimeMs / (double) r.promptTokens)
                            + " ms/tok (prompt), "
                            + Math.round(r.generateTimeMs / (double) r.generatedTokens)
                            + " ms/tok (gen)"
                    )
            );

            first = false;
        }
    }

    protected BiConsumer makeOutHandler() {
        PrintWriter out;
        BiConsumer outCallback;

        out = System.console().writer();
        out.print(chatText.format("Jlama: "));
        out.flush();
        outCallback = (w, t) -> {
            out.print(chatText.format(w));
            out.flush();
        };

        return outCallback;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy