All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.SequenceBeamSearch.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.tensor._
import com.intel.analytics.bigdl.utils.{T, Table}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

/**
 * Beam search to find the translated sequence with the highest probability.
 * @param vocabSize size of tokens
 * @param beamSize number of beams
 * @param alpha defining the strength of length normalization
 * @param maxDecodeLength maximum length to decoded sequence
 * @param eosID id of eos token, used to determine when a sequence has finished
 * @param paddingValue id of padding token, used at the beginning and the end
 * @param numHiddenLayers number of hidden layers
 * @param hiddenSize size of hidden layer
 */
class SequenceBeamSearch[T: ClassTag](
  val vocabSize: Int,
  val beamSize: Int,
  val alpha: Float,
  val maxDecodeLength: Int,
  val eosID: Float,
  val paddingValue: Float,
  val numHiddenLayers: Int,
  val hiddenSize: Int)(implicit ev: TensorNumeric[T])
  extends AbstractModule[Table, Activity, T] {

  private val inf = 1e7f * (-1)
  private var batchSize = 0
  private val newFinishedFlags = Tensor[T]
  private var aliveLogProbs = Tensor[T]
  private var finishedSeq = Tensor[T]
  private var aliveSeq = Tensor[T]
  private var finishedFlags = Tensor[Boolean]
  private val finishedFlagsSeq = Tensor[T]
  private var finishedScores = Tensor[T]
  private val gatherTensor = Tensor[T]
  private val topkSeq = Tensor[T]
  private val topkLogProbs = Tensor[T]
  private val topkScore = Tensor[T]
  private val topkFlags = Tensor[T]
  private var symbolToLogits: (Tensor[T], Int, Int, Tensor[T], Tensor[T], Table)
    => (Tensor[T], Table) = null
  private val topkEncoder = Tensor[T]
  private val topkAttentionBias = Tensor[T]
  private var topkLayer = T()
  for (i <- 1 to  numHiddenLayers) {
    val tensor1 = Tensor[T]
    val tensor2 = Tensor[T]
    topkLayer("layer_" ++ i.toString ++ "_k") = tensor1
    topkLayer("layer_" ++ i.toString ++ "_v") = tensor2
  }

  private def expandDim(tensor: Tensor[T], axis: Int): Tensor[T] = {
    val shape = tensor.size()
    val newShape = shape.toBuffer
    newShape.insert(axis, 1)
    tensor.reshape(newShape.toArray)
  }

  // Tiles a given tensor by beam_size.
  private def extendBeamSize(t: Tensor[T], beamSize: Int): Tensor[T] = {
    val tensor = expandDim(t, 1)
    val tileDim = new Array[Int](tensor.dim()).map(a => a + 1)
    tileDim(1) = beamSize
    tensor.repeatTensor(tileDim)
  }

  private def lengthNormalization(alpha: Float, length: Int): T = {
    ev.pow(ev.fromType[Double](5.0 + length / 6.0), ev.fromType[Float](alpha))
  }

  private def boolToFloat(b: Boolean): T = {
    if (b) ev.one
    else ev.zero
  }

  private def floatToBool(f: T): Boolean = {
    if (f == 1.0) true
    else false
  }

  // Computes the "logical or" of elements across dimensions of a tensor.
  private def reduceAny(tensor: Tensor[Boolean]): Tensor[T] = {
    val tSize = tensor.size()
    val outputs = Tensor[T](tSize(0))
    for (i <- 1 to tSize(0)) {
      var valueAny = false
      for (j <- 1 to tSize(1)) {
        valueAny = valueAny || tensor.valueAt(i, j)
      }
      outputs.setValue(i, boolToFloat(valueAny))
    }
    outputs
  }

  // Computes the "logical and" of elements across dimensions of a tensor.
  private def reduceAll(tensor1: Tensor[T], tensor2: Tensor[T]): Boolean = {
    val sizeT = tensor1.size()
    var outputs = true
    for (i <- 1 to sizeT(0)) {
      outputs &&= ev.isGreater(tensor1.valueAt(i), tensor2.valueAt(i))
    }
    outputs
  }

  /**
   * Return whether to continue the search loop.
   * The loops should terminate when
   * 1) when decode length has been reached, or
   * 2) when the worst score in the finished sequences is better than the best
   * score in the alive sequences (i.e. the finished sequences are provably
   * unchanging)
   *
   * @param state A map with the current loop state.
   * @return Boolean value with value True if loop should continue, False if loop should
   *         terminate.
   */
  private def continueSearch(state: Map[String, Any]): Boolean = {
    val i = state("CUR_INDEX").asInstanceOf[Int]
    finishedFlags = state("FINISHED_FLAGS").asInstanceOf[Tensor[Boolean]]
    aliveLogProbs.copy(state("ALIVE_LOG_PROBS").asInstanceOf[Tensor[T]])
    finishedScores.resizeAs(state("FINISHED_SCORES").asInstanceOf[Tensor[T]])
      .copy(state("FINISHED_SCORES").asInstanceOf[Tensor[T]])
    var notAtMaxDecodeLength = true
    if (i < maxDecodeLength) {
      notAtMaxDecodeLength = true
    } else {
      notAtMaxDecodeLength = false
    }
    val maxLengthNorm = lengthNormalization(alpha, maxDecodeLength)
    // Get the best possible scores from alive sequences.
    val bestAliveScores = aliveLogProbs.select(2, 1) / maxLengthNorm
    newFinishedFlags.applyFun[Boolean](finishedFlags, x => boolToFloat(x))
    finishedScores.cmul(newFinishedFlags)
    // Compute worst score in finished sequences for each batch element
    var lowestFinishedScores = finishedScores.min(2)._1
    lowestFinishedScores += (reduceAny(finishedFlags) * ev.fromType[Double](-1.0)
      + ev.fromType[Double](1.0)) * ev.fromType[Double](inf)
    lowestFinishedScores = lowestFinishedScores.reshape(Array(lowestFinishedScores.size()(0)))
    val worstFinishedScoreBetterThanBestAliveScore =
      reduceAll(lowestFinishedScores, bestAliveScores)
    notAtMaxDecodeLength && (!worstFinishedScoreBetterThanBestAliveScore)
  }

  // Reshapes first two dimensions in to single dimension.
  private def flattenBeamDim(tensor: Tensor[T]): Tensor[T] = {
    val shape = tensor.size()
    val newShape = shape.toBuffer
    newShape(0) = shape(0) * shape(1)
    newShape.remove(1)
    tensor.reshape(newShape.toArray)
  }

  // Reshapes first dimension back to [batch_size, beam_size].
  private def unFlattenBeamDim(tensor: Tensor[T], batchSize: Int, beamSize: Int): Tensor[T] = {
    val shape = tensor.size()
    val newShape = shape.toBuffer
    newShape(0) = batchSize
    newShape.insert(1, beamSize)
    tensor.reshape(newShape.toArray)
  }

  // logits - log(sum(exp(logits)))
  private def logProbFromLogits(logits: Tensor[T]): Tensor[T] = {
    val shape = logits.size()
    val getExp = Tensor[T](shape)
    getExp.applyFun[T](logits, x => ev.exp(x))
    val getSumExp = getExp.sum(3)
    val getLogSumExp = Tensor[T](getSumExp.size())
    getLogSumExp.applyFun[T](getSumExp, x => ev.log(x))
    logits - getLogSumExp.repeatTensor(Array(1, 1, shape(2)))
  }

  // Gather slices from tensor into outputs with shape specified by indices.
  private def gatherNd(tensor: Tensor[T], indices: Tensor[T], outputs: Tensor[T]): Tensor[T] = {
    val shape1 = tensor.size()
    val shape2 = indices.size()
    var slices = new Array[T](0)
    if (shape1.length == 2) {
      outputs.resize(shape2(0), shape2(1))
      slices = new Array[T](shape2(0) * shape2(1))
      for (i <- 1 to shape2(0)) {
        for (j <- 1 to shape2(1)) {
          slices((i - 1) * shape2(1) + j - 1) = tensor.valueAt(ev.toType[Int](ev.plus
          (indices.valueAt(i, j, 1), ev.fromType[Float](1.0f))), ev.toType[Int]
            (ev.plus(indices.valueAt(i, j, 2), ev.fromType[Float](1.0f))))
        }
      }
    } else if (shape1.length == 3) {
      outputs.resize(shape2(0), shape2(1), shape1(2))
      for (i <- 1 to shape2(0)) {
        for (j <- 1 to shape2(1)) {
          slices ++= tensor
            .select(2, ev.toType[Int](ev.plus(indices.valueAt(i, j, 2), ev.fromType[Float](1.0f))))
            .select(1, ev.toType[Int](ev.plus(indices.valueAt(i, j, 1), ev.fromType[Float](1.0f))))
            .toArray()
        }
      }
    } else if (shape1.length == 4) {
      outputs.resize(shape2(0), shape2(1), shape1(2), shape1(3))
      for (i <- 1 to shape2(0)) {
        for (j <- 1 to shape2(1)) {
          slices ++= tensor
            .select(2, ev.toType[Int](ev.plus(indices.valueAt(i, j, 2), ev.fromType[Float](1.0f))))
            .select(1, ev.toType[Int](ev.plus(indices.valueAt(i, j, 1), ev.fromType[Float](1.0f))))
            .reshape(Array(shape1(2) * shape1(3)))toArray()
        }
      }
    } else if (shape1.length == 5) {
      outputs.resize(shape2(0), shape2(1), shape1(2), shape1(3), shape1(4))
      for (i <- 1 to shape2(0)) {
        for (j <- 1 to shape2(1)) {
          slices ++= tensor
            .select(2, ev.toType[Int](ev.plus(indices.valueAt(i, j, 2), ev.fromType[Float](1.0f))))
            .select(1, ev.toType[Int](ev.plus(indices.valueAt(i, j, 1), ev.fromType[Float](1.0f))))
            .reshape(Array(shape1(2) * shape1(3) * shape1(4)))toArray()
        }
      }
    }
    val outputData = outputs.storage().array()
    val outputOffset = outputs.storageOffset() - 1
    for(i <- slices.indices) {
      outputData(outputOffset + i) = slices(i)
    }
    shape1(0) = shape2(0)
    shape1(1) = shape2(1)
    outputs
  }

  // Concatenates tensor1 and tensor2 along one dimension.
  private def concat(tensor1: Tensor[T], tensor2: Tensor[T], dim: Int): Tensor[T] = {
    val shape1 = tensor1.size()
    val shape2 = tensor2.size()
    val array1 = tensor1.reshape(Array(shape1.product)).toArray()
    val array2 = tensor2.reshape(Array(shape2.product)).toArray()
    var outputsArray = new Array[T](0)
    var concatLength1 = 1
    var concatLength2 = 1
    for (i <- dim - 1 until shape1.length) {
      concatLength1 *= shape1(i)
    }
    for (i <- dim - 1 until shape2.length) {
      concatLength2 *= shape2(i)
    }
    val group1 = array1.grouped(concatLength1)
    val group2 = array2.grouped(concatLength2)
    while (group1.hasNext) {

      outputsArray ++= group1.next()
      outputsArray ++= group2.next()
    }
    val newShape = shape1
    newShape(dim - 1) = shape1(dim - 1) + shape2(dim - 1)
    Tensor(outputsArray, newShape)
  }

  // Gather beams from tensors.
  private def gatherBeams(nested: Tensor[T], beamIndices: Tensor[T],
    batchSize: Int, newBeamSize: Int): Tensor[T] = {
    val batchPos = (Tensor.range(0, batchSize * newBeamSize - 1, 1) / ev.fromType[Int](newBeamSize))
      .reshape(Array(batchSize, newBeamSize))
    val newBatchPos = batchPos.apply1(e => ev.floor(e))
    val coordinates = Tensor[T](batchSize, newBeamSize, 2)
    for (i <- 1 to batchSize) {
      for (j <- 1 to newBeamSize) {
        coordinates.setValue(i, j, 1, newBatchPos.valueAt(i, j))
        coordinates.setValue(i, j, 2, beamIndices.valueAt(i, j))
      }
    }
    gatherNd(nested, coordinates.asInstanceOf[Tensor[T]], gatherTensor)
  }

  // Gather top beams from nested structure.
  private def gatherTopkBeams(tensor: Tensor[T], scoreOrLogProb: Tensor[T],
    batchSize: Int, beamSize: Int): Tensor[T] = {
    val (_, topkIndexes) = scoreOrLogProb.topk(beamSize, -1, false)
    topkIndexes.apply1(e => ev.minus(e, ev.fromType[Float](1.0f)))
    gatherBeams(tensor, topkIndexes, batchSize, beamSize)
  }

  /**
   * symbolToLogits is a function apply to compute decoding probability with all vocabulary
   * it has 6 inputs: current decoding sequence, current decoding length,
   *                  max decoding length, encoder outputs, encoder decoder attention bias,
   *                  decoder attention values for each layer
   * it returns: probability, decoder attention values for each layer
   */
  def setLogitFn(fn: (Tensor[T], Int, Int, Tensor[T], Tensor[T], Table) => (Tensor[T], Table)):
    SequenceBeamSearch[T] = {
    symbolToLogits = fn
    this
  }

  /**
   * Grow alive sequences by one token, and collect top 2*beam_size sequences.
   * 2*beam_size sequences are collected because some sequences may have reached
   * the EOS token. 2*beam_size ensures that at least beam_size sequences are
   * still alive.
   * @param state A map with the current loop state.
   * @return newSeq Top 2*beam_size sequences [batch_size, 2 * beam_size, cur_index + 1]
   *         topkLogProbs probabilities of returned sequences [batch_size, 2 * beam_size]
   */
  private def growAliveSeq(state: Map[String, Any]): (Tensor[T], Tensor[T]) = {
    val i = state("CUR_INDEX").asInstanceOf[Int]
    aliveSeq = state("ALIVE_SEQ").asInstanceOf[Tensor[T]]
    aliveLogProbs = state("ALIVE_LOG_PROBS").asInstanceOf[Tensor[T]]
    val aliveEncoder = state("ENCODER").asInstanceOf[Tensor[T]]
    val aliveAttentionsBias = state("ATTENTION_BIAS").asInstanceOf[Tensor[T]]
    val aliveLayer = state("LAYER").asInstanceOf[Table]
    val beamsToKeep = 2 * beamSize
    // flatten following variables with first dimension (batchSize * beamSize)
    val flatIds = flattenBeamDim(aliveSeq)
    val flatEncoder = flattenBeamDim(aliveEncoder)
    val flatAttentionBias = flattenBeamDim(aliveAttentionsBias)
    if (i > 0) {
      for (j <- 1 to  numHiddenLayers) {
        val tensor1 = aliveLayer("layer_" ++ j.toString ++ "_k").asInstanceOf[Tensor[T]]
        val tensor2 = aliveLayer("layer_" ++ j.toString ++ "_v").asInstanceOf[Tensor[T]]
        aliveLayer("layer_" ++ j.toString ++ "_k") = flattenBeamDim(tensor1)
        aliveLayer("layer_" ++ j.toString ++ "_v") = flattenBeamDim(tensor2)
      }
    }
    // get logits for the next candidate IDs for the alive sequences.
    val (flatLogits, newLayer) = symbolToLogits(flatIds, i, maxDecodeLength, flatEncoder,
      flatAttentionBias, aliveLayer)
    // unflatten following variables with first dimension batchSize
    val newFlatEncoder = unFlattenBeamDim(flatEncoder, batchSize, beamSize)
    val newAttentionBias = unFlattenBeamDim(flatAttentionBias, batchSize, beamSize)
    for (j <- 1 to  numHiddenLayers) {
      val tensor1 = newLayer("layer_" ++ j.toString ++ "_k").asInstanceOf[Tensor[T]]
      val tensor2 = newLayer("layer_" ++ j.toString ++ "_v").asInstanceOf[Tensor[T]]
      newLayer("layer_" ++ j.toString ++ "_k") = unFlattenBeamDim(tensor1, batchSize, beamSize)
      newLayer("layer_" ++ j.toString ++ "_v") = unFlattenBeamDim(tensor2, batchSize, beamSize)
    }
    val logits = unFlattenBeamDim(flatLogits, batchSize, beamSize)
    val candidateLogProbs = logProbFromLogits(logits)
    // add new logProbs value to current alive sequence logProbs
    val logProbs = candidateLogProbs + expandDim(aliveLogProbs, 2)
      .repeatTensor(Array(1, 1, vocabSize))
    val flatLogProbs = logProbs.reshape(Array(logProbs.size().product
      / (beamSize * vocabSize), beamSize * vocabSize))
    // for each batch item, get the k candidates with the highest log probabilities.
    val (topkLogProbs, topkIndices) = flatLogProbs.topk(beamsToKeep, -1, false)
    topkIndices.apply1(e => ev.minus(e, ev.fromType[Float](1.0f)))
    val topkBeamIndices = (topkIndices / ev.fromType[Int](vocabSize)).apply1(e => ev.floor(e))
    // Extract the alive sequences that generate the highest log probabilities
    var gatherTmp = gatherBeams(aliveSeq, topkBeamIndices, batchSize, beamsToKeep)
    topkSeq.resizeAs(gatherTmp).copy(gatherTmp)
    gatherTmp = gatherBeams(newFlatEncoder, topkBeamIndices, batchSize, beamsToKeep)
    topkEncoder.resizeAs(gatherTmp).copy(gatherTmp)
    gatherTmp = gatherBeams(newAttentionBias, topkBeamIndices, batchSize, beamsToKeep)
    topkAttentionBias.resizeAs(gatherTmp).copy(gatherTmp)
    for (j <- 1 to numHiddenLayers) {
      val tensor1 = newLayer("layer_" ++ j.toString ++ "_k").asInstanceOf[Tensor[T]]
      val tensor2 = newLayer("layer_" ++ j.toString ++ "_v").asInstanceOf[Tensor[T]]
      gatherTmp = gatherBeams(tensor1, topkBeamIndices, batchSize, beamsToKeep)
      topkLayer("layer_" ++ j.toString ++ "_k").asInstanceOf[Tensor[T]]
        .resizeAs(gatherTmp).copy(gatherTmp)
      gatherTmp = gatherBeams(tensor2, topkBeamIndices, batchSize, beamsToKeep)
      topkLayer("layer_" ++ j.toString ++ "_v").asInstanceOf[Tensor[T]]
        .resizeAs(gatherTmp).copy(gatherTmp)
    }
    var topkIds = topkIndices.apply1(e => ev.fromType[Int](ev.toType[Int](e) % vocabSize + 1))
    topkIds = expandDim(topkIds, 2)
    val newSeq = concat(topkSeq, topkIds, 3)
    (newSeq, topkLogProbs)
  }

  /**
   * Gather the top k sequences that are still alive.
   * @param newSeq New sequences generated by growing the current alive sequences
   * @param newLogProbs Log probabilities of new sequences
   * @return map with alive keys
   */
  private def growNewAliveState(newSeq: Tensor[T], newLogProbs: Tensor[T]): Map[String, Any] = {
    finishedFlagsSeq.copy(newSeq.select(3, newSeq.size()(2)))
    finishedFlagsSeq.apply1(x => boolToFloat(ev.toType[Float](x) == eosID))
    val newLogProbs1 = newLogProbs + finishedFlagsSeq * ev.fromType[Double](inf)
    var gatherTmp = gatherTopkBeams(newSeq, newLogProbs1, batchSize, beamSize)
    aliveSeq.resizeAs(gatherTmp).copy(gatherTmp)
    gatherTmp = gatherTopkBeams(newLogProbs1, newLogProbs1, batchSize, beamSize)
    topkLogProbs.resizeAs(gatherTmp).copy(gatherTmp)
    gatherTmp = gatherTopkBeams(topkEncoder, newLogProbs1, batchSize, beamSize)
    topkEncoder.resizeAs(gatherTmp).copy(gatherTmp)
    gatherTmp = gatherTopkBeams(topkAttentionBias, newLogProbs1, batchSize, beamSize)
    topkAttentionBias.resizeAs(gatherTmp).copy(gatherTmp)
    for (i <- 1 to numHiddenLayers) {
      val tensor1 = topkLayer("layer_" ++ i.toString ++ "_k").asInstanceOf[Tensor[T]]
      val tensor2 = topkLayer("layer_" ++ i.toString ++ "_v").asInstanceOf[Tensor[T]]
      gatherTmp = gatherTopkBeams(tensor1, newLogProbs1, batchSize, beamSize)
      tensor1.resizeAs(gatherTmp).copy(gatherTmp)
      gatherTmp = gatherTopkBeams(tensor2, newLogProbs1, batchSize, beamSize)
      tensor2.resizeAs(gatherTmp).copy(gatherTmp)
    }
    Map("ALIVE_SEQ" -> aliveSeq, "ALIVE_LOG_PROBS" -> topkLogProbs,
      "ENCODER" -> topkEncoder, "ATTENTION_BIAS" -> topkAttentionBias,
      "LAYER" -> topkLayer)
  }

  /**
   * Combine new and old finished sequences, and gather the top k sequences.
   * @param state A map with the current loop state.
   * @param newSeq New sequences generated by growing the current alive sequences
   * @param newLogProbs Log probabilities of new sequences
   * @return map with finished keys
   */
  private def getNewFinishedState(state: Map[String, Any], newSeq: Tensor[T],
    newLogProbs: Tensor[T]): Map[String, Any] = {
    val i = state("CUR_INDEX").asInstanceOf[Int]
    finishedSeq = state("FINISHED_SEQ").asInstanceOf[Tensor[T]]
    finishedScores = state("FINISHED_SCORES").asInstanceOf[Tensor[T]]
    finishedFlags = state("FINISHED_FLAGS").asInstanceOf[Tensor[Boolean]]
    // append a column of 0-ids to finished_seq to increment the length.
    finishedSeq = concat(finishedSeq, Tensor[T](batchSize, beamSize, 1)
      .fill(ev.fromType[Float](paddingValue)), 3)
    val lengthNorm = lengthNormalization(alpha, i + 1)
    var newScores = newLogProbs / lengthNorm
    // Set the scores of the still-alive seq in new_seq to large negative values.
    newScores += (Tensor(finishedFlagsSeq.size()).fill(ev.fromType[Float](1.0f))
     - finishedFlagsSeq) * ev.fromType[Float](inf)
    // Combine sequences, scores, and flags.
    finishedSeq = concat(finishedSeq, newSeq, 2)
    finishedScores = concat(finishedScores, newScores, 2)
    var finishedFlags1 = Tensor[T](finishedFlags.size())
    finishedFlags1.applyFun[Boolean](finishedFlags, x => boolToFloat(x))
    finishedFlags1 = concat(finishedFlags1, finishedFlagsSeq, 2)
    var gatherTmp = gatherTopkBeams(finishedSeq, finishedScores, batchSize, beamSize)
    topkSeq.resizeAs(gatherTmp).copy(gatherTmp)
    gatherTmp = gatherTopkBeams(finishedScores, finishedScores, batchSize, beamSize)
    topkScore.resizeAs(gatherTmp).copy(gatherTmp)
    gatherTmp = gatherTopkBeams(finishedFlags1, finishedScores, batchSize, beamSize)
    topkFlags.resizeAs(gatherTmp).copy(gatherTmp)
    val topFinishedFlags1 = topkFlags.reshape(Array(topkFlags.size().product))
      .toArray()
    val outputFlag = ArrayBuffer[Boolean]()
    for (ele <- topFinishedFlags1) {
      outputFlag.append(floatToBool(ele))
    }
    finishedFlags = Tensor(outputFlag.toArray, topkFlags.size())
    finishedSeq.resizeAs(topkSeq).copy(topkSeq)
    Map("FINISHED_SEQ" -> finishedSeq, "FINISHED_SCORES" -> topkScore,
      "FINISHED_FLAGS" -> finishedFlags)
  }

  /**
   * Grow alive sequences by a single ID. Sequences that have reached the EOS
   * token are marked as finished. The alive and finished sequences with the
   * highest log probabilities and scores are returned.
   */
  private def searchStep(state: Map[String, Any]): Map[String, Any] = {
    val (newSeq, newLogProbs) = growAliveSeq(state)
    val aliveState = growNewAliveState(newSeq, newLogProbs)
    val finishedState = getNewFinishedState(state, newSeq, newLogProbs)
    val newState: Map[String, Any] = Map("CUR_INDEX" -> (state("CUR_INDEX")
      .asInstanceOf[Int] + 1)) ++ aliveState ++ finishedState
    newState
  }

  /**
   * return initial state map
   *
   * @param encoderOutputs Sequences after encoding
   * @param encoderDecoderAttentionBias encoder decoder attention bias
   * @return map with states
   *         CUR_INDEX: Variable storing the loop index.
   *         ALIVE_SEQ: Top sequences that are alive for each batch item. Alive sequences are ones
   *         that have not generated an EOS token.
   *         ALIVE_LOG_PROBS: Log probabilities of each alive sequence.
   *         ENCODER: Sequences after encoding
   *         ATTENTION_BIAS: encoder decoder attention bias
   *         LAYER: decoder attention values for each layer.
   *         FINISHED_SEQ: Top finished sequences for each batch item.
   *         FINISHED_SCORES: Scores for each finished sequence. Score=log probability/length norm
   *         FINISHED_FLAGS: Flags indicating which sequences in the finished sequences
   *         are finished.
   */
  private def createInitialState(encoderOutputs: Tensor[T], encoderDecoderAttentionBias: Tensor[T]):
    Map[String, Any] = {
    batchSize = encoderOutputs.size()(0)
    newFinishedFlags.resize(batchSize, beamSize)
    aliveLogProbs.resize(batchSize, beamSize)
    finishedFlags.resize(batchSize, beamSize)
    finishedFlagsSeq.resize(batchSize, beamSize * 2)
    finishedScores.resize(batchSize, beamSize)
    val curIndex = 0
    val initialID = Tensor[T](Array(batchSize)).fill(ev.fromType[Float](paddingValue))
    var initialAliveSeq = extendBeamSize(initialID, beamSize)
    initialAliveSeq = expandDim(initialAliveSeq, 2)
    var initialLogProbs = Tensor[T](beamSize).apply1(e => ev.fromType[Float](inf))
    initialLogProbs.setValue(1, ev.fromType[Float](0.0f))
    initialLogProbs = initialLogProbs.repeatTensor(Array(batchSize, 1))
    val aliveEncoder = extendBeamSize(encoderOutputs, beamSize)
    val aliveAttentionsBias = extendBeamSize(encoderDecoderAttentionBias, beamSize)
    // Create aliveLayer storing decoder attention values for each layer.
    val aliveLayer = T()
    for (i <- 1 to  numHiddenLayers) {
      val tensor1 = Tensor[T]()
      val tensor2 = Tensor[T]()
      aliveLayer("layer_" ++ i.toString ++ "_k") = tensor1
      aliveLayer("layer_" ++ i.toString ++ "_v") = tensor2
    }
    val initialFinishedSeq = Tensor[T](initialAliveSeq.size())
    val initialFinishedScores = Tensor.ones[T](batchSize, beamSize) * ev.fromType[Float](inf)
    val initialFinishedFlags = Tensor[Boolean](batchSize, beamSize)
    val state = Map("CUR_INDEX" -> curIndex,
      "ALIVE_SEQ" -> initialAliveSeq,
      "ALIVE_LOG_PROBS" -> initialLogProbs,
      "ENCODER" -> aliveEncoder,
      "ATTENTION_BIAS" -> aliveAttentionsBias,
      "LAYER" -> aliveLayer,
      "FINISHED_SEQ" -> initialFinishedSeq,
      "FINISHED_SCORES" -> initialFinishedScores,
      "FINISHED_FLAGS" -> initialFinishedFlags)
    state
  }

  // replace value in a with b according to tensor value
  private def where(tensor: Tensor[T], a: Tensor[T], b: Tensor[T]): Tensor[T] = {
    val arrayBool = tensor.toArray()
    val shape = a.size()
    for (i <- arrayBool.indices) {
      if (arrayBool(i) == 0) {
        if (shape.length == 3) {
          for (j <- 1 to shape(1)) {
            for (k <- 1 to shape(2)) {
              a.setValue(i + 1, j, k, b.valueAt(i + 1, j, k))
            }
          }
        } else {
          for (j <- 1 to shape(1)) {
            a.setValue(i + 1, j, b.valueAt(i + 1, j))
          }
        }
      }
    }
    a
  }

  override def updateOutput(input: Table): Activity = {
    val encoderOutputs = input[Tensor[T]](1)
    val encoderDecoderAttentionBias = input[Tensor[T]](2)
    require(symbolToLogits != null, "symbolToLogits function is null, please set this function")
    var state = createInitialState(encoderOutputs, encoderDecoderAttentionBias)
    while (continueSearch(state)) {
      state = searchStep(state)
    }
    val finishedState = state
    val aliveSeq = finishedState("ALIVE_SEQ").asInstanceOf[Tensor[T]]
    val aliveLogProbs = finishedState("ALIVE_LOG_PROBS").asInstanceOf[Tensor[T]]
    var finishedSeq = finishedState("FINISHED_SEQ").asInstanceOf[Tensor[T]]
    var finishedScores = finishedState("FINISHED_SCORES").asInstanceOf[Tensor[T]]
    val finishedFlags = finishedState("FINISHED_FLAGS").asInstanceOf[Tensor[Boolean]]
    finishedSeq = where(reduceAny(finishedFlags), finishedSeq, aliveSeq)
    finishedScores = where(reduceAny(finishedFlags), finishedScores, aliveLogProbs)
    output = T(finishedSeq, finishedScores)
    output
  }

  override def updateGradInput(input: Table, gradOutput: Activity): Table = {
    gradInput = gradOutput.toTable
    gradInput
  }

  override def clearState(): this.type = {
    super.clearState()
    batchSize = 0
    newFinishedFlags.set()
    aliveLogProbs.set()
    finishedSeq.set()
    aliveSeq.set()
    finishedFlags.set()
    finishedFlagsSeq.set()
    finishedScores.set()
    gatherTensor.set()
    topkSeq.set()
    topkLogProbs.set()
    topkScore.set()
    topkFlags.set()
    topkEncoder.set()
    topkAttentionBias.set()
    for (i <- 1 to  numHiddenLayers) {
      topkLayer("layer_" ++ i.toString ++ "_k").asInstanceOf[Tensor[T]].set()
      topkLayer("layer_" ++ i.toString ++ "_v").asInstanceOf[Tensor[T]].set()
    }
    this
  }
}

object SequenceBeamSearch {
  def apply[@specialized(Float, Double) T: ClassTag](
    vocabSize: Int,
    beamSize: Int,
    alpha: Float,
    maxDecodeLength: Int,
    eosID: Float,
    paddingValue: Float,
    numHiddenLayers: Int,
    hiddenSize: Int)
  (implicit ev: TensorNumeric[T]): SequenceBeamSearch[T] = {
    new SequenceBeamSearch[T](
      vocabSize,
      beamSize,
      alpha,
      maxDecodeLength,
      eosID,
      paddingValue,
      numHiddenLayers,
      hiddenSize)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy