All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.opencv.ImageTransformer.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.opencv

import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.contracts.{HasInputCol, HasOutputCol}
import com.microsoft.azure.synapse.ml.core.schema.{BinaryFileSchema, ImageSchemaUtils}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.{ArrayMapParam, DataTypeParam}
import org.apache.spark.injections.UDFUtils
import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.ml.{ComplexParamsWritable, ImageInjections, Transformer}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods.parse
import org.opencv.core._
import org.opencv.imgproc.Imgproc

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer

//scalastyle:off field.name
/** Image processing stage.
  *
  * @param params Map of parameters
  */
abstract class ImageTransformerStage(params: Map[String, Any]) extends Serializable {
  def apply(image: Mat): Mat

  val stageName: String
}

object ImageTransformerStage {
  // every stage has a name like "resize", "normalize", "unroll"
  val stageNameKey = "action"

  def apply(stage: Map[String, Any]): ImageTransformerStage = {
    stage(stageNameKey) match {
      case ResizeImage.stageName => new ResizeImage(stage)
      case CropImage.stageName => new CropImage(stage)
      case ColorFormat.stageName => new ColorFormat(stage)
      case Blur.stageName => new Blur(stage)
      case Threshold.stageName => new Threshold(stage)
      case GaussianKernel.stageName => new GaussianKernel(stage)
      case Flip.stageName => new Flip(stage)
      case CenterCropImage.stageName => new CenterCropImage(stage)
      case unsupported: String => throw new IllegalArgumentException(s"unsupported transformation $unsupported")
    }
  }
}

/** Resizes the image. The parameters of the ParameterMap are:
  * "height" - the height of the resized image
  * "width" - the width of the resized image
  * "stageName"
  * "size" - the shorter side of the resized image if keep aspect ratio is true, otherwise,
  *          the side length of both height and width.
  * "keepAspectRatio" - if true, then the shorter side will be resized to "size" parameter
  * Please refer to [[http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html#resize OpenCV]]
  * for more information
  *
  * @param params ParameterMap of the parameters
  */
class ResizeImage(params: Map[String, Any]) extends ImageTransformerStage(params) {
  override val stageName: String = ResizeImage.stageName

  override def apply(image: Mat): Mat = {
    val resized = new Mat()

    val sz = if (params.isDefinedAt(ResizeImage.size)) {
      val specifiedSize = params(ResizeImage.size).asInstanceOf[Int]
      if (params(ResizeImage.keepAspectRatio).asInstanceOf[Boolean]) {
        val (originalWidth, originalHeight) = (image.width, image.height)
        val shorterSize = math.min(originalWidth, originalHeight)
        val ratio = 1.0 * specifiedSize / shorterSize
        val (targetWidth, targetHeight) = (math.round(ratio * originalWidth), math.round(ratio * originalHeight))
        new Size(targetWidth, targetHeight)
      } else {
        new Size(specifiedSize, specifiedSize)
      }
    } else {
      val height: Double = params(ResizeImage.height).asInstanceOf[Int].toDouble
      val width: Double = params(ResizeImage.width).asInstanceOf[Int].toDouble
      new Size(width, height)
    }

    Imgproc.resize(image, resized, sz)
    resized
  }
}

/** Resize object contains the information for resizing;
  * "height"
  * "width"
  * "stageName" = "resize"
  */
object ResizeImage {
  val stageName = "resize"
  val height = "height"
  val width = "width"
  val size = "size"
  val keepAspectRatio = "keepAspectRatio"
}

/** Crops the image for processing. The parameters are:
  * "x" - First dimension; start of crop
  * "y" - second dimension - start of crop
  * "height" -height of cropped image
  * "width" - width of cropped image
  * "stageName" - "crop"
  *
  * @param params ParameterMap of the dimensions for cropping
  */
class CropImage(params: Map[String, Any]) extends ImageTransformerStage(params) {
  val x: Int = params(CropImage.x).asInstanceOf[Int]
  val y: Int = params(CropImage.y).asInstanceOf[Int]
  val height: Int = params(CropImage.height).asInstanceOf[Int]
  val width: Int = params(CropImage.width).asInstanceOf[Int]
  override val stageName: String = CropImage.stageName

  override def apply(image: Mat): Mat = {
    val rect = new Rect(x, y, width, height)
    new Mat(image, rect)
  }
}

object CropImage {
  val stageName = "crop"
  val x = "x"
  val y = "y"
  val height = "height"
  val width = "width"
}

class CenterCropImage(params: Map[String, Any]) extends ImageTransformerStage(params) {
  val height: Int = params(CropImage.height).asInstanceOf[Int]
  val width: Int = params(CropImage.width).asInstanceOf[Int]

  override val stageName: String = CenterCropImage.stageName

  override def apply(image: Mat): Mat = {
    val (cropWidth, cropHeight) = (math.min(width, image.width), math.min(height, image.height))
    val (midX, midY) = (image.width / 2, image.height / 2)
    val rect = new Rect(midX - cropWidth / 2, midY - cropHeight / 2, cropWidth, cropHeight)
    new Mat(image, rect)
  }
}

object CenterCropImage {
  val stageName = "centercrop"
  val height = "height"
  val width = "width"
}

/** Converts an image from one color space to another, eg COLOR_BGR2GRAY. Refer to
  * [[http://docs.opencv.org/2.4/modules/imgproc/doc/miscellaneous_transformations.html#cvtcolor OpenCV]]
  * for more information.
  *
  * @param params Map of parameters and values
  */
class ColorFormat(params: Map[String, Any]) extends ImageTransformerStage(params) {
  val format: Int = params(ColorFormat.format).asInstanceOf[Int]
  override val stageName: String = ColorFormat.stageName

  override def apply(image: Mat): Mat = {
    val dst = new Mat()
    Imgproc.cvtColor(image, dst, format)
    dst
  }
}

object ColorFormat {
  val stageName = "colorformat"
  val format = "format"
}

/** Flips the image
  *
  * @param params Map of parameters and values
  */
class Flip(params: Map[String, Any]) extends ImageTransformerStage(params) {
  val flipCode: Int = params(Flip.flipCode).asInstanceOf[Int]
  override val stageName: String = Flip.stageName

  override def apply(image: Mat): Mat = {
    val dst = new Mat()
    Core.flip(image, dst, flipCode)
    dst
  }
}

object Flip {
  val stageName: String = "flip"
  val flipCode: String = "flipCode"

  val flipUpDown: Int = 0
  val flipLeftRight: Int = 1
  val flipBoth: Int = -1
}

/** Blurs the image using a box filter.
  * The params are a map of the dimensions of the blurring box. Please refer to
  * [[http://docs.opencv.org/2.4/modules/imgproc/doc/filtering.html#blur OpenCV]] for more information.
  *
  * @param params Map of parameters and values
  */
class Blur(params: Map[String, Any]) extends ImageTransformerStage(params) {
  val height: Double = params(Blur.height).asInstanceOf[Double]
  val width: Double = params(Blur.width).asInstanceOf[Double]
  override val stageName: String = Blur.stageName

  override def apply(image: Mat): Mat = {
    val dst = new Mat()
    Imgproc.blur(image, dst, new Size(height, width))
    dst
  }
}

object Blur {
  val stageName: String = "blur"
  val height: String = "height"
  val width: String = "width"
}

/** Applies a threshold to each element of the image. Please refer to
  * [[http://docs.opencv.org/2.4/modules/imgproc/doc/miscellaneous_transformations.html#threshold threshold]] for
  * more information
  *
  * @param params Map of parameters and values
  */
class Threshold(params: Map[String, Any]) extends ImageTransformerStage(params) {
  val threshold: Double = params(Threshold.threshold).asInstanceOf[Double]
  val maxVal: Double = params(Threshold.maxVal).asInstanceOf[Double]
  // EG Imgproc.THRESH_BINARY
  val thresholdType: Int = params(Threshold.thresholdType).asInstanceOf[Int]
  override val stageName: String = Threshold.stageName

  override def apply(image: Mat): Mat = {
    val dst = new Mat()
    Imgproc.threshold(image, dst, threshold, maxVal, thresholdType)
    dst
  }
}

object Threshold {
  val stageName: String = "threshold"
  val threshold: String = "threshold"
  val maxVal: String = "maxVal"
  val thresholdType: String = "type"
}

/** Applies gaussian kernel to blur the image. Please refer to
  * [[http://docs.opencv.org/2.4/modules/imgproc/doc/filtering.html#gaussianblur OpenCV]] for detailed information
  * about the parameters and their allowable values.
  *
  * @param params Map of parameter values containg the aperture and sigma for the kernel.
  */
class GaussianKernel(params: Map[String, Any]) extends ImageTransformerStage(params) {
  val appertureSize: Int = params(GaussianKernel.apertureSize).asInstanceOf[Int]
  val sigma: Double = params(GaussianKernel.sigma) match {
    case d: Double => d
    case i: Int => i.toDouble
  }
  override val stageName: String = GaussianKernel.stageName

  override def apply(image: Mat): Mat = {
    val dst = new Mat()
    val kernel = Imgproc.getGaussianKernel(appertureSize, sigma)
    Imgproc.filter2D(image, dst, -1, kernel)
    dst
  }
}

object GaussianKernel {
  val stageName: String = "gaussiankernel"
  val apertureSize: String = "apertureSize"
  val sigma: String = "sigma"
}

/** Pipelined image processing. */
object ImageTransformer extends DefaultParamsReadable[ImageTransformer] {

  override def load(path: String): ImageTransformer = super.load(path)

  /** Convert Spark image representation to OpenCV format. */
  private def row2mat(row: Row): (String, Mat) = {
    val path = ImageSchema.getOrigin(row)
    val height = ImageSchema.getHeight(row)
    val width = ImageSchema.getWidth(row)
    val ocvType = ImageSchema.getMode(row)
    val bytes = ImageSchema.getData(row)

    val img = new Mat(height, width, ocvType)
    img.put(0, 0, bytes)
    (path, img)
  }

  /** Convert from OpenCV format to Dataframe Row; unroll if needed. */
  private def mat2row(img: Mat, path: String = ""): Row = {
    val ocvBytes = new Array[Byte](img.total.toInt * img.elemSize.toInt)
    img.get(0, 0, ocvBytes) //extract OpenCV bytes
    Row(path, img.height, img.width, img.channels(), img.`type`, ocvBytes)
  }

  /**
   * Convert Spark image representation to OpenCV format.
   */
  def decodeImage(decodeMode: String)(r: Any): Option[(String, Mat)] = {
    Option(r).flatMap {
      row =>
        (row, decodeMode) match {
          case (row: Row, "binaryfile") =>
            val path = BinaryFileSchema.getPath(row)
            val bytes = BinaryFileSchema.getBytes(row)
            //early return if the image can't be decompressed
            ImageInjections.decode(path, bytes).map(_.getStruct(0))
          case (bytes: Array[Byte], "binary") =>
            //noinspection ScalaStyle
            ImageInjections.decode(null, bytes).map(_.getStruct(0))  //scalastyle:ignore null
          case (row: Row, "image") =>
            Some(row)
          case (_, mode) =>
            throw new MatchError(s"Unknown decoder mode $mode")
        }
    } map row2mat
  }

  /**
   * Apply all OpenCV transformation stages to a single image. Break on OpenCV errors.
   */
  def processImage(stages: Seq[ImageTransformerStage])(image: Mat): Mat = {
    stages.foldLeft(image) {
      case (imgInternal, stage) => stage.apply(imgInternal)
    }
  }

  /**
   * Extract channels from image.
   */
  def extractChannels(channelOrder: String, autoConvertToColor: Boolean)(image: Mat): Array[Mat] = {
    // OpenCV channel order is BGR - reverse the order if the intended order is RGB.
    // Also remove alpha channel if nChannels is 4.
    val channelOrderIsRgb = channelOrder.toLowerCase == "rgb"
    val converted = if (image.channels == 4) {
      // remove alpha channel and order color channels if necessary
      val dest = new Mat(image.rows, image.cols, CvType.CV_8UC3)
      val colorConversion = if (channelOrderIsRgb) Imgproc.COLOR_BGRA2RGB else Imgproc.COLOR_BGRA2BGR
      Imgproc.cvtColor(image, dest, colorConversion)
      dest
    } else if (image.channels == 3 && channelOrderIsRgb) {
      // Reorder channel if nChannel is 3 and intended tensor channel order is RGB.
      val dest = new Mat(image.rows, image.cols, CvType.CV_8UC3)
      Imgproc.cvtColor(image, dest, Imgproc.COLOR_BGR2RGB)
      dest
    } else if (image.channels == 1 && autoConvertToColor) {
      // Duplicate channels if nChannel is 1 and user indicated to auto-convert.
      val dest = new Mat(image.rows, image.cols, CvType.CV_8UC3)
      val colorConversion = if (channelOrderIsRgb) Imgproc.COLOR_GRAY2RGB else Imgproc.COLOR_GRAY2BGR
      Imgproc.cvtColor(image, dest, colorConversion)
      dest
    } else {
      image
    }

    val channelLength = converted.channels
    val channelMats = ListBuffer.fill(channelLength)(Mat.zeros(converted.rows, converted.cols, CvType.CV_8U))
    Core.split(converted, channelMats.asJava)

    channelMats.toArray
  }

  /**
   * Normalize each channel.
   */
  def normalizeChannels(means: Option[Array[Double]], stds: Option[Array[Double]], scaleFactor: Option[Double])
                       (channels: Array[Mat]): Array[Mat] = {
    val channelLength = channels.length
    val meansLength = if (means.isDefined) means.get.length else -1
    val stdLength = if (stds.isDefined) stds.get.length else -1
    require(means.forall(channelLength == _.length), s"channelLength: $channelLength, means length: $meansLength")
    require(stds.forall(channelLength == _.length), s"channelLength: $channelLength, stds length: $stdLength")

    channels
      .zip(means.getOrElse(Array.fill(channelLength)(0d)))
      .zip(stds.getOrElse(Array.fill(channelLength)(1d)))
      .map {
        case ((matrix: Mat, m: Double), sd: Double) =>
          val t = new Mat(matrix.rows, matrix.cols, CvType.CV_64F)
          matrix.convertTo(t, CvType.CV_64F)
          Core.multiply(t, new Scalar(scaleFactor.getOrElse(1d)), t) // Standardized
          Core.subtract(t, new Scalar(m), t) // Centered
          Core.divide(t, new Scalar(sd), t) // Normalized
          t
      }
  }

  private def to2DArray(m: Mat): Array[Array[Double]] = {
    val array = Array.ofDim[Double](m.rows, m.cols)
    array.indices foreach {
      i => m.get(i, 0, array(i))
    }

    array
  }

  /**
   * Convert channel matrices to tensor in the shape of (C * H * W)
   */
  def convertToTensor(matrices: Array[Mat]): Array[Array[Array[Double]]] = {
    matrices.map(to2DArray)
  }

  /**
   *  Convert from OpenCV format to Dataframe Row.
   */
  def encodeImage(path: String, image: Mat): Row = {
    mat2row(image, path)
  }
}

/** Image processing stage. Please refer to OpenCV for additional information
  *
  * @param uid The id of the module
  */
class ImageTransformer(val uid: String) extends Transformer
  with HasInputCol with HasOutputCol with Wrappable with ComplexParamsWritable with SynapseMLLogging {
  logClass(FeatureNames.OpenCV)

  import ImageTransformer._
  import ImageTransformerStage._

  override protected lazy val pyInternalWrapper = true

  def this() = this(Identifiable.randomUID("ImageTransformer"))

  val stages: ArrayMapParam = new ArrayMapParam(this, "stages", "Image transformation stages")

  def setStages(value: Array[Map[String, Any]]): this.type = set(stages, value)

  def setStages(value: java.util.ArrayList[java.util.HashMap[String, Any]]): this.type =
    set(stages, value.asScala.toArray.map(_.asScala.toMap))

  def setStages(jsonString: String): this.type = {
    implicit val formats: DefaultFormats.type = DefaultFormats
    this.setStages(parse(jsonString).extract[Array[Map[String, Any]]])
  }

  val emptyStages: Array[Map[String, Any]] = Array[Map[String, Any]]()

  def getStages: Array[Map[String, Any]] = if (isDefined(stages)) $(stages) else emptyStages

  private def addStage(stage: Map[String, Any]): this.type = set(stages, getStages :+ stage)

  val toTensor: BooleanParam = new BooleanParam(
    this,
    "toTensor",
    "Convert output image to tensor in the shape of (C * H * W)"
  )

  def getToTensor: Boolean = $(toTensor)
  def setToTensor(value: Boolean): this.type = this.set(toTensor, value)

  val ignoreDecodingErrors: BooleanParam = new BooleanParam(
    this,
    "ignoreDecodingErrors",
    "Whether to throw on decoding errors or just return null"
  )
  setDefault(ignoreDecodingErrors -> false)

  def getIgnoreDecodingErrors: Boolean = $(ignoreDecodingErrors)
  def setIgnoreDecodingErrors(value: Boolean): this.type = this.set(ignoreDecodingErrors, value)

  @transient
  private lazy val validElementTypes: Array[DataType] = Array(FloatType, DoubleType)
  val tensorElementType: DataTypeParam = new DataTypeParam(
    parent = this,
    name = "tensorElementType",
    doc = "The element data type for the output tensor. Only used when toTensor is set to true. " +
      "Valid values are DoubleType or FloatType. Default value: FloatType.",
    isValid = ParamValidators.inArray(validElementTypes)
  )

  def getTensorElementType: DataType = $(tensorElementType)
  def setTensorElementType(value: DataType): this.type = this.set(tensorElementType, value)

  val tensorChannelOrder: Param[String] = new Param[String](
    parent = this,
    name = "tensorChannelOrder",
    doc = "The color channel order of the output channels. Valid values are RGB and GBR. Default: RGB.",
    isValid = ParamValidators.inArray(Array("rgb", "RGB", "bgr", "BGR"))
  )

  def getTensorChannelOrder: String = $(tensorChannelOrder)
  def setTensorChannelOrder(value: String): this.type = this.set(tensorChannelOrder, value)

  val normalizeMean: DoubleArrayParam = new DoubleArrayParam(
    this,
    "normalizeMean",
    "The mean value to use for normalization for each channel. " +
      "The length of the array must match the number of channels of the input image."
  )

  def getNormalizeMean: Array[Double] = $(normalizeMean)
  def setNormalizeMean(value: Array[Double]): this.type = this.set(normalizeMean, value)

  val normalizeStd: DoubleArrayParam = new DoubleArrayParam(
    this,
    "normalizeStd",
    "The standard deviation to use for normalization for each channel. " +
      "The length of the array must match the number of channels of the input image."
  )

  def getNormalizeStd: Array[Double] = $(normalizeStd)
  def setNormalizeStd(value: Array[Double]): this.type = this.set(normalizeStd, value)

  val colorScaleFactor: DoubleParam = new DoubleParam(
    this,
    "colorScaleFactor",
    "The scale factor for color values. Used for normalization. " +
      "The color values will be multiplied with the scale factor.",
    ParamValidators.gt(0d)
  )

  def getColorScaleFactor: Double = $(colorScaleFactor)
  def setColorScaleFactor(value: Double): this.type = this.set(colorScaleFactor, value)

  val autoConvertToColor: BooleanParam = new BooleanParam(
    this,
    "autoConvertToColor",
    "Whether to automatically convert black and white images to color"
  )
  setDefault(autoConvertToColor -> false)

  def getAutoConvertToColor: Boolean = $(autoConvertToColor)
  def setAutoConvertToColor(value: Boolean): this.type = this.set(autoConvertToColor, value)

  setDefault(
    inputCol -> "image",
    outputCol -> (uid + "_output"),
    toTensor -> false,
    tensorChannelOrder -> "RGB",
    tensorElementType -> FloatType
  )

  def normalize(mean: Array[Double], std: Array[Double], colorScaleFactor: Double): this.type = {
    this
      .setToTensor(true)
      .setNormalizeMean(mean)
      .setNormalizeStd(std)
      .setColorScaleFactor(colorScaleFactor)
  }

  /**
   * For py4j invocation.
   */
  def normalize(mean: java.util.List[Double], std: java.util.List[Double], colorScaleFactor: Double): this.type = {
    this
      .setToTensor(true)
      .setNormalizeMean(mean.asScala.toArray)
      .setNormalizeStd(std.asScala.toArray)
      .setColorScaleFactor(colorScaleFactor)
  }

  def resize(height: Int, width: Int): this.type = {
    require(width >= 0 && height >= 0, "width and height should be non-negative")

    addStage(Map(stageNameKey -> ResizeImage.stageName,
      ResizeImage.width -> width,
      ResizeImage.height -> height))
  }

  /**
   * If keep aspect ratio is set to true, the shorter side of the image will be resized to the specified size.
   */
  def resize(size: Int, keepAspectRatio: Boolean): this.type = {
    require(size >= 0, "size should be non-negative")
    addStage(Map(stageNameKey -> ResizeImage.stageName,
      ResizeImage.size -> size,
      ResizeImage.keepAspectRatio -> keepAspectRatio
    ))
  }

  def crop(x: Int, y: Int, height: Int, width: Int): this.type = {
    require(x >= 0 && y >= 0 && width >= 0 && height >= 0, "crop values should be non-negative")

    addStage(Map(stageNameKey -> CropImage.stageName,
      CropImage.width -> width,
      CropImage.height -> height,
      CropImage.x -> x,
      CropImage.y -> y))
  }

  def centerCrop(height: Int, width: Int): this.type = {
    require(width >= 0 && height >= 0, "crop values should be non-negative")

    addStage(
      Map(
        stageNameKey -> CenterCropImage.stageName,
        CenterCropImage.width -> width,
        CenterCropImage.height -> height
      )
    )
  }

  def colorFormat(format: Int): this.type = {
    addStage(Map(stageNameKey -> ColorFormat.stageName, ColorFormat.format -> format))
  }

  def blur(height: Double, width: Double): this.type = {
    addStage(Map(stageNameKey -> Blur.stageName, Blur.height -> height, Blur.width -> width))
  }

  def threshold(threshold: Double, maxVal: Double, thresholdType: Int): this.type = {
    addStage(Map(stageNameKey -> Threshold.stageName,
      Threshold.maxVal -> maxVal,
      Threshold.threshold -> threshold,
      Threshold.thresholdType -> thresholdType))
  }

  /** Flips the image
    *
    * @param flipCode is a flag to specify how to flip the image:
    * - 0 means flipping around the x-axis (i.e. up-down)
    * - positive value (for example, 1) means flipping around y-axis (left-right)
    * - negative value (for example, -1) means flipping around both axes (diagonally)
    *                 See OpenCV documentation for details.
    * @return
    */
  def flip(flipCode: Int): this.type = {
    addStage(Map(stageNameKey -> Flip.stageName, Flip.flipCode -> flipCode))
  }

  def gaussianKernel(apertureSize: Int, sigma: Double): this.type = {
    addStage(Map(stageNameKey -> GaussianKernel.stageName,
      GaussianKernel.apertureSize -> apertureSize,
      GaussianKernel.sigma -> sigma))
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    logTransform[DataFrame]({

      //  load native OpenCV library on each partition
      // TODO: figure out more elegant way
      val df = OpenCVUtils.loadOpenCV(dataset.toDF)
      val inputDataType = df.schema(getInputCol).dataType
      val decodeMode = getDecodeType(inputDataType)

      val transforms = getStages.map(ImageTransformerStage.apply)

      val outputColumnSchema = if ($(toTensor)) tensorUdfSchema else imageColumnSchema
      val processStep = processImage(transforms) _
      val extractStep = extractChannels(getTensorChannelOrder, getAutoConvertToColor) _
      val normalizeStep = normalizeChannels(get(normalizeMean), get(normalizeStd), get(colorScaleFactor)) _
      val toTensorStep = convertToTensor _

      val convertFunc = if ($(toTensor)) {
        inputRow: Any =>
          getDecodedImage(decodeMode)(inputRow) map {
            case (_, image) =>
              processStep
                .andThen(extractStep)
                .andThen(normalizeStep)
                .andThen(toTensorStep)
                .apply(image)
          }
      } else {
        inputRow: Any =>
          getDecodedImage(decodeMode)(inputRow) map {
            case (path, image) =>
              val encodeStep = encodeImage(path, _)
              processStep.andThen(encodeStep).apply(image)
          }
      }

      val convert = UDFUtils.oldUdf(convertFunc, outputColumnSchema)
      if ($(toTensor)) {
        df.withColumn(getOutputCol, convert(df(getInputCol)).cast(tensorColumnSchema))
      } else {
        df.withColumn(getOutputCol, convert(df(getInputCol)))
      }
    }, dataset.columns.length)
  }

  private def getDecodedImage(decodeMode: String)(r: Any): Option[(String, Mat)] = {
    try {
      decodeImage(decodeMode)(r)
    } catch {
      case e: MatchError =>
        throw e
      case e: Throwable =>
        if (getIgnoreDecodingErrors) {
          logWarning("Error decoding image", e)
          None
        } else throw e
    }
  }

  private def getDecodeType(inputDataType: DataType): String = {
    inputDataType match {
      case s if ImageSchemaUtils.isImage(s) => "image"
      case s if BinaryFileSchema.isBinaryFile(s) => "binaryfile"
      case s if s == BinaryType => "binary"
      case s =>
        throw new IllegalArgumentException(s"input column should have Image or BinaryFile type, got $s")
    }
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  private lazy val tensorUdfSchema = ArrayType(ArrayType(ArrayType(DoubleType)))
  private lazy val tensorColumnSchema =  ArrayType(ArrayType(ArrayType($(tensorElementType))))
  private lazy val imageColumnSchema = ImageSchema.columnSchema
  override def transformSchema(schema: StructType): StructType = {
    assert(!schema.fieldNames.contains(getOutputCol), s"Input schema already contains output field $getOutputCol")

    val outputColumnSchema = if ($(toTensor)) tensorColumnSchema else imageColumnSchema
    schema.add(getOutputCol, outputColumnSchema)
  }
}

//scalastyle:on field.name




© 2015 - 2025 Weber Informatics LLC | Privacy Policy