All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.Attention.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl._
import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity, TensorModule}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.{T, Table}

import scala.language.existentials
import scala.reflect.ClassTag

/**
 * Implementation of multiheaded attention and self-attention layers.
 *
 * @param hiddenSize hidden size
 * @param numHeads heads number
 * @param attentionDropout
 */
class Attention[T: ClassTag](val hiddenSize: Int, val numHeads: Int, val attentionDropout: Float)
  (implicit ev: TensorNumeric[T]) extends AbstractModule[Activity, Activity, T] {

  // for prediction
  private val joinK = nn.JoinTable[T](dimension = 2, nInputDims = -1)
  private val joinV = nn.JoinTable[T](dimension = 2, nInputDims = -1)

  private val queryLayer = TransformerOperation.dense(
    hiddenSize, hiddenSize, false, name = s"${this.getName()}_q")
  private val keyLayer = TransformerOperation.dense(
    hiddenSize, hiddenSize, false, name = s"${this.getName()}_k")
  private val valueLayer = TransformerOperation.dense(
    hiddenSize, hiddenSize, false, name = s"${this.getName()}_v")

  private val querySplitLayer = new SplitHeads(hiddenSize, numHeads, true)
  private val keySplitLayer = new SplitHeads(hiddenSize, numHeads)
  private val valueSplitLayer = new SplitHeads(hiddenSize, numHeads)

  private val contiguousQLayer = new Contiguous[T]()
  private val contiguousKLayer = new Contiguous[T]()
  private val contiguousVLayer = new Contiguous[T]()
  private val matmulLayer = MM(transB = true)
  private val caddLayer = CAddTable()
  private val softMaxLayer = TransformerOperation.softMax[T]()
  private val dropLayer = Dropout(initP = (1.0 - attentionDropout))
  private val matmulNoTransLayer = MM()
  // Recombine heads --> (batch_size, length, hidden_size)
  private val combineHeadsLayer = new CombineHeads()
  // Run the combined outputs through another linear projection layer.
  private val outputLayer = TransformerOperation.dense(
    hiddenSize, hiddenSize, false, name = s"${this.getName()}_output_transform")

  private[bigdl] val model : Module[T] = {
    // InputX with shape (batch_size, length_x, hidden_size).
    // InputY with shape (batch_size, length_x, hidden_size)
    // for self attention, InputX and InputY should be the same.
    // Bias is attention bias that will be added to the result of the dot product.
    val inputX = Input()
    val inputY = Input()
    val inputBias = Input()

    val queryNode = queryLayer.inputs(inputX)
    val keyNode = keyLayer.inputs(inputY)
    val valueNode = valueLayer.inputs(inputY)

    val model = Graph(Array(inputX, inputY, inputBias),
      Array(createModule(queryNode, keyNode, valueNode, inputBias)))
    if (this.train) model.training() else model.evaluate()
  }

  private val graph: Module[T] = {
    val queryNode = Input()
    val keyNode = Input()
    val valueNode = Input()
    val inputBias = Input()

    Graph(Array(queryNode, keyNode, valueNode, inputBias),
      Array(createModule(queryNode, keyNode, valueNode, inputBias)))
  }

  private def createModule(inputQuery: ModuleNode[T], inputKey: ModuleNode[T],
    inputValue: ModuleNode[T], inputBias: ModuleNode[T]) : ModuleNode[T] = {
    val querySplit = querySplitLayer.inputs(inputQuery)
    val keySplit = keySplitLayer.inputs(inputKey)
    val valueSplit = valueSplitLayer.inputs(inputValue)

    val contiguousQ = contiguousQLayer.inputs(querySplit)
    val contiguousK = contiguousKLayer.inputs(keySplit)
    val contiguousV = contiguousVLayer.inputs(valueSplit)

    val matmul = matmulLayer.inputs(contiguousQ, contiguousK)
    val cadd = caddLayer.inputs(matmul, inputBias)
    val softMax = softMaxLayer.inputs(cadd)

    val drop = dropLayer.inputs(softMax)
    val matmulNoTrans = matmulNoTransLayer.inputs(drop, contiguousV)
    // Recombine heads --> (batch_size, length, hidden_size)
    val combineHeads = combineHeadsLayer.inputs(matmulNoTrans)
    outputLayer.inputs(combineHeads)
  }


  private def updateOutputCache(input: Activity): Activity = {
    require(!this.isTraining(), "Only support input cache for model inference")
    val inputTable = input.toTable
    val inputX = inputTable[Tensor[T]](1)
    val inputY = inputTable[Tensor[T]](2)
    val inputBias = inputTable[Table](3).apply[Tensor[T]](1)
    /**
     * cache: (Used during prediction) dictionary with tensors containing results of
     * previous attentions. The dictionary must have the items:
     * {"k": tensor with shape [batch_size, i, key_channels],
     * "v": tensor with shape [batch_size, i, value_channels]}
     * where i is the current decoded length.
     */
    val cache = inputTable[Table](3).apply[Table](2)

    val query = queryLayer.forward(inputX).toTensor[T]

    val (inputK, inputV) = if (cache.length() > 0) {
      (cache.apply[Tensor[T]](this.getName() + "_k"),
        cache.apply[Tensor[T]](this.getName() + "_v"))
    } else (null, null)

    val key = if (inputK != null && !inputK.isEmpty) {
      joinK.forward(T(keyLayer.forward(inputY).toTensor[T], inputK))
    } else keyLayer.forward(inputY).toTensor[T]
    val value = if (inputV != null && !inputV.isEmpty) {
      joinV.forward(T(valueLayer.forward(inputY).toTensor[T], inputV))
    } else valueLayer.forward(inputY).toTensor[T]

    // update cache
    if (cache.length() > 0) {
      cache.update(this.getName() + "_k", key)
      cache.update(this.getName() + "_v", value)
    }
    output = graph.updateOutput(T(query, key, value, inputBias))
    output
  }
  override def updateOutput(input: Activity): Activity = {
    require(input.toTable.length() == 3,
      s"only support 3 inputs, but get ${input.toTable.length()}")

    val cache = input.toTable.apply[Activity](3)
    if (cache.isInstanceOf[Tensor[T]]) {
      output = model.updateOutput(input)
    } else if (cache.isInstanceOf[Table]) {
      output = updateOutputCache(input)
    }
    output
  }

  override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
    gradInput = model.updateGradInput(input, gradOutput)
    gradInput
  }

  override def accGradParameters(input: Activity, gradOutput: Activity): Unit = {
    model.accGradParameters(input, gradOutput)
  }

  override def training(): this.type = {
    train = true
    model.training()
    this
  }

  override def evaluate(): this.type = {
    train = false
    model.evaluate()
    this
  }

  override def getExtraParameter(): Array[Tensor[T]] = {
    model.getExtraParameter()
  }

  override def getTimes(): Array[(AbstractModule[_ <: Activity, _ <: Activity, T], Long, Long)] = {
    model.getTimes()
  }

  override def resetTimes(): Unit = {
    model.resetTimes()
  }

  override def parameters(): (Array[Tensor[T]], Array[Tensor[T]]) = {
    model.parameters()
  }

  override def getParametersTable(): Table = {
    model.getParametersTable()
  }

  override def clearState(): this.type = {
    model.clearState()
    this
  }
}

// Combine tensor that has been splitted.
//  input should be tensor with shape (batch_size, num_heads, length, hidden_size/num_heads)
// output should be tensor with shape (batch_size, length, hidden_size)
private[nn] class CombineHeads[T: ClassTag](implicit ev: TensorNumeric[T])
  extends TensorModule[T] {

  private val permutations: (Int, Int) = (2, 3)

  override def updateOutput(input: Tensor[T]): Tensor[T] = {
    val batchSize = input.size(1)
    val length = input.size(3)
    val hiddenSize = input.size(2) * input.size(4)

    output.resizeAs(input).copy(input)
    output = output.transpose(permutations._1, permutations._2)
      .reshape(Array(batchSize, length, hiddenSize))
    output
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    val size = Array(input.size(1), input.size(3), input.size(2), input.size(4))
    if (gradOutput.isContiguous()) {
      gradInput = gradOutput.view(size)
    } else {
      gradInput = gradOutput.contiguous().view(size)
    }
    gradInput = gradInput.transpose(permutations._1, permutations._2).contiguous()
    gradInput
  }
}

/**
 * Split x into different heads, and transpose the resulting value.
 * The tensor is transposed to insure the inner dimensions hold the correct
 * values during the matrix multiplication.
 * input with shape (batch_size, length, hidden_size)
 * output with shape (batch_size, num_heads, length, hidden_size/num_heads)
 * @param hiddenSize
 * @param numHeads
 * @param mul
 * @tparam T The numeric type in this module parameters
 */
private[nn] class SplitHeads[T: ClassTag](val hiddenSize: Int, val numHeads: Int,
  val mul: Boolean = false)(implicit ev: TensorNumeric[T])
  extends TensorModule[T] {

  private val depth = hiddenSize / numHeads
  private val value = ev.fromType(math.pow(depth, -0.5))
  private val permutations: (Int, Int) = (2, 3)

  override def updateOutput(input: Tensor[T]): Tensor[T] = {
    val batchSize = input.size(1)
    val length = input.size(2)

    output.resizeAs(input).copy(input)
    output = output.reshape(Array(batchSize, length, numHeads, depth))
      .transpose(permutations._1, permutations._2)
    if (mul) {
      output.mul(value)
    }
    output
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    if (mul) {
      gradInput.resizeAs(gradOutput).zero().add(value, gradOutput)
    } else {
      gradInput.resizeAs(gradOutput).copy(gradOutput)
    }
    gradInput = gradInput.transpose(permutations._1, permutations._2).contiguous()
    gradInput.resize(input.size())
    gradInput
  }
}

object Attention {
  def apply[@specialized(Float, Double) T: ClassTag]
  (hiddenSize: Int, numHeads: Int, attentionDropout: Float)
  (implicit ev: TensorNumeric[T]): Attention[T] =
    new Attention(hiddenSize: Int, numHeads: Int, attentionDropout: Float)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy