io.quarkiverse.langchain4j.llama3.copy.Llama Maven / Gradle / Ivy
package io.quarkiverse.langchain4j.llama3.copy;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.function.IntConsumer;
import java.util.stream.IntStream;
import java.util.stream.Stream;
public record Llama(Configuration configuration, Tokenizer tokenizer, Weights weights) {
public State createNewState(int batchsize) {
State state = new State(configuration(), batchsize);
state.latestToken = tokenizer.getSpecialTokens().get("<|begin_of_text|>");
return state;
}
public static final class Configuration {
public final int dim; // transformer dimension
public final int hiddenDim; // for ffn layers
public final int numberOfLayers; // number of layers
public final int numberOfHeads; // number of query heads
public final int numberOfKeyValueHeads; // number of key/value heads (can be < query heads because of multiquery)
public final int vocabularySize; // vocabulary size, usually 256 (byte-level)
public final int contextLength; // max sequence length
public final float rmsNormEps;
public final float ropeTheta;
public final int headSize;
public Configuration withContextLength(int newContextLength) {
if (newContextLength < 0) {
return this; // no change
}
return new Configuration(this.dim, this.hiddenDim, this.numberOfLayers, this.numberOfHeads,
this.numberOfKeyValueHeads, this.vocabularySize, newContextLength, this.rmsNormEps, this.ropeTheta);
}
public Configuration(int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads,
int vocabularySize, int contextLength, float rmsNormEps, float ropeTheta) {
this.dim = dim;
this.hiddenDim = hiddenDim;
this.numberOfLayers = numberOfLayers;
this.numberOfHeads = numberOfHeads;
this.numberOfKeyValueHeads = numberOfKeyValueHeads;
this.vocabularySize = vocabularySize;
this.contextLength = contextLength;
this.rmsNormEps = rmsNormEps;
this.ropeTheta = ropeTheta;
this.headSize = dim / numberOfHeads;
}
}
public static final class Weights {
// token embedding table
public final FloatTensor token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
public final FloatBuffer[] rms_att_weight; // (layer, dim) rmsnorm weights
// weights for matmuls
public final FloatTensor[] wq; // (layer, n_heads * head_size)
public final FloatTensor[] wk; // (layer, n_kv_heads, head_size)
public final FloatTensor[] wv; // (layer, n_kv_heads * head_size)
public final FloatTensor[] wo; // (layer, n_heads * head_size, dim)
public final FloatBuffer[] rms_ffn_weight; // (layer, dim)
// weights for ffn
public final FloatTensor[] w1; // (layer, hidden_dim, dim)
public final FloatTensor[] w2; // (layer, dim, hidden_dim)
public final FloatTensor[] w3; // (layer, hidden_dim, dim)
// public final rmsnorm
public final FloatBuffer rms_final_weight; // (dim,)
// freq_cis for RoPE relatively positional embeddings
public final FloatBuffer freq_cis_real; // (seq_len, head_size/2)
public final FloatBuffer freq_cis_imag; // (seq_len, head_size/2)
// (optional) classifier weights for the logits, on the last layer
public final FloatTensor wcls; // (vocab_size, dim)
public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk,
FloatTensor[] wv, FloatTensor[] wo, FloatBuffer[] rms_ffn_weight, FloatTensor[] w1, FloatTensor[] w2,
FloatTensor[] w3, FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag,
FloatTensor wcls) {
this.token_embedding_table = token_embedding_table;
this.rms_att_weight = rms_att_weight;
this.wq = wq;
this.wk = wk;
this.wv = wv;
this.wo = wo;
this.rms_ffn_weight = rms_ffn_weight;
this.w1 = w1;
this.w2 = w2;
this.w3 = w3;
this.rms_final_weight = rms_final_weight;
this.freq_cis_real = freq_cis_real;
this.freq_cis_imag = freq_cis_imag;
this.wcls = wcls;
}
}
public static final class State {
// current wave of activations
public final int batchsize;
public final FloatTensor[] x; // activation at current time stamp (dim,)
public final FloatTensor[] xb; // same, but inside a residual branch (dim,)
public final FloatTensor[] xb2; // an additional buffer just for convenience (dim,)
public final FloatTensor[] hb; // buffer for hidden dimension in the ffn (hidden_dim,)
public final FloatTensor[] hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
public final FloatTensor[] q; // query (dim,)
public final FloatTensor[] k; // key (dim,)
public final FloatTensor[] v; // value (dim,)
public final FloatTensor[] att; // buffer for scores/attention values (n_heads, seq_len)
public final FloatTensor logits; // output logits
// kv cache
public final FloatTensor[] keyCache; // (n_layer, seq_len, kv_dim)
public final FloatTensor[] valueCache; // (n_layer, seq_len, kv_dim)
/** last index in previous block */
int idxPrevBlock;
public int latestToken;
State(Configuration config, int batchsize) {
this.batchsize = batchsize;
this.x = allocate(batchsize, config.dim);
this.xb = allocate(batchsize, config.dim);
this.xb2 = allocate(batchsize, config.dim);
this.hb = allocate(batchsize, config.hiddenDim);
this.hb2 = allocate(batchsize, config.hiddenDim);
this.q = allocate(batchsize, config.dim);
this.k = allocate(batchsize, config.dim);
this.v = allocate(batchsize, config.dim);
this.att = allocate(batchsize, config.numberOfHeads, config.contextLength);
idxPrevBlock = -1;
this.logits = ArrayFloatTensor.allocate(config.vocabularySize);
int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads;
this.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim))
.limit(config.numberOfLayers).toArray(FloatTensor[]::new);
this.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim))
.limit(config.numberOfLayers).toArray(FloatTensor[]::new);
}
}
static FloatTensor[] allocate(int numTokens, int... dims) {
return IntStream.range(0, numTokens)
.mapToObj(i -> ArrayFloatTensor.allocate(dims))
.toArray(FloatTensor[]::new);
}
static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) {
// calculate sum of squares
float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi);
ss /= size;
ss += rmsNormEps;
ss = (float) (1.0 / Math.sqrt(ss));
// normalize and scale
final float finalss = ss; // for the lambda
out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index)));
}
static FloatTensor forward(Llama model, State state, int[] tokens, int position, boolean computeLogits) {
// a few convenience variables
Configuration config = model.configuration();
Weights weights = model.weights();
int dim = config.dim;
int headSize = config.headSize;
int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads;
int kvMul = config.numberOfHeads / config.numberOfKeyValueHeads; // integer multiplier of the kv sharing in multiquery
float sqrtHeadSize = (float) Math.sqrt(headSize);
final int nTokens = tokens.length;
// copy the token embedding into x
Parallel.parallelFor(0, nTokens, t -> weights.token_embedding_table.copyTo(tokens[t] * dim, state.x[t], 0, dim));
// forward all the layers
for (int l = 0; l < config.numberOfLayers; l++) {
// attention rmsnorm
// rmsnorm(state.xb, state.x, weights.rms_att_weight[l], dim, config.rmsNormEps);
final int curLayer = l;
Parallel.parallelFor(0, nTokens,
t -> rmsnorm(state.xb[t], state.x[t], weights.rms_att_weight[curLayer], dim, config.rmsNormEps));
// qkv matmuls for this position
weights.wq[l].matmul(nTokens, state.xb, state.q, dim, dim);
weights.wk[l].matmul(nTokens, state.xb, state.k, kvDim, dim);
weights.wv[l].matmul(nTokens, state.xb, state.v, kvDim, dim);
// RoPE relative positional encoding: complex-valued rotate q and k in each head
Parallel.parallelFor(0, nTokens, t -> {
for (int i = 0; i < dim; i += 2) {
int head_dim = i % headSize;
float fcr = weights.freq_cis_real.get((position + t) * (headSize / 2) + (head_dim / 2));
float fci = weights.freq_cis_imag.get((position + t) * (headSize / 2) + (head_dim / 2));
int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
for (int vi = 0; vi < rotn; vi++) {
FloatTensor vec = vi == 0 ? state.q[t] : state.k[t]; // the vector to rotate (query or key)
float v0 = vec.getFloat(i);
float v1 = vec.getFloat(i + 1);
vec.setFloat(i, v0 * fcr - v1 * fci);
vec.setFloat(i + 1, v0 * fci + v1 * fcr);
}
}
});
// save key,value at this time step (position) to our kv cache
//int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience
Parallel.parallelFor(0, nTokens, t -> {
state.k[t].copyTo(0, state.keyCache[curLayer], (position + t) * kvDim, kvDim);
state.v[t].copyTo(0, state.valueCache[curLayer], (position + t) * kvDim, kvDim);
});
// If the logits are not required, the attention and FFN of the last layer can be skipped entirely.
if (!computeLogits && curLayer == config.numberOfLayers - 1) {
state.idxPrevBlock = nTokens - 1;
return null;
}
// multihead attention. iterate over all heads
Parallel.parallelForLong(0, (long) nTokens * (long) config.numberOfHeads, ht -> {
int token = (int) (ht / config.numberOfHeads);
int h = (int) (ht % config.numberOfHeads);
// get the query vector for this head
// float* q = s.q + h * headSize;
int qOffset = h * headSize;
// attention scores for this head
// float* att = s.att + h * config.seq_len;
int attOffset = h * config.contextLength;
// iterate over all timesteps, including the current one
for (int t = 0; t <= position + token; t++) {
// get the key vector for this head and at this timestep
// float* k = s.key_cache + loff + t * dim + h * headSize;
int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// calculate the attention score as the dot product of q and k
float score = state.q[token].dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
score /= sqrtHeadSize;
// save the score to the attention buffer
state.att[token].setFloat(attOffset + t, score);
}
// softmax the scores to get attention weights, from 0..position inclusively
state.att[token].softmaxInPlace(attOffset, position + token + 1);
// weighted sum of the values, store back into xb
// float* xb = s.xb + h * headSize;
int xbOffset = h * headSize;
// memset(xb, 0, headSize * sizeof(float));
state.xb[token].fillInPlace(xbOffset, headSize, 0f);
for (int t = 0; t <= position + token; t++) {
// get the value vector for this head and at this timestep
// float* v = s.value_cache + loff + t * dim + h * headSize;
int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// get the attention weight for this timestep
float a = state.att[token].getFloat(attOffset + t);
// accumulate the weighted value into xb
state.xb[token].saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
}
});
// final matmul to get the output of the attention
weights.wo[l].matmul(nTokens, state.xb, state.xb2, dim, dim);
// residual connection back into x
Parallel.parallelFor(0, nTokens, t -> {
state.x[t].addInPlace(state.xb2[t]);
});
// ffn rmsnorm
Parallel.parallelFor(0, nTokens, t -> {
rmsnorm(state.xb[t], state.x[t], weights.rms_ffn_weight[curLayer], dim, config.rmsNormEps);
});
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
weights.w1[l].matmul(nTokens, state.xb, state.hb, config.hiddenDim, dim);
weights.w3[l].matmul(nTokens, state.xb, state.hb2, config.hiddenDim, dim);
// SwiGLU non-linearity
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
Parallel.parallelFor(0, nTokens, t -> {
state.hb[t].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
});
// elementwise multiply with w3(x)
Parallel.parallelFor(0, nTokens, t -> {
state.hb[t].multiplyInPlace(state.hb2[t]);
});
// final matmul to get the output of the ffn
weights.w2[l].matmul(nTokens, state.hb, state.xb, dim, config.hiddenDim);
// residual connection
Parallel.parallelFor(0, nTokens, t -> {
state.x[t].addInPlace(state.xb[t]);
});
}
// final rmsnorm
Parallel.parallelFor(0, nTokens, t -> {
rmsnorm(state.x[t], state.x[t], weights.rms_final_weight, dim, config.rmsNormEps);
});
// classifier into logits
weights.wcls.matmul(state.x[nTokens - 1], state.logits, config.vocabularySize, dim);
state.idxPrevBlock = nTokens - 1;
return state.logits;
}
/**
* LLM generation entry point, ingest prompt tokens and generates new tokens.
*
*
* All prompt tokens are ingested first, then inference starts, until a stop token is found.
* The returned tokens only include generated/inferred tokens.
*
* @param model model to run inference (including weights, configuration, tokenizer ...)
* @param state state of the model e.g. key/value caches ... this is mutated by this call
* @param startPosition start prompt ingestion + inference at this position in the context e.g. useful if state was kept
* across calls (chained generation). 0 implies run with no previous context.
* @param promptTokens prompt tokens to ingest, all the prompt tokens will be ingested, given there's enough capacity left
* in the context
* @param stopTokens set of tokens that abort generation during inference, stop tokens do not affect prompt ingestion
* @param maxTokens maximum number of tokens (can go up to {@link Configuration#contextLength context length}
* if this value is negative or greater than {@link Configuration#contextLength context length}
* @param sampler {@link Sampler strategy} used to select tokens
* @param echo debugging flag, prints ALL, prompt and inferred tokens, to {@link System#err stderr}
* @param onTokenGenerated callback, if non-null, it's called every time a token is inferred e.g. it's not called when
* ingesting prompt tokens
* @return list of generated/inferred tokens, including the stop token, if any e.g. does not include any token from the
* prompt
*/
public static List generateTokens(Llama model, State state, int startPosition, List promptTokens,
Set stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated) {
long startNanos = System.nanoTime();
long startGen = 0;
if (maxTokens < 0 || model.configuration().contextLength < maxTokens) {
maxTokens = model.configuration().contextLength;
}
List generatedTokens = new ArrayList<>(maxTokens);
int token = state.latestToken; // BOS?
int nextToken;
int promptIndex = 0;
for (int position = startPosition; position < maxTokens; ++position) {
if (promptIndex < promptTokens.size()) {
final int nTokens = Math.min(maxTokens - position,
Math.min(promptTokens.size() - promptIndex, state.batchsize));
final int[] tokens = new int[nTokens];
for (int i = 0; i < nTokens; i++) {
tokens[i] = promptTokens.get(promptIndex + i);
if (echo) {
// log prompt token (different color?)
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(tokens[i]))));
}
}
if (echo) {
System.out.format("position=%d, promptIdx=%d, promptSize=%d, tokens=%s%n", position, promptIndex,
promptTokens.size(), Arrays.toString(tokens));
}
// Only compute logits on the very last batch.
boolean computeLogits = promptIndex + nTokens >= promptTokens.size();
forward(model, state, tokens, position, computeLogits);
position += nTokens - 1; // -1 -> incremented later in the for loop
promptIndex += nTokens;
if (promptIndex < promptTokens.size()) {
continue;
}
startGen = System.nanoTime();
} else {
forward(model, state, new int[] { token }, position, true);
}
nextToken = sampler.sampleToken(state.logits);
if (echo) {
// log inferred token
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
}
generatedTokens.add(nextToken);
if (onTokenGenerated != null) {
onTokenGenerated.accept(nextToken);
}
if (stopTokens.contains(nextToken)) {
break;
}
state.latestToken = token = nextToken;
}
long elapsedNanos = System.nanoTime() - startNanos;
long promptNanos = startGen - startNanos;
long genNanos = elapsedNanos - startGen + startNanos;
System.err.printf("%nprompt: %.2f tokens/s (%d) generation: %.2f tokens/s (%d)%n",
promptTokens.size() / (promptNanos / 1_000_000_000.0), promptTokens.size(),
generatedTokens.size() / (genNanos / 1_000_000_000.0), generatedTokens.size());
return generatedTokens;
}
}