All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.source.image.PatchedImageFileFormat.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package org.apache.spark.ml.source.image

import com.google.common.io.{ByteStreams, Closeables}
import com.microsoft.ml.spark.core.schema.ImageSchemaUtils
import com.microsoft.ml.spark.io.image.ImageUtils
import javax.imageio.ImageIO
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
import org.apache.spark.internal.Logging
import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration

class PatchedImageFileFormat extends ImageFileFormat with Serializable with Logging {

  override def shortName(): String = "patchedImage"

  private def verifySchema(schema: StructType): Unit = {
    val target = ImageSchema.imageSchema.add("filenames", StringType)
    val targetNullable = ImageSchemaUtils.ImageSchemaNullable.add("filenames", StringType)
    if (schema != target && schema != targetNullable) {
      throw new IllegalArgumentException(
        s"Image data source supports: " +
          s"\n\t$target" +
          s"\n\tyou have :" +
          s"\n\t$schema.")
    }
  }

  override def prepareWrite(sparkSession: SparkSession,
                            job: Job,
                            options: Map[String, String],
                            dataSchema: StructType): OutputWriterFactory = {
    verifySchema(dataSchema)
    new OutputWriterFactory {
      override def newInstance(
                                path: String,
                                dataSchema: StructType,
                                context: TaskAttemptContext): OutputWriter = {
        new ImageOutputWriter(
          path,
          dataSchema.fieldIndex(options.getOrElse("imageCol", "image")),
          dataSchema.fieldIndex(options.getOrElse("pathCol", "filenames")),
          context)
      }

      override def getFileExtension(context: TaskAttemptContext): String = {
        ""
      }
    }
  }

  //This is needed due to a multiththreading bug n the jvm
  private def catchFlakiness[T](times: Int)(f: => Option[T]): Option[T] = {
    try {
      f
    } catch {
      case e: NullPointerException if times >= 1 =>
        logWarning("caught null pointer exception due to jvm bug", e)
        catchFlakiness(times - 1)(f)
      case _: Exception =>
        None
    }
  }

  override protected def buildReader(sparkSession: SparkSession,
                                     dataSchema: StructType,
                                     partitionSchema: StructType,
                                     requiredSchema: StructType,
                                     filters: Seq[Filter],
                                     options: Map[String, String],
                                     hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
    assert(
      requiredSchema.length <= 1,
      "Image data source only produces a single data column named \"image\".")

    val broadcastedHadoopConf =
      sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))

    val imageSourceOptions = new ImageOptions(options)

    (file: PartitionedFile) => {
      val emptyUnsafeRow = new UnsafeRow(0)
      if (!imageSourceOptions.dropInvalid && requiredSchema.isEmpty) {
        Iterator(emptyUnsafeRow)
      } else {
        val origin = file.filePath
        val path = new Path(origin)
        val fs = path.getFileSystem(broadcastedHadoopConf.value.value)
        val stream = fs.open(path)
        val bytes = try {
          ByteStreams.toByteArray(stream)
        } finally {
          Closeables.close(stream, true)
        }
        val resultOpt = catchFlakiness(5)(ImageSchema.decode(origin, bytes))
        val filteredResult = if (imageSourceOptions.dropInvalid) {
          resultOpt.toIterator
        } else {
          Iterator(resultOpt.getOrElse(ImageSchema.invalidImageRow(origin)))
        }

        if (requiredSchema.isEmpty) {
          filteredResult.map(_ => emptyUnsafeRow)
        } else {
          val converter = RowEncoder(requiredSchema)
          filteredResult.map(row => converter.toRow(row))
        }
      }
    }
  }

  override def toString: String = "PatchedImageFileFormat"
}

class ImageOutputWriter(val path: String,
                        val imageCol: Int,
                        val pathCol: Int,
                        val context: TaskAttemptContext)
  extends OutputWriter {

  private val hconf = context.getConfiguration

  private val fs = new Path(path).getFileSystem(hconf)

  override def write(row: InternalRow): Unit = {
    val imgRow = row.getStruct(imageCol, 6)
    val bImg = ImageUtils.toBufferedImage(imgRow)
    val nonTempPath = new Path(path).getParent.getParent.getParent.getParent.getParent
    val outputPath = new Path(nonTempPath, row.getString(pathCol))
    val os = fs.create(outputPath)
    try {
      val codec = outputPath.toString.split(".".charAt(0)).last
      val success = ImageIO.write(bImg, codec, os)
      assert(success, s"codec failed: $codec")
    } finally {
      os.close()
    }
  }

  override def close(): Unit = {
    fs.close()
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy