All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.Reverse.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.abstractnn.TensorModule
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric

import scala.reflect.ClassTag

/**
 * Reverse the input w.r.t given dimension.
 * The input can be a Tensor or Table.
 * @param dim
 * @param ev
 * @tparam T Numeric type. Only support float/double now
 */
class Reverse[T: ClassTag](dim: Int = 1, isInplace: Boolean = false)
  (implicit ev: TensorNumeric[T])
  extends TensorModule[T] {

  var buffer: Tensor[T] = null

  /**
   * reverse the src Tensor and write it to target w.r.t given dim.
   * E.g. src: (1,2,3; 4,5,6) and dim = 1
   *      target: (4,5,6; 1,2,3)
   * @param src
   * @param target
   * @param dim
   * @return
   */
  private def reverseTensor(src: Tensor[T], target: Tensor[T], dim: Int): Tensor[T] = {
    require(dim > 0 && dim <= src.dim,
      s"Reverse: the designated dimension ${dim} to reverse input Tensor" +
        s" is out of index. The input.dim = ${src.dim}")

    val time = src.size(dim)
    target.resizeAs(src)
    var i = 1
    while (i <= time) {
      target.select(dim, i).copy(src.select(dim, time - i + 1))
      i += 1
    }
    target
  }

  /**
   * reverse the src Tensor given dim in-place.
   * E.g. src: (1,2,3; 4,5,6) and dim = 1
   *      => src: (4,5,6; 1,2,3)
   * @param src
   * @param dim
   * @return
   */
  private def reverseTensor(src: Tensor[T], dim: Int): Tensor[T] = {
    require(dim > 0 && dim <= src.dim,
      s"Reverse: the designated dimension ${dim} to reverse input Tensor" +
        s" is out of index. The input.dim = ${src.dim}")
    if (buffer == null) buffer = Tensor[T]()

    val time = src.size(dim)
    val half = time >> 1
    buffer.resizeAs(src.select(dim, 1))
    var i = 1
    while (i <= half) {
      buffer.copy(src.select(dim, time - i + 1))
      src.select(dim, time - i + 1).copy(src.select(dim, i))
      src.select(dim, i).copy(buffer)
      i += 1
    }
    src
  }

  override def updateOutput(input: Tensor[T]): Tensor[T] = {
    if (isInplace) {
      output = reverseTensor(input.toTensor[T], dim)
    } else {
      if (output == null) output = Tensor[T]()
      reverseTensor(input.toTensor[T], output.toTensor[T], dim)
    }
    output
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    if (isInplace) {
      gradInput = reverseTensor(gradOutput.toTensor[T], dim)
    } else {
      if (gradInput == null) gradInput = Tensor[T]()
      reverseTensor(gradOutput.toTensor[T], gradInput.toTensor[T], dim)
    }
    gradInput
  }

  override def toString: String = s"nn.Reverse"

  override def equals(other: Any): Boolean = super.equals(other)

  override def hashCode(): Int = super.hashCode()

  override def canEqual(other: Any): Boolean = other.isInstanceOf[JoinTable[T]]

  override def clearState(): this.type = {
    if (!isInplace) {
      super.clearState()
    }
    this
  }

}

object Reverse {
  def apply[@specialized(Float, Double) T: ClassTag](
    dimension: Int = 1, isInplace: Boolean = false)
    (implicit ev: TensorNumeric[T]) : Reverse[T] = {
    new Reverse[T](dimension, isInplace)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy