All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.tjake.jlama.model.DistributedContext Maven / Gradle / Ivy

/*
 * Copyright 2024 T Jake Luciani
 *
 * The Jlama Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package com.github.tjake.jlama.model;

import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.tensor.AbstractTensor;

import java.util.List;
import java.util.function.Consumer;

public class DistributedContext {

    private final Config c;
    private final int modelShard;
    private final int numModelShards;
    private final int layerShard;
    private final int numLayerShards;

    private final Consumer> tensorSync;

    // Suppliers to store values that chance when offset is adjusted
    public final int embeddingSegmentStart;
    public final int embeddingSegmentLength;
    public final int embeddingSegmentEnd;

    public final int attentionSegmentStart;
    public final int attentionSegmentLength;
    public final int attentionSegmentEnd;

    public final int hiddenSegmentStart;
    public final int hiddenSegmentLength;
    public final int hiddenSegmentEnd;

    public final int kvSegmentStart;
    public final int kvSegmentLength;
    public final int kvSegmentEnd;

    public final int headStart;
    public final int headEnd;
    public final int groupHeadStart;
    public final int groupHeadEnd;

    public final int numberOfLayers;
    public final int layerStart;
    public final int layerEnd;

    private DistributedContext(
        Config c,
        int modelShard,
        int numModelShards,
        int layerShard,
        int numLayerShards,
        Consumer> tensorSync
    ) {
        this.c = c;
        this.modelShard = modelShard;
        this.numModelShards = numModelShards;
        this.layerShard = layerShard;
        this.numLayerShards = numLayerShards;
        this.tensorSync = tensorSync;

        this.numberOfLayers = c.numberOfLayers / numLayerShards;
        this.layerStart = numberOfLayers * layerShard;
        this.layerEnd = layerStart + numberOfLayers;

        this.embeddingSegmentLength = c.embeddingLength / numModelShards;
        this.embeddingSegmentStart = embeddingSegmentLength * modelShard;
        this.embeddingSegmentEnd = embeddingSegmentStart + embeddingSegmentLength;

        this.attentionSegmentLength = c.attentionLength / numModelShards;
        this.attentionSegmentStart = attentionSegmentLength * modelShard;
        this.attentionSegmentEnd = attentionSegmentStart + attentionSegmentLength;

        this.hiddenSegmentLength = c.hiddenLength / numModelShards;
        this.hiddenSegmentStart = hiddenSegmentLength * modelShard;
        this.hiddenSegmentEnd = hiddenSegmentStart + hiddenSegmentLength;

        this.kvSegmentStart = embeddingSegmentStart / c.headGroupSize;
        this.kvSegmentEnd = embeddingSegmentEnd / c.headGroupSize;
        this.kvSegmentLength = embeddingSegmentLength / c.headGroupSize;

        this.headStart = embeddingSegmentStart / c.headSize;
        this.headEnd = embeddingSegmentEnd / c.headSize;
        this.groupHeadStart = kvSegmentStart / c.headSize;
        this.groupHeadEnd = kvSegmentEnd / c.headSize;
    }

    public boolean hasModelShard() {
        return numModelShards > 1;
    }

    public void syncTensors(List tensors) {
        tensorSync.accept(tensors);
    }

    public int getShardOffsetForLength(int length) {
        return length / numModelShards * modelShard;
    }

    public int getShardLength(int length) {
        return length / numModelShards;
    }

    public static Builder builder(Config c) {
        return new Builder(c);
    }

    public static class Builder {
        private Config c;
        private int modelShard = 0;
        private int numModelShards = 1;
        private int layerShard = 0;
        private int numLayerShards = 1;
        private Consumer> tensorSync;

        public Builder(Config c) {
            this.c = c;
        }

        public Builder setModelShard(int modelShard) {
            this.modelShard = modelShard;
            return this;
        }

        public Builder setNumModelShards(int numModelShards) {
            this.numModelShards = numModelShards;
            return this;
        }

        public Builder setLayerShard(int layerShard) {
            this.layerShard = layerShard;
            return this;
        }

        public Builder setNumLayerShards(int numLayerShards) {
            this.numLayerShards = numLayerShards;
            return this;
        }

        public Builder setTensorSync(Consumer> tensorSync) {
            this.tensorSync = tensorSync;
            return this;
        }

        public DistributedContext build() {
            return new DistributedContext(c, modelShard, numModelShards, layerShard, numLayerShards, tensorSync);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy