io.quarkiverse.langchain4j.llama3.copy.Llama3 Maven / Gradle / Ivy
///usr/bin/env jbang "$0" "$@" ; exit $?
//JAVA 21+
//PREVIEW
//COMPILE_OPTIONS --add-modules=jdk.incubator.vector
//RUNTIME_OPTIONS --add-modules=jdk.incubator.vector
//MAIN com.llama4j.Llama3
// Practical Llama 3 (and 3.1) inference in a single Java file
// Author: Alfonso² Peterssen
// Based on Andrej Karpathy's llama2.c and minbpe projects
//
// Supports llama.cpp's GGUF format, restricted to Q4_0 and Q8_0 quantized models
// Multi-threaded matrix vector multiplication routines implemented using Java's Vector API
// Simple CLI with --chat and --instruct mode
//
// To run just:
// jbang Llama3.java --help
//
// Enjoy!
package io.quarkiverse.langchain4j.llama3.copy;
import java.io.IOException;
import java.io.PrintStream;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.reflect.Field;
import java.nio.ByteOrder;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.function.IntConsumer;
import java.util.function.LongConsumer;
import java.util.random.RandomGenerator;
import java.util.random.RandomGeneratorFactory;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import jdk.incubator.vector.*;
import sun.misc.Unsafe;
public class Llama3 {
// Batch-size used in prompt evaluation.
public static final int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16);
public static Sampler selectSampler(int vocabularySize, float temperature, float topp, long rngSeed) {
Sampler sampler;
if (temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
sampler = Sampler.ARGMAX;
} else {
// we sample from this distribution to get the next token
RandomGenerator rng = RandomGeneratorFactory.getDefault().create(rngSeed);
Sampler innerSampler;
if (topp <= 0 || topp >= 1) {
// simply sample from the predicted probability distribution
innerSampler = new CategoricalSampler(rng);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
innerSampler = new ToppSampler(vocabularySize, topp, rng);
}
sampler = logits -> {
// apply the temperature to the logits
logits.divideInPlace(0, logits.size(), temperature);
// apply softmax to the logits to get the probabilities for next token
logits.softmaxInPlace(0, logits.size());
return innerSampler.sampleToken(logits);
};
}
return sampler;
}
static void runInteractive(Llama model, Sampler sampler, Options options) {
Llama.State state = null;
List conversationTokens = new ArrayList<>();
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
conversationTokens.add(chatFormat.beginOfText);
if (options.systemPrompt() != null) {
conversationTokens
.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
}
int startPosition = 0;
Scanner in = new Scanner(System.in);
loop: while (true) {
System.out.print("> ");
System.out.flush();
String userText = in.nextLine();
switch (userText) {
case "/quit":
case "/exit":
break loop;
case "/context": {
System.out.printf("%d out of %d context tokens used (%d tokens remaining)%n",
conversationTokens.size(),
options.maxTokens(),
options.maxTokens() - conversationTokens.size());
continue;
}
}
if (state == null) {
state = model.createNewState(BATCH_SIZE);
}
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
Set stopTokens = chatFormat.getStopTokens();
List responseTokens = Llama.generateTokens(model, state, startPosition,
conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(),
sampler, options.echo(), token -> {
if (options.stream()) {
if (!model.tokenizer().isSpecialToken(token)) {
System.out.print(model.tokenizer().decode(List.of(token)));
}
}
});
// Include stop token in the prompt history, but not in the response displayed to the user.
conversationTokens.addAll(responseTokens);
startPosition = conversationTokens.size();
Integer stopToken = null;
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
stopToken = responseTokens.getLast();
responseTokens.removeLast();
}
if (!options.stream()) {
String responseText = model.tokenizer().decode(responseTokens);
System.out.println(responseText);
}
if (stopToken == null) {
System.err.println("Ran out of context length...");
break;
}
}
}
static void runInstructOnce(Llama model, Sampler sampler, Options options) {
Llama.State state = model.createNewState(BATCH_SIZE);
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
List promptTokens = new ArrayList<>();
promptTokens.add(chatFormat.beginOfText);
if (options.systemPrompt() != null) {
promptTokens
.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
}
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt())));
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
Set stopTokens = chatFormat.getStopTokens();
List responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(),
sampler, options.echo(), token -> {
if (options.stream()) {
if (!model.tokenizer().isSpecialToken(token)) {
System.out.print(model.tokenizer().decode(List.of(token)));
}
}
});
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
responseTokens.removeLast();
}
if (!options.stream()) {
String responseText = model.tokenizer().decode(responseTokens);
System.out.println(responseText);
}
}
public record Options(Path modelPath, String prompt, String systemPrompt, boolean interactive,
float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) {
static final int DEFAULT_MAX_TOKENS = 512;
public Options {
require(modelPath != null, "Missing argument: --model is required");
require(interactive || prompt != null,
"Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
}
static void require(boolean condition, String messageFormat, Object... args) {
if (!condition) {
System.out.println("ERROR " + messageFormat.formatted(args));
System.out.println();
printUsage(System.out);
System.exit(-1);
}
}
static void printUsage(PrintStream out) {
out.println("Usage: jbang Llama3.java [options]");
out.println();
out.println("Options:");
out.println(" --model, -m required, path to .gguf file");
out.println(" --interactive, --chat, -i run in chat mode");
out.println(" --instruct run in instruct (once) mode, default mode");
out.println(" --prompt, -p input prompt");
out.println(" --system-prompt, -sp (optional) system prompt");
out.println(" --temperature, -temp temperature in [0,inf], default 0.1");
out.println(" --top-p p value in top-p (nucleus) sampling in [0,1] default 0.95");
out.println(" --seed random seed, default System.nanoTime()");
out.println(" --max-tokens, -n number of steps to run for < 0 = limited by context length, default "
+ DEFAULT_MAX_TOKENS);
out.println(
" --stream print tokens during generation; may cause encoding artifacts for non ASCII text, default true");
out.println(
" --echo print ALL tokens to stderr, if true, recommended to set --stream=false, default false");
out.println();
out.println("Examples:");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --prompt \"Tell me a joke\"");
out.println(
" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --system-prompt \"Reply concisely, in French\" --prompt \"Who was Marie Curie?\"");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --system-prompt \"Answer concisely\" --chat");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --chat");
out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --prompt \"Print 5 emojis\" --stream=false");
}
static Options parseOptions(String[] args) {
String prompt = null;
String systemPrompt = null;
float temperature = 0.1f;
float topp = 0.95f;
Path modelPath = null;
long seed = System.nanoTime();
// Keep max context length small for low-memory devices.
int maxTokens = DEFAULT_MAX_TOKENS;
boolean interactive = false;
boolean stream = true;
boolean echo = false;
for (int i = 0; i < args.length; i++) {
String optionName = args[i];
require(optionName.startsWith("-"), "Invalid option %s", optionName);
switch (optionName) {
case "--interactive", "--chat", "-i" -> interactive = true;
case "--instruct" -> interactive = false;
case "--help", "-h" -> {
printUsage(System.out);
System.exit(0);
}
default -> {
String nextArg;
if (optionName.contains("=")) {
String[] parts = optionName.split("=", 2);
optionName = parts[0];
nextArg = parts[1];
} else {
require(i + 1 < args.length, "Missing argument for option %s", optionName);
nextArg = args[i + 1];
i += 1; // skip arg
}
switch (optionName) {
case "--prompt", "-p" -> prompt = nextArg;
case "--system-prompt", "-sp" -> systemPrompt = nextArg;
case "--temperature", "--temp" -> temperature = Float.parseFloat(nextArg);
case "--top-p" -> topp = Float.parseFloat(nextArg);
case "--model", "-m" -> modelPath = Paths.get(nextArg);
case "--seed", "-s" -> seed = Long.parseLong(nextArg);
case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg);
case "--stream" -> stream = Boolean.parseBoolean(nextArg);
case "--echo" -> echo = Boolean.parseBoolean(nextArg);
default -> require(false, "Unknown option: %s", optionName);
}
}
}
}
return new Options(modelPath, prompt, systemPrompt, interactive, temperature, topp, seed, maxTokens, stream, echo);
}
}
public static void main(String[] args) throws IOException {
Options options = Options.parseOptions(args);
Llama model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
if (model == null) {
// No compatible preloaded model found, fallback to fully parse and load the specified file.
model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true);
}
Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(),
options.seed());
if (options.interactive()) {
runInteractive(model, sampler, options);
} else {
runInstructOnce(model, sampler, options);
}
}
}
final class Parallel {
public static void parallelFor(int startInclusive, int endExclusive, IntConsumer action) {
if (startInclusive == 0 && endExclusive == 1) {
action.accept(0);
return;
}
IntStream.range(startInclusive, endExclusive).parallel().forEach(action);
}
public static void parallelForLong(long startInclusive, long endExclusive, LongConsumer action) {
if (startInclusive == 0 && endExclusive == 1) {
action.accept(0);
return;
}
LongStream.range(startInclusive, endExclusive).parallel().forEach(action);
}
}
final class Float16 {
public static final int BYTES = 2;
}
/**
* Over-simplified, shapeless, float tensor.
*
* Not a strict tensor, but rather just a sequence of floats, not required to be backed by memory
* e.g. can represent a sequence of quantized floats.
*/
abstract class FloatTensor {
static final int VECTOR_BIT_SIZE = Integer.getInteger("llama.VectorBitSize", VectorShape.preferredShape().vectorBitSize());
static final boolean USE_VECTOR_API = VECTOR_BIT_SIZE != 0;
// static final ValueLayout.OfFloat JAVA_FLOAT_LE = ValueLayout.JAVA_FLOAT.withOrder(ByteOrder.LITTLE_ENDIAN);
// static final ValueLayout.OfShort JAVA_SHORT_LE = ValueLayout.JAVA_SHORT.withOrder(ByteOrder.LITTLE_ENDIAN);
// The use of Unsafe in this file is a temporary workaround to support native-image.
static final Unsafe UNSAFE;
static {
try {
Field f = Unsafe.class.getDeclaredField("theUnsafe");
f.setAccessible(true);
UNSAFE = (Unsafe) f.get(null);
} catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
static short readShort(MemorySegment memorySegment, long offset) {
// The MemorySegment.get* methods should be used instead.
return UNSAFE.getShort(memorySegment.address() + offset);
}
static byte readByte(MemorySegment memorySegment, long offset) {
// The MemorySegment.get* methods should be used instead.
return UNSAFE.getByte(memorySegment.address() + offset);
}
// Preferred vector size for the fast multiplication routines.
// (Apple Silicon) NEON only supports up-to 128bit vectors.
static final VectorSpecies F_SPECIES = USE_VECTOR_API
? VectorShape.forBitSize(VECTOR_BIT_SIZE).withLanes(float.class)
: null;
abstract int size();
abstract float getFloat(int index);
abstract void setFloat(int index, float value);
abstract FloatVector getFloatVector(VectorSpecies species, int offset);
abstract GGMLType type();
public static int numberOfElements(int... dimensions) {
assert Arrays.stream(dimensions).allMatch(i -> i > 0);
return Arrays.stream(dimensions).reduce(Math::multiplyExact).orElseThrow();
}
static float scalarDot(FloatTensor thiz, int thisOffset, FloatTensor that, int thatOffset, int size) {
float result = 0f;
for (int j = 0; j < size; j++) {
result += thiz.getFloat(thisOffset + j) * that.getFloat(thatOffset + j);
}
return result;
}
float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
return scalarDot(this, thisOffset, that, thatOffset, size);
}
void matmul(FloatTensor that, FloatTensor out, int dim0, int dim1) {
Parallel.parallelFor(0, dim0, i -> out.setFloat(i, dot(i * dim1, that, 0, dim1)));
}
void matmul(int context, FloatTensor[] that, FloatTensor[] out, int dim0, int dim1) {
if (that.length != out.length) {
throw new IllegalArgumentException(String.format("that.len=%d, out.len=%d", that.length, out.length));
}
Parallel.parallelForLong(0, dim0 * context, ti -> {
int idxArr = (int) (ti / dim0);
int i = (int) (ti % dim0);
out[idxArr].setFloat(i, dot(i * dim1, that[idxArr], 0, dim1));
});
}
@FunctionalInterface
interface AggregateFunction {
float apply(float acc, float value);
}
float reduce(int thisOffset, int size, float seed, AggregateFunction reduce) {
float result = seed;
for (int i = 0; i < size; ++i) {
result = reduce.apply(result, getFloat(thisOffset + i));
}
return result;
}
float sum(int thisOffset, int size) {
return reduce(thisOffset, size, 0f, Float::sum);
}
float max(int thisOffset, int size) {
return reduce(thisOffset, size, Float.NEGATIVE_INFINITY, Float::max);
}
void copyTo(int thisOffset, FloatTensor that, int thatOffset, int size) {
that.mapWithIndexInPlace(thatOffset, size, (value, index) -> this.getFloat(index - thatOffset + thisOffset));
}
int argmax(int thisOffset, int size) {
assert size > 0;
int maxIndex = thisOffset;
float maxValue = this.getFloat(maxIndex);
int endIndex = thisOffset + size;
for (int i = thisOffset; i < endIndex; ++i) {
float f = this.getFloat(i);
if (f > maxValue) {
maxValue = f;
maxIndex = i;
}
}
return maxIndex;
}
int argmax() {
return argmax(0, size());
}
@FunctionalInterface
interface MapFunction {
float apply(float value);
}
@FunctionalInterface
interface MapWithIndexFunction {
float apply(float value, int index);
}
FloatTensor mapInPlace(int thisOffset, int size, MapFunction mapFunction) {
int endIndex = thisOffset + size;
for (int i = thisOffset; i < endIndex; ++i) {
setFloat(i, mapFunction.apply(getFloat(i)));
}
return this;
}
FloatTensor mapInPlace(MapFunction mapFunction) {
return mapInPlace(0, size(), mapFunction);
}
FloatTensor mapWithIndexInPlace(int thisOffset, int size, FloatTensor.MapWithIndexFunction mapWithIndexFunction) {
int endOffset = thisOffset + size;
for (int i = thisOffset; i < endOffset; ++i) {
setFloat(i, mapWithIndexFunction.apply(getFloat(i), i));
}
return this;
}
FloatTensor addInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) {
return mapWithIndexInPlace(thisOffset, size, (value, index) -> value + that.getFloat(index - thisOffset + thatOffset));
}
FloatTensor addInPlace(FloatTensor that) {
return addInPlace(0, that, 0, size());
}
FloatTensor multiplyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) {
return mapWithIndexInPlace(thisOffset, size, (value, index) -> value * that.getFloat(index - thisOffset + thatOffset));
}
FloatTensor multiplyInPlace(FloatTensor that) {
return multiplyInPlace(0, that, 0, size());
}
FloatTensor divideInPlace(int thisOffset, int size, float value) {
return mapInPlace(thisOffset, size, f -> f / value);
}
FloatTensor fillInPlace(int thisOffset, int size, float value) {
return mapInPlace(thisOffset, size, unused -> value);
}
FloatTensor softmaxInPlace(int thisOffset, int size) {
// find max value (for numerical stability)
float maxVal = max(thisOffset, size);
// exp and sum
mapInPlace(thisOffset, size, f -> (float) Math.exp(f - maxVal));
float sum = sum(thisOffset, size);
// normalize
return divideInPlace(thisOffset, size, sum);
}
FloatTensor saxpyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size, float a) {
// this[thatOffset ... thatOffset + size) = a * that[thatOffset ... thatOffset + size) + this[thisOffset ... thisOffset + size)
for (int i = 0; i < size; ++i) {
setFloat(thisOffset + i, a * that.getFloat(thatOffset + i) + this.getFloat(thisOffset + i));
}
return this;
}
}
/**
* {@link FloatTensor} quantized in the {@link GGMLType#Q4_0} format.
*
* This tensor implementation is not compatible with {@link FloatTensor}, but
* {@link #dot(int, FloatTensor, int, int)} has a vectorized implementation that is used when
* the second argument implements {@link FloatTensor}.
*/
final class Q4_0FloatTensor extends FloatTensor {
final int size;
final MemorySegment memorySegment;
public Q4_0FloatTensor(int size, MemorySegment memorySegment) {
this.size = size;
this.memorySegment = memorySegment;
}
@Override
int size() {
return size;
}
@Override
public void setFloat(int index, float value) {
throw new UnsupportedOperationException("setFloat");
}
@Override
FloatVector getFloatVector(VectorSpecies species, int index) {
throw new UnsupportedOperationException("getFloatVector");
}
@Override
public GGMLType type() {
return GGMLType.Q4_0;
}
@Override
public float getFloat(int index) {
assert 0 <= index && index < size;
int blockIndex = index / GGMLType.Q4_0.getBlockSize();
int blockOffset = blockIndex * GGMLType.Q4_0.getTypeSize();
float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset));
byte quant;
int modIndex = index % GGMLType.Q4_0.getBlockSize();
if (modIndex < GGMLType.Q4_0.getBlockSize() / 2) {
quant = (byte) (readByte(memorySegment, blockOffset + Float16.BYTES + modIndex) & 0x0F);
} else {
quant = (byte) ((readByte(memorySegment,
blockOffset + Float16.BYTES + modIndex - GGMLType.Q4_0.getBlockSize() / 2) >>> 4) & 0x0F);
}
quant -= 8;
return quant * scale;
}
@Override
public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
if (FloatTensor.USE_VECTOR_API) {
return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size);
} else {
return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size);
}
}
private static float vectorDot(Q4_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) {
float result = 0f;
int j = 0;
// Align thisOffset + j to type().getBlockSize().
assert Integer.bitCount(GGMLType.Q4_0.getBlockSize()) == 1 : "power of 2";
int alignmentBound = Math.min(size, -thisOffset & (GGMLType.Q4_0.getBlockSize() - 1));
if (alignmentBound > 0) {
result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound);
j += alignmentBound;
}
assert (thisOffset + j) % GGMLType.Q4_0.getBlockSize() == 0;
FloatVector val = FloatVector.zero(F_SPECIES);
int blockOffset = (thisOffset + j) / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getTypeSize();
int upperBound = size / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getBlockSize();
for (; j < upperBound; j += GGMLType.Q4_0.getBlockSize(), blockOffset += GGMLType.Q4_0.getTypeSize()) {
float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset));
var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue);
var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, thiz.memorySegment, blockOffset + Float16.BYTES,
ByteOrder.LITTLE_ENDIAN);
var loBytes = wBytes.and((byte) 0xF).sub((byte) 8);
var hiBytes = wBytes.lanewise(VectorOperators.LSHR, 4).sub((byte) 8);
switch (F_SPECIES.vectorBitSize()) {
case 512 -> {
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length())
.mul(loBytes.castShape(F_SPECIES, 0));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length())
.mul(hiBytes.castShape(F_SPECIES, 0));
val = sum0.add(sum2).fma(wScale, val);
}
case 256 -> {
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length())
.mul(loBytes.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length())
.mul(loBytes.castShape(F_SPECIES, 1));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length())
.mul(hiBytes.castShape(F_SPECIES, 0));
var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length())
.mul(hiBytes.castShape(F_SPECIES, 1));
val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val);
}
case 128 -> {
// This loop cannot be unrolled, why?
for (int i = 0; i < 2; ++i) {
var tmp = i == 0 ? loBytes : hiBytes;
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 0) * F_SPECIES.length())
.mul(tmp.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 1) * F_SPECIES.length())
.mul(tmp.castShape(F_SPECIES, 1));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 2) * F_SPECIES.length())
.mul(tmp.castShape(F_SPECIES, 2));
var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 3) * F_SPECIES.length())
.mul(tmp.castShape(F_SPECIES, 3));
val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val);
}
}
default -> throw new UnsupportedOperationException(F_SPECIES.toString());
}
}
result += val.reduceLanes(VectorOperators.ADD);
// Remaining entries.
if (j < size) {
result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j);
}
return result;
}
}
final class Q8_0FloatTensor extends FloatTensor {
final int size;
final MemorySegment memorySegment;
public Q8_0FloatTensor(int size, MemorySegment memorySegment) {
this.size = size;
this.memorySegment = memorySegment;
}
@Override
int size() {
return size;
}
@Override
public void setFloat(int index, float value) {
throw new UnsupportedOperationException("setFloat");
}
@Override
FloatVector getFloatVector(VectorSpecies species, int index) {
throw new UnsupportedOperationException("getFloatVector");
}
@Override
public GGMLType type() {
return GGMLType.Q8_0;
}
@Override
public float getFloat(int index) {
assert 0 <= index && index < size;
int blockIndex = index / GGMLType.Q8_0.getBlockSize();
int withinBlockIndex = index % GGMLType.Q8_0.getBlockSize();
int blockOffset = blockIndex * GGMLType.Q8_0.getTypeSize();
byte quant = readByte(memorySegment, blockOffset + Float16.BYTES + withinBlockIndex);
float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset));
return quant * scale;
}
public static final ValueLayout.OfShort JAVA_SHORT_LE = ValueLayout.JAVA_SHORT.withOrder(ByteOrder.LITTLE_ENDIAN);
@Override
public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
if (FloatTensor.USE_VECTOR_API) {
return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size);
} else {
return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size);
}
}
private static float vectorDot(Q8_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) {
float result = 0f;
int j = 0;
// Align thisOffset + startIndex to type().getBlockSize().
assert Integer.bitCount(GGMLType.Q8_0.getBlockSize()) == 1 : "power of 2";
int alignmentBound = Math.min(size, -thisOffset & (GGMLType.Q8_0.getBlockSize() - 1));
if (alignmentBound > 0) {
result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound);
j += alignmentBound;
}
assert (thisOffset + j) % GGMLType.Q8_0.getBlockSize() == 0;
FloatVector val = FloatVector.zero(F_SPECIES);
int blockOffset = (thisOffset + j) / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getTypeSize();
int upperBound = size / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getBlockSize();
for (; j < upperBound; j += GGMLType.Q8_0.getBlockSize(), blockOffset += GGMLType.Q8_0.getTypeSize()) {
float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset));
var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue);
switch (F_SPECIES.vectorBitSize()) {
case 512 -> {
var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment,
blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN);
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length())
.mul(wBytes.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length())
.mul(wBytes.castShape(F_SPECIES, 1));
val = sum0.add(sum1).fma(wScale, val);
}
case 256 -> {
var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment,
blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN);
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length())
.mul(wBytes.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length())
.mul(wBytes.castShape(F_SPECIES, 1));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length())
.mul(wBytes.castShape(F_SPECIES, 2));
var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length())
.mul(wBytes.castShape(F_SPECIES, 3));
val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val);
}
case 128 -> {
// This loop cannot be unrolled, why?
for (int i = 0; i < 2; ++i) {
var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, thiz.memorySegment,
blockOffset + Float16.BYTES + i * ByteVector.SPECIES_128.vectorByteSize(),
ByteOrder.LITTLE_ENDIAN);
var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 0 * F_SPECIES.length())
.mul(wBytes.castShape(F_SPECIES, 0));
var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 1 * F_SPECIES.length())
.mul(wBytes.castShape(F_SPECIES, 1));
var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 2 * F_SPECIES.length())
.mul(wBytes.castShape(F_SPECIES, 2));
var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 3 * F_SPECIES.length())
.mul(wBytes.castShape(F_SPECIES, 3));
val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val);
}
}
default -> throw new UnsupportedOperationException(F_SPECIES.toString());
}
}
result += val.reduceLanes(VectorOperators.ADD);
// Remaining entries.
if (j < size) {
result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j);
}
return result;
}
}
final class ArrayFloatTensor extends FloatTensor {
final float[] values;
ArrayFloatTensor(float[] values) {
this.values = values;
}
public static FloatTensor allocate(int... dims) {
int numberOfElements = FloatTensor.numberOfElements(dims);
return new ArrayFloatTensor(new float[numberOfElements]);
}
@Override
public int size() {
return values.length;
}
@Override
public float getFloat(int index) {
return values[index];
}
@Override
public void setFloat(int index, float value) {
values[index] = value;
}
@Override
public GGMLType type() {
return GGMLType.F32;
}
@Override
public FloatTensor fillInPlace(int thisOffset, int size, float value) {
Arrays.fill(values, thisOffset, thisOffset + size, value);
return this;
}
@Override
public FloatVector getFloatVector(VectorSpecies species, int index) {
if (!USE_VECTOR_API) {
throw new UnsupportedOperationException();
}
return FloatVector.fromArray(species, values, index);
}
}
final class RoPE {
public static Pair precomputeFreqsCis(int contextLength, int headSize, double theta,
boolean ropeScaling, float scaleFactor, float loFreqFactor, float hiFreqFactor, float oldContextLength) {
assert headSize % 2 == 0;
float[] cr = new float[contextLength * (headSize / 2)];
float[] ci = new float[contextLength * (headSize / 2)];
int n = 0;
for (int pos = 0; pos < contextLength; ++pos) {
for (int i = 0; i < headSize; i += 2) {
float freq = (float) (1.0 / Math.pow(theta, i / (double) headSize));
if (ropeScaling) {
// Llama 3.1 scaling
float loFreqWavelen = oldContextLength / loFreqFactor;
float hiFreqWavelen = oldContextLength / hiFreqFactor;
float wavelen = (float) (2.0 * Math.PI / freq);
if (wavelen < hiFreqWavelen) {
freq = freq;
} else if (wavelen > loFreqWavelen) {
freq = freq / scaleFactor;
} else {
float smooth = (oldContextLength / wavelen - loFreqFactor) / (hiFreqFactor - loFreqFactor);
freq = (1.0f - smooth) * freq / scaleFactor + smooth * freq;
}
}
float val = pos * freq;
cr[n] = (float) Math.cos(val);
ci[n] = (float) Math.sin(val);
n++;
}
}
assert contextLength * (headSize / 2) == n;
return new Pair<>(cr, ci);
}
}
record CategoricalSampler(RandomGenerator rng) implements Sampler {
@Override
public int sampleToken(FloatTensor logits) {
// sample index from probabilities (they must sum to 1!)
float random0to1 = rng.nextFloat(1f);
float cdf = 0.0f;
for (int i = 0; i < logits.size(); i++) {
cdf += logits.getFloat(i);
if (random0to1 < cdf) {
return i;
}
}
return logits.size() - 1; // in case of rounding errors
}
}
final class ToppSampler implements Sampler {
final int[] indices;
final float topp;
final RandomGenerator rng;
public ToppSampler(int maxNumberOfElements, float topp, RandomGenerator rng) {
this.indices = new int[maxNumberOfElements];
this.topp = topp;
this.rng = rng;
}
static void swap(int[] array, int from, int to) {
int tmp = array[from];
array[from] = array[to];
array[to] = tmp;
}
static void siftDown(int[] array, int from, int n, Comparator comparator) {
int prev = from, next;
while ((next = 2 * prev + 1) < n) {
int r = 2 * prev + 2;
if (r < n && comparator.compare(array[r], array[next]) < 0) {
next = r;
}
if (comparator.compare(array[next], array[prev]) < 0) {
swap(array, prev, next);
prev = next;
} else {
break;
}
}
}
@Override
public int sampleToken(FloatTensor logits) {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability topp. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
Comparator comparator = Comparator.comparingDouble(logits::getFloat).reversed();
int n = logits.size();
int head = 0;
int tail = n - 1;
// values smaller than (1 - topp) / (n - 1) cannot be part of the result
// so for efficiency we crop these out as candidates before sorting
float cutoff = (1.0f - topp) / (n - 1);
for (int i = 0; i < indices.length; i++) {
if (logits.getFloat(i) >= cutoff) {
indices[head++] = i;
} else {
indices[tail--] = i;
}
}
int n0 = head;
// build heap O(n0)
for (int i = n0 / 2 - 1; i >= 0; --i) {
siftDown(indices, i, n0, comparator);
}
// truncate the list where cumulative probability of the largest k elements exceeds topp
// O(k lg n0)
float cumulativeProb = 0.0f;
int lastIndex = 0;
for (int i = n0 - 1; i >= 0; i--) {
swap(indices, 0, i);
cumulativeProb += logits.getFloat(indices[i]);
if (cumulativeProb > topp) {
lastIndex = i;
break; // we've exceeded topp by including lastIndex
}
siftDown(indices, 0, i - 1, comparator);
}
// sample from the truncated list
float r = rng.nextFloat(1f) * cumulativeProb;
float cdf = 0.0f;
for (int i = n0 - 1; i >= lastIndex; i--) {
cdf += logits.getFloat(indices[i]);
if (r < cdf) {
return indices[i];
}
}
return indices[lastIndex]; // in case of rounding errors
}
}