All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.models.image.common.ImageModel.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.zoo.models.image.common

import com.intel.analytics.bigdl.nn.Module
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.zoo.feature.image.ImageSet
import com.intel.analytics.zoo.models.common.ZooModel
import com.intel.analytics.zoo.models.image.imageclassification.ImageClassifier
import com.intel.analytics.zoo.models.image.objectdetection.ObjectDetector
import org.apache.log4j.Logger

import scala.reflect.ClassTag

/**
 * The base class for image models in Analytics Zoo.
 */
abstract class ImageModel[T: ClassTag]()(implicit ev: TensorNumeric[T])
  extends ZooModel[Activity, Activity, T] {

  private var config: ImageConfigure[T] = null

  /**
   * Computes the output using the current parameter set of the class and input. This function
   * returns the result which is stored in the output field.
   *
   * @param image
   * @return
   */
  def predictImageSet(image: ImageSet, configure: ImageConfigure[T] = null):
  ImageSet = {
    val predictConfig = if (null == configure) config else configure

    val result = if (predictConfig == null) {
      ImageSet.fromImageFrame(predictImage(image.toImageFrame()))
    } else {
      // apply preprocessing if preProcessor is defined
      val data = if (null != predictConfig.preProcessor) {
        image -> predictConfig.preProcessor
      } else {
        image
      }

      val imageSet = ImageSet.fromImageFrame(predictImage(data.toImageFrame(), batchPerPartition =
        predictConfig.batchPerPartition, featurePaddingParam = predictConfig.featurePaddingParam))

      if (null != predictConfig.postProcessor) {
        imageSet -> predictConfig.postProcessor
      }
      else imageSet
    }
    result
  }

  def getConfig(): ImageConfigure[T] = config

}

object ImageModel {

  val logger = Logger.getLogger(getClass)
  /**
   * Load an pre-trained image model (with weights).
   *
   * @param path The path for the pre-trained model.
   *             Local file system, HDFS and Amazon S3 are supported.
   *             HDFS path should be like "hdfs://[host]:[port]/xxx".
   *             Amazon S3 path should be like "s3a://bucket/xxx".
   * @param weightPath The path for pre-trained weights if any. Default is null.
   * @tparam T Numeric type of parameter(e.g. weight, bias). Only support float/double now.
   * @return
   */
  def loadModel[T: ClassTag](path: String, weightPath: String = null, modelType: String = "")
    (implicit ev: TensorNumeric[T]): ImageModel[T] = {
    val model = Module.loadModule[T](path, weightPath)
    val imageModel = if (model.isInstanceOf[ImageModel[T]]) {
      model.asInstanceOf[ImageModel[T]]
    } else {
      val specificModel = modelType.toLowerCase() match {
        case "objectdetection" => new ObjectDetector[T]()
        case "imageclassification" => new ImageClassifier[T]()
        case _ => logger.error(s"model type $modelType is not defined in Analytics zoo." +
          s"Only 'imageclassification' and 'objectdetection' are currently supported.")
          throw new IllegalArgumentException(
            s"model type $modelType is not defined in Analytics zoo." +
              s"Only 'imageclassification' and 'objectdetection' are currently supported.")
      }
      specificModel.addModel(model)
      specificModel.setName(model.getName())
    }
    imageModel.config = ImageConfigure.parse(model.getName())
    imageModel
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy