All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.lime.SuperpixelTransformer.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.lime

import com.microsoft.ml.spark.core.contracts.{HasInputCol, HasOutputCol, Wrappable}
import com.microsoft.ml.spark.core.schema.ImageSchemaUtils
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{BinaryType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset}

object SuperpixelTransformer extends DefaultParamsReadable[SuperpixelTransformer]

trait HasCellSize extends Params {
  val cellSize = new DoubleParam(this, "cellSize", "Number that controls the size of the superpixels")

  def getCellSize: Double = $(cellSize)

  def setCellSize(v: Double): this.type = set(cellSize, v)
}

trait HasModifier extends Params {
  val modifier = new DoubleParam(this, "modifier", "Controls the trade-off spatial and color distance")

  def getModifier: Double = $(modifier)

  def setModifier(v: Double): this.type = set(modifier, v)
}

/** A transformer that decomposes an image into it's superpixels
  */
class SuperpixelTransformer(val uid: String) extends Transformer
  with HasInputCol with HasOutputCol
  with Wrappable with DefaultParamsWritable with HasCellSize with HasModifier {
  def this() = this(Identifiable.randomUID("SuperpixelTransformer"))

  setDefault(cellSize->16.0, modifier->130.0, outputCol->s"${uid}_output")

  override def transform(dataset: Dataset[_]): DataFrame = {
    val getSuperPixels = Superpixel.getSuperpixelUDF(
      dataset.schema(getInputCol).dataType, getCellSize, getModifier)

    dataset.toDF().withColumn(getOutputCol, getSuperPixels(col(getInputCol)))
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema(getInputCol).dataType
    assert(ImageSchemaUtils.isImage(inputType) || inputType == BinaryType)
    schema.add(getOutputCol, SuperpixelData.Schema)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy