All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.pipeline.nnframes.NNImageReader.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.zoo.pipeline.nnframes

import com.intel.analytics.bigdl.tensor.{Storage, Tensor}
import com.intel.analytics.bigdl.transform.vision.image.{BytesToMat, ImageFeature, ImageFrame}
import org.apache.spark.SparkContext
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.opencv.core.CvType
import scala.language.existentials

/**
 * Definition for image data in a DataFrame
 */
object NNImageSchema {

  /**
   * Schema for the image column in a DataFrame. Image data is saved in an array of Bytes.
   */
  val byteSchema = StructType(
    StructField("origin", StringType, true) ::
      StructField("height", IntegerType, false) ::
      StructField("width", IntegerType, false) ::
      StructField("nChannels", IntegerType, false) ::
      // OpenCV-compatible type: CV_8UC3, CV_8UC1 in most cases
      StructField("mode", IntegerType, false) ::
      // Bytes in image file
      StructField("data", BinaryType, false) :: Nil)

  /**
   * Schema for the image column in a DataFrame. Image data is saved in an array of Floats.
   */
  val floatSchema = StructType(
    StructField("origin", StringType, true) ::
      StructField("height", IntegerType, false) ::
      StructField("width", IntegerType, false) ::
      StructField("nChannels", IntegerType, false) ::
      // OpenCV-compatible type: CV_32FC3 in most cases
      StructField("mode", IntegerType, false) ::
      // floats in OpenCV-compatible order: row-wise BGR in most cases
      StructField("data", new ArrayType(FloatType, false), false) :: Nil)

  private[zoo] def imf2Row(imf: ImageFeature): Row = {
    val (mode, data) = if (imf.contains(ImageFeature.imageTensor)) {
      val floatData = imf(ImageFeature.imageTensor).asInstanceOf[Tensor[Float]].storage().array()
      val cvType = imf.getChannel() match {
        case 1 => CvType.CV_32FC1
        case 3 => CvType.CV_32FC3
        case 4 => CvType.CV_32FC4
        case other => throw new IllegalArgumentException(s"Unsupported number of channels:" +
          s" $other in ${imf.uri()}. Only 1, 3 and 4 are supported.")
      }
      (cvType, floatData)
    } else if (imf.contains(ImageFeature.bytes)) {
      val bytesData = imf.bytes()
      val cvType = imf.getChannel() match {
        case 1 => CvType.CV_8UC1
        case 3 => CvType.CV_8UC3
        case 4 => CvType.CV_8UC4
        case other => throw new IllegalArgumentException(s"Unsupported number of channels:" +
          s" $other in ${imf.uri()}. Only 1, 3 and 4 are supported.")
      }
      (cvType, bytesData)
    } else {
      throw new IllegalArgumentException(s"ImageFeature should have imageTensor or bytes.")
    }

    Row(
      imf.uri(),
      imf.getHeight(),
      imf.getWidth(),
      imf.getChannel(),
      mode,
      data
    )
  }

  private[zoo] def row2IMF(row: Row): ImageFeature = {
    val (origin, h, w, c) = (row.getString(0), row.getInt(1), row.getInt(2), row.getInt(3))
    val imf = ImageFeature()
    imf.update(ImageFeature.uri, origin)
    imf.update(ImageFeature.size, (h, w, c))
    val storageType = row.getInt(4)
    storageType match {
      case CvType.CV_8UC3 | CvType.CV_8UC1 | CvType.CV_8UC4 =>
        imf.update(ImageFeature.bytes, row.getAs[Array[Byte]](5))
        BytesToMat().transform(imf)
      case CvType.CV_32FC3 | CvType.CV_32FC1 | CvType.CV_32FC4 =>
        val data = row.getSeq[Float](5).toArray
        val size = Array(h, w, c)
        val ten = Tensor(Storage(data)).resize(size)
        imf.update(ImageFeature.imageTensor, ten)
      case _ =>
        throw new IllegalArgumentException(s"Unsupported data type in imageColumn: $storageType")
    }
    imf
  }

  /**
   * Gets the origin of the image
   *
   * @return The origin of the image
   */
  def getOrigin(row: Row): String = row.getString(0)

  /**
   * Extracts the origin info from Image column and add it to a new column.
   *
   * @param imageDF input DataFrame that contains an image column
   * @param imageColumn image column name
   * @param originColumn name of new column
   * @return The origin of the image
   */
  def withOriginColumn(
      imageDF: DataFrame,
      imageColumn: String = "image",
      originColumn: String = "origin"): DataFrame = {
    val getPathUDF = udf { row: Row => getOrigin(row) }
    imageDF.withColumn(originColumn, getPathUDF(col(imageColumn)))
  }

}

/**
 * Primary DataFrame-based image loading interface, defining API to read images into DataFrame.
 */
object NNImageReader {

  /**
   * DataFrame with a single column of images named "image" (nullable)
   */
  private val imageColumnSchema =
    StructType(StructField("image", NNImageSchema.byteSchema, true) :: Nil)

  /**
   * Read the directory of images into DataFrame from the local or remote source.
   *
   * @param path Directory to the input data files, the path can be comma separated paths as the
   *             list of inputs. Wildcards path are supported similarly to sc.binaryFiles(path).
   * @param sc SparkContext to be used.
   * @param minPartitions Number of the DataFrame partitions,
   *                      if omitted uses defaultParallelism instead
   * @return DataFrame with a single column "image" of images;
   *         see DLImageSchema.byteSchema for the details
   */
  def readImages(path: String, sc: SparkContext, minPartitions: Int = 1): DataFrame = {
    val imageFrame = ImageFrame.read(path, sc, minPartitions)
    val rowRDD = imageFrame.toDistributed().rdd.map { imf =>
      Row(NNImageSchema.imf2Row(imf))
    }
    SQLContext.getOrCreate(sc).createDataFrame(rowRDD, imageColumnSchema)
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy