smile.llm.llama.Llama Maven / Gradle / Ivy
The newest version!
/*
* Copyright (c) 2010-2024 Haifeng Li. All rights reserved.
*
* Smile is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Smile is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Smile. If not, see .
*/
package smile.llm.llama;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.SubmissionPublisher;
import org.bytedeco.cuda.global.cudart;
import org.bytedeco.pytorch.TypeMeta;
import org.bytedeco.pytorch.global.torch;
import org.bytedeco.pytorch.global.torch_cuda;
import smile.deep.tensor.Device;
import smile.deep.tensor.Index;
import smile.deep.tensor.ScalarType;
import smile.deep.tensor.Tensor;
import smile.llm.CompletionPrediction;
import smile.llm.FinishReason;
import smile.llm.Message;
import smile.util.AutoScope;
/**
* LLaMA model specification.
*
* @author Haifeng Li
*/
public class Llama {
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(Llama.class);
/** The model family name. */
final String family = "meta/llama3";
/** The model instance name. */
final String name;
/** The transformer model. */
final Transformer model;
/** The tokenizer. */
final Tokenizer tokenizer;
/**
* Constructor.
* @param name the model name.
* @param model the transformer model.
* @param tokenizer the tokenizer.
*/
public Llama(String name, Transformer model, Tokenizer tokenizer) {
this.name = name;
this.model = model;
this.tokenizer = tokenizer;
}
@Override
public String toString() {
return String.format("%s/%s", family, name);
}
/**
* Returns the model family name.
* @return the model family name.
*/
public String family() {
return family;
}
/**
* Returns the model instance name.
* @return the model instance name.
*/
public String name() {
return name;
}
/**
* Builds a Llama instance by initializing and loading a model checkpoint.
* @param checkpointDir the directory path of checkpoint files.
* @param tokenizerPath the path of tokenizer model file.
* @param maxSeqLen the maximum sequence length for input text.
* @param maxBatchSize the maximum batch size for inference.
* @return an instance of Llama model.
*/
public static Llama build(String checkpointDir, String tokenizerPath, int maxBatchSize, int maxSeqLen) throws IOException {
return build(checkpointDir, tokenizerPath, maxBatchSize, maxSeqLen, null);
}
/**
* Builds a Llama instance by initializing and loading a model checkpoint.
* @param checkpointDir the directory path of checkpoint files.
* @param tokenizerPath the path of tokenizer model file.
* @param maxSeqLen the maximum sequence length for input text.
* @param maxBatchSize the maximum batch size for inference.
* @param deviceId the optional CUDA device ID.
* @return an instance of Llama model.
*/
public static Llama build(String checkpointDir, String tokenizerPath, int maxBatchSize, int maxSeqLen, Integer deviceId) throws IOException {
File dir = new File(checkpointDir);
if (!dir.exists() || !dir.isDirectory()) {
throw new IllegalArgumentException("Checkpoint directory doesn't exist: " + checkpointDir);
}
String worldSize = Objects.requireNonNullElse(System.getenv("WORLD_SIZE"), "1");
int modelParallelSize = Integer.parseInt(worldSize);
String localRank = Objects.requireNonNullElse(System.getenv("LOCAL_RANK"), "0");
int rank = Integer.parseInt(localRank);
if (deviceId == null) {
deviceId = rank;
}
var startTime = System.currentTimeMillis();
cudart.cuInit(0);
torch_cuda.set_device(deviceId.byteValue());
// half precision to lower memory usage.
var meta = new TypeMeta();
meta.put(Tensor.isBF16Supported() ? torch.ScalarType.BFloat16 : torch.ScalarType.Half);
torch.set_default_dtype(meta);
Device device = Device.CUDA(deviceId.byteValue());
var options = new Tensor.Options().device(device).requireGradients(false);
Tensor.setDefaultOptions(options);
var time = System.currentTimeMillis() - startTime;
logger.info("Initialized CUDA[{}]: {}.{} seconds", rank, time/1000, time%1000);
startTime = System.currentTimeMillis();
List checkpoints = getCheckpoints(dir);
if (checkpoints.isEmpty()) {
throw new IllegalArgumentException("No checkpoint files found in " + checkpointDir);
}
if (checkpoints.size() != modelParallelSize) {
throw new IllegalStateException(String.format("Loading a checkpoint for MP=%d but world size is %d", checkpoints.size(), modelParallelSize));
}
var modelArgs = ModelArgs.from(checkpointDir + "/params.json", maxBatchSize, maxSeqLen);
var tokenizer = Tokenizer.of(tokenizerPath);
if (tokenizer.size() != modelArgs.vocabSize()) {
throw new IllegalStateException("Tokenizer and ModelArgs have different vocabulary size.");
}
var model = new Transformer(modelArgs, device);
model.eval();
Collections.sort(checkpoints);
var checkpoint = checkpoints.get(rank);
model.load(checkpoint);
time = System.currentTimeMillis() - startTime;
logger.info("Model {}[{}]: loaded in {}.{} seconds", checkpointDir, rank, time/1000, time%1000);
return new Llama(dir.getName(), model, tokenizer);
}
/**
* Returns the checkpoint file paths.
* @param dir the checkpoint directory.
* @return the checkpoint file paths.
*/
private static List getCheckpoints(File dir) {
List checkpoints = new ArrayList<>();
for (var file : dir.listFiles()) {
var path = file.getPath();
if (path.endsWith(".pt")) {
checkpoints.add(path);
}
}
return checkpoints;
}
/**
* Generates text sequences based on provided prompts. This method uses
* the provided prompts as a basis for generating text. It employs nucleus
* sampling to produce text with controlled randomness.
* @param prompts List of tokenized prompts, where each prompt is represented as a list of integers.
* @param maxGenLen Maximum length of the generated text sequence.
* @param temperature Temperature value for controlling randomness in sampling.
* @param topp Top-p probability threshold for nucleus sampling.
* @param logprobs Flag indicating whether to compute token log probabilities.
* @param seed the optional random number generation seed to sample deterministically.
* @param publisher an optional flow publisher that asynchronously issues generated chunks.
* The batch size must be 1.
* @return The generated text completion.
*/
public CompletionPrediction[] generate(int[][] prompts, int maxGenLen, double temperature, double topp, boolean logprobs, Long seed, SubmissionPublisher publisher) {
int batchSize = prompts.length;
if (batchSize > model.params.maxBatchSize()) {
throw new IllegalArgumentException("The number of prompts is greater than max_batch_size");
}
if (publisher != null && batchSize > 1) {
throw new IllegalArgumentException("The batch size is > 1 while publisher is provided");
}
int minPromptLen = Integer.MAX_VALUE;
int maxPromptLen = Integer.MIN_VALUE;
for (var prompt : prompts) {
minPromptLen = Math.min(minPromptLen, prompt.length);
maxPromptLen = Math.max(maxPromptLen, prompt.length);
}
if (maxPromptLen > model.params.maxSeqLen()) {
throw new IllegalArgumentException("The prompt length is greater than max_seq_len");
}
// seed must be the same in all processes
if (seed != null) {
torch.manual_seed(seed);
}
try (var guard = Tensor.noGradGuard();
var scope = new AutoScope()) {
Tensor.push(scope);
int totalLen = Math.min(model.params.maxSeqLen(), maxGenLen + maxPromptLen);
int pad = tokenizer.pad();
Tensor tokens = Tensor.full(pad, batchSize, totalLen);
for (int i = 0; i < batchSize; i++) {
var prompt = Tensor.of(prompts[i]);
tokens.put_(prompt, Index.of(i), Index.slice(0, prompts[i].length));
}
Tensor tokenLogprobs = null;
if (logprobs) {
var options = new Tensor.Options().device(model.device()).requireGradients(false).dtype(ScalarType.Float32);
tokenLogprobs = Tensor.zeros(options, batchSize, totalLen);
}
Tensor eosReached = Tensor.of(new boolean[batchSize]);
Tensor inputTextMask = tokens.ne(pad);
Tensor stopTokens = Tensor.of(tokenizer.stopTokens());
tokens = tokens.to(model.device());
eosReached = eosReached.to(model.device());
inputTextMask = inputTextMask.to(model.device());
stopTokens = stopTokens.to(model.device());
int prevPos = 0;
if (minPromptLen == totalLen) {
var logits = model.forward(tokens, prevPos);
if (logprobs) {
tokenLogprobs = Tensor.crossEntropy(logits.transpose(1, 2), tokens, "none", pad).neg_();
}
}
int chunkPos = minPromptLen;
for (int curPos = minPromptLen; curPos < totalLen; curPos++) {
try (var loopScope = new AutoScope()) {
Tensor.push(loopScope);
var logits = model.forward(tokens.get(Index.Colon, Index.slice(prevPos, curPos)), prevPos);
Tensor nextToken;
if (temperature > 0) {
var probs = logits.get(Index.Colon, Index.of(-1)).div(temperature).softmax(-1);
nextToken = probs.topp(topp);
} else {
nextToken = logits.get(Index.Colon, Index.of(-1)).argmax(-1, false);
}
nextToken = nextToken.reshape(-1);
// only replace token if prompt has already been generated
nextToken = Tensor.where(
inputTextMask.get(Index.Colon, Index.of(curPos)),
tokens.get(Index.Colon, Index.of(curPos)),
nextToken);
tokens.put_(nextToken, Index.Colon, Index.of(curPos));
if (logprobs) {
var entropy = Tensor.crossEntropy(
logits.transpose(1, 2),
tokens.get(Index.Colon, Index.slice(prevPos + 1, curPos + 1)),
"none", pad).neg_();
tokenLogprobs.put_(entropy, Index.Colon, Index.slice(prevPos + 1, curPos + 1));
}
var text = inputTextMask.get(Index.Colon, Index.of(curPos)).not();
var stop = nextToken.isin(stopTokens);
eosReached.or_(text.and_(stop));
prevPos = curPos;
// Free up memory at each iteration
Tensor.pop();
}
boolean eos = eosReached.all();
if (publisher != null && (curPos - chunkPos >= 20 || curPos == totalLen-1 || eos)) {
int end = eos ? curPos : curPos + 1;
if (end > chunkPos) {
var longArray = tokens.get(Index.of(0), Index.slice(chunkPos, end)).to(Device.CPU()).longArray();
var completion = Arrays.stream(longArray).mapToInt(x -> (int) x).toArray();
try {
var chunk = tokenizer.tryDecode(completion);
publisher.submit(chunk);
chunkPos = curPos + 1;
} catch (Exception ex) {
logger.debug("Cannot decode a chunk", ex);
}
}
}
if (eos) break;
}
var longArray = tokens.to(Device.CPU()).longArray();
float[] logprobArray = null;
if (logprobs) {
logprobArray = tokenLogprobs.to(Device.CPU()).floatArray();
}
CompletionPrediction[] predictions = new CompletionPrediction[batchSize];
for (int i = 0; i < batchSize; i++) {
// cut to max gen len
int start = prompts[i].length;
var completion = Arrays.stream(longArray)
.skip((long) i * totalLen + start)
.mapToInt(x -> (int) x)
.limit(prompts[i].length + maxGenLen - start)
.toArray();
float[] probs = null;
if (logprobs) {
probs = Arrays.copyOfRange(logprobArray, i * totalLen + start, i * totalLen + prompts[i].length + maxGenLen);
}
// cut to after eos tok if any
boolean stop = false;
for (var stopToken : tokenizer.stopTokens()) {
for (int eosIdx = 0; eosIdx < completion.length; eosIdx++) {
if (completion[eosIdx] == stopToken) {
stop = true;
completion = Arrays.copyOf(completion, eosIdx);
if (logprobs) {
probs = Arrays.copyOf(probs, eosIdx);
}
break;
}
}
}
var reason = stop ? FinishReason.stop : FinishReason.length;
predictions[i] = new CompletionPrediction(name, tokenizer.decode(completion), prompts[i], completion, reason, probs);
}
Tensor.pop();
System.gc();
return predictions;
}
}
/**
* Performs text completion for a list of prompts
* @param prompts List of text prompts.
* @param maxGenLen Maximum length of the generated text sequence.
* @param temperature Temperature value for controlling randomness in sampling.
* @param topp Top-p probability threshold for nucleus sampling.
* @param logprobs Flag indicating whether to compute token log probabilities.
* @param seed the optional random number generation seed to sample deterministically.
* @param publisher an optional flow publisher that asynchronously issues generated chunks.
* The batch size must be 1.
* @return The generated text completion.
*/
public CompletionPrediction[] complete(String[] prompts, int maxGenLen, double temperature, double topp, boolean logprobs, Long seed, SubmissionPublisher publisher) {
int batchSize = prompts.length;
int[][] tokens = new int[batchSize][];
for (int i = 0; i < batchSize; i++) {
tokens[i] = tokenizer.encode(prompts[i], true, false);
}
return generate(tokens, maxGenLen, temperature, topp, logprobs, seed, publisher);
}
/**
* Generates assistant responses for a list of conversational dialogs.
* @param dialogs List of conversational dialogs, where each dialog is a list of messages.
* @param maxGenLen Maximum length of the generated text sequence.
* @param temperature Temperature value for controlling randomness in sampling.
* @param topp Top-p probability threshold for nucleus sampling.
* @param logprobs Flag indicating whether to compute token log probabilities.
* @param seed the optional random number generation seed to sample deterministically.
* @param publisher an optional flow publisher that asynchronously issues generated chunks.
* The batch size must be 1.
* @return The generated chat responses.
*/
public CompletionPrediction[] chat(Message[][] dialogs, int maxGenLen, double temperature, double topp, boolean logprobs, Long seed, SubmissionPublisher publisher) {
int batchSize = dialogs.length;
int[][] tokens = new int[batchSize][];
for (int i = 0; i < batchSize; i++) {
tokens[i] = tokenizer.encodeDialog(dialogs[i]);
}
return generate(tokens, maxGenLen, temperature, topp, logprobs, seed, publisher);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy