All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.example.imageclassification.MlUtils.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.example.imageclassification

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.dataset.Transformer
import com.intel.analytics.bigdl.dataset.image.{BGRImage, LocalLabeledImagePath}
import com.intel.analytics.bigdl.nn.Module
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.dataset.DataSet.SeqFileFolder
import org.apache.hadoop.io.Text
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import scopt.OptionParser

import scala.reflect.ClassTag

object MlUtils {

  val testMean = (0.485, 0.456, 0.406)
  val testStd = (0.229, 0.224, 0.225)

  val imageSize = 224

  /**
   * This is a trait meaning the model type.
   * There are two sorts of model type, which
   * are torch model [[TorchModel]] and BigDL
   * model [[BigDlModel]].
   */
  sealed trait ModelType

  case object TorchModel extends ModelType

  case object BigDlModel extends ModelType

  case class PredictParams(
    folder: String = "./",
    batchSize: Int = 32,
    classNum: Int = 1000,
    isHdfs: Boolean = false,

    modelType: ModelType = BigDlModel,
    modelPath: String = "",
    showNum : Int = 100
  )

  val predictParser = new OptionParser[PredictParams]("BigDL Predict Example") {
    opt[String]('f', "folder")
      .text("where you put the test data")
      .action((x, c) => c.copy(folder = x))
      .required()

    opt[String]("modelPath")
      .text("model snapshot location")
      .action((x, c) => c.copy(modelPath = x))
      .required()

    opt[Int]('b', "batchSizePerCore")
      .text("batch size")
      .action((x, c) => c.copy(batchSize = x))

    opt[Int]("classNum")
      .text("class num")
      .action((x, c) => c.copy(classNum = x))

    opt[Boolean]("isHdfs")
      .text("whether the input data is from Hdfs or not")
      .action((x, c) => c.copy(isHdfs = x))

    opt[Int]("showNum")
      .text("show num")
      .action((x, c) => c.copy(showNum = x))

    opt[String]('f', "folder")
      .text("where you put your local image files")
      .action((x, c) => c.copy(folder = x))

    opt[String]('t', "modelType")
      .text("torch, bigdl")
      .action((x, c) =>
        x.toLowerCase() match {
          case "torch" => c.copy(modelType = TorchModel)
          case "bigdl" => c.copy(modelType = BigDlModel)
          case _ =>
            throw new IllegalArgumentException("only torch, bigdl supported")
        }
      )
  }

  def loadModel[@specialized(Float, Double) T : ClassTag](param : PredictParams)
    (implicit ev: TensorNumeric[T]): Module[T] = {
    val model = param.modelType match {
      case TorchModel =>
        Module.loadTorch[T](param.modelPath)
      case BigDlModel =>
        Module.load[T](param.modelPath)
      case _ => throw new IllegalArgumentException(s"${param.modelType}")
    }
    model
  }

  /**
   * It is used to store single data frame information
   *
   * @param features extracted features after the transformers
   * @param imageName image name
   */
  case class DfPoint(features: DenseVector, imageName: String)

  /**
   * [[ByteImage]] is case class, which represents an object
   * of image in byte format.
   *
   * @param data image byte data
   * @param imageName image name
   */
  case class ByteImage(data: Array[Byte], imageName: String)

  def transformDF(data: DataFrame, f: Transformer[Row, DenseVector]): DataFrame = {
    val vectorRdd = data.select("data").rdd.mapPartitions(f(_))
    val dataRDD = data.rdd.zipPartitions(vectorRdd) { (a, b) =>
      b.zip(a.map(_.getAs[String]("imageName")))
        .map(
        v => DfPoint(v._1, v._2)
      )
    }
    data.sqlContext.createDataFrame(dataRDD)
  }

  def imagesLoad(paths: Array[LocalLabeledImagePath], scaleTo: Int):
    Array[ByteImage] = {
    var count = 1
    val buffer = paths.map(imageFile => {
      count += 1
      ByteImage(BGRImage.readImage(imageFile.path, scaleTo), imageFile.path.getFileName.toString)
    })
    buffer
  }

  def imagesLoadSeq(url: String, sc: SparkContext, classNum: Int): RDD[ByteImage] = {
    sc.sequenceFile(url, classOf[Text], classOf[Text]).map(image => {
      ByteImage(image._2.copyBytes(), SeqFileFolder.readName(image._1))
    })
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy