All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.tjake.jlama.net.grpc.JlamaService Maven / Gradle / Ivy

/*
 * Copyright 2024 T Jake Luciani
 *
 * The Jlama Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package com.github.tjake.jlama.net.grpc;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.net.*;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.util.Pair;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.Uninterruptibles;
import com.google.protobuf.ByteString;
import com.google.protobuf.UnsafeByteOperations;
import io.grpc.stub.StreamObserver;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
import java.util.concurrent.*;

import jdk.incubator.vector.FloatVector;
import org.jctools.queues.MpmcArrayQueue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JlamaService extends JlamaServiceGrpc.JlamaServiceImplBase {
    private static final Logger logger = LoggerFactory.getLogger(JlamaService.class);
    private final AbstractModel model;
    private final int workerCount;
    private final ConcurrentMap workers;

    private final GeneratorGroup generatorGroup;

    private final ConcurrentMap>>> combinations;

    public JlamaService(AbstractModel model, int workerCount) {
        Preconditions.checkArgument(
            workerCount <= model.getConfig().numberOfKeyValueHeads,
            "Worker count must be less than or equal to number of KV heads"
        );
        this.model = model;
        this.workerCount = workerCount;
        this.workers = new ConcurrentHashMap<>();
        this.combinations = new ConcurrentHashMap<>();
        this.generatorGroup = new GeneratorGroup();
    }

    public void waitForReady() {
        while (true) {
            if (generatorGroup.generators.size() == workerCount) {
                generatorGroup.waitForReady();
                return;
            }
            Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS);
        }
    }

    public void shutdown() {
        for (Generator g : generatorGroup.generators) {
            try {
                g.responseObserver.onCompleted();
            } catch (Exception e) {
                logger.debug("Exception when shutting down", e);
            }
        }
    }

    /**
     * Register a worker with the coordinator.  The coordinator will return the offset and length of the embedding that
     * the worker is responsible for.
     */
    @Override
    public void register(RegisterRequest request, StreamObserver responseObserver) {
        synchronized (workers) {
            ByteBuffer bb = request.getWorkerid().asReadOnlyByteBuffer();
            UUID wid = new UUID(bb.getLong(), bb.getLong());
            if (workers.containsKey(wid)) {
                responseObserver.onNext(workers.get(wid));
                responseObserver.onCompleted();
            } else {

                if (workers.size() == workerCount) {
                    responseObserver.onError(new RuntimeException("Not accepting any more workers"));
                    return;
                }

                int workerNum = workers.size();

                RegisterResponse r = RegisterResponse.newBuilder()
                    .setModelShard(workerNum)
                    .setNumModelShards(workerCount)
                    .setLayerShard(0)
                    .setNumLayerShards(1)
                    .build();

                workers.put(wid, r);
                logger.info("Registered worker {} with workerNum {} of {}", wid, workerNum, workerCount);

                responseObserver.onNext(r);
                responseObserver.onCompleted();
            }
        }
    }

    public AbstractTensor generateNextOutput(UUID session, List tokenIds, int startPosition) {
        return generatorGroup.generateNextOutput(session, tokenIds, startPosition);
    }

    public AbstractTensor generateNextOutput(UUID session, int tokenId, int position) {
        return generatorGroup.generateNextOutput(session, tokenId, position);
    }

    @Override
    public StreamObserver generate(StreamObserver responseObserver) {
        Generator generator = new Generator(responseObserver);
        generatorGroup.add(generator);
        logger.info("Added worker {}", generatorGroup.generators.size());
        return generator;
    }

    @Override
    public StreamObserver combine(StreamObserver responseObserver) {

        return new StreamObserver<>() {
            @Override
            public void onNext(CombineRequest request) {
                String key = String.format("%s:%d", UUID.nameUUIDFromBytes(request.getUuid().toByteArray()), request.getLayer());
                MpmcArrayQueue>> members = combinations.computeIfAbsent(
                    key,
                    k -> new MpmcArrayQueue<>(workerCount + 1)
                );
                members.add(Pair.of(request, responseObserver));

                // If we have all the workers, then we can calculate the result and send it back
                if (members.size() == workerCount && combinations.remove(key, members)) {
                    MemorySegment[] tensors = null;
                    for (Pair> f : members) {
                        if (f.left.getTensorCount() > 0) {
                            if (tensors == null) {
                                tensors = new MemorySegment[f.left.getTensorCount()];
                                for (int i = 0; i < tensors.length; i++) {
                                    ByteBuffer bb = ByteBuffer.wrap(f.left.getTensor(i).toByteArray()).order(ByteOrder.LITTLE_ENDIAN);
                                    tensors[i] = MemorySegment.ofBuffer(bb);
                                }
                            } else {
                                for (int i = 0; i < tensors.length; i++) {
                                    MemorySegment ms = MemorySegment.ofBuffer(
                                        f.left.getTensor(i).asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN)
                                    );
                                    // Sum float buffers
                                    accumulateF32(tensors[i], ms, (int) tensors[i].byteSize() / Float.BYTES);
                                }
                            }
                        }
                    }

                    CombineResponse.Builder responseBuilder = CombineResponse.newBuilder();

                    if (tensors != null) {
                        for (int i = 0; i < tensors.length; i++)
                            responseBuilder = responseBuilder.addTensor(
                                UnsafeByteOperations.unsafeWrap(tensors[i].asByteBuffer().order(ByteOrder.LITTLE_ENDIAN))
                            );
                    }

                    CombineResponse response = responseBuilder.build();
                    for (Pair> f : members) {
                        f.right.onNext(response);
                    }

                    members.clear();
                }
            }

            @Override
            public void onError(Throwable throwable) {}

            @Override
            public void onCompleted() {}
        };
    }

    void accumulateF32(MemorySegment a, MemorySegment b, int length) {
        int upperBound = FloatVector.SPECIES_PREFERRED.loopBound(length);
        int i = 0;

        for (; i < upperBound; i += FloatVector.SPECIES_PREFERRED.length()) {
            int fi = i * Float.BYTES;
            FloatVector va = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, a, fi, ByteOrder.LITTLE_ENDIAN);
            FloatVector vb = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, b, fi, ByteOrder.LITTLE_ENDIAN);
            va.add(vb).intoMemorySegment(a, fi, ByteOrder.LITTLE_ENDIAN);
        }

        // tail
        for (; i < length; i++) {
            a.set(ValueLayout.JAVA_FLOAT, i, a.get(ValueLayout.JAVA_FLOAT, i) + b.get(ValueLayout.JAVA_FLOAT, i));
        }
    }

    public class GeneratorGroup {
        private final List generators;

        private GeneratorGroup() {
            this.generators = new ArrayList<>();
        }

        private void add(Generator generator) {
            generators.add(generator);
        }

        public void waitForReady() {
            for (Generator g : generators) {
                Uninterruptibles.awaitUninterruptibly(g.isReady());
            }
        }

        public AbstractTensor generateNextOutput(UUID session, int tokenId, int position) {
            return generateNextOutput(session, Collections.singletonList(tokenId), position);
        }

        public AbstractTensor generateNextOutput(UUID session, List tokenIds, int startPosition) {
            Preconditions.checkArgument(generators.size() == workerCount, "Missing workers %d", workers.size());
            ByteString sid = ByteString.copyFrom(
                ByteBuffer.allocate(128).putLong(session.getMostSignificantBits()).putLong(session.getLeastSignificantBits()).flip()
            );
            GenerateResponse gr = GenerateResponse.newBuilder()
                .setSession(sid)
                .addAllTokens(tokenIds)
                .setStartPosition(startPosition)
                .build();
            for (Generator g : generators) {
                g.registerLatch(session);
                g.responseObserver.onNext(gr);
            }

            AbstractTensor output = model.makeDenseTensor(model.getConfig().embeddingLength);

            for (int j = 0; j < workerCount; j++) {
                Generator g = generators.get(j);
                ByteString v = g.waitForOutput(session);
                RegisterResponse r = workers.get(g.workerId);

                if (j == 0) {
                    output.getMemorySegment().copyFrom(MemorySegment.ofBuffer(v.asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN)));
                }
            }

            // logger.info("Received output from worker {}", TensorOperationsProvider.get().sum(output));

            return output;
        }
    }

    class Generator implements StreamObserver {
        private static final Logger logger = LoggerFactory.getLogger(Generator.class);

        private volatile UUID workerId;
        private CountDownLatch readyLatch;
        private final StreamObserver responseObserver;
        private final ConcurrentMap outputs;
        private final ConcurrentMap outputLatches;

        public Generator(StreamObserver responseObserver) {
            this.workerId = null;
            this.readyLatch = new CountDownLatch(1);
            this.responseObserver = responseObserver;
            this.outputs = new ConcurrentHashMap<>();
            this.outputLatches = new ConcurrentHashMap<>();
        }

        @Override
        public void onNext(GenerateRequest generateRequest) {
            if (workerId == null) {
                ByteBuffer bb = generateRequest.getWorkerid().asReadOnlyByteBuffer();
                workerId = new UUID(bb.getLong(), bb.getLong());
                readyLatch.countDown();
                logger.info("Worker {} ready", workerId);
                return;
            }

            ByteBuffer bb = generateRequest.getSession().asReadOnlyByteBuffer();
            UUID session = new UUID(bb.getLong(), bb.getLong());

            if (outputs.containsKey(session)) {
                logger.error("Previous output not consumed from worker {}", workerId);
            }

            outputs.put(session, generateRequest.getTensor());

            if (outputLatches.containsKey(session)) {
                outputLatches.get(session).countDown();
            } else {
                logger.error("No latch registered for session {}", session);
            }
        }

        public void registerLatch(UUID session) {
            outputLatches.put(session, new CountDownLatch(1));
        }

        public ByteString waitForOutput(UUID session) {
            CountDownLatch latch = outputLatches.get(session);
            if (latch == null) throw new RuntimeException("No latch registered for session " + session);

            Uninterruptibles.awaitUninterruptibly(latch);
            ByteString output = outputs.get(session);
            if (output == null) throw new RuntimeException("No output received for session " + session);

            outputs.remove(session);
            return output;
        }

        @Override
        public void onError(Throwable throwable) {
            logger.error("Error encountered from worker {}", workerId, throwable);
        }

        @Override
        public void onCompleted() {
            logger.info("Worker {} completed", workerId);
        }

        public CountDownLatch isReady() {
            return readyLatch;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy