All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.utils.intermediate.IRConverter.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.utils.intermediate

import com.intel.analytics.bigdl.nn.Graph
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.nn.mkldnn._
import com.intel.analytics.bigdl.tensor.{FloatType, Tensor}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.{Module, utils}
import com.intel.analytics.bigdl.utils.{Engine, MklBlas, MklDnn, Node}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag


private[bigdl] class IRConverter[T: ClassTag](IRgraph: IRGraph[T])(implicit ev: TensorNumeric[T]) {
  private val allNodes = new ArrayBuffer[Node[IRElement[T]]]
  private val irInputs = IRgraph.inputs.toArray
  private val irOutputs = IRgraph.outputs.toArray

  init()
  private def init() : Unit = {
    getNodes(irInputs, allNodes)
    // reminder: some output nodes may not be searched from inputs
    irOutputs.foreach(node => {
      if (!allNodes.contains(node)) allNodes.append(node)
    })
  }


  private def getNodes(inputs: Seq[Node[IRElement[T]]],
                       nodesBuffer: ArrayBuffer[Node[IRElement[T]]]): Unit = {
    if (inputs.length == 0) return
    inputs.foreach(node => {
      if (!nodesBuffer.contains(node)) {
        nodesBuffer.append(node)
        getNodes(node.nextNodes, nodesBuffer)
      }
    })
  }

  /**
   * convert IRgraph to blas or dnn graph according to engine type
   * @return dnn graph or blas graph converted from ir graph
   */
  def toGraph() : Graph[T] = {
    if (utils.Engine.getEngineType() == MklBlas) {
      require(IRToBlas[T].convertingCheck(allNodes.toArray),
        "IR graph can not be converted to Blas layer")
      toBlasGraph()
    } else if (utils.Engine.getEngineType() == MklDnn) {
      require(ev.getType() == FloatType, "Mkldnn engine only supports float data")
      require(IRToDnn[Float].convertingCheck(
        allNodes.toArray.asInstanceOf[Array[Node[IRElement[Float]]]]),
        "IR graph can not be converted to Dnn layer")
      toDnnGraph()
    } else throw new UnsupportedOperationException(
      s"Only support engineType mkldnn/mklblas, but get ${Engine.getEngineType()}")
  }

  private def toDnnGraph(): Graph[T] = {
    val nodeMap = IRToDnn[Float].convert(
      allNodes.toArray.asInstanceOf[Array[Node[IRElement[Float]]]])
    val inputs = irInputs.map(
      n => nodeMap.get(n.asInstanceOf[Node[IRElement[Float]]]).get)
    val outputs = irOutputs.map(
      n => nodeMap.get(n.asInstanceOf[Node[IRElement[Float]]]).get)

    // add input node for dnn graph
    val realInputs = inputs.map(n => {
      val node = new Node[Module[Float]](new InputWrapper())
      n.from(node)
      node
    })

    // add output node for graph
    val realOutputs = outputs.zipWithIndex.map {
      case (model: Node[Module[Float]], index: Int) =>
        val node = if (model.element.isInstanceOf[BlasWrapper]) {
          model
        } else {
          model.add(new Node[Module[Float]](Output(IRgraph.outputFormats(index))))
        }
        node
    }

    DnnGraph(realInputs, realOutputs,
      IRgraph.variables.asInstanceOf[Option[(Array[Tensor[Float]], Array[Tensor[Float]])]],
      IRgraph.generateBackward).asInstanceOf[Graph[T]]
  }

  private def toBlasGraph(): Graph[T] = {
    val nodeMap = IRToBlas[T].convert(allNodes.toArray)
    val inputs = irInputs.map(n => nodeMap.get(n).get)
    val outputs = irOutputs.map(n => nodeMap.get(n).get)

    Graph.dynamic(inputs, outputs, IRgraph.variables, IRgraph.generateBackward)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy