All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.nn.transformer.ScaledDotProductAttentionBlock Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

/**
 * A Block implementing scaled product attention according to Vaswani et. al..
 *
 * 

Abbreviations used: * *

    *
  • E = embedding size *
  • B = batch size *
  • N = number of attention heads *
  • F = "from" sequence length (key/value sequence), the input sequence *
  • T = "to" sequence length (query sequence), e.g. the length of the output sequence *
  • S = a sequence length, either F or T *
  • H = Attention head size (= E / N) *
* *

In many use cases F=T. For self attention, the input is equal to the output. * *

This block can process input in four forms: * *

    *
  • Input size one: [Values] = [(B, F, E)], only input is used as key, query and value * (unmasked self attention), e.g. BERT *
  • Input size two: [Values, Mask] = [(B, F, E), (B, F, F)], first input is used as key, query * and value, masked self attention *
  • Input size three: [Keys, Queries, Values] = [(B, F, E), (B, T, E), (B, F, E)], inputs are * interpreted as keys, queries and values, unmasked attention *
  • Input size four: [Keys, Queries, Values, Mask] = [(B, F, E), (B, T, E), (B, F, E), (B, T, * F)], inputs are interpreted as keys, queries, values and an attention mask, full masked * attention. *
* *

Attention masks must contain a 1 for positions to keep and a 0 for positions to mask. */ // We name local variables for tensor dimensions as in the paper and the reference code. // While against the general code style, it makes things much easier readable here. @SuppressWarnings({ "LocalVariableName", "PMD.LocalVariableNamingConventions", "ParameterName", "PMD.FormalParameterNamingConventions" }) public final class ScaledDotProductAttentionBlock extends AbstractBlock { private static final byte VERSION = 1; /** Size of the Word-/Token-embeddings we use the attention on. */ private int embeddingSize; /** Number of attention heads. */ private int headCount; /** Pointwise Linear projection of the keys. */ private Linear keyProjection; /** Pointwise Linear projection of the queries. */ private Linear queryProjection; /** Pointwise Linear projection of the values. */ private Linear valueProjection; /** Pointwise Linear projection of the results. */ private Linear resultProjection; /** Dropout operation to be applied after probability calculation. */ private Dropout attentionProbsDropout; private ScaledDotProductAttentionBlock(Builder builder) { super(VERSION); this.embeddingSize = builder.embeddingSize; this.headCount = builder.headCount; this.keyProjection = addChildBlock("keyProjection", buildProjection()); this.queryProjection = addChildBlock("queryProjection", buildProjection()); this.valueProjection = addChildBlock("valueProjection", buildProjection()); this.resultProjection = addChildBlock("resultProjection", buildProjection()); this.attentionProbsDropout = addChildBlock( "probabilityDropout", Dropout.builder() .optProbability(builder.attentionProbsDropoutProb) .build()); } /** * Helper method to build a pointwise linear projection for the current embedding size. * * @return a linear projection with bias and an output size equal to the embedding size. */ private Linear buildProjection() { return Linear.builder() .setOutChannels(embeddingSize) .optBias(true) .optFlatten(false) .build(); } /** * Pointwise Linear projection of the keys. * * @return Pointwise Linear projection of the keys. */ public Linear getKeyProjection() { return keyProjection; } /** * Pointwise Linear projection of the queries. * * @return Pointwise Linear projection of the queries. */ public Linear getQueryProjection() { return queryProjection; } /** * Pointwise Linear projection of the values. * * @return Pointwise Linear projection of the values. */ public Linear getValueProjection() { return valueProjection; } /** * Pointwise Linear projection of the results. * * @return Pointwise Linear projection of the results. */ public Linear getResultProjection() { return resultProjection; } @Override public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { // Return shape is the shape of the query. For 2 or less inputs we have self-attention, i.e. // the shape of the output is the shape of the input if (inputShapes.length == 1 || inputShapes.length == 2) { return new Shape[] {inputShapes[0]}; } else if (inputShapes.length == 3 || inputShapes.length == 4) { // For attention with a dedicated query, the output shape is the query shape return new Shape[] {inputShapes[1]}; } else { throw new IllegalArgumentException( "Invalid number of input shapes: " + inputShapes.length + ", must be 1-4."); } } @Override public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { // The lookups are fed reshaped input where the batch size is combined with the sequence // length. // The linear layers only care about the 2nd dimension, so we set the first to -1. final Shape projectionShape = new Shape(-1L, embeddingSize); // We initialize the lookup with that reshaped input shape for (final Block projection : children.values()) { projection.initialize(manager, DataType.FLOAT32, projectionShape); } } /** * Utility function to reshape and transpose an input of the shape (B, S, E) into (B, N, S, H). * * @param projection projected embeddings * @param B batch size * @param S sequence length * @param N number of attention heads * @param H size of attention heads * @return the reshaped input */ private NDArray createAttentionHeadsFromEmbeddings( NDArray projection, long B, long S, long N, long H) { // Reshape projection to sequence & heads: (B, S, E) -> (B, S, N, H) NDArray sequenceAndHeads = projection.reshape(B, S, N, H); // Switch sequence idx & head index, so we have sequences of heads at the end return sequenceAndHeads.transpose(0, 2, 1, 3); } @Override public NDList forward( ParameterStore parameterStore, NDList inputs, boolean training, PairList params) { // E=embedding size long E = embeddingSize; // B=batch size long B = inputs.head().getShape().get(0); // N=number of attention heads long N = headCount; // F=from sequence length long F; // T=to sequence length long T; // H=Attention head size (= E / N) long H = E / N; // Create key, query & value input based on input size NDList flattenedKeyInput; NDList flattenedQueryInput; NDList flattenedValueInput; NDArray attentionMask; if (inputs.size() < 3) { // self attention, either masked or unmasked F = inputs.head().getShape().get(1); T = F; flattenedKeyInput = new NDList(inputs.head()); flattenedQueryInput = flattenedKeyInput; flattenedValueInput = flattenedKeyInput; } else { // attention with separate key, query & value F = inputs.get(0).getShape().get(1); T = inputs.get(1).getShape().get(1); flattenedKeyInput = new NDList(inputs.get(0)); flattenedQueryInput = new NDList(inputs.get(1)); flattenedValueInput = new NDList(inputs.get(2)); } if (inputs.size() == 2 || inputs.size() == 4) { // we have an additional attention mask attentionMask = inputs.get(inputs.size() - 1); } else { attentionMask = null; } // apply projection for key, query and value, preserves shape: (B, S, E) NDList keys = keyProjection.forward(parameterStore, flattenedKeyInput, training, params); NDList queries = queryProjection.forward(parameterStore, flattenedQueryInput, training, params); NDList values = valueProjection.forward(parameterStore, flattenedValueInput, training, params); // reshape to (B, N, S, H) to create separate attention heads NDArray keyHeads = createAttentionHeadsFromEmbeddings(keys.head(), B, F, N, H); NDArray queryHeads = createAttentionHeadsFromEmbeddings(queries.head(), B, T, N, H); NDArray valueHeads = createAttentionHeadsFromEmbeddings(values.head(), B, F, N, H); // Apply attention by multiplying the key and query vectors: (B, N, T, F) // (For each entry in the sequence there is a weight for each other head in the sequence) NDArray attentionScores = queryHeads.matMul(keyHeads.transpose(0, 1, 3, 2)); // Normalize the scores with 1/sqrt(H) NDArray normalizedAttentionScores = attentionScores.mul(attentionScores.getManager().create(1f / (float) Math.sqrt(H))); // Apply masking if requested, mask has shape (B, T, F) if (attentionMask != null) { // expand mask to be used on all heads at once NDArray expandedMask = attentionMask.reshape(B, 1, T, F); // we turn the mask from ints into floats and turn all 1s into 0s and all // 0s int o a value of -10000. Adding this to the scores will push all unwanted // values towards -inf and keep the unmasked values unchanged NDArray maskOffset = expandedMask .toType(DataType.FLOAT32, false) .mul(expandedMask.getManager().create(-1f)) // turn 1 into -1 .add(expandedMask.getManager().create(1f)) // turn 0s to 1s, -1s to 0s .mul( expandedMask .getManager() .create(-100000f)); // turn 1s (original 0s) into // -100000 // adding the mask to the scores removes the scores of unwanted positions normalizedAttentionScores = normalizedAttentionScores.add(maskOffset); } // Then apply softmax to get a probability distribution, shape (B, N, T, F) NDArray attentionProbs = normalizedAttentionScores.softmax(3); // We apply dropout to the attention probabilities - this will remove entire tokens from the // result of a position, as their probability will be set to 0 NDArray attentionProbsAfterDropout = attentionProbsDropout .forward(parameterStore, new NDList(attentionProbs), training) .singletonOrThrow(); // The result of the attention mechanism is created by a weighted sum using the attention // probs. The new head is the weighted sum of the value heads. (B, N, T, H) NDArray attentionResult = attentionProbsAfterDropout.matMul(valueHeads); // Finally, the heads are reshaped and concatenated into an embedding again NDArray resultEmbeddings = attentionResult // (B, N, T, H) .transpose(0, 2, 1, 3) // -> (B, T, N, H) .reshape(B, T, E); // -> (B, T, E) // As a last step, we add another linear projection for each token to the embedding size NDList projectedEmbeddings = resultProjection.forward(parameterStore, new NDList(resultEmbeddings), training); // done! return new NDList(projectedEmbeddings); } /** * Creates a new Builder to build an Attention Block with. * * @return a new Builder to build an Attention Block with. */ public static Builder builder() { return new Builder(); } /** A builder for {@link ScaledDotProductAttentionBlock}s. */ public static final class Builder { private int embeddingSize; private int headCount; private float attentionProbsDropoutProb = 0.1f; private Builder() {} /** * Sets the embedding Size to be used for the internal token representation. * * @param embeddingSize the embedding Size to be used for the internal token representation. * @return this builder */ public Builder setEmbeddingSize(int embeddingSize) { this.embeddingSize = embeddingSize; return this; } /** * Sets the number of attention Heads, must divide the embedding size without rest. I.e. if * embeddingSize = 10, a headCount of 3 would not be valid, a headCount of 1, 2 or 5 would * be. * * @param headCount the number of attention Heads * @return this builder */ public Builder setHeadCount(int headCount) { this.headCount = headCount; return this; } /** * Sets the probability of applying dropout to the attention probability distribution. This * dropout can randomly remove a complete token from the result at a position. * * @param attentionProbsDropoutProb the probability of applying dropout to the attention * probability distribution * @return this builder */ public Builder optAttentionProbsDropoutProb(float attentionProbsDropoutProb) { this.attentionProbsDropoutProb = attentionProbsDropoutProb; return this; } /** * Creates a new {@code ScaledDotProductAttentionBlock} with the current configuration. * * @return a new {@code ScaledDotProductAttentionBlock} with the current configuration. */ public ScaledDotProductAttentionBlock build() { if (embeddingSize < 1) { throw new IllegalStateException("Embedding size not initialized."); } if (headCount < 1) { throw new IllegalStateException("Head count not initialized."); } if (embeddingSize % headCount != 0) { throw new IllegalStateException( "Embedding Size (" + embeddingSize + ") is not divisible by head count (" + headCount + ")"); } return new ScaledDotProductAttentionBlock(this); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy