All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.nn.transformer.BertBlock Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.norm.Dropout;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Implements the core bert model (without next sentence and masked language task) of bert.
 *
 * 

This closely follows the original Devlin et. * al. paper and its reference implementation. */ // We name local variables for tensor dimensions as in the paper and the reference code. // While against the general code style, it makes things much easier readable here. @SuppressWarnings({ "LocalFinalVariableName", "PMD.LocalVariableNamingConventions", "ParameterName", "PMD.FormalParameterNamingConventions" }) public final class BertBlock extends AbstractBlock { private static final byte VERSION = 1; private static final String PARAM_POSITION_EMBEDDING = "positionEmbedding"; private int embeddingSize; private int tokenDictionarySize; private int typeDictionarySize; private IdEmbedding tokenEmbedding; private IdEmbedding typeEmbedding; private Parameter positionEmebdding; private BatchNorm embeddingNorm; private Dropout embeddingDropout; private List transformerEncoderBlocks; private Linear pooling; private BertBlock(Builder builder) { super(VERSION); this.embeddingSize = builder.embeddingSize; // embedding for the input tokens this.tokenEmbedding = addChildBlock( "tokenEmbedding", new IdEmbedding.Builder() .setEmbeddingSize(builder.embeddingSize) .setDictionarySize(builder.tokenDictionarySize) .build()); this.tokenDictionarySize = builder.tokenDictionarySize; // embedding for the position this.positionEmebdding = addParameter( Parameter.builder() .setName(PARAM_POSITION_EMBEDDING) .setType(Parameter.Type.WEIGHT) .optShape( new Shape(builder.maxSequenceLength, builder.embeddingSize)) .build()); // embedding for the input types this.typeEmbedding = addChildBlock( "typeEmbedding", new IdEmbedding.Builder() .setEmbeddingSize(builder.embeddingSize) .setDictionarySize(builder.typeDictionarySize) .build()); this.typeDictionarySize = builder.typeDictionarySize; // normalizer for the embeddings this.embeddingNorm = addChildBlock("embeddingNorm", BatchNorm.builder().optAxis(2).build()); // dropout to apply after embedding normalization this.embeddingDropout = addChildBlock( "embeddingDropout", Dropout.builder().optRate(builder.hiddenDropoutProbability).build()); // the transformer blocks this.transformerEncoderBlocks = new ArrayList<>(builder.transformerBlockCount); for (int i = 0; i < builder.transformerBlockCount; ++i) { this.transformerEncoderBlocks.add( addChildBlock( "transformer_" + i, new TransformerEncoderBlock( builder.embeddingSize, builder.attentionHeadCount, builder.hiddenSize, 0.1f, Activation::gelu))); } // add projection for pooling layer this.pooling = addChildBlock( "poolingProjection", Linear.builder().setUnits(builder.embeddingSize).optBias(true).build()); } /** * Returns the token embedding used by this Bert model. * * @return the token embedding used by this Bert model */ public IdEmbedding getTokenEmbedding() { return this.tokenEmbedding; } /** * Returns the embedding size used for tokens. * * @return the embedding size used for tokens */ public int getEmbeddingSize() { return embeddingSize; } /** * Returns the size of the token dictionary. * * @return the size of the token dictionary */ public int getTokenDictionarySize() { return tokenDictionarySize; } /** * Returns the size of the type dictionary. * * @return the size of the type dictionary */ public int getTypeDictionarySize() { return typeDictionarySize; } /** {@inheritDoc} */ @Override public Shape[] getOutputShapes(Shape[] inputShapes) { long batch = inputShapes[0].get(0); long seqLength = inputShapes[0].get(1); return new Shape[] { new Shape(batch, seqLength, embeddingSize), new Shape(batch, embeddingSize) }; } /** {@inheritDoc} */ @Override public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { super.beforeInitialize(inputShapes); inputNames = Arrays.asList("tokenIds", "typeIds", "masks"); Shape[] tokenShape = {inputShapes[0]}; Shape[] typeShape = {inputShapes[1]}; this.tokenEmbedding.initialize(manager, dataType, tokenShape); Shape[] embeddingOutput = this.tokenEmbedding.getOutputShapes(tokenShape); this.typeEmbedding.initialize(manager, dataType, typeShape); this.embeddingNorm.initialize(manager, dataType, embeddingOutput); this.embeddingDropout.initialize(manager, dataType, embeddingOutput); for (final TransformerEncoderBlock tb : transformerEncoderBlocks) { tb.initialize(manager, dataType, embeddingOutput); } long batchSize = inputShapes[0].get(0); this.pooling.initialize(manager, dataType, new Shape(batchSize, embeddingSize)); } /** * Creates a 3D attention mask from a 2D tensor mask. * * @param ids 2D Tensor of shape (B, F) * @param mask 2D Tensor of shape (B, T) * @return float tensor of shape (B, F, T) */ public static NDArray createAttentionMaskFromInputMask(NDArray ids, NDArray mask) { long batchSize = ids.getShape().get(0); long fromSeqLength = ids.getShape().get(1); long toSeqLength = mask.getShape().get(1); // we ignore the actual content of the ids, we just create a "pseudo-mask" of ones for them NDArray broadcastOnes = ids.onesLike().toType(DataType.FLOAT32, false).reshape(batchSize, fromSeqLength, 1); // add empty dimension to multiply with broadcasted ones NDArray mask3D = mask.toType(DataType.FLOAT32, false).reshape(batchSize, 1, toSeqLength); return broadcastOnes.matMul(mask3D); } /** {@inheritDoc} */ @Override protected NDList forwardInternal( ParameterStore ps, NDList inputs, boolean training, PairList params) { // First input are the tokens. NDArray tokenIds = inputs.get(0); // Second are the token types (first / second sentence). NDArray typeIds = inputs.get(1); // Third are the masks for the input NDArray masks = inputs.get(2); NDManager initScope = NDManager.subManagerOf(tokenIds); initScope.tempAttachAll(inputs); // Create embeddings for inputs NDArray embeddedTokens = tokenEmbedding.forward(ps, new NDList(tokenIds), training).singletonOrThrow(); NDArray embeddedTypes = typeEmbedding.forward(ps, new NDList(typeIds), training).singletonOrThrow(); NDArray embeddedPositions = ps.getValue(positionEmebdding, tokenIds.getDevice(), training); // Merge them to one embedding by adding them // (We can just add the position embedding, even though it does not have a batch dimension: // the tensor is automagically "broadcast" i.e. repeated in the batch dimension. That // gives us the result we want: every embedding gets the same position embedding added // to it) NDArray embedding = embeddedTokens.add(embeddedTypes).add(embeddedPositions); // Apply normalization NDList normalizedEmbedding = embeddingNorm.forward(ps, new NDList(embedding), training); NDList dropoutEmbedding = embeddingDropout.forward(ps, normalizedEmbedding, training); // create 3D attention mask NDArray attentionMask = createAttentionMaskFromInputMask(tokenIds, masks); Shape maskShape = attentionMask.getShape(); NDArray offsetMask = attentionMask .reshape(maskShape.get(0), 1, maskShape.get(1), maskShape.get(2)) .toType(DataType.FLOAT32, false) .mul(-1f) // turn 1 into -1 .add(1f) // turn 0s to 1s, -1s to 0s .mul(-100000f); // turn 1s (original 0s) into -100000 // Run through all transformer blocks NDList lastOutput = dropoutEmbedding; initScope.ret(lastOutput); initScope.ret(offsetMask); initScope.close(); for (final TransformerEncoderBlock block : transformerEncoderBlocks) { NDList input = new NDList(lastOutput.head(), offsetMask); try (NDManager innerScope = NDManager.subManagerOf(input)) { innerScope.tempAttachAll(input); lastOutput = innerScope.ret(block.forward(ps, input, training)); } } // We also return the pooled output - this is an additional fully connected layer // only applied to the first token, assumed to be the CLS token to be used for training // classifiers. shape = (B, E) We apply a tanh activation to it. NDArray firstToken = lastOutput.head().get(new NDIndex(":,1,:")).squeeze(); NDArray pooledFirstToken = pooling.forward(ps, new NDList(firstToken), training).head().tanh(); lastOutput.add(pooledFirstToken); return lastOutput; } /** * Returns a new BertBlock builder. * * @return a new BertBlock builder. */ public static Builder builder() { return new Builder(); } /** The Builder to construct a {@link BertBlock} type of {@link Block}. */ public static final class Builder { int tokenDictionarySize; int typeDictionarySize = 16; int embeddingSize = 768; int transformerBlockCount = 12; int attentionHeadCount = 12; int hiddenSize = 4 * embeddingSize; float hiddenDropoutProbability = 0.1f; // float attentionDropoutProbability = 0.1f; int maxSequenceLength = 512; // float initializerRange = 0.02f; private Builder() {} /** * Sets the number of tokens in the dictionary. * * @param tokenDictionarySize the number of tokens in the dictionary * @return this builder */ public Builder setTokenDictionarySize(int tokenDictionarySize) { this.tokenDictionarySize = tokenDictionarySize; return this; } /** * Sets the number of possible token types. This should be a very small number (2-16). * * @param typeDictionarySize the number of possible token types. This should be a very small * number (2-16) * @return this builder */ public Builder optTypeDictionarySize(int typeDictionarySize) { this.typeDictionarySize = typeDictionarySize; return this; } /** * Sets the embedding size to use for input tokens. This size must be divisible by the * number of attention heads. * * @param embeddingSize the embedding size to use for input tokens. * @return this builder */ public Builder optEmbeddingSize(int embeddingSize) { this.embeddingSize = embeddingSize; return this; } /** * Sets the number of transformer blocks to use. * * @param transformerBlockCount the number of transformer blocks to use * @return this builder */ public Builder optTransformerBlockCount(int transformerBlockCount) { this.transformerBlockCount = transformerBlockCount; return this; } /** * Sets the number of attention heads to use in each transformer block. This number must * divide the embedding size without rest. * * @param attentionHeadCount the number of attention heads to use in each transformer block. * @return this builder */ public Builder optAttentionHeadCount(int attentionHeadCount) { this.attentionHeadCount = attentionHeadCount; return this; } /** * Sets the size of the hidden layers in the fully connected networks used. * * @param hiddenSize the size of the hidden layers in the fully connected networks used. * @return this builder */ public Builder optHiddenSize(int hiddenSize) { this.hiddenSize = hiddenSize; return this; } /** * Sets the dropout probabilty in the hidden fully connected networks. * * @param hiddenDropoutProbability the dropout probabilty in the hidden fully connected * networks. * @return this builder */ public Builder optHiddenDropoutProbability(float hiddenDropoutProbability) { this.hiddenDropoutProbability = hiddenDropoutProbability; return this; } /** * Sets the maximum sequence length this model can process. Memory and compute requirements * of the attention mechanism is O(n²), so large values can easily exhaust your GPU memory! * * @param maxSequenceLength the maximum sequence length this model can process. * @return this builder */ public Builder optMaxSequenceLength(int maxSequenceLength) { this.maxSequenceLength = maxSequenceLength; return this; } /** * Tiny config for testing on laptops. * * @return this builder */ public Builder nano() { typeDictionarySize = 2; embeddingSize = 256; transformerBlockCount = 4; attentionHeadCount = 4; hiddenSize = 4 * embeddingSize; hiddenDropoutProbability = 0.1f; // attentionDropoutProbability = 0.1f; maxSequenceLength = 128; // initializerRange = 0.02f; return this; } /** * Sets this builder's params to a minimal configuration that nevertheless performs quite * well. * * @return this builder */ public Builder micro() { typeDictionarySize = 2; embeddingSize = 512; transformerBlockCount = 12; attentionHeadCount = 8; hiddenSize = 4 * embeddingSize; hiddenDropoutProbability = 0.1f; // attentionDropoutProbability = 0.1f; maxSequenceLength = 128; // initializerRange = 0.02f; return this; } /** * Sets this builder's params to the BASE config of the original BERT paper. (except for the * dictionary size) * * @return this builder */ public Builder base() { typeDictionarySize = 16; embeddingSize = 768; transformerBlockCount = 12; attentionHeadCount = 12; hiddenSize = 4 * embeddingSize; hiddenDropoutProbability = 0.1f; // attentionDropoutProbability = 0.1f; maxSequenceLength = 256; // initializerRange = 0.02f; return this; } /** * Sets this builder's params to the LARGE config of the original BERT paper. (except for * the dictionary size) * * @return this builder */ public Builder large() { typeDictionarySize = 16; embeddingSize = 1024; transformerBlockCount = 24; attentionHeadCount = 16; hiddenSize = 4 * embeddingSize; hiddenDropoutProbability = 0.1f; // attentionDropoutProbability = 0.1f; maxSequenceLength = 512; // initializerRange = 0.02f; return this; } /** * Returns a new BertBlock with the parameters of this builder. * * @return a new BertBlock with the parameters of this builder. */ public BertBlock build() { if (tokenDictionarySize == 0) { throw new IllegalArgumentException("You must specify the dictionary size."); } return new BertBlock(this); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy