All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.sql.sources.mnist.MnistRelation.scala Maven / Gradle / Ivy

The newest version!
package org.deeplearning4j.spark.sql.sources.mnist


import java.net.URI
import java.nio.file.Paths

import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkContext, Partition, TaskContext, Logging}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources.{PrunedScan, BaseRelation, RelationProvider}
import org.deeplearning4j.datasets.mnist.MnistManager
import org.deeplearning4j.spark.sql.types.VectorUDT
import org.nd4j.linalg.util.ArrayUtil

/**
 * Mnist dataset as a Spark SQL relation.
 *
 * @author Kai Sasaki, Eron Wright
 */
case class MnistRelation(
                          imagesPath: Path,
                          labelsPath: Path,
                          recordsPerPartition: Int,
                          maxRecords: Option[Int])
    (@transient val sqlContext: SQLContext) extends BaseRelation
    with PrunedScan with Logging  {

  override def schema: StructType = StructType(
    StructField("label", DoubleType, nullable = false) ::
      StructField("features", VectorUDT(), nullable = false) :: Nil)

  override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
    val sc = sqlContext.sparkContext

    new MnistRDD(sc, imagesPath.toUri, labelsPath.toUri, requiredColumns, recordsPerPartition, maxRecords)
  }
}

private class MnistPartition(override val index: Int, val startIndex: Int, val endIndex: Int)
  extends Partition with Serializable {
}

private class MnistRDD(
    sc: SparkContext,
    private val imagesPath: URI,
    private val labelsPath: URI,
    private val requiredColumns: Array[String],
    recordsPerPartition: Int = 1000,
    maxRecords: Option[Int] = None
  )
  extends RDD[Row](sc, Nil) {

  override val partitioner = None

  override def getPartitions: Array[Partition] = {
    val manager = open()
    try {
      val numLabels = manager.getLabels.getCount
      val numRecords = Math.min(numLabels, maxRecords.getOrElse { numLabels })
      val numPartitions = Math.ceil(numRecords / (recordsPerPartition: Double)).asInstanceOf[Int]
      val array = new Array[Partition](numPartitions)
      for (i <- 0 until numPartitions) {
        array(i) = new MnistPartition(
          i,
          i * recordsPerPartition,
          Math.min((i + 1) * recordsPerPartition - 1, numRecords - 1))
      }
      array
    }
    finally {
      manager.close()
    }
  }

  override def compute(split: Partition, context: TaskContext): Iterator[Row] = {
    val manager = open()
    context.addTaskCompletionListener((_) => manager.close())

    val msplit = split.asInstanceOf[MnistPartition]

    val cols = {
      val req = requiredColumns.toSet
      Seq("label", "features") flatMap {
        case "label" if req("label") => Seq(
          () => manager.readLabel().toDouble)
        case "features" if req("features") => Seq(
          () => Vectors.dense(ArrayUtil.flatten(manager.readImage()).map(_.toDouble)))
        case _ => Seq.empty
      }
    }

    manager.setCurrent(msplit.startIndex)

    (msplit.startIndex to msplit.endIndex).map((_) => {
      Row.fromSeq(cols.map(_()))
    }).iterator
  }

  private def open(): MnistManager = {
    // CAVEAT: MnistManager supports only local files at this time.
    val imagesFile = Paths.get(imagesPath).toFile.getAbsolutePath
    val labelsFile = Paths.get(labelsPath).toFile.getAbsolutePath
    new MnistManager(imagesFile, labelsFile)
  }
}

/**
 * Mnist dataset provider.
 */
class DefaultSource extends RelationProvider {
  import DefaultSource._

  private def checkImagesFilePath(parameters: Map[String, String]): String = {
    parameters.getOrElse(ImagesPath,
      sys.error("'imagesPath' must be specified for mnist data"))
  }

  private def checkLabelsFilePath(parameters: Map[String, String]): String = {
    parameters.getOrElse(LabelsPath,
      sys.error("'labelsPath' must be specified for mnist labels"))
  }

  private def checkRecordsPerPartition(parameters: Map[String, String]): Int = {
    parameters.getOrElse(RecordsPerPartition, 1000) match {
      case r: String => Integer.parseInt(r)
      case r: Int => r
    }
  }

  private def checkMaxRecords(parameters: Map[String, String]): Option[Int] = {
    parameters.getOrElse(MaxRecords, None) match {

      case r: String => Option(Integer.parseInt(r))
      case None => None
    }
  }

  override def createRelation(sqlContext: SQLContext,
      parameters: Map[String, String]) = {
    val imagesPath = new Path(checkImagesFilePath(parameters))
    val labelsPath = new Path(checkLabelsFilePath(parameters))
    val recordsPerPartition = checkRecordsPerPartition(parameters)
    val maxRecords = checkMaxRecords(parameters)
    new MnistRelation(imagesPath, labelsPath, recordsPerPartition, maxRecords)(sqlContext)
  }
}

object DefaultSource {
  val ImagesPath = "imagesPath"
  val LabelsPath = "labelsPath"
  val RecordsPerPartition = "recordsPerPartition"
  val MaxRecords = "maxRecords"
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy