All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.ImageAssembler.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp

import com.johnsnowlabs.nlp.AnnotatorType._
import com.johnsnowlabs.nlp.annotators.cv.util.schema.ImageSchemaUtils
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset}

/** Prepares images read by Spark into a format that is processable by Spark NLP. This component
  * is needed to process images.
  *
  * ==Example==
  * {{{
  * import com.johnsnowlabs.nlp.ImageAssembler
  * import org.apache.spark.ml.Pipeline
  *
  * val imageDF: DataFrame = spark.read
  *   .format("image")
  *   .option("dropInvalid", value = true)
  *   .load("src/test/resources/image/")
  *
  * val imageAssembler = new ImageAssembler()
  *   .setInputCol("image")
  *   .setOutputCol("image_assembler")
  *
  * val pipeline = new Pipeline().setStages(Array(imageAssembler))
  * val pipelineDF = pipeline.fit(imageDF).transform(imageDF)
  * pipelineDF.printSchema()
  * root
  *  |-- image_assembler: array (nullable = true)
  *  |    |-- element: struct (containsNull = true)
  *  |    |    |-- annotatorType: string (nullable = true)
  *  |    |    |-- origin: string (nullable = true)
  *  |    |    |-- height: integer (nullable = false)
  *  |    |    |-- width: integer (nullable = false)
  *  |    |    |-- nChannels: integer (nullable = false)
  *  |    |    |-- mode: integer (nullable = false)
  *  |    |    |-- result: binary (nullable = true)
  *  |    |    |-- metadata: map (nullable = true)
  *  |    |    |    |-- key: string
  *  |    |    |    |-- value: string (valueContainsNull = true)
  * }}}
  * @param uid
  *   required uid for storing annotator to disk
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class ImageAssembler(override val uid: String)
    extends Transformer
    with DefaultParamsWritable
    with HasOutputAnnotatorType
    with HasOutputAnnotationCol {

  /** Output Annotator Type: DOCUMENT
    *
    * @group anno
    */
  override val outputAnnotatorType: AnnotatorType = IMAGE

  /** Input text column for processing
    *
    * @group param
    */
  val inputCol: Param[String] =
    new Param[String](this, "inputCol", "input text column for processing")

  /** Input text column for processing
    *
    * @group setParam
    */
  def setInputCol(value: String): this.type = set(inputCol, value)

  /** Input text column for processing
    *
    * @group getParam
    */
  def getInputCol: String = $(inputCol)

  setDefault(inputCol -> IMAGE, outputCol -> "image_assembler")

  def this() = this(Identifiable.randomUID("ImageAssembler"))

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  private[nlp] def assemble(
      image: Option[ImageFields],
      metadata: Map[String, String]): Seq[AnnotationImage] = {

    if (image.isDefined) {
      Seq(
        AnnotationImage(
          annotatorType = outputAnnotatorType,
          origin = image.get.origin,
          height = image.get.height,
          width = image.get.width,
          nChannels = image.get.nChannels,
          mode = image.get.mode,
          result = image.get.data,
          metadata = metadata))
    } else Seq.empty

  }

  private[nlp] def dfAssemble: UserDefinedFunction = udf { (image: ImageFields) =>
    // Apache Spark has only 1 image per row
    assemble(Some(image), Map("image" -> "0"))
  }

  /** requirement for pipeline transformation validation. It is called on fit() */
  override final def transformSchema(schema: StructType): StructType = {
    val metadataBuilder: MetadataBuilder = new MetadataBuilder()
    metadataBuilder.putString("annotatorType", outputAnnotatorType)
    val outputFields = schema.fields :+
      StructField(
        getOutputCol,
        ArrayType(AnnotationImage.dataType),
        nullable = false,
        metadataBuilder.build)
    StructType(outputFields)
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    val metadataBuilder: MetadataBuilder = new MetadataBuilder()
    metadataBuilder.putString("annotatorType", outputAnnotatorType)
    require(
      dataset.schema.fields.exists(_.name == getInputCol),
      s"column $getInputCol is not presented in your DataFrame")
    require(
      ImageSchemaUtils.isImage(dataset.schema(getInputCol)),
      s"column $getInputCol doesn't have Apache Spark ImageSchema. Make sure you read your images via spark.read.format(image).load(PATH)")

    val imageAnnotations = {
      dfAssemble(dataset($(inputCol)))
    }

    dataset.withColumn(getOutputCol, imageAnnotations.as(getOutputCol, metadataBuilder.build))
  }

}

private[nlp] case class ImageFields(
    origin: String,
    height: Int,
    width: Int,
    nChannels: Int,
    mode: Int,
    data: Array[Byte])

/** This is the companion object of [[ImageAssembler]]. Please refer to that class for the
  * documentation.
  */
object ImageAssembler extends DefaultParamsReadable[ImageAssembler]




© 2015 - 2024 Weber Informatics LLC | Privacy Policy