All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.dsl.mmlib.algs.SQLMnistLoaderExt.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.dsl.mmlib.algs

import java.nio.ByteBuffer

import com.intel.analytics.bigdl.dataset.{ByteRecord, DataSet, DistributedDataSet, MiniBatch}
import com.intel.analytics.bigdl.dataset.image.{BytesToGreyImg, GreyImgNormalizer, GreyImgToBatch}
import com.intel.analytics.bigdl.models.lenet.Utils.{trainMean, trainStd}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.Engine
import org.apache.spark.ml.param.{BooleanParam, Param}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.mlsql.session.MLSQLException
import streaming.common.HDFSOperator
import streaming.dsl.mmlib.{AlgType, ModelType, ProcessType, SQLAlg}
import streaming.dsl.mmlib.algs.param.BaseParams

class SQLMnistLoaderExt(override val uid: String) extends SQLAlg with BaseParams {

  def this() = this(BaseParams.randomUID())

  def load(featureFile: String, labelFile: String) = {

    val featureBuffer = ByteBuffer.wrap(HDFSOperator.readBytes(featureFile))
    val labelBuffer = ByteBuffer.wrap(HDFSOperator.readBytes(labelFile))

    val labelMagicNumber = labelBuffer.getInt()
    require(labelMagicNumber == 2049)
    val featureMagicNumber = featureBuffer.getInt()
    require(featureMagicNumber == 2051)
    val labelCount = labelBuffer.getInt()
    val featureCount = featureBuffer.getInt()
    require(labelCount == featureCount)
    val rowNum = featureBuffer.getInt()
    val colNum = featureBuffer.getInt()

    val result = new Array[ByteRecord](featureCount)
    var i = 0
    while (i < featureCount) {
      val img = new Array[Byte]((rowNum * colNum))
      var y = 0
      while (y < rowNum) {
        var x = 0
        while (x < colNum) {
          img(x + y * colNum) = featureBuffer.get()
          x += 1
        }
        y += 1
      }
      result(i) = ByteRecord(img, labelBuffer.get().toFloat + 1.0f)
      i += 1
    }

    result
  }

  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    Engine.init
    params.get(mnistDir.name).
      map(m => set(mnistDir, m)).getOrElse {
      set(mnistDir, path)
      require($(mnistDir) != null, "mnistDir should not empty")
    }



    params.get(dataType.name).
      map(m => set(dataType, m)).
      getOrElse("")

    val trainData = $(mnistDir) + "/train-images-idx3-ubyte"
    val trainLabel = $(mnistDir) + "/train-labels-idx1-ubyte"
    val validationData = $(mnistDir) + "/t10k-images-idx3-ubyte"
    val validationLabel = $(mnistDir) + "/t10k-labels-idx1-ubyte"

    val data = if ($(dataType) == "train") trainData else validationData
    val validate = if ($(dataType) == "train") trainLabel else validationLabel

    val trainSet = DataSet.array(load(data, validate), df.sparkSession.sparkContext) ->
      BytesToGreyImg(28, 28) -> GreyImgNormalizer(trainMean, trainStd) -> GreyImgToBatch(1)

    val trainingRDD: RDD[Data[Float]] = trainSet.
      asInstanceOf[DistributedDataSet[MiniBatch[Float]]].data(false).map(batch => {
      val feature = batch.getInput().asInstanceOf[Tensor[Float]]
      val label = batch.getTarget().asInstanceOf[Tensor[Float]]
      Data[Float](feature.storage().array(), label.storage().array())
    })
    val trainingDF = df.sparkSession.createDataFrame(trainingRDD).toDF("features", "label")
    trainingDF
  }

  override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    train(df, path, params)
  }

  override def modelType: ModelType = ProcessType

  override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
    throw new MLSQLException(s"register statement is not support by ${getClass.getName}")
  }

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = ???

  final val mnistDir: Param[String] = new Param[String](this, "mnistDir", "mnistDir directory which contains 4 ubyte files")
  final val dataType: Param[String] = new Param[String](this, "dataType", "load train|validate data")
  set(dataType, "train")
}

private case class Data[T](featureData: Array[T], labelData: Array[T])




© 2015 - 2024 Weber Informatics LLC | Privacy Policy