All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.common.TensorOperation.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.zoo.common

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric

import scala.reflect.ClassTag

object TensorOperation {

  def expandSize[T: ClassTag](tensor: Tensor[T], other: Tensor[T]): Array[Int] = {
    val errorMsg = s"tensor size not match ${tensor.size.mkString("x")} " +
      s"${other.size.mkString("x")}"
    val longTensor = if (tensor.dim() > other.dim()) tensor else other
    val shortTensor = if (tensor.dim() > other.dim()) other else tensor
    val ndim = longTensor.nDimension()
    val delta = longTensor.nDimension() - shortTensor.nDimension()
    val size = new Array[Int](ndim)
    var i = ndim - 1
    while (i >= delta) {
      require(longTensor.size(i + 1) == shortTensor.size(i + 1 - delta) ||
        longTensor.size(i + 1) == 1 ||
        shortTensor.size(i + 1 - delta) == 1, errorMsg)
      size(i) = math.max(longTensor.size(i + 1), shortTensor.size(i + 1 - delta))
      i -= 1
    }

    while (i >= 0) {
      size(i) = longTensor.size(i + 1)
      i -= 1
    }

    size
  }

  def expandTensor[T: ClassTag](tensor: Tensor[T], tensor2: Tensor[T])
                               (implicit ev: TensorNumeric[T]): Tensor[T] = {
    val targetSize = expandSize(tensor, tensor2)
    val expandStrides = new Array[Int](targetSize.length)

    val expandStridesX = new Array[Int](targetSize.length)
    var i = targetSize.length - 1
    val delta2 = targetSize.length - tensor2.nDimension
    while(i >= delta2) {
      if (tensor2.size(i + 1- delta2) != 1) expandStridesX(i) = tensor2.stride(i + 1- delta2)
      i -= 1
    }
    val expandX = Tensor[T](
      tensor2.storage(),
      tensor2.storageOffset(),
      targetSize,
      expandStridesX
    )
    if (targetSize.product != tensor.nElement()) {
      i = targetSize.length - 1
      val delta1 = targetSize.length - tensor.nDimension
      while (i >= delta1) {
        if (tensor.size(i + 1 - delta1) != 1) expandStrides(i) = tensor.stride(i + 1 - delta1)
        i -= 1
      }
      val tensor1 = Tensor[T](
        tensor.storage,
        tensor.storageOffset(),
        targetSize,
        expandStrides
      )
      val newTensor = Tensor[T]().resize(targetSize).add(tensor1)
      tensor.set(newTensor)
    }
    expandX
  }

  def subTensor[T: ClassTag](tensor: Tensor[T], tensor2: Tensor[T])
    (implicit ev: TensorNumeric[T]): Tensor[T] = {
    val expandedTensor = expandTensor(tensor, tensor2).contiguous()
    tensor.sub(expandedTensor)
    tensor
  }

  def divTensor[T: ClassTag](tensor: Tensor[T], tensor2: Tensor[T])
    (implicit ev: TensorNumeric[T]): Tensor[T] = {
    val expandedTensor = expandTensor(tensor, tensor2).contiguous()
    tensor.div(expandedTensor)
    tensor
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy