All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.mkldnn.SoftMax.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.nn.mkldnn

import com.intel.analytics.bigdl.mkl.{DataType, Memory, MklDnn, PropKind, Query, Stream => DnnStream}
import com.intel.analytics.bigdl.nn
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.nn.mkldnn.Phase.{InferencePhase, TrainingPhase}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.tensor.{DenseType, Tensor}
import com.intel.analytics.bigdl.utils.Shape

import scala.collection.mutable.ArrayBuffer

class SoftMax(val axis: Int = -1) extends MklDnnLayer {
  @transient private var updateOutputTensors: Array[Tensor[Float]] = _
  @transient private var updateOutputMemoryPrimitives: Array[Long] = _
  @transient private var updateGradInputTensors: Array[Tensor[Float]] = _
  @transient private var updateGradInputMemoryPrimitives: Array[Long] = _
  @transient private var modelPhase: Phase = null

  private var defaultAxis = 0

  private def initPhase(phase: Phase): Unit = {
    if (phase != null) return modelPhase = phase
    isTraining() match {
      case true =>
        modelPhase = TrainingPhase
      case false =>
        modelPhase = InferencePhase
    }
  }

  private def format(shape: Array[Int]): Int = {
    shape.length match {
      case 2 => Memory.Format.nc
      case 4 => Memory.Format.nchw
      case _ => throw new UnsupportedOperationException(s"${getName()} unsupported input shape")
    }
  }

  override private[mkldnn] def initFwdPrimitives(inputs: Array[MemoryData], phase: Phase) = {
    initPhase(phase)
    defaultAxis = inputs(0).shape.length match {
      case 1 => 0
      case 2 => 1
      case 3 => 0
      case 4 => 1
      case _ => throw new UnsupportedOperationException("1D, 2D, 3D or 4D tensor expected")
    }

    _inputFormats = Array(NativeData(inputs(0).shape, inputs(0).layout, DataType.F32))

    val localInputFormat = if (inputs(0).shape.length == 3 &&
      inputs(0).layout == Memory.Format.ntc) {
      // note: here, the format and the true memory layout is not consistent.
      // for ntc input, we should reshape the `shape` and make the format to tnc
      val shape = Array(inputs(0).shape(1), inputs(0).shape(0), inputs(0).shape(2))
      NativeData(shape, Memory.Format.tnc)
    } else {
      _inputFormats(0)
    }

    val desc = MklDnnMemory.SoftMaxForwardDescInit(PropKind.Forward,
      localInputFormat.getMemoryDescription(), if (axis == -1) defaultAxis else axis)
    val forwardPrimDesc = MklDnnMemory.PrimitiveDescCreate(desc, runtime.engine, 0L)

    _outputFormats = if (inputs(0).shape.length ==3 &&
      inputs(0).layout == Memory.Format.ntc) {
      // because set the input format as tnc first, we should set the output to ntc.
      Array(NativeData(inputs(0).shape, Memory.Format.ntc))
    } else {
      Array(MemoryData.primitiveOutput(forwardPrimDesc))
    }

    val srcs = Array(inputs(0).getPrimitive(runtime))
    val indexes = Array(0)
    val dsts = Array(_outputFormats(0).getPrimitive(runtime))

    val primitive = MklDnnMemory.PrimitiveCreate2(forwardPrimDesc, srcs, indexes,
      srcs.length, dsts, dsts.length)

    updateOutputPrimitives = Array(primitive)
    updateOutputMemoryPrimitives = srcs ++ dsts

    output = initTensor(_outputFormats(0))

    (_inputFormats, _outputFormats)
  }

  override private[mkldnn] def initBwdPrimitives(grad: Array[MemoryData], phase: Phase) = {
    val desc = MklDnnMemory.SoftMaxBackwardDescInit(PropKind.Backward,
      grad(0).getMemoryDescription(), outputFormats()(0).getMemoryDescription(),
      if (axis == -1) defaultAxis else axis)
    val primDesc = MklDnnMemory.PrimitiveDescCreate(desc, runtime.engine, 0L)

    _gradOutputFormats = grad.clone()
    _gradInputFormats = Array(MemoryData.operationWant(primDesc, Query.DiffSrcPd))

    val srcs = Array(outputFormats()(0).getPrimitive(runtime), grad(0).getPrimitive(runtime))
    val indexes = Array(0)
    val dsts = Array(_gradInputFormats(0).getPrimitive(runtime))

    val primitive = MklDnnMemory.PrimitiveCreate2(primDesc, srcs, indexes,
      srcs.length, dsts, dsts.length)

    updateGradInputPrimitives = Array(primitive)
    updateGradInputMemoryPrimitives = srcs ++ dsts

    gradInput = initTensor(_gradInputFormats(0))

    (_gradInputFormats, _gradOutputFormats)
  }

  override def updateOutput(input: Activity): Activity = {
    if (updateOutputTensors == null) {
      val buffer = new ArrayBuffer[Tensor[Float]]()
      buffer.append(input.asInstanceOf[Tensor[Float]])
      buffer.append(output.asInstanceOf[Tensor[Float]])
      updateOutputTensors = buffer.toArray
    }

    input.toTensor[Float].getTensorType match {
      case DenseType => updateOutputTensors(0) = input.toTensor
      case _ =>
    }

    MklDnnOps.streamSubmit(runtime.stream, 1,
      updateOutputPrimitives,
      updateOutputPrimitives.length,
      updateOutputMemoryPrimitives, updateOutputTensors)
    output
  }

  override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
    if (updateGradInputTensors == null) {
      val buffer = new ArrayBuffer[Tensor[Float]]()
      buffer.append(output.asInstanceOf[Tensor[Float]])
      buffer.append(gradOutput.asInstanceOf[Tensor[Float]])
      buffer.append(gradInput.asInstanceOf[Tensor[Float]])

      updateGradInputTensors = buffer.toArray
    }

    gradOutput.toTensor[Float].getTensorType match {
      case DenseType => updateGradInputTensors(1) = gradOutput.toTensor
      case _ =>
    }

    MklDnnOps.streamSubmit(runtime.stream, 1,
      updateGradInputPrimitives,
      updateGradInputPrimitives.length,
      updateGradInputMemoryPrimitives, updateGradInputTensors
    )

    gradInput
  }

  override def computeOutputShape(inputShape: Shape): Shape = {
    inputShape
  }
}

object SoftMax {
  def apply(axis: Int = -1)(implicit ev: TensorNumeric[Float]): SoftMax = {
    new SoftMax(axis)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy