All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.utils.intermediate.IRToBlas.scala Maven / Gradle / Ivy

There is a newer version: 0.11.1
Show newest version
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.utils.intermediate

import com.intel.analytics.bigdl._
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity, TensorModule}
import com.intel.analytics.bigdl.optim.DistriOptimizer._
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.{Node, T}

import scala.collection.mutable
import scala.reflect.ClassTag


abstract class ConvertBase[T, D] {
  /**
   * clone node relations
   * @param nodeMap node element maps from T to D
   */
  def cloneNode(allNodes: Array[Node[T]], nodeMap: mutable.HashMap[Node[T], Node[D]]): Unit = {
    allNodes.foreach(node => {
      node.nextNodesAndEdges.foreach(nextNodeAndEdge => {
        if (nodeMap.contains(nextNodeAndEdge._1)) {
          nodeMap.get(node).get.add(nodeMap.get(nextNodeAndEdge._1).get, nextNodeAndEdge._2)
        }
      })
    })
    // sort previous node
    nodeMap.toArray.foreach(node => {
      // if node has more than one previous nodes, we have to consider nodes order
      if (node._1.prevNodesAndEdges.length > 1) {
        node._2.removePrevEdges()
        node._1.prevNodesAndEdges.foreach(prevNodeAndEdge => {
          if (nodeMap.contains(prevNodeAndEdge._1)) {
            node._2.from(nodeMap.get(prevNodeAndEdge._1).get, prevNodeAndEdge._2)
          }
        })
      }
    })
  }

  def convertLayerCheck(layer: T) : Boolean

  def convertLayer(layer : T) : D

  def convertingCheck(allNodes: Array[Node[T]]) : Boolean = {
    var convert = true
    allNodes.foreach(node => {
      if (!convertLayerCheck(node.element)) {
        logger.info(s"${node.element} convertion failed")
        convert = false
      }
    })
    convert
  }

  def convert(allNodes: Array[Node[T]]): mutable.HashMap[Node[T], Node[D]] = {
    val nodeMap = new mutable.HashMap[Node[T], Node[D]]()
    allNodes.foreach(node => {
      nodeMap.put(node, new Node(convertLayer(node.element)))
    })
    cloneNode(allNodes, nodeMap)
    nodeMap
  }
}

private[bigdl] class IRToBlas[T: ClassTag] extends ConvertBase[IRElement[T], Module[T]]{

  private def className(layer: IRElement[T]): String = {
    val name = layer.getOp().name
    s"com.intel.analytics.bigdl.nn.${name.substring(2)}"
  }

  override def convertLayerCheck(layer: IRElement[T]): Boolean = {
    ReflectionUtils.findClass(className(layer)) != null ||
      layer.getOp().isInstanceOf[IRGeneralModule[T]]
  }

  override def convertLayer(layer : IRElement[T]) : Module[T] = {
    if (layer.getOp().isInstanceOf[IRGeneralModule[T]]) {
      return layer.getOp().asInstanceOf[IRGeneralModule[T]].model
    }
    ReflectionUtils.reflectFromIR(layer, Class.forName(className(layer)))
  }
}

private[bigdl] object IRToBlas {
  def apply[T: ClassTag](implicit ev: TensorNumeric[T]): IRToBlas[T] = new IRToBlas
}






© 2015 - 2024 Weber Informatics LLC | Privacy Policy