All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.Reshape.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity, TensorModule}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.serializer._
import com.intel.analytics.bigdl.utils.serializer.converters.DataConverter
import com.intel.analytics.bigdl.serialization.Bigdl.{AttrValue, BigDLModule}
import com.intel.analytics.bigdl.utils.Shape

import scala.reflect.ClassTag
import scala.reflect.runtime.universe

/**
 * The `forward(input)` reshape the input tensor into a
 * `size(0) * size(1) * ...` tensor, taking the elements row-wise.
 *
 * @param size the reshape size
 * @param batchMode It is a optional argument. If it is set to `Some(true)`,
 *                  the first dimension of input is considered as batch dimension,
 *                  and thus keep this dimension size fixed. This is necessary
 *                  when dealing with batch sizes of one. When set to `Some(false)`,
 *                  it forces the entire input (including the first dimension) to be reshaped
 *                  to the input size. Default is `None`, which means the module considers
 *                  inputs with more elements than the product of provided sizes (size(0) *
 *                  size(1) * ..) to be batches, otherwise in no batch mode.
 *
 */
@SerialVersionUID(- 830146931795053244L)
class Reshape[T: ClassTag](
  val size: Array[Int], var batchMode: Option[Boolean] = None)(
  implicit ev: TensorNumeric[T]) extends TensorModule[T]  {
  val batchSize = new Array[Int](size.length + 1)
  var nElement: Int = 1
  for (i <- 1 to size.length) {
    batchSize(i) = size(i - 1)
    nElement *= size(i - 1)
  }

  // whether share the storage between input and output
  // in this layer, if input is contiguous, inplace is true. otherwise, inplace is false
  private var inplace: Boolean = true

  override def updateOutput(input: Tensor[T]): Tensor[T] = {

    if ((batchMode.nonEmpty && !batchMode.get) ||
          (input.nElement() == nElement && batchMode.isEmpty && input.size(1) != 1)) {
      require(input.nElement() == nElement, s"element number must match Reshape size. " +
        s"But In ${this.getName()} : element number is: ${ input.nElement() } , " +
        s"reshape size is: ${nElement}")
      if (input.isContiguous()) output =
        input.view(size)
      else {
        output = input.contiguous().view(size)
        inplace = false
      }
    }
    else {
      require(input.nElement() == nElement * input.size(1),
        s"element number must match Reshape size. " +
          s"But In ${this.getName()} : element number is: ${ input.nElement() } , " +
          s"reshape size is: ${ nElement * input.size(1) }")
      batchSize(0) = input.size(1)
      if (input.isContiguous()) {
        output = input.view(batchSize)
      } else {
        output = input.contiguous().view(batchSize)
        inplace = false
      }
    }
    output
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    if (gradOutput.isContiguous()) {
      gradInput = gradOutput.view(input.size())
    } else {
      gradInput = gradOutput.contiguous().view(input.size())
    }
    gradInput
  }

  override def equals(obj: Any): Boolean = {

    if (!super.equals(obj)) {
      return false
    }

    if (!obj.isInstanceOf[Reshape[T]]) {
      return false
    }
    val other = obj.asInstanceOf[Reshape[T]]
    if (this.eq(other)) {
      return true
    }

    var i = 0
    while (i < batchSize.length) {
      if (batchSize(i) != other.batchSize(i)) {
        return false
      }
      i += 1
    }
    nElement == other.nElement &&
      batchMode == other.batchMode
  }

  override def hashCode() : Int = {
    val seed = 37
    var hash = super.hashCode()
    var i = 0
    while (i < batchSize.length) {
      hash = hash * seed + batchSize(i).hashCode()
      i += 1
    }
    hash = hash * seed + nElement.hashCode()
    hash = hash * seed + batchMode.hashCode()

    hash
  }

  override def toString(): String = {
    s"${getPrintName}(${size.mkString("x")})"
  }

  override def clearState(): this.type = {
    if (!inplace) {
      super.clearState()
    }
    this
  }

  override def computeOutputShape(inputShape: Shape): Shape = {
    val input = inputShape.toSingle().toArray
    val output = if ((batchMode.nonEmpty && !batchMode.get) ||
        (input.product == nElement && batchMode.isEmpty && input(0) != 1)) {
      size
    } else {
      Array(input(0)) ++ batchSize.slice(1, batchSize.length)
    }
    Shape(output)
  }
}

object Reshape extends ModuleSerializable {
  def apply[T: ClassTag](
      size: Array[Int],
      batchMode: Option[Boolean] = None)(implicit ev: TensorNumeric[T]) : Reshape[T] = {
    new Reshape[T](size, batchMode)
  }

  override def doLoadModule[T: ClassTag](context: DeserializeContext)
    (implicit ev: TensorNumeric[T]) : AbstractModule[Activity, Activity, T] = {

    val attrMap = context.bigdlModule.getAttrMap
    val size = DataConverter.getAttributeValue(context, attrMap.get("size")).
      asInstanceOf[Array[Int]]
    val batchModeV = DataConverter.getAttributeValue(context, attrMap.get("batchMode")).
      asInstanceOf[Int]
    var batchMode : Option[Boolean] = None
    if (batchModeV == 1) {
      batchMode = Some(false)
    } else if (batchModeV == 2) {
      batchMode = Some(true)
    }
    Reshape(size, batchMode).asInstanceOf[AbstractModule[Activity, Activity, T]]
  }

  override def doSerializeModule[T: ClassTag](context: SerializeContext[T],
                                           reshapeBuilder : BigDLModule.Builder)
                                           (implicit ev: TensorNumeric[T]) : Unit = {

    val reshape = context.moduleData.module.asInstanceOf[Reshape[T]]

    val sizeBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, sizeBuilder, reshape.size,
      universe.typeOf[Array[Int]])
    reshapeBuilder.putAttr("size", sizeBuilder.build)

    var batchMode = 0
    if (reshape.batchMode != None) {
      batchMode = if (reshape.batchMode.get == false) 1 else 2
    }
    val batchModeBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, batchModeBuilder, batchMode,
      universe.typeOf[Int])
    reshapeBuilder.putAttr("batchMode", batchModeBuilder.build)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy