All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.pipeline.api.Net.scala Maven / Gradle / Ivy

There is a newer version: 0.6.0
Show newest version
/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.zoo.pipeline.api

import java.nio.ByteOrder

import com.intel.analytics.bigdl.nn.Graph._
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity, Initializable}
import com.intel.analytics.bigdl.nn.keras.KerasLayer
import com.intel.analytics.bigdl.nn.{Container, Graph, InitializationMethod}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.File
import com.intel.analytics.bigdl.utils.caffe.CaffeLoader
import com.intel.analytics.bigdl.utils.serializer.ModuleLoader
import com.intel.analytics.bigdl.utils.tf.{Session, TensorflowLoader}
import com.intel.analytics.zoo.pipeline.api.autograd.Variable
import com.intel.analytics.zoo.pipeline.api.keras.layers.WordEmbedding
import com.intel.analytics.zoo.pipeline.api.keras.models.{KerasNet, Model, Sequential}
import com.intel.analytics.zoo.pipeline.api.net.{GraphNet, NetUtils}

import scala.reflect.ClassTag

/**
 * A placeholder to add layer's utilities
 */
trait Net {

  def isFrozen[T: ClassTag](): Boolean = {
    val labor = this.asInstanceOf[KerasLayer[Activity, Activity, T]].labor
    (labor.getScaleW() == 0) && (labor.getScaleB() == 0)
  }

  /**
   * Build graph: some other modules point to current module
   * @param vars upstream variables
   * @return Variable containing current module
   */
  def from[T: ClassTag](vars : Variable[T]*)(implicit ev: TensorNumeric[T]): Variable[T] = {
    new Variable(
      this.asInstanceOf[AbstractModule[Activity, Activity, T]].inputs(vars.map(_.node): _*))
  }
}

object Net {
  Model
  Sequential
  GraphNet
  WordEmbedding
  def setInitMethod(module: AbstractModule[_, _, _],
      weightInitMethod: InitializationMethod = null,
      biasInitMethod: InitializationMethod = null, throwException: Boolean = true): Unit = {
    module match {
      case i: Initializable =>
        i.setInitMethod(weightInitMethod, biasInitMethod)
      case k: KerasLayer[_, _, _] =>
        setInitMethod(k.labor, weightInitMethod, biasInitMethod, throwException)
      case c: Container[_, _, _] => // Some KerasLayer may be constructed by multiple layers
        c.modules.map { module =>
          setInitMethod(module, weightInitMethod, biasInitMethod, false)
        }
      case _ =>
        if (throwException) {
          throw new RuntimeException(s"$module does not support setInitMethod")
        }
    }
  }
  /**
   * Load model from path.
   *
   * @param path path to save module, local file system, HDFS and Amazon S3 is supported.
   *             HDFS path should be like "hdfs://[host]:[port]/xxx"
   *             Amazon S3 path should be like "s3a://bucket/xxx"
   * @param weightPath : where weight is stored
   * @tparam T numeric type
   * @return An Analytics Zoo model.
   */
  def load[T: ClassTag](path : String,
      weightPath : String = null)(implicit ev: TensorNumeric[T])
  : KerasNet[T] = {
    val model = ModuleLoader.loadFromFile(path, weightPath)
    if (!model.isInstanceOf[KerasNet[T]]) {
      throw new RuntimeException(
        "Not an Analytics Zoo Keras-style model. Please use loadBigDL, loadCaffe or loadTF instead")
    }
    model.asInstanceOf[KerasNet[T]]
  }

  /**
   * Load BigDL model from path.
   *
   * @param path path to save module, local file system, HDFS and Amazon S3 is supported.
   *             HDFS path should be like "hdfs://[host]:[port]/xxx"
   *             Amazon S3 path should be like "s3a://bucket/xxx"
   * @param weightPath : where weight is stored
   * @tparam T numeric type
   * @return model loaded from path
   */
  def loadBigDL[T: ClassTag](path : String,
      weightPath : String = null)(implicit ev: TensorNumeric[T])
  : GraphNet[T] = {
    val graph = ModuleLoader.loadFromFile(path, weightPath).toGraph()
    new GraphNet(graph)
  }

  /**
   * Load Torch model from path.
   *
   * @param path path to load module, local file system, HDFS and Amazon S3 is supported.
   *             HDFS path should be like "hdfs://[host]:[port]/xxx"
   *             Amazon S3 path should be like "s3a://bucket/xxx"
   * @tparam T numeric type
   * @return model loaded from path
   */
  def loadTorch[T: ClassTag](path : String)(implicit ev: TensorNumeric[T]):
  GraphNet[T] = {
    val graph = File.loadTorch[AbstractModule[Activity, Activity, T]](path).toGraph()
    new GraphNet[T](graph)
  }

  /**
   * Loaf caffe trained model from prototxt and weight files
   * @param defPath  caffe model definition file path
   * @param modelPath caffe model binary file containing weight and bias
   */
  def loadCaffe[T: ClassTag](defPath: String, modelPath: String)(
      implicit ev: TensorNumeric[T]): GraphNet[T] = {
    val graph = CaffeLoader.loadCaffe[T](defPath, modelPath)._1
      .asInstanceOf[Graph[T]]
    new GraphNet[T](graph)
  }

  /**
   * Load tensorflow model from its saved protobuf file.
   * @param graphFile where is the protobuf model file
   * @param inputs input node names
   * @param outputs output node names, the output tensor order is same with the node order
   * @param byteOrder byte order in the tensorflow file. The default value is little endian
   * @param binFile where is the model variable file
   * @return model loaded from path
   */
  def loadTF[T: ClassTag](graphFile: String, inputs: Seq[String], outputs: Seq[String],
      byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN,
      binFile: Option[String] = None)(
      implicit ev: TensorNumeric[T]): GraphNet[T] = {

    val graph = TensorflowLoader.load(graphFile, inputs, outputs, byteOrder, binFile)
      .asInstanceOf[Graph[T]]
    new GraphNet[T](graph)
  }

  /**
   * Load TensorFlow model from exported folder.
   * @param folder The folder path which contains 'frozen_inference_graph.pb' and
   *               'graph_meta.json'.
   * @return model loaded from path
   */
  def loadTF[T: ClassTag](folder: String)
      (implicit ev: TensorNumeric[T]): GraphNet[T] = {
    val (model, meta) = NetUtils.processTFFolder(folder)
    loadTF[T](model, NetUtils.removePort(meta.inputNames), NetUtils.removePort(meta.outputNames))
  }

  /**
   * Load tensorflow checkpoints
   * @param graphFile
   * @param binFile
   * @tparam T
   * @return
   */
  def loadTFCheckpoints[T: ClassTag](graphFile: String, binFile: String,
      byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN)(
      implicit ev: TensorNumeric[T]): Session[T] = {
    TensorflowLoader.checkpoints(graphFile, binFile, byteOrder)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy