All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.quarkiverse.langchain4j.llama3.copy.Llama3 Maven / Gradle / Ivy

///usr/bin/env jbang "$0" "$@" ; exit $?
//JAVA 21+
//PREVIEW
//COMPILE_OPTIONS --add-modules=jdk.incubator.vector
//RUNTIME_OPTIONS --add-modules=jdk.incubator.vector
//MAIN com.llama4j.Llama3

// Practical Llama 3 (and 3.1) inference in a single Java file
// Author: Alfonso² Peterssen
// Based on Andrej Karpathy's llama2.c and minbpe projects
//
// Supports llama.cpp's GGUF format, restricted to Q4_0 and Q8_0 quantized models
// Multi-threaded matrix vector multiplication routines implemented using Java's Vector API
// Simple CLI with --chat and --instruct mode
//
// To run just:
// jbang Llama3.java --help
//
// Enjoy!
package io.quarkiverse.langchain4j.llama3.copy;

import java.io.IOException;
import java.io.PrintStream;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.reflect.Field;
import java.nio.ByteOrder;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.function.IntConsumer;
import java.util.function.LongConsumer;
import java.util.random.RandomGenerator;
import java.util.random.RandomGeneratorFactory;
import java.util.stream.IntStream;
import java.util.stream.LongStream;

import jdk.incubator.vector.*;
import sun.misc.Unsafe;

public class Llama3 {
    // Batch-size used in prompt evaluation.
    public static final int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16);

    public static Sampler selectSampler(int vocabularySize, float temperature, float topp, long rngSeed) {
        Sampler sampler;
        if (temperature == 0.0f) {
            // greedy argmax sampling: take the token with the highest probability
            sampler = Sampler.ARGMAX;
        } else {
            // we sample from this distribution to get the next token
            RandomGenerator rng = RandomGeneratorFactory.getDefault().create(rngSeed);
            Sampler innerSampler;
            if (topp <= 0 || topp >= 1) {
                // simply sample from the predicted probability distribution
                innerSampler = new CategoricalSampler(rng);
            } else {
                // top-p (nucleus) sampling, clamping the least likely tokens to zero
                innerSampler = new ToppSampler(vocabularySize, topp, rng);
            }
            sampler = logits -> {
                // apply the temperature to the logits
                logits.divideInPlace(0, logits.size(), temperature);
                // apply softmax to the logits to get the probabilities for next token
                logits.softmaxInPlace(0, logits.size());
                return innerSampler.sampleToken(logits);
            };
        }
        return sampler;
    }

    static void runInteractive(Llama model, Sampler sampler, Options options) {
        Llama.State state = null;
        List conversationTokens = new ArrayList<>();
        ChatFormat chatFormat = new ChatFormat(model.tokenizer());
        conversationTokens.add(chatFormat.beginOfText);
        if (options.systemPrompt() != null) {
            conversationTokens
                    .addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
        }
        int startPosition = 0;
        Scanner in = new Scanner(System.in);
        loop: while (true) {
            System.out.print("> ");
            System.out.flush();
            String userText = in.nextLine();
            switch (userText) {
                case "/quit":
                case "/exit":
                    break loop;
                case "/context": {
                    System.out.printf("%d out of %d context tokens used (%d tokens remaining)%n",
                            conversationTokens.size(),
                            options.maxTokens(),
                            options.maxTokens() - conversationTokens.size());
                    continue;
                }
            }
            if (state == null) {
                state = model.createNewState(BATCH_SIZE);
            }
            conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
            conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
            Set stopTokens = chatFormat.getStopTokens();
            List responseTokens = Llama.generateTokens(model, state, startPosition,
                    conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(),
                    sampler, options.echo(), token -> {
                        if (options.stream()) {
                            if (!model.tokenizer().isSpecialToken(token)) {
                                System.out.print(model.tokenizer().decode(List.of(token)));
                            }
                        }
                    });
            // Include stop token in the prompt history, but not in the response displayed to the user.
            conversationTokens.addAll(responseTokens);
            startPosition = conversationTokens.size();
            Integer stopToken = null;
            if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
                stopToken = responseTokens.getLast();
                responseTokens.removeLast();
            }
            if (!options.stream()) {
                String responseText = model.tokenizer().decode(responseTokens);
                System.out.println(responseText);
            }
            if (stopToken == null) {
                System.err.println("Ran out of context length...");
                break;
            }
        }
    }

    static void runInstructOnce(Llama model, Sampler sampler, Options options) {
        Llama.State state = model.createNewState(BATCH_SIZE);
        ChatFormat chatFormat = new ChatFormat(model.tokenizer());

        List promptTokens = new ArrayList<>();
        promptTokens.add(chatFormat.beginOfText);
        if (options.systemPrompt() != null) {
            promptTokens
                    .addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
        }
        promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt())));
        promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));

        Set stopTokens = chatFormat.getStopTokens();
        List responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(),
                sampler, options.echo(), token -> {
                    if (options.stream()) {
                        if (!model.tokenizer().isSpecialToken(token)) {
                            System.out.print(model.tokenizer().decode(List.of(token)));
                        }
                    }
                });
        if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
            responseTokens.removeLast();
        }
        if (!options.stream()) {
            String responseText = model.tokenizer().decode(responseTokens);
            System.out.println(responseText);
        }
    }

    public record Options(Path modelPath, String prompt, String systemPrompt, boolean interactive,
            float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) {

        static final int DEFAULT_MAX_TOKENS = 512;

        public Options {
            require(modelPath != null, "Missing argument: --model  is required");
            require(interactive || prompt != null,
                    "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
            require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
            require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
        }

        static void require(boolean condition, String messageFormat, Object... args) {
            if (!condition) {
                System.out.println("ERROR " + messageFormat.formatted(args));
                System.out.println();
                printUsage(System.out);
                System.exit(-1);
            }
        }

        static void printUsage(PrintStream out) {
            out.println("Usage:  jbang Llama3.java [options]");
            out.println();
            out.println("Options:");
            out.println("  --model, -m             required, path to .gguf file");
            out.println("  --interactive, --chat, -i     run in chat mode");
            out.println("  --instruct                    run in instruct (once) mode, default mode");
            out.println("  --prompt, -p          input prompt");
            out.println("  --system-prompt, -sp  (optional) system prompt");
            out.println("  --temperature, -temp   temperature in [0,inf], default 0.1");
            out.println("  --top-p                p value in top-p (nucleus) sampling in [0,1] default 0.95");
            out.println("  --seed                  random seed, default System.nanoTime()");
            out.println("  --max-tokens, -n         number of steps to run for < 0 = limited by context length, default "
                    + DEFAULT_MAX_TOKENS);
            out.println(
                    "  --stream             print tokens during generation; may cause encoding artifacts for non ASCII text, default true");
            out.println(
                    "  --echo               print ALL tokens to stderr, if true, recommended to set --stream=false, default false");
            out.println();
            out.println("Examples:");
            out.println("  jbang Llama3.java --model llama3.2-1b-q4_0.gguf --prompt \"Tell me a joke\"");
            out.println(
                    "  jbang Llama3.java --model llama3.2-1b-q4_0.gguf --system-prompt \"Reply concisely, in French\" --prompt \"Who was Marie Curie?\"");
            out.println("  jbang Llama3.java --model llama3.2-1b-q4_0.gguf --system-prompt \"Answer concisely\" --chat");
            out.println("  jbang Llama3.java --model llama3.2-1b-q4_0.gguf --chat");
            out.println("  jbang Llama3.java --model llama3.2-1b-q4_0.gguf --prompt \"Print 5 emojis\" --stream=false");
        }

        static Options parseOptions(String[] args) {
            String prompt = null;
            String systemPrompt = null;
            float temperature = 0.1f;
            float topp = 0.95f;
            Path modelPath = null;
            long seed = System.nanoTime();
            // Keep max context length small for low-memory devices.
            int maxTokens = DEFAULT_MAX_TOKENS;
            boolean interactive = false;
            boolean stream = true;
            boolean echo = false;

            for (int i = 0; i < args.length; i++) {
                String optionName = args[i];
                require(optionName.startsWith("-"), "Invalid option %s", optionName);
                switch (optionName) {
                    case "--interactive", "--chat", "-i" -> interactive = true;
                    case "--instruct" -> interactive = false;
                    case "--help", "-h" -> {
                        printUsage(System.out);
                        System.exit(0);
                    }
                    default -> {
                        String nextArg;
                        if (optionName.contains("=")) {
                            String[] parts = optionName.split("=", 2);
                            optionName = parts[0];
                            nextArg = parts[1];
                        } else {
                            require(i + 1 < args.length, "Missing argument for option %s", optionName);
                            nextArg = args[i + 1];
                            i += 1; // skip arg
                        }
                        switch (optionName) {
                            case "--prompt", "-p" -> prompt = nextArg;
                            case "--system-prompt", "-sp" -> systemPrompt = nextArg;
                            case "--temperature", "--temp" -> temperature = Float.parseFloat(nextArg);
                            case "--top-p" -> topp = Float.parseFloat(nextArg);
                            case "--model", "-m" -> modelPath = Paths.get(nextArg);
                            case "--seed", "-s" -> seed = Long.parseLong(nextArg);
                            case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg);
                            case "--stream" -> stream = Boolean.parseBoolean(nextArg);
                            case "--echo" -> echo = Boolean.parseBoolean(nextArg);
                            default -> require(false, "Unknown option: %s", optionName);
                        }
                    }
                }
            }
            return new Options(modelPath, prompt, systemPrompt, interactive, temperature, topp, seed, maxTokens, stream, echo);
        }
    }

    public static void main(String[] args) throws IOException {
        Options options = Options.parseOptions(args);
        Llama model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
        if (model == null) {
            // No compatible preloaded model found, fallback to fully parse and load the specified file.
            model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true);
        }
        Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(),
                options.seed());
        if (options.interactive()) {
            runInteractive(model, sampler, options);
        } else {
            runInstructOnce(model, sampler, options);
        }
    }
}

final class Parallel {
    public static void parallelFor(int startInclusive, int endExclusive, IntConsumer action) {
        if (startInclusive == 0 && endExclusive == 1) {
            action.accept(0);
            return;
        }
        IntStream.range(startInclusive, endExclusive).parallel().forEach(action);
    }

    public static void parallelForLong(long startInclusive, long endExclusive, LongConsumer action) {
        if (startInclusive == 0 && endExclusive == 1) {
            action.accept(0);
            return;
        }
        LongStream.range(startInclusive, endExclusive).parallel().forEach(action);
    }
}

final class Float16 {
    public static final int BYTES = 2;
}

/**
 * Over-simplified, shapeless, float tensor.
 * 

* Not a strict tensor, but rather just a sequence of floats, not required to be backed by memory * e.g. can represent a sequence of quantized floats. */ abstract class FloatTensor { static final int VECTOR_BIT_SIZE = Integer.getInteger("llama.VectorBitSize", VectorShape.preferredShape().vectorBitSize()); static final boolean USE_VECTOR_API = VECTOR_BIT_SIZE != 0; // static final ValueLayout.OfFloat JAVA_FLOAT_LE = ValueLayout.JAVA_FLOAT.withOrder(ByteOrder.LITTLE_ENDIAN); // static final ValueLayout.OfShort JAVA_SHORT_LE = ValueLayout.JAVA_SHORT.withOrder(ByteOrder.LITTLE_ENDIAN); // The use of Unsafe in this file is a temporary workaround to support native-image. static final Unsafe UNSAFE; static { try { Field f = Unsafe.class.getDeclaredField("theUnsafe"); f.setAccessible(true); UNSAFE = (Unsafe) f.get(null); } catch (NoSuchFieldException | IllegalAccessException e) { throw new RuntimeException(e); } } static short readShort(MemorySegment memorySegment, long offset) { // The MemorySegment.get* methods should be used instead. return UNSAFE.getShort(memorySegment.address() + offset); } static byte readByte(MemorySegment memorySegment, long offset) { // The MemorySegment.get* methods should be used instead. return UNSAFE.getByte(memorySegment.address() + offset); } // Preferred vector size for the fast multiplication routines. // (Apple Silicon) NEON only supports up-to 128bit vectors. static final VectorSpecies F_SPECIES = USE_VECTOR_API ? VectorShape.forBitSize(VECTOR_BIT_SIZE).withLanes(float.class) : null; abstract int size(); abstract float getFloat(int index); abstract void setFloat(int index, float value); abstract FloatVector getFloatVector(VectorSpecies species, int offset); abstract GGMLType type(); public static int numberOfElements(int... dimensions) { assert Arrays.stream(dimensions).allMatch(i -> i > 0); return Arrays.stream(dimensions).reduce(Math::multiplyExact).orElseThrow(); } static float scalarDot(FloatTensor thiz, int thisOffset, FloatTensor that, int thatOffset, int size) { float result = 0f; for (int j = 0; j < size; j++) { result += thiz.getFloat(thisOffset + j) * that.getFloat(thatOffset + j); } return result; } float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { return scalarDot(this, thisOffset, that, thatOffset, size); } void matmul(FloatTensor that, FloatTensor out, int dim0, int dim1) { Parallel.parallelFor(0, dim0, i -> out.setFloat(i, dot(i * dim1, that, 0, dim1))); } void matmul(int context, FloatTensor[] that, FloatTensor[] out, int dim0, int dim1) { if (that.length != out.length) { throw new IllegalArgumentException(String.format("that.len=%d, out.len=%d", that.length, out.length)); } Parallel.parallelForLong(0, dim0 * context, ti -> { int idxArr = (int) (ti / dim0); int i = (int) (ti % dim0); out[idxArr].setFloat(i, dot(i * dim1, that[idxArr], 0, dim1)); }); } @FunctionalInterface interface AggregateFunction { float apply(float acc, float value); } float reduce(int thisOffset, int size, float seed, AggregateFunction reduce) { float result = seed; for (int i = 0; i < size; ++i) { result = reduce.apply(result, getFloat(thisOffset + i)); } return result; } float sum(int thisOffset, int size) { return reduce(thisOffset, size, 0f, Float::sum); } float max(int thisOffset, int size) { return reduce(thisOffset, size, Float.NEGATIVE_INFINITY, Float::max); } void copyTo(int thisOffset, FloatTensor that, int thatOffset, int size) { that.mapWithIndexInPlace(thatOffset, size, (value, index) -> this.getFloat(index - thatOffset + thisOffset)); } int argmax(int thisOffset, int size) { assert size > 0; int maxIndex = thisOffset; float maxValue = this.getFloat(maxIndex); int endIndex = thisOffset + size; for (int i = thisOffset; i < endIndex; ++i) { float f = this.getFloat(i); if (f > maxValue) { maxValue = f; maxIndex = i; } } return maxIndex; } int argmax() { return argmax(0, size()); } @FunctionalInterface interface MapFunction { float apply(float value); } @FunctionalInterface interface MapWithIndexFunction { float apply(float value, int index); } FloatTensor mapInPlace(int thisOffset, int size, MapFunction mapFunction) { int endIndex = thisOffset + size; for (int i = thisOffset; i < endIndex; ++i) { setFloat(i, mapFunction.apply(getFloat(i))); } return this; } FloatTensor mapInPlace(MapFunction mapFunction) { return mapInPlace(0, size(), mapFunction); } FloatTensor mapWithIndexInPlace(int thisOffset, int size, FloatTensor.MapWithIndexFunction mapWithIndexFunction) { int endOffset = thisOffset + size; for (int i = thisOffset; i < endOffset; ++i) { setFloat(i, mapWithIndexFunction.apply(getFloat(i), i)); } return this; } FloatTensor addInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) { return mapWithIndexInPlace(thisOffset, size, (value, index) -> value + that.getFloat(index - thisOffset + thatOffset)); } FloatTensor addInPlace(FloatTensor that) { return addInPlace(0, that, 0, size()); } FloatTensor multiplyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) { return mapWithIndexInPlace(thisOffset, size, (value, index) -> value * that.getFloat(index - thisOffset + thatOffset)); } FloatTensor multiplyInPlace(FloatTensor that) { return multiplyInPlace(0, that, 0, size()); } FloatTensor divideInPlace(int thisOffset, int size, float value) { return mapInPlace(thisOffset, size, f -> f / value); } FloatTensor fillInPlace(int thisOffset, int size, float value) { return mapInPlace(thisOffset, size, unused -> value); } FloatTensor softmaxInPlace(int thisOffset, int size) { // find max value (for numerical stability) float maxVal = max(thisOffset, size); // exp and sum mapInPlace(thisOffset, size, f -> (float) Math.exp(f - maxVal)); float sum = sum(thisOffset, size); // normalize return divideInPlace(thisOffset, size, sum); } FloatTensor saxpyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size, float a) { // this[thatOffset ... thatOffset + size) = a * that[thatOffset ... thatOffset + size) + this[thisOffset ... thisOffset + size) for (int i = 0; i < size; ++i) { setFloat(thisOffset + i, a * that.getFloat(thatOffset + i) + this.getFloat(thisOffset + i)); } return this; } } /** * {@link FloatTensor} quantized in the {@link GGMLType#Q4_0} format. *

* This tensor implementation is not compatible with {@link FloatTensor}, but * {@link #dot(int, FloatTensor, int, int)} has a vectorized implementation that is used when * the second argument implements {@link FloatTensor}. */ final class Q4_0FloatTensor extends FloatTensor { final int size; final MemorySegment memorySegment; public Q4_0FloatTensor(int size, MemorySegment memorySegment) { this.size = size; this.memorySegment = memorySegment; } @Override int size() { return size; } @Override public void setFloat(int index, float value) { throw new UnsupportedOperationException("setFloat"); } @Override FloatVector getFloatVector(VectorSpecies species, int index) { throw new UnsupportedOperationException("getFloatVector"); } @Override public GGMLType type() { return GGMLType.Q4_0; } @Override public float getFloat(int index) { assert 0 <= index && index < size; int blockIndex = index / GGMLType.Q4_0.getBlockSize(); int blockOffset = blockIndex * GGMLType.Q4_0.getTypeSize(); float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset)); byte quant; int modIndex = index % GGMLType.Q4_0.getBlockSize(); if (modIndex < GGMLType.Q4_0.getBlockSize() / 2) { quant = (byte) (readByte(memorySegment, blockOffset + Float16.BYTES + modIndex) & 0x0F); } else { quant = (byte) ((readByte(memorySegment, blockOffset + Float16.BYTES + modIndex - GGMLType.Q4_0.getBlockSize() / 2) >>> 4) & 0x0F); } quant -= 8; return quant * scale; } @Override public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { if (FloatTensor.USE_VECTOR_API) { return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size); } else { return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size); } } private static float vectorDot(Q4_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) { float result = 0f; int j = 0; // Align thisOffset + j to type().getBlockSize(). assert Integer.bitCount(GGMLType.Q4_0.getBlockSize()) == 1 : "power of 2"; int alignmentBound = Math.min(size, -thisOffset & (GGMLType.Q4_0.getBlockSize() - 1)); if (alignmentBound > 0) { result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound); j += alignmentBound; } assert (thisOffset + j) % GGMLType.Q4_0.getBlockSize() == 0; FloatVector val = FloatVector.zero(F_SPECIES); int blockOffset = (thisOffset + j) / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getTypeSize(); int upperBound = size / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getBlockSize(); for (; j < upperBound; j += GGMLType.Q4_0.getBlockSize(), blockOffset += GGMLType.Q4_0.getTypeSize()) { float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset)); var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, thiz.memorySegment, blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN); var loBytes = wBytes.and((byte) 0xF).sub((byte) 8); var hiBytes = wBytes.lanewise(VectorOperators.LSHR, 4).sub((byte) 8); switch (F_SPECIES.vectorBitSize()) { case 512 -> { var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()) .mul(loBytes.castShape(F_SPECIES, 0)); var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()) .mul(hiBytes.castShape(F_SPECIES, 0)); val = sum0.add(sum2).fma(wScale, val); } case 256 -> { var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()) .mul(loBytes.castShape(F_SPECIES, 0)); var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()) .mul(loBytes.castShape(F_SPECIES, 1)); var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()) .mul(hiBytes.castShape(F_SPECIES, 0)); var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()) .mul(hiBytes.castShape(F_SPECIES, 1)); val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); } case 128 -> { // This loop cannot be unrolled, why? for (int i = 0; i < 2; ++i) { var tmp = i == 0 ? loBytes : hiBytes; var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 0) * F_SPECIES.length()) .mul(tmp.castShape(F_SPECIES, 0)); var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 1) * F_SPECIES.length()) .mul(tmp.castShape(F_SPECIES, 1)); var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 2) * F_SPECIES.length()) .mul(tmp.castShape(F_SPECIES, 2)); var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 3) * F_SPECIES.length()) .mul(tmp.castShape(F_SPECIES, 3)); val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); } } default -> throw new UnsupportedOperationException(F_SPECIES.toString()); } } result += val.reduceLanes(VectorOperators.ADD); // Remaining entries. if (j < size) { result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j); } return result; } } final class Q8_0FloatTensor extends FloatTensor { final int size; final MemorySegment memorySegment; public Q8_0FloatTensor(int size, MemorySegment memorySegment) { this.size = size; this.memorySegment = memorySegment; } @Override int size() { return size; } @Override public void setFloat(int index, float value) { throw new UnsupportedOperationException("setFloat"); } @Override FloatVector getFloatVector(VectorSpecies species, int index) { throw new UnsupportedOperationException("getFloatVector"); } @Override public GGMLType type() { return GGMLType.Q8_0; } @Override public float getFloat(int index) { assert 0 <= index && index < size; int blockIndex = index / GGMLType.Q8_0.getBlockSize(); int withinBlockIndex = index % GGMLType.Q8_0.getBlockSize(); int blockOffset = blockIndex * GGMLType.Q8_0.getTypeSize(); byte quant = readByte(memorySegment, blockOffset + Float16.BYTES + withinBlockIndex); float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset)); return quant * scale; } public static final ValueLayout.OfShort JAVA_SHORT_LE = ValueLayout.JAVA_SHORT.withOrder(ByteOrder.LITTLE_ENDIAN); @Override public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { if (FloatTensor.USE_VECTOR_API) { return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size); } else { return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size); } } private static float vectorDot(Q8_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) { float result = 0f; int j = 0; // Align thisOffset + startIndex to type().getBlockSize(). assert Integer.bitCount(GGMLType.Q8_0.getBlockSize()) == 1 : "power of 2"; int alignmentBound = Math.min(size, -thisOffset & (GGMLType.Q8_0.getBlockSize() - 1)); if (alignmentBound > 0) { result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound); j += alignmentBound; } assert (thisOffset + j) % GGMLType.Q8_0.getBlockSize() == 0; FloatVector val = FloatVector.zero(F_SPECIES); int blockOffset = (thisOffset + j) / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getTypeSize(); int upperBound = size / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getBlockSize(); for (; j < upperBound; j += GGMLType.Q8_0.getBlockSize(), blockOffset += GGMLType.Q8_0.getTypeSize()) { float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset)); var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); switch (F_SPECIES.vectorBitSize()) { case 512 -> { var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment, blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN); var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 0)); var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 1)); val = sum0.add(sum1).fma(wScale, val); } case 256 -> { var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment, blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN); var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 0)); var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 1)); var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 2)); var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 3)); val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); } case 128 -> { // This loop cannot be unrolled, why? for (int i = 0; i < 2; ++i) { var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, thiz.memorySegment, blockOffset + Float16.BYTES + i * ByteVector.SPECIES_128.vectorByteSize(), ByteOrder.LITTLE_ENDIAN); var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 0 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 0)); var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 1 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 1)); var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 2 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 2)); var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 3 * F_SPECIES.length()) .mul(wBytes.castShape(F_SPECIES, 3)); val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); } } default -> throw new UnsupportedOperationException(F_SPECIES.toString()); } } result += val.reduceLanes(VectorOperators.ADD); // Remaining entries. if (j < size) { result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j); } return result; } } final class ArrayFloatTensor extends FloatTensor { final float[] values; ArrayFloatTensor(float[] values) { this.values = values; } public static FloatTensor allocate(int... dims) { int numberOfElements = FloatTensor.numberOfElements(dims); return new ArrayFloatTensor(new float[numberOfElements]); } @Override public int size() { return values.length; } @Override public float getFloat(int index) { return values[index]; } @Override public void setFloat(int index, float value) { values[index] = value; } @Override public GGMLType type() { return GGMLType.F32; } @Override public FloatTensor fillInPlace(int thisOffset, int size, float value) { Arrays.fill(values, thisOffset, thisOffset + size, value); return this; } @Override public FloatVector getFloatVector(VectorSpecies species, int index) { if (!USE_VECTOR_API) { throw new UnsupportedOperationException(); } return FloatVector.fromArray(species, values, index); } } final class RoPE { public static Pair precomputeFreqsCis(int contextLength, int headSize, double theta, boolean ropeScaling, float scaleFactor, float loFreqFactor, float hiFreqFactor, float oldContextLength) { assert headSize % 2 == 0; float[] cr = new float[contextLength * (headSize / 2)]; float[] ci = new float[contextLength * (headSize / 2)]; int n = 0; for (int pos = 0; pos < contextLength; ++pos) { for (int i = 0; i < headSize; i += 2) { float freq = (float) (1.0 / Math.pow(theta, i / (double) headSize)); if (ropeScaling) { // Llama 3.1 scaling float loFreqWavelen = oldContextLength / loFreqFactor; float hiFreqWavelen = oldContextLength / hiFreqFactor; float wavelen = (float) (2.0 * Math.PI / freq); if (wavelen < hiFreqWavelen) { freq = freq; } else if (wavelen > loFreqWavelen) { freq = freq / scaleFactor; } else { float smooth = (oldContextLength / wavelen - loFreqFactor) / (hiFreqFactor - loFreqFactor); freq = (1.0f - smooth) * freq / scaleFactor + smooth * freq; } } float val = pos * freq; cr[n] = (float) Math.cos(val); ci[n] = (float) Math.sin(val); n++; } } assert contextLength * (headSize / 2) == n; return new Pair<>(cr, ci); } } record CategoricalSampler(RandomGenerator rng) implements Sampler { @Override public int sampleToken(FloatTensor logits) { // sample index from probabilities (they must sum to 1!) float random0to1 = rng.nextFloat(1f); float cdf = 0.0f; for (int i = 0; i < logits.size(); i++) { cdf += logits.getFloat(i); if (random0to1 < cdf) { return i; } } return logits.size() - 1; // in case of rounding errors } } final class ToppSampler implements Sampler { final int[] indices; final float topp; final RandomGenerator rng; public ToppSampler(int maxNumberOfElements, float topp, RandomGenerator rng) { this.indices = new int[maxNumberOfElements]; this.topp = topp; this.rng = rng; } static void swap(int[] array, int from, int to) { int tmp = array[from]; array[from] = array[to]; array[to] = tmp; } static void siftDown(int[] array, int from, int n, Comparator comparator) { int prev = from, next; while ((next = 2 * prev + 1) < n) { int r = 2 * prev + 2; if (r < n && comparator.compare(array[r], array[next]) < 0) { next = r; } if (comparator.compare(array[next], array[prev]) < 0) { swap(array, prev, next); prev = next; } else { break; } } } @Override public int sampleToken(FloatTensor logits) { // top-p sampling (or "nucleus sampling") samples from the smallest set of // tokens that exceed probability topp. This way we never sample tokens that // have very low probabilities and are less likely to go "off the rails". Comparator comparator = Comparator.comparingDouble(logits::getFloat).reversed(); int n = logits.size(); int head = 0; int tail = n - 1; // values smaller than (1 - topp) / (n - 1) cannot be part of the result // so for efficiency we crop these out as candidates before sorting float cutoff = (1.0f - topp) / (n - 1); for (int i = 0; i < indices.length; i++) { if (logits.getFloat(i) >= cutoff) { indices[head++] = i; } else { indices[tail--] = i; } } int n0 = head; // build heap O(n0) for (int i = n0 / 2 - 1; i >= 0; --i) { siftDown(indices, i, n0, comparator); } // truncate the list where cumulative probability of the largest k elements exceeds topp // O(k lg n0) float cumulativeProb = 0.0f; int lastIndex = 0; for (int i = n0 - 1; i >= 0; i--) { swap(indices, 0, i); cumulativeProb += logits.getFloat(indices[i]); if (cumulativeProb > topp) { lastIndex = i; break; // we've exceeded topp by including lastIndex } siftDown(indices, 0, i - 1, comparator); } // sample from the truncated list float r = rng.nextFloat(1f) * cumulativeProb; float cdf = 0.0f; for (int i = n0 - 1; i >= lastIndex; i--) { cdf += logits.getFloat(indices[i]); if (r < cdf) { return indices[i]; } } return indices[lastIndex]; // in case of rounding errors } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy