com.github.tjake.jlama.net.grpc.JlamaService Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jlama-net Show documentation
Show all versions of jlama-net Show documentation
Jlama: A modern LLM inference engine for Java
The newest version!
/*
* Copyright 2024 T Jake Luciani
*
* The Jlama Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.github.tjake.jlama.net.grpc;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.net.*;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.util.Pair;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.Uninterruptibles;
import com.google.protobuf.ByteString;
import com.google.protobuf.UnsafeByteOperations;
import io.grpc.stub.StreamObserver;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
import java.util.concurrent.*;
import jdk.incubator.vector.FloatVector;
import org.jctools.queues.MpmcArrayQueue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class JlamaService extends JlamaServiceGrpc.JlamaServiceImplBase {
private static final Logger logger = LoggerFactory.getLogger(JlamaService.class);
private final AbstractModel model;
private final int workerCount;
private final ConcurrentMap workers;
private final GeneratorGroup generatorGroup;
private final ConcurrentMap>>> combinations;
public JlamaService(AbstractModel model, int workerCount) {
Preconditions.checkArgument(
workerCount <= model.getConfig().numberOfKeyValueHeads,
"Worker count must be less than or equal to number of KV heads"
);
this.model = model;
this.workerCount = workerCount;
this.workers = new ConcurrentHashMap<>();
this.combinations = new ConcurrentHashMap<>();
this.generatorGroup = new GeneratorGroup();
}
public void waitForReady() {
while (true) {
if (generatorGroup.generators.size() == workerCount) {
generatorGroup.waitForReady();
return;
}
Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS);
}
}
public void shutdown() {
for (Generator g : generatorGroup.generators) {
try {
g.responseObserver.onCompleted();
} catch (Exception e) {
logger.debug("Exception when shutting down", e);
}
}
}
/**
* Register a worker with the coordinator. The coordinator will return the offset and length of the embedding that
* the worker is responsible for.
*/
@Override
public void register(RegisterRequest request, StreamObserver responseObserver) {
synchronized (workers) {
ByteBuffer bb = request.getWorkerid().asReadOnlyByteBuffer();
UUID wid = new UUID(bb.getLong(), bb.getLong());
if (workers.containsKey(wid)) {
responseObserver.onNext(workers.get(wid));
responseObserver.onCompleted();
} else {
if (workers.size() == workerCount) {
responseObserver.onError(new RuntimeException("Not accepting any more workers"));
return;
}
int workerNum = workers.size();
RegisterResponse r = RegisterResponse.newBuilder()
.setModelShard(workerNum)
.setNumModelShards(workerCount)
.setLayerShard(0)
.setNumLayerShards(1)
.build();
workers.put(wid, r);
logger.info("Registered worker {} with workerNum {} of {}", wid, workerNum, workerCount);
responseObserver.onNext(r);
responseObserver.onCompleted();
}
}
}
public AbstractTensor generateNextOutput(UUID session, List tokenIds, int startPosition) {
return generatorGroup.generateNextOutput(session, tokenIds, startPosition);
}
public AbstractTensor generateNextOutput(UUID session, int tokenId, int position) {
return generatorGroup.generateNextOutput(session, tokenId, position);
}
@Override
public StreamObserver generate(StreamObserver responseObserver) {
Generator generator = new Generator(responseObserver);
generatorGroup.add(generator);
logger.info("Added worker {}", generatorGroup.generators.size());
return generator;
}
@Override
public StreamObserver combine(StreamObserver responseObserver) {
return new StreamObserver<>() {
@Override
public void onNext(CombineRequest request) {
String key = String.format("%s:%d", UUID.nameUUIDFromBytes(request.getUuid().toByteArray()), request.getLayer());
MpmcArrayQueue>> members = combinations.computeIfAbsent(
key,
k -> new MpmcArrayQueue<>(workerCount + 1)
);
members.add(Pair.of(request, responseObserver));
// If we have all the workers, then we can calculate the result and send it back
if (members.size() == workerCount && combinations.remove(key, members)) {
MemorySegment[] tensors = null;
for (Pair> f : members) {
if (f.left.getTensorCount() > 0) {
if (tensors == null) {
tensors = new MemorySegment[f.left.getTensorCount()];
for (int i = 0; i < tensors.length; i++) {
ByteBuffer bb = ByteBuffer.wrap(f.left.getTensor(i).toByteArray()).order(ByteOrder.LITTLE_ENDIAN);
tensors[i] = MemorySegment.ofBuffer(bb);
}
} else {
for (int i = 0; i < tensors.length; i++) {
MemorySegment ms = MemorySegment.ofBuffer(
f.left.getTensor(i).asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN)
);
// Sum float buffers
accumulateF32(tensors[i], ms, (int) tensors[i].byteSize() / Float.BYTES);
}
}
}
}
CombineResponse.Builder responseBuilder = CombineResponse.newBuilder();
if (tensors != null) {
for (int i = 0; i < tensors.length; i++)
responseBuilder = responseBuilder.addTensor(
UnsafeByteOperations.unsafeWrap(tensors[i].asByteBuffer().order(ByteOrder.LITTLE_ENDIAN))
);
}
CombineResponse response = responseBuilder.build();
for (Pair> f : members) {
f.right.onNext(response);
}
members.clear();
}
}
@Override
public void onError(Throwable throwable) {}
@Override
public void onCompleted() {}
};
}
void accumulateF32(MemorySegment a, MemorySegment b, int length) {
int upperBound = FloatVector.SPECIES_PREFERRED.loopBound(length);
int i = 0;
for (; i < upperBound; i += FloatVector.SPECIES_PREFERRED.length()) {
int fi = i * Float.BYTES;
FloatVector va = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, a, fi, ByteOrder.LITTLE_ENDIAN);
FloatVector vb = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, b, fi, ByteOrder.LITTLE_ENDIAN);
va.add(vb).intoMemorySegment(a, fi, ByteOrder.LITTLE_ENDIAN);
}
// tail
for (; i < length; i++) {
a.set(ValueLayout.JAVA_FLOAT, i, a.get(ValueLayout.JAVA_FLOAT, i) + b.get(ValueLayout.JAVA_FLOAT, i));
}
}
public class GeneratorGroup {
private final List generators;
private GeneratorGroup() {
this.generators = new ArrayList<>();
}
private void add(Generator generator) {
generators.add(generator);
}
public void waitForReady() {
for (Generator g : generators) {
Uninterruptibles.awaitUninterruptibly(g.isReady());
}
}
public AbstractTensor generateNextOutput(UUID session, int tokenId, int position) {
return generateNextOutput(session, Collections.singletonList(tokenId), position);
}
public AbstractTensor generateNextOutput(UUID session, List tokenIds, int startPosition) {
Preconditions.checkArgument(generators.size() == workerCount, "Missing workers %d", workers.size());
ByteString sid = ByteString.copyFrom(
ByteBuffer.allocate(128).putLong(session.getMostSignificantBits()).putLong(session.getLeastSignificantBits()).flip()
);
GenerateResponse gr = GenerateResponse.newBuilder()
.setSession(sid)
.addAllTokens(tokenIds)
.setStartPosition(startPosition)
.build();
for (Generator g : generators) {
g.registerLatch(session);
g.responseObserver.onNext(gr);
}
AbstractTensor output = model.makeDenseTensor(model.getConfig().embeddingLength);
for (int j = 0; j < workerCount; j++) {
Generator g = generators.get(j);
ByteString v = g.waitForOutput(session);
RegisterResponse r = workers.get(g.workerId);
if (j == 0) {
output.getMemorySegment().copyFrom(MemorySegment.ofBuffer(v.asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN)));
}
}
// logger.info("Received output from worker {}", TensorOperationsProvider.get().sum(output));
return output;
}
}
class Generator implements StreamObserver {
private static final Logger logger = LoggerFactory.getLogger(Generator.class);
private volatile UUID workerId;
private CountDownLatch readyLatch;
private final StreamObserver responseObserver;
private final ConcurrentMap outputs;
private final ConcurrentMap outputLatches;
public Generator(StreamObserver responseObserver) {
this.workerId = null;
this.readyLatch = new CountDownLatch(1);
this.responseObserver = responseObserver;
this.outputs = new ConcurrentHashMap<>();
this.outputLatches = new ConcurrentHashMap<>();
}
@Override
public void onNext(GenerateRequest generateRequest) {
if (workerId == null) {
ByteBuffer bb = generateRequest.getWorkerid().asReadOnlyByteBuffer();
workerId = new UUID(bb.getLong(), bb.getLong());
readyLatch.countDown();
logger.info("Worker {} ready", workerId);
return;
}
ByteBuffer bb = generateRequest.getSession().asReadOnlyByteBuffer();
UUID session = new UUID(bb.getLong(), bb.getLong());
if (outputs.containsKey(session)) {
logger.error("Previous output not consumed from worker {}", workerId);
}
outputs.put(session, generateRequest.getTensor());
if (outputLatches.containsKey(session)) {
outputLatches.get(session).countDown();
} else {
logger.error("No latch registered for session {}", session);
}
}
public void registerLatch(UUID session) {
outputLatches.put(session, new CountDownLatch(1));
}
public ByteString waitForOutput(UUID session) {
CountDownLatch latch = outputLatches.get(session);
if (latch == null) throw new RuntimeException("No latch registered for session " + session);
Uninterruptibles.awaitUninterruptibly(latch);
ByteString output = outputs.get(session);
if (output == null) throw new RuntimeException("No output received for session " + session);
outputs.remove(session);
return output;
}
@Override
public void onError(Throwable throwable) {
logger.error("Error encountered from worker {}", workerId, throwable);
}
@Override
public void onCompleted() {
logger.info("Worker {} completed", workerId);
}
public CountDownLatch isReady() {
return readyLatch;
}
}
}