All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.InferReshape.scala Maven / Gradle / Ivy

There is a newer version: 0.11.1
Show newest version
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.abstractnn.TensorModule
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Shape

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

/**
 * Reshape the input tensor with automatic size inference support.
 * Positive numbers in the `size` argument are used to reshape the input to the
 * corresponding dimension size.
 * There are also two special values allowed in `size`:
 *    a. `0` means keep the corresponding dimension size of the input unchanged.
 *       i.e., if the 1st dimension size of the input is 2,
 *       the 1st dimension size of output will be set as 2 as well.
 *    b. `-1` means infer this dimension size from other dimensions.
 *       This dimension size is calculated by keeping the amount of output elements
 *       consistent with the input.
 *       Only one `-1` is allowable in `size`.
 *
 * For example,
 *    Input tensor with size: (4, 5, 6, 7)
 *    -> InferReshape(Array(4, 0, 3, -1))
 *    Output tensor with size: (4, 5, 3, 14)
 * The 1st and 3rd dim are set to given sizes, keep the 2nd dim unchanged,
 * and inferred the last dim as 14.
 * @param size      the target tensor size
 * @param batchMode whether in batch mode
 * @tparam T Numeric type ([[Float]] and [[Double]] are allowed)
 */
class InferReshape[T: ClassTag](
  size: Array[Int], var batchMode: Boolean = false)(
  implicit ev: TensorNumeric[T]) extends TensorModule[T] {
  private var inferedSizes: Array[Int] = _
  private var startIndex = 0
  private var inferIndex = -1
  private var subTotal = 1
  private var inPlace = true

  init()

  private def init(): Unit = {
    var minusOneCount = 0
    inferedSizes = if (batchMode) new Array[Int](size.length + 1) else new Array[Int](size.length)
    if (batchMode) startIndex = 1
    var i = 0
    while (i < size.length) {
      if (size(i) == -1) {
        minusOneCount += 1
        inferIndex = i + startIndex
      }
      else if (size(i) != 0) { // use the exact value in given size
        inferedSizes(i + startIndex) = size(i)
        subTotal *= size(i)
      }
      i += 1
    }
    require(minusOneCount == 1, "at most a single value of -1 may be specified")
  }

  override def updateOutput(input: Tensor[T]): Tensor[T] = {
    var total = subTotal
    var i = 0
    while (i < size.length) {
      if (size(i) == 0) { // use the same dim value as input
        inferedSizes(i + startIndex) = input.size(i + 1)
        total *= input.size(i + 1)
      }
      i += 1
    }
    require(total <= input.nElement(), "inferred size " +
      s"dim product must be <= total input #elements" +
      s"dim product($total) input(${input.nElement()})")
    if (inferIndex != -1) {
      inferedSizes(inferIndex) = input.nElement() / total
      if (batchMode) inferedSizes(inferIndex) = inferedSizes(inferIndex) / input.size(1)
    }

    if (batchMode) {
      inferedSizes(0) = input.size(1)
    }

    if (input.isContiguous()) {
      output = input.view(inferedSizes)
    } else {
      output = input.contiguous().view(inferedSizes)
      inPlace = false
    }
    output
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    if (gradOutput.isContiguous()) {
      gradInput = gradOutput.view(input.size())
    } else {
      gradInput = gradOutput.contiguous().view(input.size())
    }
    gradInput
  }

  override def equals(obj: Any): Boolean = {

    if (!super.equals(obj)) {
      return false
    }

    if (!obj.isInstanceOf[InferReshape[T]]) {
      return false
    }
    val other = obj.asInstanceOf[InferReshape[T]]
    if (this.eq(other)) {
      return true
    }

    var i = 0
    while (i < inferedSizes.length) {
      if (inferedSizes(i) != other.inferedSizes(i)) {
        return false
      }
      i += 1
    }
    batchMode == other.batchMode
  }

  override def hashCode(): Int = {
    val seed = 37
    var hash = super.hashCode()
    var i = 0
    while (i < inferedSizes.length) {
      hash = hash * seed + inferedSizes(i).hashCode()
      i += 1
    }
    hash = hash * seed + batchMode.hashCode()

    hash
  }

  override def toString(): String = {
    s"${getPrintName}(${
      size.mkString("x")
    })"
  }

  override def clearState(): this.type = {
    if (!inPlace) {
      super.clearState()
    }
    this
  }

  override def computeOutputShape(inputShape: Shape): Shape = {
    val inputSize = inputShape.toSingle().toArray
    val outputSize = new ArrayBuffer[Int]()
    inferedSizes.foreach(outputSize.append(_))

    var total = subTotal
    var i = 0
    while (i < size.length) {
      if (size(i) == 0) { // use the same dim value as input
        outputSize(i + startIndex) = inputSize(i)
        total *= inputSize(i)
      }
      i += 1
    }
    if (inferIndex != -1) {
      outputSize(inferIndex) = inputSize.product / total
      if (batchMode) outputSize(inferIndex) = outputSize(inferIndex) / inputSize(0)
    }
    if (batchMode) outputSize(0) = inputSize(0)
    Shape(outputSize.toArray)
  }
}

object InferReshape {
  def apply[@specialized(Float, Double) T: ClassTag](size: Array[Int], batchMode: Boolean = false)
    (implicit ev: TensorNumeric[T]): InferReshape[T] =
    new InferReshape(size, batchMode)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy