All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.dsl.mmlib.algs.bigdl.BigDLClassifyExtParamsExtractor.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.dsl.mmlib.algs.bigdl

import com.intel.analytics.bigdl.optim.{Loss, _}
import org.apache.spark.sql.mlsql.session.MLSQLException
import streaming.common.ScalaObjectReflect
import streaming.dsl.mmlib.algs.SQLBigDLClassifyExt
import org.json4s._
import org.json4s.jackson.JsonMethods


class OptimizeParamExtractor(bigDLClassifyExt: SQLBigDLClassifyExt, _params: Map[String, String]) extends BaseExtractor {
  def optimizeMethod = {
    val methods = _params.filter(f => f._1 == cleanGroupPrefix(bigDLClassifyExt.optimizeMethod.name)).map { f =>
      f._2
    }.toArray
    OptimizeParamExtractor.extractOptimizeMethods(methods).headOption
  }
}

object OptimizeParamExtractor {
  private[bigdl] val optimizeMethodCandidates = List(
    classOf[Adam[Float]],
    classOf[Adamax[Float]],
    classOf[Adadelta[Float]],
    classOf[Ftrl[Float]],
    classOf[LBFGS[Float]],
    classOf[RMSprop[Float]],
    classOf[SGD[Float]]
  )

  def extractOptimizeMethods(methods: Array[String]) = {
    val filterSet = optimizeMethodCandidates.map(f => f.getSimpleName).toSet
    val notExistMethods = methods.filterNot(f => filterSet.contains(f))
    if (notExistMethods.size > 0) {
      throw new MLSQLException(
        s"""
           |${notExistMethods.mkString(",")} are not exits.
           |${EvaluateParamsExtractor.helperStr}
         """.stripMargin)
    }
    methods.map { name =>
      val candiate = optimizeMethodCandidates.filter(c => c.getSimpleName == name).head
      val instance = candiate match {
        case c if c == classOf[Adam[Float]] => new Adam[Float]()
        case c if c == classOf[Adamax[Float]] => new Adamax[Float]()
        case c if c == classOf[Adadelta[Float]] => new Adadelta[Float]()
        case c if c == classOf[Ftrl[Float]] => new Ftrl[Float]()
        case c if c == classOf[LBFGS[Float]] => new LBFGS[Float]()
        case c if c == classOf[RMSprop[Float]] => new RMSprop[Float]()
        case c if c == classOf[SGD[Float]] => new SGD[Float]()
      }
      instance
    }.toSeq
  }

  def optimizeMethodCandidatesStr = {
    optimizeMethodCandidates.map(f => f.getSimpleName).mkString("|")
  }

}

class ClassWeightParamExtractor(bigDLClassifyExt: SQLBigDLClassifyExt, _params: Map[String, String]) extends BaseExtractor {
  def weights = {
    implicit val formats = DefaultFormats
    val classWeight = _params.filter(f => f._1 == cleanGroupPrefix(bigDLClassifyExt.criterion_classWeight.name)).headOption
    val res = classWeight.map(f => JsonMethods.parse(f._2).extract[Array[Float]]).headOption
    if (classNum.isDefined && res.isDefined) {
      if (classNum.head != res.head.size) {
        throw new MLSQLException(s"classWeight should have the same size with classNum(${classNum.get})")
      }
    }
    res
  }

  def classNum = {
    _params.filter(f => f._1 == "classNum").map(f => f._2.toInt).headOption
  }

  def sizeAverage = {
    _params.filter(f => f._1 == cleanGroupPrefix(bigDLClassifyExt.criterion_sizeAverage.name)).map(f => f._2.toBoolean).headOption
  }

  def logProbAsInput = {
    _params.filter(f => f._1 == cleanGroupPrefix(bigDLClassifyExt.criterion_logProbAsInput.name)).map(f => f._2.toBoolean).headOption
  }

  def paddingValue = {
    _params.filter(f => f._1 == cleanGroupPrefix(bigDLClassifyExt.criterion_paddingValue.name)).map(f => f._2.toInt).headOption
  }
}

class EvaluateParamsExtractor(bigDLClassifyExt: SQLBigDLClassifyExt, _params: Map[String, String]) extends BaseExtractor {
  val bigDLClassifyExtParams = bigDLClassifyExt.params


  private[bigdl] def trigger = {
    _params.filter(f => f._1.startsWith("evaluate.trigger.")).map { f =>
      if (bigDLClassifyExtParams.filter(m => cleanGroupPrefix(m.name) == f._1).size == 0) {
        throw new MLSQLException(
          s"""
             |${f._1} is not recognized by BigDLClassifyExt
             |
             |${EvaluateParamsExtractor.helperStr}
           """.stripMargin)
      }
      val triggerType = f._1.split("\\.").last
      val (clzz, instance) = ScalaObjectReflect.findObjectMethod(Trigger.getClass.getName.stripSuffix("$"))
      val method = clzz.getDeclaredMethods.filter(f => f.getName == triggerType).head
      if (method.getParameterTypes.size == 0) {
        method.invoke(instance).asInstanceOf[Trigger]
      } else {
        // for now, trigger only have 0-1 parameters
        val pt = method.getParameterTypes.head
        val v2 = pt match {
          case i if i.isAssignableFrom(classOf[Int]) => f._2.toInt
          case i if i.isAssignableFrom(classOf[Double]) => f._2.toDouble
          case i if i.isAssignableFrom(classOf[Float]) => f._2.toFloat
          case i if i.isAssignableFrom(classOf[Boolean]) => f._2.toBoolean
          case i if i.isAssignableFrom(classOf[String]) => f._2
        }
        method.invoke(instance, v2.asInstanceOf[AnyRef]).asInstanceOf[Trigger]
      }

    }.headOption
  }

  private[bigdl] def evaluateTable = {
    _params.filter(f => f._1 == cleanGroupPrefix(bigDLClassifyExt.evaluate_table.name)).map(f => f._2).headOption
  }

  private[bigdl] def evaluateMethods = {
    _params.filter(f => f._1 == cleanGroupPrefix(bigDLClassifyExt.evaluate_methods.name)).map { f =>
      val methods = f._2.split(",")
      EvaluateParamsExtractor.extractEvaluateMethods(methods)
    }.headOption
  }

  private[bigdl] def batchSize = {
    _params.filter(f => f._1 == cleanGroupPrefix(bigDLClassifyExt.evaluate_batchSize.name)).map(f => f._2.toInt).headOption
  }


  def bigDLEvaluateConfig = {
    BigDLEvaluateConfig(trigger, evaluateTable, evaluateMethods.map(f => f.toArray), batchSize)
  }

  def defaultEvaluateMethods = Array(
    new Top1Accuracy[Float](), new Loss[Float]()
  )
}

trait BaseExtractor {
  private[bigdl] def cleanGroupPrefix(str: String) = {
    str.split("\\.").splitAt(2)._2.mkString(".")
  }

}

object ClassWeightParamExtractor {

}

object EvaluateParamsExtractor {
  private[bigdl] val evaluateMethodCandidates = List(
    classOf[Loss[Float]],
    classOf[Top1Accuracy[Float]],
    classOf[MAE[Float]],
    classOf[Top5Accuracy[Float]],
    classOf[TreeNNAccuracy[Float]]
  )

  def extractEvaluateMethods(methods: Array[String]) = {
    val filterSet = evaluateMethodCandidates.map(f => f.getSimpleName).toSet
    val notExistMethods = methods.filterNot(f => filterSet.contains(f))
    if (notExistMethods.size > 0) {
      throw new MLSQLException(
        s"""
           |${notExistMethods.mkString(",")} are not exits.
           |${helperStr}
         """.stripMargin)
    }
    methods.map { name =>
      evaluateMethodCandidates.filter(c => c.getSimpleName == name).head match {

        case a if a.isAssignableFrom(classOf[Loss[Float]]) => new Loss[Float]()
        case a if a.isAssignableFrom(classOf[Top1Accuracy[Float]]) => new Top1Accuracy[Float]()
        case a if a.isAssignableFrom(classOf[MAE[Float]]) => new MAE[Float]()
        case a if a.isAssignableFrom(classOf[Top5Accuracy[Float]]) => new Top5Accuracy[Float]()
        case a if a.isAssignableFrom(classOf[TreeNNAccuracy[Float]]) => new TreeNNAccuracy[Float]()

      }
    }.toSeq
  }

  def evaluateMethodCandidatesStr = {
    evaluateMethodCandidates.map(f => f.getSimpleName).mkString("|")
  }


  val helperStr =
    s"""
       |Please use load statement to check params detail:
       |
       |```
       |load modelParams.`BigDLClassifyExt` as output;
       |```
     """.stripMargin
}

case class BigDLEvaluateConfig(trigger: Option[Trigger], validationTable: Option[String],
                               vMethods: Option[Array[ValidationMethod[Float]]], batchSize: Option[Int])




© 2015 - 2024 Weber Informatics LLC | Privacy Policy