All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.opencv.ImageSetAugmenter.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.opencv

import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.contracts.{HasInputCol, HasOutputCol}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import org.apache.spark.ml._
import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql._
import org.apache.spark.sql.types._

object ImageSetAugmenter extends DefaultParamsReadable[ImageSetAugmenter]

class ImageSetAugmenter(val uid: String) extends Transformer
  with HasInputCol with HasOutputCol with DefaultParamsWritable with Wrappable with SynapseMLLogging {
  logClass(FeatureNames.OpenCV)

  def this() = this(Identifiable.randomUID("ImageSetAugmenter"))

  // if the problem is symmetric, the training dataset is supplemented with the "flipped" images
  // by default, we assume the left-right symmetry, but not up-down
  val flipLeftRight = new BooleanParam(this, "flipLeftRight", "Symmetric Left-Right")
  setDefault(flipLeftRight->false)

  def setFlipLeftRight(value: Boolean): this.type = set(flipLeftRight, value)

  def getFlipLeftRight: Boolean = $(flipLeftRight)

  val flipUpDown = new BooleanParam(this, "flipUpDown", "Symmetric Up-Down")
  setDefault(flipUpDown->false)

  def setFlipUpDown(value: Boolean): this.type = set(flipUpDown, value)

  def getFlipUpDown: Boolean = $(flipUpDown)

  setDefault(flipLeftRight -> true, flipUpDown -> false,
    outputCol -> (uid + "_output"), inputCol -> "image")

  //flip images in a given column, keep the rest of the columns intact
  private def flipImages(dataset: Dataset[_], inCol: String, outCol: String, flipCode: Int): DataFrame = {
    val tr = new ImageTransformer()
      .flip(flipCode)
      .setInputCol(inCol)
      .setOutputCol(outCol)

    tr.transform(dataset)
  }

  def transform(dataset: Dataset[_]): DataFrame = {
    logTransform[DataFrame]({
      val df = dataset.toDF
      val dfID = df.withColumn(getOutputCol, new Column(getInputCol))

      val dfLR: Option[DataFrame] =
        if (!getFlipLeftRight) None
        else Some(flipImages(df, getInputCol, getOutputCol, Flip.flipLeftRight))

      val dfUD: Option[DataFrame] =
        if (!getFlipUpDown) None
        else Some(flipImages(df, getInputCol, getOutputCol, Flip.flipUpDown))

      List(dfLR, dfUD).flatten(x => x).foldLeft(dfID) { case (dfl, tdr) => dfl.union(tdr) }
    }, dataset.columns.length)

  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  def transformSchema(schema: StructType): StructType = {
    schema.add(getOutputCol, ImageSchema.columnSchema)
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy