All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.examples.nnframes.finetune.TransferLearning.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.zoo.examples.nnframes.finetune

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.{CrossEntropyCriterion, Graph}
import com.intel.analytics.bigdl.optim.{Top1Accuracy, Trigger}
import com.intel.analytics.bigdl.utils.{LoggerFilter, Shape}
import com.intel.analytics.zoo.common.NNContext
import com.intel.analytics.zoo.pipeline.api.keras.layers._
import com.intel.analytics.zoo.pipeline.api.keras.models.Model
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.NumericFloat
import com.intel.analytics.zoo.feature.image.{RowToImageFeature, _}
import com.intel.analytics.zoo.pipeline.api.Net
import com.intel.analytics.zoo.pipeline.nnframes.{NNClassifier, NNClassifierModel, NNImageReader}
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.{SparkConf, SparkContext}
import scopt.OptionParser


object Utils {

  case class TrainParams(
                          modelPath: String = "/tmp/models/bigdl_inception-v1_imagenet_0.4.0.model",
                          dataPath: String = "/tmp/datasets/cat_dog/train_sampled",
                          batchSize: Int = 32,
                          nEpochs: Int = 2)

  val trainParser = new OptionParser[TrainParams]("BigDL ptbModel Train Example") {
    opt[String]('m', "modelPath")
      .text("pretrained model path")
      .action((x, c) => c.copy(modelPath = x))
    opt[String]('d', "dataPath")
      .text("training data path")
      .action((x, c) => c.copy(dataPath = x))
    opt[Int]('b', "batchSize")
      .text("batchSize")
      .action((x, c) => c.copy(batchSize = x))
    opt[Int]('e', "nEpochs")
      .text("epoch numbers")
      .action((x, c) => c.copy(nEpochs = x))
  }
}

object TransferLearning {

  LoggerFilter.redirectSparkInfoLogs()

  def main(args: Array[String]): Unit = {
    Utils.trainParser.parse(args, Utils.TrainParams()).map(param => {

      val conf = new SparkConf()
        .setAppName("Transfer Learning Example")

      val sc = NNContext.initNNContext(conf)

      val model = getTransferLearningModel(param.modelPath)

      val (trainDf, valDf) = getImageData(param.dataPath, sc)

      val featureTransformer = RowToImageFeature() -> ImageResize(256, 256) ->
                                   ImageCenterCrop(224, 224) ->
                                   ImageChannelNormalize(123, 117, 104) ->
                                   ImageMatToTensor() ->
                                   ImageFeatureToTensor()


      // val valTransformed = featureTransformer.transform(valDf)
//      val dlmodel = new NNClassifierModel(model, featureTransformer)
      val classifier = NNClassifier(model, CrossEntropyCriterion[Float](), featureTransformer)
        .setFeaturesCol("image")
        .setLearningRate(0.003)
        .setBatchSize(param.batchSize)
        .setMaxEpoch(param.nEpochs)
        .setCachingSample(false)
        .setValidation(Trigger.everyEpoch, valDf, Array(new Top1Accuracy()), param.batchSize)


      val pipeline = new Pipeline().setStages(Array(classifier))

      val pipelineModel = pipeline.fit(trainDf)

      val predictions = pipelineModel.transform(valDf)

      predictions.select(col("image"), col("label"), col("prediction"))
        .sample(withReplacement = true, 0.1)
        .show(false)

      sc.stop()
    })
  }

  private def getTransferLearningModel(preTrainedPath: String): Module[Float] = {
    // you can use Net.loadBigDL[Float](preTrainedPath).saveGraphTopology(somePath)
    // and use tensorboard to visualize the model topology and decide
    val inception = Net.loadBigDL[Float](preTrainedPath)
      .newGraph(output = "pool5/drop_7x7_s1") // remove layers after pool5/drop_7x7_s1
      .freezeUpTo("pool4/3x3_s2") // freeze layer pool4/3x3_s2 and the layers before it
      .toKeras()

    // add a new classifer
    val input = Input[Float](inputShape = Shape(3, 224, 224))
    val feature = inception.inputs(input)
    val flatten = Flatten[Float]().inputs(feature)
    val logits = Dense[Float](2).inputs(flatten)

    Model[Float](input, logits)
  }

  private def getImageData(dataPath: String, sc: SparkContext) = {
    val createLabel = udf { row: Row =>
      if (row.getString(0).split("/").last.contains("cat")) 1.0 else 2.0 }
    val imagesDF = NNImageReader.readImages(dataPath, sc)
      .withColumn("label", createLabel(col("image")))
    val Array(validationDF, trainingDF) = imagesDF
      .randomSplit(Array(0.20, 0.80), seed = 1L)
    (trainingDF, validationDF)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy