All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.models.image.common.ImageConfigure.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.zoo.models.image.common

import com.intel.analytics.bigdl.dataset.PaddingParam
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.transform.vision.image.{ImageFeature}
import com.intel.analytics.zoo.feature.common.{Preprocessing}
import com.intel.analytics.zoo.models.image.objectdetection.ObjectDetectionConfig
import com.intel.analytics.zoo.models.image.imageclassification.ImageClassificationConfig
import org.apache.log4j.Logger

import scala.reflect.ClassTag

/**
 * predictor configure
 *
 * @param preProcessor preprocessor of ImageFrame before model inference
 * @param postProcessor postprocessor of ImageFrame after model inference
 * @param batchPerPartition batch size per partition
 * @param labelMap label mapping
 * @param featurePaddingParam featurePaddingParam if the inputs have variant size
 */
case class ImageConfigure[T: ClassTag](
  preProcessor: Preprocessing[ImageFeature, ImageFeature] = null,
  postProcessor: Preprocessing[ImageFeature, ImageFeature] = null,
  batchPerPartition: Int = 4,
  labelMap: Map[Int, String] = null,
  featurePaddingParam: Option[PaddingParam[T]] = None)(implicit ev: TensorNumeric[T]) {
}

object ImageConfigure {
  val logger = Logger.getLogger(getClass)

  val splitter = "_"

  /**
   * Get config for each model
   *
   * @param tag If model is published by analytics zoo, tag should be in
   *            'publisher_model_dataset_version' format,
   *            and publisher is required to be 'analytics-zoo'.
   *            If model is not published by analytics zoo, tag does not
   *            need to start with 'analytics-zoo', and format does not need to
   *            be in 'publisher_model_dataset_version' format
   * @return
   */
  def parse[T: ClassTag](tag: String)(implicit ev: TensorNumeric[T]): ImageConfigure[T] = {
    // If tag does not start with "analytics-zoo", it's a third-party model.
    // Don't create default image configuration
    if (!tag.startsWith("analytics-zoo")) {
      logger.warn("Only support to create default image configuration for models published by " +
        "analytics zoo. Third-party models need to pass its own image configuration during " +
        "prediction")
      null
    } else {
      val splits = tag.split(splitter)
      // It's a analytics zoo model. The tag should be in 'publisher_model_dataset_version' format
      require(splits.length >= 4, s"tag ${tag}" +
        s" needs at least 4 elements, publisher, model, dataset, version")
      val model = splits(1)
      val dataset = splits(2)
      val version = splits(3)
      model.toLowerCase() match {
        case obModel if ObjectDetectionConfig.models contains obModel =>
          ObjectDetectionConfig(obModel, dataset, version)
        case imcModel if ImageClassificationConfig.models contains imcModel =>
          ImageClassificationConfig(imcModel, dataset, version)
        case _ => logger.warn(s"$model is not defined in Analytics zoo.")
          null
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy