All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.tjake.jlama.cli.commands.SimpleBaseCommand Maven / Gradle / Ivy

There is a newer version: 0.8.2
Show newest version
/*
 * Copyright 2024 T Jake Luciani
 *
 * The Jlama Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package com.github.tjake.jlama.cli.commands;

import com.github.tjake.jlama.cli.JlamaCli;
import java.io.File;
import java.io.IOException;
import java.net.URLEncoder;
import java.nio.file.Path;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.github.tjake.jlama.util.TriConsumer;
import com.google.common.util.concurrent.Uninterruptibles;
import me.tongfei.progressbar.ProgressBar;
import me.tongfei.progressbar.ProgressBarBuilder;
import me.tongfei.progressbar.ProgressBarStyle;
import picocli.CommandLine;

public class SimpleBaseCommand extends JlamaCli {
    static AtomicReference progressRef = new AtomicReference<>();

    @CommandLine.ArgGroup(exclusive = false, heading = "Download Options:%n")
    protected DownloadSection downloadSection = new DownloadSection();

    @CommandLine.Option(names = {
        "--model-cache" }, paramLabel = "ARG", description = "The local directory for downloaded models (default: ${DEFAULT-VALUE})", defaultValue = "models")
    protected File modelDirectory = new File("models");

    @CommandLine.Parameters(index = "0", arity = "1", paramLabel = "", description = "The huggingface model owner/name pair")
    protected String modelName;

    static class DownloadSection {
        @CommandLine.Option(names = {
            "--auto-download" }, paramLabel = "ARG", description = "Download the model if missing (default: ${DEFAULT-VALUE})", defaultValue = "false")
        Boolean autoDownload = false;

        @CommandLine.Option(names = {
            "--branch" }, paramLabel = "ARG", description = "The model branch to download from (default: ${DEFAULT-VALUE})", defaultValue = "main")
        String branch = "main";

        @CommandLine.Option(names = { "--auth-token" }, paramLabel = "ARG", description = "HuggingFace auth token (for restricted models)")
        String authToken = null;
    }

    static String getOwner(String modelName) {
        String[] parts = modelName.split("/");
        if (parts.length == 0 || parts.length > 2) {
            System.err.println("Model name must be in the form owner/name");
            System.exit(1);
        }
        return parts[0];
    }

    static String getName(String modelName) {
        String[] parts = modelName.split("/");
        if (parts.length == 0 || parts.length > 2) {
            System.err.println("Model name must be in the form owner/name");
            System.exit(1);
        }
        return parts[1];
    }

    static Optional> getProgressConsumer() {
        if (System.console() == null) return Optional.empty();

        return Optional.of((n, c, t) -> {
            if (progressRef.get() == null || !progressRef.get().getTaskName().equals(n)) {
                ProgressBarBuilder builder = new ProgressBarBuilder().setTaskName(n).setInitialMax(t).setStyle(ProgressBarStyle.ASCII);

                if (t > 1000000) {
                    builder.setUnit("MB", 1000000);
                } else if (t > 1000) {
                    builder.setUnit("KB", 1000);
                } else {
                    builder.setUnit("B", 1);
                }

                progressRef.set(builder.build());
            }

            progressRef.get().stepTo(c);
            Uninterruptibles.sleepUninterruptibly(150, TimeUnit.MILLISECONDS);
        });
    }

    static void downloadModel(String owner, String name, File modelDirectory, String branch, String authToken, boolean downloadWeights) {
        try {
            SafeTensorSupport.maybeDownloadModel(
                modelDirectory.getAbsolutePath(),
                Optional.ofNullable(owner),
                name,
                downloadWeights,
                Optional.ofNullable(URLEncoder.encode(branch)),
                Optional.ofNullable(authToken),
                getProgressConsumer()
            );
        } catch (IOException e) {
            e.printStackTrace();
            System.exit(1);
        }
    }

    static Path getModel(String modelName, File modelDirectory, boolean autoDownload, String branch, String authToken) {
        return getModel(modelName, modelDirectory, autoDownload, branch, authToken, true);
    }

    static Path getModel(
        String modelName,
        File modelDirectory,
        boolean autoDownload,
        String branch,
        String authToken,
        boolean downloadWeights
    ) {
        String owner = getOwner(modelName);
        String name = getName(modelName);

        Path modelPath = SafeTensorSupport.constructLocalModelPath(modelDirectory.getAbsolutePath(), owner, name);

        if (autoDownload) {
            downloadModel(owner, name, modelDirectory, branch, authToken, downloadWeights);
        } else if (!modelPath.toFile().exists()) {
            System.err.println("Model not found: " + modelPath);
            System.err.println("Use --auto-download to download the model");
            System.exit(1);
        }

        return modelPath;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy