All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.LookupTable.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.nn

import breeze.numerics.{abs, pow}
import com.intel.analytics.bigdl.nn.abstractnn.{Initializable, TensorModule}
import com.intel.analytics.bigdl.optim.Regularizer
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.RandomGenerator._
import com.intel.analytics.bigdl.utils.{Shape, T, Table}

import scala.reflect.ClassTag

/**
 * This layer is a particular case of a convolution, where the width of the convolution would be 1.
 * Input should be a 1D or 2D tensor filled with indices. Indices are corresponding to the position
 * in weight. For each index element of input, it outputs the selected index part of weight.
 * Elements of input should be in range of (1, nIndex)
 * This layer is often used in word embedding.
 * @param nIndex Indices of input row
 * @param nOutput the last dimension size of output
 * @param paddingValue padding value, default 0
 * @param maxNorm max norm, defalt Double.MaxValue
 * @param normType norm regularization number, default 2
 * @param shouldScaleGradByFreq
 * @tparam T The numeric type in the criterion, usually which are [[Float]] or [[Double]]
 * @param wRegularizer: instance of [[Regularizer]]
 *                    (eg. L1 or L2 regularization), applied to the input weights matrices.
 * @param maskZero: if maskZero is set to true, the input whose value equals `paddingValue`
 *                the output will be masked to zero vector.
 */
@SerialVersionUID( - 4832171200145114633L)
class LookupTable[T: ClassTag]
(val nIndex: Int, val nOutput: Int, val paddingValue: Double = 0,
  val maxNorm: Double = Double.MaxValue,
  val normType: Double = 2.0,
  shouldScaleGradByFreq: Boolean = false,
  var wRegularizer: Regularizer[T] = null,
  val maskZero: Boolean = false
)
(implicit ev: TensorNumeric[T]) extends TensorModule[T] with Initializable {

  var weight = Tensor[T](nIndex, nOutput)
  var gradWeight = Tensor[T](nIndex, nOutput).zero()

  private var inputBuffer = Tensor[T]()
  private var normBuffer = Tensor[T]()
  private val countBuffer = Tensor[T]()

  {
    val wInit = RandomNormal(0, 1)
    setInitMethod(weightInitMethod = wInit)
  }

  override def reset(): Unit = {
    weightInitMethod.init(weight, VariableFormat.Default)
  }

  private def renorm(input : Tensor[T]): Unit = {
    if (Double.MaxValue == maxNorm) {
      return
    }
    normBuffer.resize(input.size()).copy(input)
    if (normBuffer.dim() == 2) {
      normBuffer = normBuffer.view(normBuffer.nElement())
    }
    require(weight.isContiguous(), "LookupTable: weight must be contiguous")
    require(normBuffer.isContiguous(), "LookupTable: input must be contiguous")
    require(normBuffer.nDimension() == 1, "LookupTable: idx must be a vector")
    require(normType > 0, "LookupTable: non-positive-norm not supported")

    val rowIdx = normBuffer.storage().array()
    val rowOffset = normBuffer.storageOffset() - 1
    var numEle = normBuffer.nElement()
    val stride = weight.stride(1)

    val gw = weight.storage().array()
    val gw_offset = weight.storageOffset() - 1

    var i = 0
    while (i < numEle) {
      require(ev.isGreater(ev.fromType(weight.size(1) + 1), rowIdx(i + rowOffset)),
        s"LookupTable: elements of input should be little than or equal to $nIndex + 1")
      require(ev.isGreaterEq(rowIdx(i + rowOffset), ev.one),
        "LookupTable: elements of input should be greater than or equal to 1")
      i += 1
    }

    implicit val ord = Ordering.fromLessThan[T]((e1, e2) => (ev.isGreater(e1, e2)))
    scala.util.Sorting.quickSort(rowIdx)

    var ptr = 0
    i = 0
    while (i < numEle) {
      if (i == 0 || rowIdx(i + rowOffset) != rowIdx(i - 1 + rowOffset)) {
        rowIdx(ptr + rowOffset) = rowIdx(i + rowOffset)
        ptr += 1
      }
      i += 1
    }
    numEle = ptr

    i = 0
    while (i < numEle) {
      val k = ev.toType[Int](rowIdx(i + rowOffset)) - 1
      renormRow(gw, k * stride + gw_offset, stride, maxNorm, normType)
      i += 1
    }
  }

  private def renormRow(row_data: Array[T], offset: Int, stride: Int,
                        maxNorm: Double, normType: Double): Unit = {
    var norm = 0.0
    var j = 0
    while (j < stride) {
      if (normType == 1) {
        norm += ev.toType[Double](ev.abs(row_data(j + offset)))
      } else if (normType == 2) {
        norm += ev.toType[Double](ev.times(row_data(j + offset), row_data(j + offset)))
      } else {
        norm += math.pow(abs(ev.toType[Double](row_data(j + offset))), normType)
      }
      j += 1
    }
    norm = pow(norm, 1.0 / normType)

    // Keep the norm of weight smaller than maxNorm
    if (norm > maxNorm) {
      val new_norm = maxNorm / (norm + 1e-7)
      j = 0
      while (j < stride) {
        row_data(j + offset) = ev.times(row_data(j + offset), ev.fromType(new_norm))
        j += 1
      }
    }
  }

  private def resetCount(count: Tensor[T], input: Tensor[T]): Unit = {
    var i = 1
    val numEle = input.nElement()

    while (i <= numEle) {
      val k = ev.toType[Int](input.valueAt(i))
      count.update(k, ev.zero)
      i += 1
    }

    i = 1
    while (i <= numEle) {
      val k = ev.toType[Int](input.valueAt(i))
      count.update(k, ev.plus(count.valueAt(k), ev.one))
      i += 1
    }
  }

  override def updateOutput(input: Tensor[T]): Tensor[T] = {
    if (maskZero && paddingValue != 0) {
      weight.select(1, paddingValue.toInt).zero()
    }
    require(input.dim() == 1 || input.dim() == 2,
      s"LookupTable: ${ErrorInfo.constrainInputAsVectorOrBatch}, input dim [${input.dim()}]"  )
    renorm(input)
    inputBuffer = input.contiguous()
    try {
      if (inputBuffer.dim() == 1) {
        output.index(1, inputBuffer, weight)
      } else if (inputBuffer.dim() == 2) {
        output.index(1, inputBuffer.view(inputBuffer.nElement()), weight)
        output = output.view(inputBuffer.size(1), inputBuffer.size(2), weight.size(2))
      }
    } catch {
      case e: IllegalArgumentException =>
        throw new IllegalArgumentException(
          s"LookupTable updateOutput get exception:${e.getMessage}\n" +
          s"please ensure elements of your input will not exceed ${nIndex}")
      case e: Exception =>
        throw e
    }

    output
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    if (!gradInput.isSameSizeAs(input)) {
      gradInput.resizeAs(input).zero()
    }
    gradInput
  }

  override def accGradParameters(input: Tensor[T], gradOutput: Tensor[T]): Unit = {
    inputBuffer = input.contiguous()
    require(gradWeight.isContiguous(), "LookupTable: gradWeight must be contiguous")
    require(inputBuffer.dim() == 1 || inputBuffer.dim() == 2,
      s"LookupTable: input must be a vector or matrix, input dim ${inputBuffer.dim()}" )

    if (inputBuffer.dim() == 2) {
      inputBuffer.view(inputBuffer.nElement())
    }
    val _gradOutput = gradOutput.contiguous()
    var count_data : Array[T] = null
    if (shouldScaleGradByFreq) {
      countBuffer.resize(gradWeight.size(1))
      resetCount(countBuffer, inputBuffer)
      count_data = countBuffer.storage().array()
    }

    val input_data = inputBuffer.storage().array()
    val input_offset = inputBuffer.storageOffset() - 1
    val numEle = inputBuffer.nElement()

    var i = 0
    while (i < numEle) {
      require(ev.isGreater(ev.fromType(gradWeight.size(1) + 1), input_data(i + input_offset)),
        s"LookupTable: elements of input should be little than or equal to $nIndex + 1")
      require(ev.isGreaterEq(input_data(i + input_offset), ev.one),
        "LookupTable: elements of input should be greater than or equal to 1")
      i += 1
    }
    if (scaleW != 0) {
      val gw = gradWeight.storage().array()
      val go = _gradOutput.storage().array()
      val stride = gradWeight.stride(1)

      i = 0
      while (i < numEle) {
        if (input_data(i + input_offset) != paddingValue) {
          val k = ev.toType[Int](input_data(i + input_offset)) - 1
          val scale_ = if (null != count_data) scaleW /
            ev.toType[Double](count_data(k)) else scaleW
          ev.axpy(stride, ev.fromType(scale_), go, i * stride + _gradOutput.storageOffset() - 1, 1,
            gw, k * stride + gradWeight.storageOffset() - 1, 1)
        }
        i += 1
      }

      if (null != wRegularizer) {
        wRegularizer.accRegularization(weight, gradWeight, scaleW)
      }
    }
  }

  override def toString(): String = {
    val s = s"${getPrintName}" +
      s"(nIndex=$nIndex,nOutput=$nOutput,paddingValue=$paddingValue,normType=$normType"
    if (maxNorm == Double.MaxValue) {
      s + ")"
    } else {
      s + s" ,maxNorm=$maxNorm)"
    }
  }

  override def parameters(): (Array[Tensor[T]], Array[Tensor[T]]) = {
    (Array(this.weight), Array(this.gradWeight))
  }

  override def clearState() : this.type = {
    super.clearState()
    inputBuffer.set()
    countBuffer.set()
    normBuffer.set()
    this
  }

  override def canEqual(other: Any): Boolean = other.isInstanceOf[LookupTable[T]]

  override def equals(other: Any): Boolean = other match {
    case that: LookupTable[T] =>
      super.equals(that) &&
        (that canEqual this) &&
        weight == that.weight &&
        gradWeight == that.gradWeight &&
        nIndex == that.nIndex &&
        nOutput == that.nOutput &&
        paddingValue == that.paddingValue &&
        maxNorm == that.maxNorm &&
        normType == that.normType
    case _ => false
  }

  override def hashCode(): Int = {
    def getHashCode(a: Any): Int = if (a == null) 0 else a.hashCode()
    val state = Seq(super.hashCode(), weight, gradWeight, nIndex, nOutput,
      paddingValue, maxNorm, normType)
    state.map(getHashCode).foldLeft(0)((a, b) => 31 * a + b)
  }
  override def computeOutputShape(inputShape: Shape): Shape = {
    val _inputSize = inputShape.toSingle().toArray
    if (_inputSize.length == 2) {
      Shape(Array(_inputSize(0), _inputSize(1), nOutput))
    } else Shape(Array(_inputSize(0), nOutput))
  }
}

object LookupTable {
  def apply[@specialized(Float, Double)T: ClassTag](
    nIndex: Int, nOutput: Int,
    paddingValue: Double = 0, maxNorm: Double = Double.MaxValue,
    normType: Double = 2.0, shouldScaleGradByFreq: Boolean = false,
    wRegularizer: Regularizer[T] = null,
    maskZero: Boolean = false
  )
   (implicit ev: TensorNumeric[T]): LookupTable[T] =
    new LookupTable[T](nIndex, nOutput, paddingValue,
      maxNorm, normType, shouldScaleGradByFreq, wRegularizer, maskZero)
}






© 2015 - 2025 Weber Informatics LLC | Privacy Policy