All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.dllib.nn.Transformer.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.dllib.nn

import breeze.linalg.*
import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.dllib.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.dllib.nn.abstractnn.{AbstractModule, Activity, TensorModule}
import com.intel.analytics.bigdl.serialization.Bigdl.{AttrValue, BigDLModule, DataType}
import com.intel.analytics.bigdl.dllib.tensor.{Storage, Tensor}
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.dllib.utils.serializer.converters.DataConverter
import com.intel.analytics.bigdl.dllib.utils.serializer.{DeserializeContext, ModuleSerializable, ModuleSerializer, SerializeContext}
import com.intel.analytics.bigdl.dllib.utils.{Log4Error, T, Table}
import org.apache.zookeeper.ZooDefs.Ids

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import scala.reflect.runtime._

/**
 * Transformer model from "Attention Is All You Need".
 * The Transformer model consists of an encoder and a decoder, both are stacks
 * of self-attention layers followed by feed-forward layers. This model yields
 * good results on a number of problems, especially in NLP and machine translation.
 * See "Attention Is All You Need" (https://arxiv.org/abs/1706.03762) for the full
 * description of the model and the results obtained with its early version.
 * @param hiddenSize
 * @param numHeads
 * @param filterSize
 * @param numHiddenlayers
 * @param embeddingDropout
 * @param attentionDropout
 * @param ffnDropout
 * @param paddingValue padding value for word embedding, default 0, which means no padding.
 * @param withShareWeightsLinear whether to add linear that sharing weights with embedding layer.
 * @param transformerType transformer type, support LanguageModel and Translation.
 * @tparam T The numeric type in this module parameters.
 */
class Transformer[T: ClassTag](
   val vocabSize: Int,
   val hiddenSize: Int,
   val numHeads: Int,
   val filterSize: Int,
   val numHiddenlayers: Int,
   val embeddingDropout: Float,
   val attentionDropout: Float,
   val ffnDropout: Float,
   val paddingValue: Double = 0,
   val withShareWeightsLinear: Boolean = false,
   val transformerType: TransformerType = LanguageModel,
   val beamSearch: SequenceBeamSearch[T] = null)
  (implicit ev: TensorNumeric[T]) extends AbstractModule[Activity, Activity, T] {

  // for translation layers
  private[bigdl] var decoderStack: Module[T] = null
  private[bigdl] var encoderStack: Module[T] = null
  private[bigdl] var predictModel: Module[T] = null
  private var linearSharedWeigths : Module[T] = null
  // for symbols
  private val rangeBuffer = Tensor[T]()
  private val timeBuffer = Tensor[T]()
  private var decoderBiasBuffer = Tensor[T]()

  private val embeddingLayer = Sequential[T]().add(
    LookupTable[T](vocabSize, hiddenSize, paddingValue = paddingValue,
      maskZero = true).setName("embedding")).add(MulConstant(math.sqrt(hiddenSize)))

  private[bigdl] var model : Module[T] = {
    transformerType match {
      case LanguageModel => buildLM()
      case Translation => buildTranslation()
    }
  }

  private def createDecoder(): Module[T] = {
    val decoderInputNode = Input()
    val decoderSelfAttentionBiasNode = Input()
    val encoderOutputNode = Input()
    val encoderAttentionBiasNode = Input()

    Graph(Array(decoderInputNode, decoderSelfAttentionBiasNode,
      encoderOutputNode, encoderAttentionBiasNode),
      Array(block(numHiddenlayers, decoderInputNode, decoderSelfAttentionBiasNode,
        encoderOutputNode, encoderAttentionBiasNode, blockType = "decoder")))
  }

  private def createEncoder(): Module[T] = {
    val encoderInputNode = Input()
    val encoderAttentionBiasNode = Input()
    Graph(Array(encoderInputNode, encoderAttentionBiasNode),
      Array(block(numHiddenlayers, encoderInputNode, encoderAttentionBiasNode,
        blockType = "encoder")))
  }

  private def buildTranslation(): Module[T] = {
    // init layers
    val mask = new PaddingMask()
    if (linearSharedWeigths == null) {
      linearSharedWeigths = TimeDistributed(new Linear(
        inputSize = hiddenSize, outputSize = vocabSize, withBias = false)).asInstanceOf[Module[T]]
    }
    if (decoderStack == null) decoderStack = createDecoder()
    if (encoderStack == null) encoderStack = createEncoder()

    // input: int tensor with shape [batch_size, input_length].
    val inputNode = Input()
    // target: int tensor with shape [batch_size, target_length].
    val targetNode = Input()
    val attentionBias = mask.inputs(inputNode)
    val join = JoinTable(1, -1).inputs(inputNode, targetNode)
    val embeddingForTrain = embeddingLayer.inputs(join)
    val split = new SplitTensor(1, 2).inputs(embeddingForTrain)
    val embeddingInput = SelectTable(1).inputs(split)
    val embeddingOutput = SelectTable(2).inputs(split)

    // create encode
    val embeddingNode = Input()
    val paddingNode = Input()
    val encoderGraph = Graph(Array(embeddingNode, paddingNode),
      encode(embeddingNode, paddingNode))

    // create predict model
    val predictNode = Input()
    val attentionMask = mask.inputs(predictNode)
    val embeddingForPredict = embeddingLayer.inputs(predictNode)
    predictModel = Graph(predictNode,
      Array(encoderGraph.inputs(embeddingForPredict, attentionMask), attentionMask))

    // init beam search
    if (beamSearch != null) beamSearch.setLogitFn(symbols)

    // create training model
    val outputNode = decode(embeddingOutput,
      encoderGraph.inputs(embeddingInput, attentionBias), attentionBias)
    Graph(Array(inputNode, targetNode), outputNode)
  }

  private def buildLM(): Module[T] = {
    val inputNode = Input()
    val constantValue = math.sqrt(hiddenSize)
    val embeddingInput = MulConstant(constantValue).inputs(
      LookupTable[T](vocabSize, hiddenSize, paddingValue = paddingValue,
        maskZero = true).setName("embedding").inputs(inputNode))

    val decoderInput = new PositionEncodeWithShift().inputs(embeddingInput)
    val decoderSelfAttentionBias = new SelfAttentionMask().inputs(embeddingInput)
    val decoderInputDrop = Dropout(1- embeddingDropout).inputs(decoderInput)

    val outputNode = block(numHiddenlayers, decoderInputDrop,
      decoderSelfAttentionBias, blockType = "decode")
    Graph(inputNode, outputNode)
  }

  private def updateOutputLM(input: Tensor[T]): Tensor[T] = {
    output = model.forward(input)
    if (withShareWeightsLinear) {
      shareWeights(true)
      output = linearSharedWeigths.updateOutput(model.output.toTensor[T])
    }
    output.toTensor[T]
  }

  private def shareWeights(share: Boolean): Unit = {
    if (share) {
      val embeddingParams = embeddingLayer.getParameters()
      val linearParams = linearSharedWeigths.getParameters()
      linearParams._1.copy(embeddingParams._1)
    }
  }

  /**
   * Pass this function to beam search
   * @param Ids
   * @param i index
   * @param maxDecodeLength max decode length
   * @param encoder_outputs output from encoder
   * @param encoder_decoder_attention_bias attention bias
   * @param cacheValue k and v values for attention layers
   * @return
   */
  def symbols(Ids: Tensor[T], i: Int, maxDecodeLength: Int,
              encoder_outputs: Tensor[T], encoder_decoder_attention_bias: Tensor[T],
              cacheValue: Table): (Tensor[T], Table) = {
    val cache = T() // pass to attention layer
    for(m <- 1 to hiddenSize) {
      if (cacheValue.contains(s"layer_${m}_k")) {
        cache.update(s"decoder_self_attention_${m - 1}/self_attention_k",
          cacheValue(s"layer_${m}_k"))
        cache.update(s"decoder_self_attention_${m - 1}/self_attention_v",
          cacheValue(s"layer_${m}_v"))
      }
    }

    val length = maxDecodeLength + 1
    TransformerOperation.initRangeTensor(length, rangeBuffer)
    timeBuffer.resize(length, hiddenSize)
    TransformerOperation.getPositionEncode(length, hiddenSize,
      rangeBuffer = rangeBuffer, outBuffer = timeBuffer)
    val timeSignal = TransformerOperation.getPositionEncode(length, hiddenSize,
      rangeBuffer = rangeBuffer, outBuffer = timeBuffer)

    // size (1, 1, maxDecodeLength, maxDecodeLength)
    if (decoderBiasBuffer == null
      || decoderBiasBuffer.nElement() != maxDecodeLength * maxDecodeLength) {
      decoderBiasBuffer = Tensor[T](1, 1, maxDecodeLength, maxDecodeLength)
    }
    TransformerOperation.attentionBiasLowerTriangle(maxDecodeLength, decoderBiasBuffer)

    val decoder_input = Ids.narrow(2, i + 1, 1)
    val decoder_input_embedding = embeddingLayer.forward(decoder_input).toTensor[T]

    val timeSize = timeSignal.size()
    val timingTemp = timeSignal.select(1, i + 1)
    val decoder_input_add = decoder_input_embedding.add(timingTemp)

    val self_attention_bias = decoderBiasBuffer.select(3, i + 1)
      .select(3, i + 1).resize(Array(1, 1, 1, i + 1))

    val decoder_outputs = decoderStack.forward(T(decoder_input_add,
      T(self_attention_bias, cache), encoder_outputs, encoder_decoder_attention_bias)).toTensor[T]

    shareWeights(withShareWeightsLinear)
    val logits = this.linearSharedWeigths.forward(decoder_outputs).toTensor[T]

    for(m <- 1 to hiddenSize) {
      if (cacheValue.contains(s"layer_${m}_k")) {
        cacheValue.update(s"layer_${m}_k",
          cache(s"decoder_self_attention_${m - 1}/self_attention_k"))
        cacheValue.update(s"layer_${m}_v",
          cache(s"decoder_self_attention_${m - 1}/self_attention_v"))
      }
    }

    (logits.squeeze(2), cacheValue)
  }

  private def updateOutputTranslation(input: Activity): Activity = {
    if (input.isTensor) {
      Log4Error.invalidInputError(!this.isTraining(),
        "Input for Transformer should be tensor when doing translation prediction")
      // inference case, first tensor is encoder_outputs,  another is attention_bias
      val res = predictModel.forward(input).toTable
      beamSearch.forward(T(res[Tensor[T]](1), res[Tensor[T]](2)))
      // output for beamsearch is table, and first tensor is decoder_ids, another is scores
      val decodedIds = beamSearch.output.toTable.apply[Tensor[T]](1).select(2, 1)
      val scores = beamSearch.output.toTable.apply[Tensor[T]](2).select(2, 1)
      output = T(decodedIds.narrow(2, 2, decodedIds.size(2) - 1), scores)
    } else {
      Log4Error.invalidInputError(input.toTable.length() == 2,
        s"Input should be two tensors when doing " +
        s"translation training, but get ${input.toTable.length()}")
      // training case
      output = model.forward(input)
      if (withShareWeightsLinear) {
        shareWeights(true)
        output = linearSharedWeigths.updateOutput(model.output.toTensor[T])
      }
    }
    output
  }

  override def updateOutput(input: Activity): Activity = {
    if (transformerType == Translation) {
      updateOutputTranslation(input)
    } else {
      updateOutputLM(input.toTensor[T])
    }
  }

  override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
    val grad = if (withShareWeightsLinear) {
      linearSharedWeigths.updateGradInput(model.output.toTensor[T], gradOutput.toTensor[T])
    } else gradOutput
    gradInput = model.updateGradInput(input, grad)
    gradInput
  }

  override def accGradParameters(input: Activity, gradOutput: Activity): Unit = {
    model.accGradParameters(input, gradOutput)
  }

  private[nn] def encode(inputs: ModuleNode[T], attentionBias: ModuleNode[T]): ModuleNode[T] = {
    // Prepare inputs to the layer stack by adding positional encodings and
    // applying dropout.
    val position = new PositionEncode().inputs(inputs)
    val encoderInput = CAddTable().inputs(inputs, position)
    val encoderInputDrop = Dropout(1- embeddingDropout).inputs(encoderInput)
    encoderStack.inputs(encoderInputDrop, attentionBias)
  }

  private[nn] def decode(targets: ModuleNode[T],
                     encoderOutput: ModuleNode[T] = null,
                     attentionBias: ModuleNode[T] = null): ModuleNode[T] = {
    val decoderInput = new PositionEncodeWithShift().inputs(targets)
    val decoderSelfAttentionBias = new SelfAttentionMask().inputs(targets)
    val decoderInputDrop = Dropout(1- embeddingDropout).inputs(decoderInput)

    decoderStack.inputs(Array(decoderInputDrop,
      decoderSelfAttentionBias, encoderOutput, attentionBias))
  }

  private[nn] def block(numLayers: Int,
                        decoderInput: ModuleNode[T],
                        decoderSelfAttentionBias: ModuleNode[T],
                        encoderOutput: ModuleNode[T] = null,
                        encoderAttentionBias: ModuleNode[T] = null,
                        blockType: String): ModuleNode[T] = {

    var input = decoderInput
    var i = 0
    while (i < numLayers) {
      val selfAttention = new Attention[T](hiddenSize, numHeads, attentionDropout)
      val selfAttentionModel = processSelfAttention(
        selfAttention, input, decoderSelfAttentionBias,
        s"${blockType}_self_attention_${i}")
      input = selfAttentionModel

      if (encoderOutput != null && encoderAttentionBias != null) {
        val encdecAttention = new Attention[T](hiddenSize, numHeads, attentionDropout)
        val encdecAttentionModel = processEncDecAttention(
          encdecAttention, input, encoderOutput, encoderAttentionBias,
          s"${blockType}_encdec_attention_${i}")
        input = encdecAttentionModel
      }

      val ffn = new FeedForwardNetwork[T](hiddenSize, filterSize, ffnDropout)
      val ffnModel = processFFN(ffn, input, s"${blockType}_ffn_${i}")
      input = ffnModel

      i += 1
    }
    new LayerNormalization[T](hiddenSize).inputs(input)
  }

  private def processSelfAttention(layer: Module[T], decoderInput: ModuleNode[T],
    decoderSelfAttentionBias: ModuleNode[T], preName: String): ModuleNode[T] = {
    val norm = new LayerNormalization[T](hiddenSize).setName(preName + "/norm")
        .inputs(decoderInput)
    val drop = Dropout[T](1 - embeddingDropout).setName(preName + "/dropout")
        .inputs(layer.setName(preName + "/self_attention")
        .inputs(norm, norm, decoderSelfAttentionBias))
    CAddTable().inputs(decoderInput, drop)
  }

  private def processEncDecAttention(
    layer: Module[T],
    decoderInput: ModuleNode[T],
    encoderOutput: ModuleNode[T],
    attentionBias: ModuleNode[T], preName: String): ModuleNode[T] = {
    val norm = new LayerNormalization[T](hiddenSize).setName(preName + "/norm")
      .inputs(decoderInput)
    val drop = Dropout[T](1 - embeddingDropout).setName(preName + "/dropout")
      .inputs(layer.setName(preName + "/encdec_attention")
        .inputs(norm, encoderOutput, attentionBias))
    CAddTable().inputs(decoderInput, drop)
  }

  private def processFFN(layer: Module[T],
    decoderInput: ModuleNode[T], preName: String): ModuleNode[T] = {
    val norm = new LayerNormalization[T](hiddenSize).setName(preName + "/norm")
      .inputs(decoderInput)
    val drop = Dropout[T](1 - embeddingDropout).setName(preName + "/dropout")
      .inputs(layer.setName(preName + "/ffn").inputs(norm))
    CAddTable().inputs(decoderInput, drop)
  }

  override def clearState(): this.type = {
    if (withShareWeightsLinear) linearSharedWeigths.clearState()
    model.clearState()
    this
  }

  override def training(): this.type = {
    train = true
    model.training()
    this
  }

  override def evaluate(): this.type = {
    train = false
    model.evaluate()
    this
  }

  override def getExtraParameter(): Array[Tensor[T]] = {
    model.getExtraParameter()
  }

  override def getTimes(): Array[(AbstractModule[_ <: Activity, _ <: Activity, T], Long, Long)] = {
    model.getTimes()
  }

  override def resetTimes(): Unit = {
    model.resetTimes()
  }

  override def parameters(): (Array[Tensor[T]], Array[Tensor[T]]) = {
    model.parameters()
  }

  override def getParametersTable(): Table = {
    model.getParametersTable()
  }
}

object Transformer extends ModuleSerializable {
  def apply[T: ClassTag](
     vocabSize: Int,
     hiddenSize: Int,
     numHeads: Int,
     filterSize: Int,
     numHiddenlayers: Int,
     embeddingDropout: Float,
     attentionDropout: Float,
     ffnDropout: Float,
     paddingValue: Double = 0,
     withShareWeightsLinear: Boolean = false,
     transformerType: TransformerType = LanguageModel,
     beamSearch: SequenceBeamSearch[T] = null)
   (implicit ev: TensorNumeric[T]): Transformer[T] = {
    new Transformer(vocabSize, hiddenSize, numHeads,
      filterSize, numHiddenlayers,
      embeddingDropout, attentionDropout, ffnDropout, paddingValue,
      withShareWeightsLinear, transformerType = transformerType, beamSearch)
  }

  override def doLoadModule[T: ClassTag](context: DeserializeContext)
    (implicit ev: TensorNumeric[T]) : AbstractModule[Activity, Activity, T] = {
    val attrMap = context.bigdlModule.getAttrMap

    val model = DataConverter
      .getAttributeValue(context, attrMap.get("model")).
      asInstanceOf[Module[T]]

    val vocabSize = DataConverter
      .getAttributeValue(context, attrMap.get("vocabSize"))
      .asInstanceOf[Int]

    val hiddenSize = DataConverter
      .getAttributeValue(context, attrMap.get("hiddenSize"))
      .asInstanceOf[Int]

    val numHeads = DataConverter
      .getAttributeValue(context, attrMap.get("numHeads"))
      .asInstanceOf[Int]

    val filterSize = DataConverter
      .getAttributeValue(context, attrMap.get("filterSize"))
      .asInstanceOf[Int]

    val numHiddenlayers = DataConverter
      .getAttributeValue(context, attrMap.get("numHiddenlayers"))
      .asInstanceOf[Int]

    val embeddingDropout = DataConverter
      .getAttributeValue(context, attrMap.get("embeddingDropout"))
      .asInstanceOf[Float]

    val attentionDropout = DataConverter
      .getAttributeValue(context, attrMap.get("attentionDropout"))
      .asInstanceOf[Float]

    val ffnDropout = DataConverter
      .getAttributeValue(context, attrMap.get("ffnDropout"))
      .asInstanceOf[Float]

    val beamSearch = DataConverter
      .getAttributeValue(context, attrMap.get("beamSearch"))
      .asInstanceOf[Module[T]]

    val paddingValue = DataConverter
      .getAttributeValue(context, attrMap.get("paddingValue"))
      .asInstanceOf[Double]

    val tag = DataConverter
      .getAttributeValue(context, attrMap.get("transformerType"))
      .asInstanceOf[Int]

    val withShareWeightsLinear = DataConverter
      .getAttributeValue(context, attrMap.get("withShareWeightsLinear"))
      .asInstanceOf[Boolean]

    val transformerType = tag match {
      case 1 => LanguageModel
      case 2 => Translation
      case _ => Log4Error.invalidInputError(false,
        s"Only support transformer tag 1 and 2, but get ${tag}")
        null
    }

    val transformer = Transformer(vocabSize, hiddenSize, numHeads, filterSize,
      numHiddenlayers, embeddingDropout, attentionDropout, ffnDropout, paddingValue,
      withShareWeightsLinear, transformerType, beamSearch.asInstanceOf[SequenceBeamSearch[T]])

    transformer.model = model
    transformer
  }

  override def doSerializeModule[T: ClassTag](context: SerializeContext[T],
     transformerBuilder : BigDLModule.Builder)(implicit ev: TensorNumeric[T]) : Unit = {

    val transformer = context.moduleData.module.asInstanceOf[Transformer[T]]

    val modelBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, modelBuilder, transformer.model,
      ModuleSerializer.abstractModuleType)
    transformerBuilder.putAttr("model", modelBuilder.build)

    val vocabSizeBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, vocabSizeBuilder,
      transformer.vocabSize, universe.typeOf[Int])
    transformerBuilder.putAttr("vocabSize", vocabSizeBuilder.build)

    val hiddenSizeBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, hiddenSizeBuilder,
      transformer.hiddenSize, universe.typeOf[Int])
    transformerBuilder.putAttr("hiddenSize", hiddenSizeBuilder.build)

    val numHeadsBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, numHeadsBuilder,
      transformer.numHeads, universe.typeOf[Int])
    transformerBuilder.putAttr("numHeads", numHeadsBuilder.build)

    val filterSizeBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, filterSizeBuilder,
      transformer.filterSize, universe.typeOf[Int])
    transformerBuilder.putAttr("filterSize", filterSizeBuilder.build)

    val numHiddenlayersBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, numHiddenlayersBuilder,
      transformer.numHiddenlayers, universe.typeOf[Int])
    transformerBuilder.putAttr("numHiddenlayers", numHiddenlayersBuilder.build)

    val embeddingDropoutBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, embeddingDropoutBuilder,
      transformer.embeddingDropout, universe.typeOf[Float])
    transformerBuilder.putAttr("embeddingDropout", embeddingDropoutBuilder.build)

    val attentionDropoutBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, attentionDropoutBuilder,
      transformer.attentionDropout, universe.typeOf[Float])
    transformerBuilder.putAttr("attentionDropout", attentionDropoutBuilder.build)

    val ffnDropoutBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, ffnDropoutBuilder,
      transformer.ffnDropout, universe.typeOf[Float])
    transformerBuilder.putAttr("ffnDropout", embeddingDropoutBuilder.build)

    val paddingValueBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, paddingValueBuilder,
      transformer.paddingValue, universe.typeOf[Double])
    transformerBuilder.putAttr("paddingValue", paddingValueBuilder.build)

    val shareWeightsBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, shareWeightsBuilder,
      transformer.withShareWeightsLinear, universe.typeOf[Boolean])
    transformerBuilder.putAttr("withShareWeightsLinear", shareWeightsBuilder.build)

    // for language model, marked as 1
    // for translation model, marked as 2
    val tag = transformer.transformerType match {
      case LanguageModel => 1
      case Translation => 2
      case _ => Log4Error.invalidInputError(false, s"Only support LanguageModel" +
        s"and Translation transformer type, but get ${transformer.transformerType}")
    }
    val transformerTypeBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, transformerTypeBuilder,
      tag, universe.typeOf[Int])
    transformerBuilder.putAttr("transformerType", transformerTypeBuilder.build)

    val beamSearchBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, beamSearchBuilder,
      transformer.beamSearch, universe.typeOf[Module[_]])
    transformerBuilder.putAttr("beamSearch", beamSearchBuilder.build)
  }
}

/**
 * Return positional encoding.
 * Calculates the position encoding as a mix of sine and cosine functions with
 * geometrically increasing wavelengths.
 * Defined and formulized in Attention is All You Need, section 3.5.
 * @param ev$1
 * @param ev
 * @tparam T The numeric type in this module parameters
 */
private[nn] class PositionEncode[T: ClassTag](implicit ev: TensorNumeric[T])
  extends TensorModule[T] {
  @transient private var rangeBuffer : Tensor[T] = null

  override def updateOutput(input: Tensor[T]): Tensor[T] = {
    val length = input.size(2)
    val channel = input.size(3)

    if (!output.isEmpty && output.nElement() == length * channel) return output

    if (rangeBuffer == null) rangeBuffer = Tensor[T]()
    TransformerOperation.initRangeTensor(length, rangeBuffer)

    output.resize(length, channel)
    TransformerOperation.getPositionEncode(length, channel,
      rangeBuffer = rangeBuffer, outBuffer = output)
    output
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    if (!gradInput.isEmpty && gradInput.nElement() == input.nElement()) return gradInput
    gradInput.resizeAs(input).zero()
    gradInput
  }
}

// Return postition encoding with input shift right
private[nn] class PositionEncodeWithShift[T: ClassTag](implicit ev: TensorNumeric[T])
  extends TensorModule[T] {

  @transient private var rangeBuffer : Tensor[T] = null
  @transient private var timeBuffer : Tensor[T] = null

  // input a Tensor with shape [batch, length, channels]
  override def updateOutput(input: Tensor[T]): Tensor[T] = {
    TransformerOperation.shiftRight3D(input, output)
    val length = output.size(2)
    val channel = output.size(3)

    if (rangeBuffer == null) rangeBuffer = Tensor[T]()
    if (timeBuffer == null) timeBuffer = Tensor[T]()

    if (timeBuffer.nElement() != length * channel) {
      TransformerOperation.initRangeTensor(length, rangeBuffer)
      timeBuffer = Tensor[T]().resize(length, channel)
      TransformerOperation.getPositionEncode(length, channel,
        rangeBuffer = rangeBuffer, outBuffer = timeBuffer)
    }

    val batchSize = input.size(1)
    var i = 1
    while (i <= batchSize) {
      output.select(1, i).add(timeBuffer)
      i += 1
    }
    return output
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    if (gradInput == null) gradInput = Tensor[T]()
    gradInput.resizeAs(gradOutput).zero()
    val size = gradOutput.size(2)
    var i = 1
    while (i < size) {
      gradInput.select(2, i).copy(gradOutput.select(2, i + 1))
      i += 1
    }
    gradInput
  }
}

/**
 * Calculate bias tensor from padding values in tensor.
 * Bias tensor that is added to the pre-softmax multi-headed attention logits,
 * which has shape [batch_size, num_heads, length, length]. The tensor is zero at
 * non-padding locations, and -1e9 (negative infinity) at padding locations.
 * @param ev$1
 * @param ev
 * @tparam T The numeric type in this module parameters
 */
private[nn] class PaddingMask[T: ClassTag](implicit ev: TensorNumeric[T])
  extends TensorModule[T] {
  override def updateOutput(input: Tensor[T]): Tensor[T] = {
    output.resizeAs(input).copy(input)
    output = TransformerOperation.getPaddingBias(output)
    output
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    gradInput.resizeAs(input).zero()
    gradInput
  }
}

// This mask is to hide both  and future words. Used in decode
private[nn] class SelfAttentionMask[T: ClassTag](implicit ev: TensorNumeric[T])
  extends TensorModule[T] {

  private val maskValue = -1e9

  /**
   * Create an bias tensor to be added to attention logits.
   * Returns tensor with shape (1, 1, length, length)
   * @param length
   * @tparam T
   * @return
   */
  private def attentionBiasLowerTriangle[T: ClassTag](
    length: Int, output: Tensor[T])(implicit ev: TensorNumeric[T]): Tensor[T] = {
    val arr = output.storage().array()
    for (i <- 0 to (length - 1)) {
      var j = length - 1
      while (j > i) {
        // reminder: here not 1
        arr(i * length + j) = ev.fromType(maskValue)
        j -= 1
      }
    }
    output.resize(Array(1, 1, length, length))
  }

  override def updateOutput(input: Tensor[T]): Tensor[T] = {
    if (!output.isEmpty && output.nElement() == input.nElement()) return output
    output.resize(input.size(2), input.size(2)).zero()
    attentionBiasLowerTriangle[T](input.size(2), output)
    output
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    if (!gradInput.isEmpty && gradInput.nElement() == input.nElement()) return gradInput
    gradInput.resizeAs(input).zero()
    gradInput
  }
}

private[nn] class SplitTensor[T: ClassTag](dimension: Int, num: Int)
  (implicit ev: TensorNumeric[T]) extends AbstractModule[Tensor[T], Table, T] {

  private val innerLayer = new JoinTable[T](dimension, -1)

  override def updateOutput(input: Tensor[T]): Table = {
    output = T.array(input.split(input.size(dimension) / num, dimension))
    output
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Table): Tensor[T] = {
    gradInput = innerLayer.forward(gradOutput).toTensor[T]
    gradInput
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy