All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.utils.intermediate.IRElement.scala Maven / Gradle / Ivy

/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.utils.intermediate

import com.intel.analytics.bigdl.nn.MklInt8Convertible
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity, DataFormat, TensorModule}
import com.intel.analytics.bigdl.optim.Regularizer
import com.intel.analytics.bigdl.tensor.{Tensor, TensorNumericMath}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric

import scala.reflect.ClassTag

sealed class IROperator[T: ClassTag] extends Serializable {
  val numerics: TensorNumeric[T] = getNumerics(scala.reflect.classTag[T])
  final def getNumerics[T](tag: ClassTag[T]) : TensorNumeric[T] = {
    tag match {
      case ClassTag.Float => TensorNumeric.NumericFloat.asInstanceOf[TensorNumeric[T]]
      case ClassTag.Double => TensorNumeric.NumericDouble.asInstanceOf[TensorNumeric[T]]
      case _ => throw new IllegalArgumentException(s"not supported class tag: ${tag}")
    }
  }
  def getClassTagNumerics() : (Array[ClassTag[_]], Array[TensorNumeric[_]]) = {
    (Array(scala.reflect.classTag[T]), Array(numerics))
  }
  def name: String = this.getClass.getSimpleName
}

case class IRSpatialMaxPooling[T: ClassTag](
            kW: Int, kH: Int,
            dW: Int = 1, dH: Int = 1,
            padW: Int = 0, padH: Int = 0,
            format: DataFormat = DataFormat.NCHW, ceilMode: Boolean = false) extends IROperator[T]

case class IRSpatialAveragePooling[T: ClassTag](
            kW: Int, kH: Int,
            dW: Int = 1, dH: Int = 1,
            padW: Int = 0, padH: Int = 0,
            globalPooling: Boolean = false,
            ceilMode: Boolean = false, countIncludePad: Boolean = true,
            divide: Boolean = true, format: DataFormat = DataFormat.NCHW) extends IROperator[T]

case class IRSpatialConvolution[T: ClassTag](
            nInputPlane: Int, nOutputPlane: Int,
            kernelW: Int, kernelH: Int,
            strideW: Int = 1, strideH: Int = 1,
            padW: Int = 0, padH: Int = 0,
            nGroup: Int = 1, propagateBack: Boolean = true,
            wRegularizer: Regularizer[T] = null, bRegularizer: Regularizer[T] = null,
            initWeight: Tensor[T] = null, initBias: Tensor[T] = null,
            initGradWeight: Tensor[T] = null, initGradBias: Tensor[T] = null,
            withBias: Boolean = true, format: DataFormat = DataFormat.NCHW) extends IROperator[T]

case class IRSpatialShareConvolution[T: ClassTag](
            nInputPlane: Int, nOutputPlane: Int,
            kernelW: Int, kernelH: Int,
            strideW: Int = 1, strideH: Int = 1,
            padW: Int = 0, padH: Int = 0,
            nGroup: Int = 1, propagateBack: Boolean = true,
            wRegularizer: Regularizer[T] = null, bRegularizer: Regularizer[T] = null,
            initWeight: Tensor[T] = null, initBias: Tensor[T] = null,
            initGradWeight: Tensor[T] = null, initGradBias: Tensor[T] = null,
            withBias: Boolean = true, format: DataFormat = DataFormat.NCHW) extends IROperator[T]

case class IRSpatialBatchNormalization[T: ClassTag](
            nOutput: Int, eps: Double = 1e-5, momentum: Double = 0.1,
            affine: Boolean = true,
            initWeight: Tensor[T] = null, initBias: Tensor[T] = null,
            initGradWeight: Tensor[T] = null, initGradBias: Tensor[T] = null,
            dataFormat: DataFormat = DataFormat.NCHW,
            runningMean: Tensor[T] = null, runningVar: Tensor[T] = null) extends IROperator[T]

case class IRIdentity[T: ClassTag]() extends IROperator[T]

case class IRReLU[T: ClassTag](ip: Boolean = false) extends IROperator[T]

case class IRLinear[T: ClassTag](
            inputSize: Int,
            outputSize: Int,
            withBias: Boolean = true,
            wRegularizer: Regularizer[T] = null,
            bRegularizer: Regularizer[T] = null,
            initWeight: Tensor[T] = null,
            initBias: Tensor[T] = null,
            initGradWeight: Tensor[T] = null,
            initGradBias: Tensor[T] = null) extends IROperator[T]

case class IRSpatialCrossMapLRN[T: ClassTag](
            size: Int = 5,
            alpha: Double = 1.0,
            beta: Double = 0.75,
            k: Double = 1.0,
            format: DataFormat = DataFormat.NCHW) extends IROperator[T]

case class IRSoftMax[T: ClassTag]() extends IROperator[T]

case class IRSelectTable[T: ClassTag](dimension: Int) extends IROperator[T]

case class IRCAddTable[T: ClassTag, D: ClassTag](inplace: Boolean = false) extends IROperator[T] {
  private val ev = getNumerics(scala.reflect.classTag[T])
  private val ev2 = getNumerics(scala.reflect.classTag[D])

  override def getClassTagNumerics() : (Array[ClassTag[_]], Array[TensorNumeric[_]]) = {
    (Array[ClassTag[_]](scala.reflect.classTag[T], scala.reflect.classTag[D]),
      Array[TensorNumeric[_]](ev, ev2))
  }
}

case class IRJoinTable[T: ClassTag](dimension: Int,
                                    nInputDims: Int = 0) extends IROperator[T]

case class IRConcatTable[T: ClassTag]() extends IROperator[T]

case class IRInput[T: ClassTag]() extends IROperator[T]

/**
 * if blas module has no corresponding IROperator,
 * then we can use IRGeneralModule to wrap this layer to IROperator
 * @param model
 */
case class IRGeneralModule[T: ClassTag](
             model: AbstractModule[Activity, Activity, T]) extends IROperator[T]

private[bigdl] class IRElement[T: ClassTag](
  val name: String,
  val op: IROperator[T],
  private var weights: Tensor[T] = null,
  private var gradWeights: Tensor[T] = null) extends Serializable with MklInt8Convertible {

  /**
   * set weight and bias
   */
  def setWeights(weightsAndBias: Tensor[T]) : Unit = {
    weights = weightsAndBias
  }

  /**
   * set gradWeight and gradbias
   */
  def setGradWeights(gradWeightsAndGradBias: Tensor[T]) : Unit = {
    gradWeights = gradWeightsAndGradBias
  }

  def getParameters(): (Tensor[T], Tensor[T]) = (weights, gradWeights)

  def getName() : String = this.name

  def getOp() : IROperator[T] = this.op
}

object IRElement {
  /**
   * create IRElement
   * @param name element name
   * @param op element operation, like IRSpatialMaxPooling, IRBlasModule, etc.
   * @param weights weights & bias for IRElement
   * @param gradWeights gradWeight & gradbias for IRElement
   * @tparam T
   * @return
   */
  def apply[T: ClassTag](name: String, op: IROperator[T],
                         weights: Tensor[T] = null, gradWeights: Tensor[T] = null): IRElement[T] =
    new IRElement[T](name, op, weights, gradWeights)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy