All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.sql.sources.iris.IrisRelation.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2015 Skymind,Inc.
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

package org.deeplearning4j.spark.sql.sources.iris

import org.apache.hadoop.conf.{Configuration => HadoopConfiguration}
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit}
import org.apache.spark.Logging
import org.apache.spark.ml.attribute.{NominalAttribute, Attribute}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}
import org.deeplearning4j.spark.sql.sources.canova.{CanovaRecordReaderAdapter, CanovaImageVectorizer}
import org.deeplearning4j.spark.sql.sources.lfw.LfwRelation
import org.deeplearning4j.spark.sql.sources.mapreduce.{PrunedReader, ColumnarRecordReader, LabelRecordReader, CachedStatus}
import org.deeplearning4j.spark.sql.types.VectorUDT

/**
 * Iris dataset as a Spark SQL relation.
 *
 * @author Eron Wright
 */
case class IrisRelation(location: String)(@transient val sqlContext: SQLContext)
  extends BaseRelation
  with PrunedScan with Logging {

  private val labelMetadata = new MetadataBuilder().putMetadata("ml_attr",
    new MetadataBuilder().putLong("num_vals", 3).build()).build()

  override def schema: StructType = StructType(
    StructField("label", DoubleType, nullable = false, metadata = labelMetadata) ::
      StructField("features", VectorUDT(), nullable = false) :: Nil)

  override def buildScan(requiredColumns: Array[String]): RDD[Row] = {

    val sc = sqlContext.sparkContext
    val baseRdd = MLUtils.loadLibSVMFile(sc, location)

    val rowBuilders = requiredColumns.map {
      case "label" => (pt: LabeledPoint) => Seq(pt.label)
      case "features" => (pt: LabeledPoint) => Seq(pt.features.toDense)
    }

    baseRdd.map(pt => {
      Row.fromSeq(rowBuilders.map(_(pt)).reduceOption(_ ++ _).getOrElse(Seq.empty))
    })
  }

  override def hashCode(): Int = 41 * (41 + location.hashCode) + schema.hashCode()

  override def equals(other: Any): Boolean = other match {
    case that: IrisRelation =>
      (this.location == that.location) && this.schema.equals(that.schema)
    case _ => false
  }
}

/**
 * Iris dataset provider.
 */
class DefaultSource extends RelationProvider {
  private def checkPath(parameters: Map[String, String]): String = {
    parameters.getOrElse("path", sys.error("'path' must be specified for Iris data."))
  }

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) = {
    val path = checkPath(parameters)
    new IrisRelation(path)(sqlContext)
  }
}









© 2015 - 2025 Weber Informatics LLC | Privacy Policy