All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.dsl.mmlib.algs.SQLLeNet5Ext.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.dsl.mmlib.algs

import com.intel.analytics.bigdl.dlframes.{DLClassifier, DLModel}
import com.intel.analytics.bigdl.models.lenet.LeNet5
import com.intel.analytics.bigdl.models.utils.ModelBroadcast
import com.intel.analytics.bigdl.nn.{ClassNLLCriterion, Module}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.Engine
import net.sf.json.JSONArray
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import streaming.common.HDFSOperator
import streaming.dsl.mmlib._
import streaming.dsl.mmlib.algs.classfication.BaseClassification
import streaming.dsl.mmlib.algs.param.BaseParams
import streaming.dsl.mmlib.algs.bigdl.BigDLFunctions

import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer

class SQLLeNet5Ext(override val uid: String) extends SQLAlg with MllibFunctions with BigDLFunctions with BaseClassification {

  def this() = this(BaseParams.randomUID())

  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    Engine.init
    params.get(keepVersion.name).
      map(m => set(keepVersion, m.toBoolean)).
      getOrElse($(keepVersion))

    val eTable = params.get(evaluateTable.name)

    SQLPythonFunc.incrementVersion(path, $(keepVersion))
    val spark = df.sparkSession


    bigDLClassifyTrain[Float](df, path, params, (newFitParam) => {
      val model = LeNet5(classNum = newFitParam("classNum").toInt)
      val criterion = ClassNLLCriterion[Float]()
      val alg = new DLClassifier[Float](model, criterion,
        JSONArray.fromObject(newFitParam("featureSize")).map(f => f.asInstanceOf[Int]).toArray)
      alg
    }, (_model, newFitParam) => {
      eTable match {
        case Some(etable) =>
          val model = _model
          val evaluateTableDF = spark.table(etable)
          val predictions = model.transform(evaluateTableDF)
          multiclassClassificationEvaluate(predictions, (evaluator) => {
            evaluator.setLabelCol(newFitParam.getOrElse("labelCol", "label"))
            evaluator.setPredictionCol("prediction")
          })

        case None => List()
      }
    })

    formatOutput(getModelMetaData(spark, path))
  }


  override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    Engine.init
    val models = load(df.sparkSession, path, params).asInstanceOf[ArrayBuffer[DLModel[Float]]]
    models.head.transform(df)
  }

  override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
    Engine.init
    val (bestModelPath, baseModelPath, metaPath) = mllibModelAndMetaPath(path, params, sparkSession)
    val trainParams = sparkSession.read.parquet(metaPath + "/0").collect().head.getAs[Map[String, String]]("trainParams")

    val featureSize = JSONArray.fromObject(trainParams("featureSize")).map(f => f.asInstanceOf[Int]).toArray

    val model = Module.loadModule[Float](HDFSOperator.getFilePath(bestModelPath(0)))
    val dlModel = new DLModel[Float](model, featureSize)
    ArrayBuffer(dlModel)
  }

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
    val dlmodel = _model.asInstanceOf[ArrayBuffer[DLModel[Float]]].head
    val featureSize = dlmodel.featureSize
    val modelBroadCast = ModelBroadcast[Float]().broadcast(sparkSession.sparkContext, dlmodel.model.evaluate())


    val f = (vec: Vector) => {

      val localModel = modelBroadCast.value()
      val featureTensor = Tensor(vec.toArray.map(f => f.toFloat), featureSize)
      val output = localModel.forward(featureTensor)
      val res = output.toTensor[Float].clone().storage().array().map(f => f.toDouble)
      Vectors.dense(res)
    }

    UserDefinedFunction(f, VectorType, Some(Seq(VectorType)))
  }


  override def explainParams(sparkSession: SparkSession): DataFrame = {
    _explainParams(sparkSession, () => {
      val model = LeNet5(classNum = 10)
      val criterion = ClassNLLCriterion[Float]()
      val alg = new DLClassifier[Float](model, criterion, Array(28, 28))
      alg
    })
  }

  override def modelType: ModelType = AlgType

  override def doc: Doc = Doc(MarkDownDoc,
    """
      |
      |load modelParams.`LeNet5Ext`  as output;
    """.stripMargin)

  override def codeExample: Code = Code(SQLCode,
    """
      |set json = '''{}''';
      |load jsonStr.`json` as emptyData;
      |
      |run emptyData as MnistLoaderExt.`` where
      |mnistDir="/Users/allwefantasy/Downloads/mnist"
      |as data;
      |
      |train data as LeNet5Ext.`/tmp/lenet` where
      |fitParam.0.featureSize="[28,28]"
      |and fitParam.0.classNum="10";
      |
      |predict data as LeNet5Ext.`/tmp/lenet`;
      |
      |register LeNet5Ext.`/tmp/lenet` as mnistPredict;
      |
      |select
      |vec_argmax(mnistPredict(vec_dense(features))) as predict_label,
      |label from data
      |as output;
    """.stripMargin)
}

case class BigDLDefaultConfig(batchSize: Int = 128, maxEpoch: Int = 1)




© 2015 - 2024 Weber Informatics LLC | Privacy Policy