All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.View.scala Maven / Gradle / Ivy

There is a newer version: 0.11.1
Show newest version
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.abstractnn.TensorModule
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric

import scala.reflect.ClassTag

/**
 * This module creates a new view of the input tensor using the sizes passed to the constructor.
 * The method setNumInputDims() allows to specify the expected number of dimensions of the inputs
 * of the modules. This makes it possible to use minibatch inputs
 * when using a size -1 for one of the dimensions.
 * @param sizes sizes use for creates a new view
 */
@SerialVersionUID(1238814703013238333L)
class View[T: ClassTag](val sizes: Array[Int])(
  implicit ev: TensorNumeric[T]) extends TensorModule[T] {

  def getSize(): Array[Int] = {
    return sizes
  }

  val numElements = {
    var init = 1
    var inferDim = false
    var i = 0
    while (i < sizes.length) {
      if (sizes(i) >= 0) {
        init *= sizes(i)
      } else {
        require(sizes(i) == -1, "size should be positive or -1")
        require(!inferDim, "only one dimension should be -1")
        inferDim = true
      }
      i += 1
    }

    init
  }

  private var numInputDims = 0

  def setNumInputDims(numInputDims: Int): this.type = {
    this.numInputDims = numInputDims
    this
  }

  private def batchSize(
    input: Tensor[T], size: Array[Int], numberInputDims: Int, numElements: Int): Int = {
    val ind = input.nDimension()
    val isz = input.size()
    val maxDim = if (numberInputDims == 0) ind else numberInputDims

    var ine = 1
    var i = ind - 1
    while (i >= ind - maxDim) {
      ine *= isz(i)
      i -= 1
    }

    require(ine % numElements == 0, "input view doesn't match desired view")

    var bse = ine / numElements

    i = 0
    var break = false
    while (i < sizes.length && !break) {
      if (sizes(i) == -1) {
        bse = 1
        break = true
      }
      i += 1
    }

    i = ind - maxDim - 1
    while (i >= 0) {
      bse *= isz(i)
      i -= 1
    }

    if (bse == 1 && (numberInputDims == 0 || input.nDimension() <= numberInputDims)) {
      -1
    } else {
      bse
    }
  }

  override def updateOutput(input: Tensor[T]): Tensor[T] = {
    val bse = batchSize(input, this.sizes, this.numInputDims, this.numElements)
    if (bse != -1) {
      val newSizes = new Array[Int](this.sizes.length + 1)
      newSizes(0) = bse
      System.arraycopy(this.sizes, 0, newSizes, 1, this.sizes.length)
      this.output = input.view(newSizes)
    } else {
      this.output = input.view(this.sizes)
    }

    this.output
  }

  def this(s: Int*)(implicit ev: TensorNumeric[T]) = this(s.toArray)

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    this.gradInput = gradOutput.view(input.size)
    this.gradInput
  }

  override def toString(): String = {
    s"${getPrintName}(${sizes.mkString("x")})"
  }
}

object View {
  def apply[@specialized(Float, Double) T: ClassTag](
      sizes: Int*)(implicit ev: TensorNumeric[T]) : View[T] = {
    new View[T](sizes.toArray)
  }

  def apply[@specialized(Float, Double) T: ClassTag](
      sizes: Array[Int])(implicit ev: TensorNumeric[T]) : View[T] = {
    new View[T](sizes)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy