All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.transform.vision.image.FeatureTransformer.scala Maven / Gradle / Ivy

/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.transform.vision.image

import com.intel.analytics.bigdl.dataset.{ChainedTransformer, Transformer}
import com.intel.analytics.bigdl.opencv.OpenCV
import com.intel.analytics.bigdl.transform.vision.image.opencv.OpenCVMat
import org.apache.log4j.Logger

/**
 * FeatureTransformer is a transformer that transform ImageFeature
 */
abstract class FeatureTransformer()
  extends Transformer[ImageFeature, ImageFeature] {

  import FeatureTransformer.logger

  private var outKey: Option[String] = None

  /**
   * if true, catch the exception of the transformer to avoid crashing.
   * if false, interrupt the transformer when error happens
   */
  private[image] var ignoreException: Boolean = false

  /**
   * set the output key to store current transformed result
   * if the key is not set, or same as default, then the transformed result
   * will be overwritted by the following transformer
   * @param key output key
   */
  def setOutKey(key: String): this.type = {
    outKey = Some(key)
    this
  }

  /**
   * transform mat
   */
  protected def transformMat(feature: ImageFeature): Unit = {}

  /**
   * transform feature
   * @param feature ImageFeature
   * @return ImageFeature
   */
  def transform(feature: ImageFeature): ImageFeature = {
    require(OpenCV.isOpenCVLoaded, "opencv isn't loaded")
    if (!feature.isValid) return feature
    try {
      transformMat(feature)
      if (outKey.isDefined) {
        require(outKey.get != ImageFeature.mat, s"the output key should not equal to" +
          s" ${ImageFeature.mat}, please give another name")
        if (feature.contains(outKey.get)) {
          val mat = feature[OpenCVMat](outKey.get)
          feature.opencvMat().copyTo(mat)
        } else {
          feature(outKey.get) = feature.opencvMat().clone()
        }
      }
    } catch {
      case e: Exception =>
        feature.isValid = false
        if (ignoreException) {
          val path = if (feature.contains(ImageFeature.uri)) feature(ImageFeature.uri) else ""
          logger.warn(s"failed ${path} in transformer ${getClass}")
          e.printStackTrace()
        } else {
          throw e
        }
    }
    feature
  }

  override def apply(prev: Iterator[ImageFeature]): Iterator[ImageFeature] = {
    prev.map(transform)
  }

  def apply(imageFrame: ImageFrame): ImageFrame = {
    imageFrame.transform(this)
  }

  // scalastyle:off methodName
  // scalastyle:off noSpaceBeforeLeftBracket
  def -> (other: FeatureTransformer): FeatureTransformer = {
    new ChainedFeatureTransformer(this, other)
  }

  // scalastyle:off methodName
  // scalastyle:off noSpaceBeforeLeftBracket
  override def -> [C](other: Transformer[ImageFeature, C]): Transformer[ImageFeature, C] = {
    new ChainedTransformer(this, other)
  }

  /**
   * catch the exception of the transformer to avoid crashing.
   */
  def enableIgnoreException(): this.type = {
    ignoreException = true
    this
  }
}

object FeatureTransformer {
  val logger = Logger.getLogger(getClass)
}

/**
 * A transformer chain two FeatureTransformer together.
 */
class ChainedFeatureTransformer(first: FeatureTransformer, last: FeatureTransformer) extends
  FeatureTransformer {

  override def transform(prev: ImageFeature): ImageFeature = {
    last.transform(first.transform(prev))
  }

  override def enableIgnoreException(): this.type = {
    first.enableIgnoreException()
    last.enableIgnoreException()
    this
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy