All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.nn.transformer.BertMaskedLanguageModelBlock Maven / Gradle / Ivy

/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

import java.util.Arrays;
import java.util.function.Function;

/** Block for the bert masked language task. */
public class BertMaskedLanguageModelBlock extends AbstractBlock {

    private static final byte VERSION = 1;

    private Linear sequenceProjection;

    private BatchNorm sequenceNorm;

    private Parameter dictionaryBias;

    private Function hiddenActivation;

    /**
     * Creates a new block that applies the masked language task.
     *
     * @param bertBlock the bert block to create the task for
     * @param hiddenActivation the activation to use for the hidden layer
     */
    @SuppressWarnings("this-escape")
    public BertMaskedLanguageModelBlock(
            BertBlock bertBlock, Function hiddenActivation) {
        super(VERSION);
        this.sequenceProjection =
                addChildBlock(
                        "sequenceProjection",
                        Linear.builder()
                                .setUnits(bertBlock.getEmbeddingSize())
                                .optBias(true)
                                .build());
        this.sequenceNorm = addChildBlock("sequenceNorm", BatchNorm.builder().optAxis(1).build());
        this.dictionaryBias =
                addParameter(
                        Parameter.builder()
                                .setName("dictionaryBias")
                                .setType(Parameter.Type.BIAS)
                                .optShape(new Shape(bertBlock.getTokenDictionarySize()))
                                .build());
        this.hiddenActivation = hiddenActivation;
    }

    /**
     * Given a 3D array of shape (B, S, E) and a 2D array of shape (B, I) returns the flattened
     * lookup result of shape (B * I * E).
     *
     * @param sequences Sequences of embeddings
     * @param indices Indices into the sequences. The indices are relative within each sequence,
     *     i.e. [[0, 1],[0, 1]] would return the first two elements of two sequences.
     * @return The flattened result of gathering elements from the sequences
     */
    public static NDArray gatherFromIndices(NDArray sequences, NDArray indices) {
        int batchSize = (int) sequences.getShape().get(0);
        int sequenceLength = (int) sequences.getShape().get(1);
        int width = (int) sequences.getShape().get(2);
        int indicesPerSequence = (int) indices.getShape().get(1);
        // this creates a list of offsets for each sequence. Say sequence length is 16 and
        // batch size is 4, this creates [0, 16, 32, 48]. Each
        NDArray sequenceOffsets =
                indices.getManager()
                        .newSubManager(indices.getDevice())
                        .arange(0, batchSize) // [0, 1, 2, ..., batchSize - 1]
                        .mul(sequenceLength) // [0, 16, 32, ...]
                        .reshape(batchSize, 1); // [[0], [16], [32], ...]
        // The following adds the sequence offsets to every index for every sequence.
        // This works, because the single values in the sequence offsets are propagated
        NDArray absoluteIndices =
                indices.add(sequenceOffsets).reshape(1, (long) batchSize * indicesPerSequence);
        // Now we create one long sequence by appending all sequences
        NDArray flattenedSequences = sequences.reshape((long) batchSize * sequenceLength, width);
        // We use the absolute indices to gather the elements of the flattened sequences
        return flattenedSequences.gatherNd(absoluteIndices);
    }

    /** {@inheritDoc} */
    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
        inputNames = Arrays.asList("sequence", "maskedIndices", "embeddingTable");
        int width = (int) inputShapes[0].get(2);
        sequenceProjection.initialize(manager, dataType, new Shape(-1, width));
        sequenceNorm.initialize(manager, dataType, new Shape(-1, width));
    }

    /** {@inheritDoc} */
    @Override
    protected NDList forwardInternal(
            ParameterStore ps, NDList inputs, boolean training, PairList params) {
        NDArray sequenceOutput = inputs.get(0); // (B, S, E)
        NDArray maskedIndices = inputs.get(1); // (B, I)
        NDArray embeddingTable = inputs.get(2); // (D, E)
        try (NDManager scope = NDManager.subManagerOf(sequenceOutput)) {
            scope.tempAttachAll(sequenceOutput, maskedIndices);
            NDArray gatheredTokens = gatherFromIndices(sequenceOutput, maskedIndices); // (B * I, E)
            NDArray projectedTokens =
                    hiddenActivation.apply(
                            sequenceProjection
                                    .forward(ps, new NDList(gatheredTokens), training)
                                    .head()); // (B * I, E)
            NDArray normalizedTokens =
                    sequenceNorm
                            .forward(ps, new NDList(projectedTokens), training)
                            .head(); // (B * I, E)
            // raw logits for each position to correspond to an entry in the embedding table
            NDArray embeddingTransposed = embeddingTable.transpose();
            embeddingTransposed.attach(gatheredTokens.getManager());
            NDArray logits = normalizedTokens.dot(embeddingTransposed); // (B * I, D)
            // we add an offset for each dictionary entry
            NDArray logitsWithBias =
                    logits.add(
                            ps.getValue(
                                    dictionaryBias, logits.getDevice(), training)); // (B * I, D)
            // now we apply log Softmax to get proper log probabilities
            NDArray logProbs = logitsWithBias.logSoftmax(1); // (B * I, D)

            return scope.ret(new NDList(logProbs));
        }
    }

    /** {@inheritDoc} */
    @Override
    public Shape[] getOutputShapes(final Shape[] inputShapes) {
        int batchSize = (int) inputShapes[0].get(0);
        int indexCount = (int) inputShapes[1].get(1);
        int dictionarySize = (int) inputShapes[2].get(0);
        return new Shape[] {new Shape((long) batchSize * indexCount, dictionarySize)};
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy