com.intel.analytics.zoo.pipeline.nnframes.NNEstimator.scala Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2018 Analytics Zoo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.zoo.pipeline.nnframes
import java.io.{FileInputStream, FileOutputStream, ObjectInputStream, ObjectOutputStream}
import com.intel.analytics.bigdl.dataset.{SampleToMiniBatch, _}
import com.intel.analytics.bigdl.models.utils.ModelBroadcast
import com.intel.analytics.bigdl.optim._
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.tensor.{Tensor, DoubleType => TensorDouble, FloatType => TensorFloat}
import com.intel.analytics.bigdl.utils.T
import com.intel.analytics.bigdl.utils.serializer.ModuleLoader
import com.intel.analytics.bigdl.visualization.{TrainSummary, ValidationSummary}
import com.intel.analytics.bigdl.{Criterion, DataSet, Module}
import com.intel.analytics.zoo.feature.common.{Preprocessing, _}
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.ml.adapter.{HasFeaturesCol, HasPredictionCol, SchemaUtils}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml.{DLEstimatorBase, DLTransformerBase, DefaultParamsWriterWrapper}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}
import org.json4s.JsonDSL._
import org.json4s.{DefaultFormats, JObject}
import scala.reflect.ClassTag
private[nnframes] trait HasBatchSize extends Params {
final val batchSize: IntParam = new IntParam(this, "batchSize", "batchSize")
def getBatchSize: Int = $(batchSize)
}
private[nnframes] trait TrainingParams[@specialized(Float, Double) T] extends Params {
/**
* When to stop the training, passed in a [[Trigger]]. E.g. Trigger.maxIterations
*/
final val endWhen = new Param[Trigger](this, "endWhen", "Trigger to stop the training")
def getEndWhen: Trigger = $(endWhen)
/**
* learning rate for the optimizer in the NNEstimator.
* Default: 0.001
*/
final val learningRate = new DoubleParam(
this, "learningRate", "learningRate", ParamValidators.gt(0))
def getLearningRate: Double = $(learningRate)
/**
* learning rate decay for each iteration.
* Default: 0
*/
final val learningRateDecay = new DoubleParam(this, "learningRateDecay", "learningRateDecay")
def getLearningRateDecay: Double = $(learningRateDecay)
/**
* Number of max Epoch for the training, an epoch refers to a traverse over the training data
* Default: 50
*/
final val maxEpoch = new IntParam(this, "maxEpoch", "number of max Epoch", ParamValidators.gt(0))
def getMaxEpoch: Int = $(maxEpoch)
/**
* optimization method to be used. BigDL supports many optimization methods like Adam,
* SGD and LBFGS. Refer to package com.intel.analytics.bigdl.optim for all the options.
* Default: SGD
*/
final val optimMethod = new Param[OptimMethod[T]](this, "optimMethod", "optimMethod")
def getOptimMethod: OptimMethod[T] = $(optimMethod)
/**
* Constant gradient clipping thresholds.
*/
final val constantGradientClippingParams = new Param[(Float, Float)](this,
"threshold for constant clipping", "constantGradientClippingParams")
/**
* L2 norm gradient clipping threshold.
*/
final val l2GradientClippingParams = new FloatParam(this,
"threshold for l2 norm gradient clipping", "l2GradientClippingParams")
/**
* whether to cache the Samples after preprocessing.
* Default: true
*/
final val cachingSample = new BooleanParam(
this, "cachingSample", "whether to cache the Samples after preprocessing")
def isCachingSample: Boolean = $(cachingSample)
}
/**
* Common trait for NNEstimator and NNModel
*/
private[nnframes] trait NNParams[@specialized(Float, Double) T] extends HasFeaturesCol
with HasPredictionCol with HasBatchSize {
final val samplePreprocessing = new Param[Preprocessing[Any, Sample[T]]](this,
"samplePreprocessing", "samplePreprocessing ")
def getSamplePreprocessing: Preprocessing[Any, Sample[T]] = $(samplePreprocessing)
setDefault(batchSize -> 1)
}
/**
* [[NNEstimator]] extends [[org.apache.spark.ml.Estimator]] and supports training a BigDL
* model with Spark DataFrame data. It can be integrated into a standard Spark ML Pipeline
* to allow users combine the components of BigDL and Spark MLlib.
*
* [[NNEstimator]] supports different feature and label data type through [[Preprocessing]]. We
* provide pre-defined [[Preprocessing]] for popular data types like Array or Vector in package
* [[com.intel.analytics.zoo.feature]], while user can also develop customized [[Preprocessing]].
* During fit, NNEstimator will extract feature and label data from input DataFrame and use
* the [[Preprocessing]] to prepare data for the model. Using the [[Preprocessing]] allows
* [[NNEstimator]] to cache only the raw data and decrease the memory consumption during feature
* conversion and training.
* More concrete examples are available in package [[com.intel.analytics.zoo.examples.nnframes]]
*
* @param model BigDL module to be optimized
* @param criterion BigDL criterion
* @tparam T data type of BigDL Model
*/
class NNEstimator[T: ClassTag] private[zoo] (
@transient val model: Module[T],
val criterion : Criterion[T],
override val uid: String = Identifiable.randomUID("nnestimator")
)(implicit ev: TensorNumeric[T])
extends DLEstimatorBase[NNEstimator[T], NNModel[T]] with NNParams[T]
with TrainingParams[T] {
def setSamplePreprocessing[FF <: Any, LL <: Any](
value: Preprocessing[(FF, Option[LL]), Sample[T]]): this.type =
set(samplePreprocessing, value.asInstanceOf[Preprocessing[Any, Sample[T]]])
def setFeaturesCol(featuresColName: String): this.type = set(featuresCol, featuresColName)
def setLabelCol(labelColName : String) : this.type = set(labelCol, labelColName)
def setPredictionCol(value: String): this.type = set(predictionCol, value)
def setBatchSize(value: Int): this.type = set(batchSize, value)
def setEndWhen(trigger: Trigger): this.type = set(endWhen, trigger)
def setLearningRate(value: Double): this.type = set(learningRate, value)
setDefault(learningRate -> 1e-3)
def setLearningRateDecay(value: Double): this.type = set(learningRateDecay, value)
setDefault(learningRateDecay -> 0.0)
def setMaxEpoch(value: Int): this.type = set(maxEpoch, value)
setDefault(maxEpoch -> 50)
def setOptimMethod(value: OptimMethod[T]): this.type = set(optimMethod, value)
set(optimMethod, new SGD[T])
def setConstantGradientClipping(min: Float, max: Float): this.type = {
set(constantGradientClippingParams, (min, max))
}
def setGradientClippingByL2Norm(clipNorm: Float): this.type = {
set(l2GradientClippingParams, clipNorm)
}
def setCachingSample(value: Boolean): this.type = {
set(cachingSample, value)
}
setDefault(cachingSample, true)
/**
* Clear clipping params, in this case, clipping will not be applied.
*/
def clearGradientClipping(): this.type = {
clear(l2GradientClippingParams)
clear(constantGradientClippingParams)
}
@transient private var trainSummary: Option[TrainSummary] = None
def getTrainSummary: Option[TrainSummary] = trainSummary
/**
* Statistics (LearningRate, Loss, Throughput, Parameters) collected during training for the
* training data, which can be used for visualization via Tensorboard.
* Use setTrainSummary to enable train logger. Then the log will be saved to
* logDir/appName/train as specified by the parameters of TrainSummary.
*
* Default: Not enabled
*/
def setTrainSummary(value: TrainSummary): this.type = {
this.trainSummary = Some(value)
this
}
@transient private var validationSummary: Option[ValidationSummary] = None
/**
* Statistics (LearningRate, Loss, Throughput, Parameters) collected during training for the
* validation data if validation data is set, which can be used for visualization via
* Tensorboard. Use setValidationSummary to enable validation logger. Then the log will be
* saved to logDir/appName/ as specified by the parameters of validationSummary.
*
* Default: None
*/
def getValidationSummary: Option[ValidationSummary] = validationSummary
/**
* Enable validation Summary
*/
def setValidationSummary(value: ValidationSummary): this.type = {
this.validationSummary = Some(value)
this
}
@transient protected var validationTrigger: Option[Trigger] = None
@transient protected var validationDF: DataFrame = _
@transient protected var validationMethods: Array[ValidationMethod[T]] = _
@transient protected var validationBatchSize: Int = 0
/**
* Set a validate evaluation during training
*
* @param trigger how often to evaluation validation set
* @param validationDF validate data set
* @param vMethods a set of validation method [[ValidationMethod]]
* @param batchSize batch size for validation
* @return this optimizer
*/
def setValidation(trigger: Trigger, validationDF: DataFrame,
vMethods : Array[ValidationMethod[T]], batchSize: Int)
: this.type = {
this.validationTrigger = Some(trigger)
this.validationDF = validationDF
this.validationMethods = vMethods
this.validationBatchSize = batchSize
this
}
/**
* get the validate configuration during training
*
* @return an Option of Tuple(ValidationTrigger, Validation data, Array[ValidationMethod[T] ],
* batchsize)
*/
def getValidation: Option[(Trigger, DataFrame, Array[ValidationMethod[T]], Int)] = {
if (validationTrigger.isDefined) {
Some(validationTrigger.get, validationDF, validationMethods, validationBatchSize)
}
else {
None
}
}
protected def validateParams(schema : StructType): Unit = {
if (isSet(endWhen) && isSet(maxEpoch)) {
throw new IllegalArgumentException(s"endWhen and maxEpoch cannot be both set")
}
if (validationTrigger.isEmpty && validationSummary.isDefined) {
throw new IllegalArgumentException(
s"validationSummary is only valid if validation data is set.")
}
}
override def transformSchema(schema : StructType): StructType = {
validateParams(schema)
ev.getType() match {
case TensorDouble =>
SchemaUtils.appendColumn(schema, $(predictionCol), ArrayType(DoubleType, false))
case TensorFloat =>
SchemaUtils.appendColumn(schema, $(predictionCol), ArrayType(FloatType, false))
case _ => throw new Exception("Only support Double and Float for now")
}
}
private def getDataSet(
dataFrame: DataFrame,
batchSize: Int): DataSet[MiniBatch[T]] = {
val sp = $(samplePreprocessing).asInstanceOf[Preprocessing[(Any, Option[Any]), Sample[T]]]
val featureColIndex = dataFrame.schema.fieldIndex($(featuresCol))
val labelColIndex = if (dataFrame.columns.contains($(labelCol))) {
Some(dataFrame.schema.fieldIndex($(labelCol)))
} else {
None
}
val featureAndLabel = dataFrame.rdd.map { row =>
val features = row.get(featureColIndex)
val labels = labelColIndex match {
case Some(i) => Some(row.get(i))
case None => None
}
(features, labels)
}
val initialDataSet = if ($(cachingSample)) {
DataSet.rdd(sp.apply(featureAndLabel))
} else {
DataSet.rdd(featureAndLabel).transform(sp)
}
initialDataSet.transform(SampleToMiniBatch[T](batchSize))
}
protected override def internalFit(dataFrame: DataFrame): NNModel[T] = {
val trainingDataSet = getDataSet(dataFrame, $(batchSize))
val state = T("learningRate" -> $(learningRate), "learningRateDecay" -> $(learningRateDecay))
val endTrigger = if (isSet(endWhen)) $(endWhen) else Trigger.maxEpoch($(maxEpoch))
val optimizer = Optimizer(model, trainingDataSet, criterion)
.setState(state)
.setOptimMethod($(optimMethod))
.setEndWhen(endTrigger)
if (isSet(l2GradientClippingParams)) {
optimizer.setGradientClippingByl2Norm($(l2GradientClippingParams))
}
if (isSet(constantGradientClippingParams)) {
val constantClippingValues = $(constantGradientClippingParams)
optimizer.setConstantGradientClipping(constantClippingValues._1, constantClippingValues._2)
}
if (validationTrigger.isDefined) {
val validationSamples = getDataSet(validationDF, validationBatchSize)
optimizer.setValidation(
validationTrigger.get,
validationSamples,
validationMethods)
if (this.validationSummary.isDefined) {
optimizer.setValidationSummary(this.validationSummary.get)
}
}
if (this.trainSummary.isDefined) {
optimizer.setTrainSummary(this.trainSummary.get)
}
val optimizedModel = optimizer.optimize()
wrapBigDLModel(optimizedModel)
}
/**
* sub classes can extend the method and return required model for different transform tasks
*/
protected def wrapBigDLModel(m: Module[T]): NNModel[T] = {
val dlModel = new NNModel[T](m)
copyValues(dlModel.setParent(this))
val clonedTransformer = ToTuple() -> $(samplePreprocessing)
.asInstanceOf[Preprocessing[(Any, Option[Any]), Sample[T]]].clonePreprocessing()
dlModel.setSamplePreprocessing(clonedTransformer)
}
/**
* Return a deep copy for DLEstimator.
* Note that trainSummary and validationSummary will not be copied to the new instance since
* currently they are not thread-safe.
*/
override def copy(extra: ParamMap): NNEstimator[T] = {
val copied = copyValues(
new NNEstimator[T](
model.cloneModule(),
criterion.cloneCriterion(),
this.uid
), extra)
if (this.validationTrigger.isDefined) {
copied.setValidation(
validationTrigger.get, validationDF, validationMethods.clone(), validationBatchSize)
}
copied
}
}
object NNEstimator {
/**
* Construct a [[NNEstimator]] with default Preprocessing: A SeqToTensor
*
* @param model BigDL module to be optimized
* @param criterion BigDL criterion method
*/
def apply[T: ClassTag](
model: Module[T],
criterion: Criterion[T]
)(implicit ev: TensorNumeric[T]): NNEstimator[T] = {
new NNEstimator(model, criterion)
.setSamplePreprocessing(FeatureLabelPreprocessing(SeqToTensor(), SeqToTensor()))
}
/**
* Construct a [[NNEstimator]] with a feature size and label size. The constructor is useful
* when the feature column and label column contains the following data types:
* Float, Double, Int, Array[Float], Array[Double], Array[Int] and MLlib Vector. The feature and
* label data are converted to Tensors with the specified sizes before sending to the model.
*
* @param model BigDL module to be optimized
* @param criterion BigDL criterion method
* @param featureSize The size (Tensor dimensions) of the feature data. e.g. an image may be with
* width * height = 28 * 28, featureSize = Array(28, 28).
* @param labelSize The size (Tensor dimensions) of the label data.
*/
def apply[T: ClassTag](
model: Module[T],
criterion: Criterion[T],
featureSize : Array[Int],
labelSize : Array[Int]
)(implicit ev: TensorNumeric[T]): NNEstimator[T] = {
new NNEstimator(model, criterion)
.setSamplePreprocessing(FeatureLabelPreprocessing(
SeqToTensor(featureSize), SeqToTensor(labelSize))
)
}
/**
* Construct a [[NNEstimator]] with a feature Preprocessing and label Preprocessing.
*
* @param model BigDL module to be optimized
* @param criterion BigDL criterion method
* @param featurePreprocessing Preprocessing[Any, Tensor[T] ]
* @param labelPreprocessing Preprocessing[Any, Tensor[T] ]
*/
def apply[F, L, T: ClassTag](
model: Module[T],
criterion: Criterion[T],
featurePreprocessing: Preprocessing[F, Tensor[T]],
labelPreprocessing: Preprocessing[L, Tensor[T]]
)(implicit ev: TensorNumeric[T]): NNEstimator[T] = {
new NNEstimator(model, criterion)
.setSamplePreprocessing(FeatureLabelPreprocessing(featurePreprocessing, labelPreprocessing))
}
/**
* Construct a [[NNEstimator]] with a featurePreprocessing only. The constructor is useful
* when both feature and label are derived from the same column of the original DataFrame.
*
* @param model BigDL module to be optimized
* @param criterion BigDL criterion method
* @param featurePreprocessing A [[Preprocessing]] that transforms the feature data to a
* Sample[T].
*/
def apply[F, T: ClassTag](
model: Module[T],
criterion: Criterion[T],
featurePreprocessing: Preprocessing[F, Sample[T]]
)(implicit ev: TensorNumeric[T]): NNEstimator[T] = {
new NNEstimator(model, criterion)
.setSamplePreprocessing(TupleToFeatureAdapter(featurePreprocessing))
}
}
/**
* [[NNModel]] extends Spark ML Transformer and supports BigDL model with Spark DataFrame data.
*
* [[NNModel]] supports different feature data type through [[Preprocessing]]. We
* provide pre-defined [[Preprocessing]] for popular data types like Array or Vector in package
* [[com.intel.analytics.zoo.feature]], while user can also develop
* customized [[Preprocessing]].
* During transform, [[NNModel]] will extract feature data from input DataFrame and use
* the [[Preprocessing]] to prepare data for the model.
*
* After transform, the prediction column contains the output of the model as Array[T], where
* T (Double or Float) is decided by the model type.
*
* @param model trained BigDL models to use in prediction.
*/
class NNModel[T: ClassTag] private[zoo] (
@transient val model: Module[T],
override val uid: String = "DLModel")(implicit ev: TensorNumeric[T])
extends DLTransformerBase[NNModel[T]] with NNParams[T]
with HasBatchSize with MLWritable {
def setFeaturesCol(featuresColName: String): this.type = set(featuresCol, featuresColName)
def setPredictionCol(value: String): this.type = set(predictionCol, value)
def setBatchSize(value: Int): this.type = set(batchSize, value)
/**
* set Preprocessing.
* @param value: A [[Preprocessing]] that transforms the feature data to a Sample[T].
*/
def setSamplePreprocessing[FF <: Any](value: Preprocessing[FF, Sample[T]]): this.type =
set(samplePreprocessing, value.asInstanceOf[Preprocessing[Any, Sample[T]]])
/**
* Perform a prediction on featureCol, and write result to the predictionCol.
*/
protected override def internalTransform(dataFrame: DataFrame): DataFrame = {
val featureColIndex = dataFrame.schema.fieldIndex($(featuresCol))
val sc = dataFrame.sqlContext.sparkContext
val modelBroadCast = ModelBroadcast[T]().broadcast(sc, model.evaluate())
val localBatchSize = $(batchSize)
val featureTransformersBC = sc.broadcast($(samplePreprocessing))
val toBatchBC = sc.broadcast(SampleToMiniBatch[T](localBatchSize))
// concat the prediction and other columns in DF. avoid zip between RDD
val resultRDD = dataFrame.rdd.mapPartitions { rowIter =>
val localModel = modelBroadCast.value()
val featureSteps = featureTransformersBC.value.cloneTransformer()
val toBatch = toBatchBC.value.cloneTransformer()
rowIter.grouped(localBatchSize).flatMap { rowBatch =>
val featureSeq = rowBatch.map(r => r.get(featureColIndex))
val samples = featureSteps(featureSeq.iterator)
val predictions = toBatch(samples).flatMap { batch =>
val batchResult = localModel.forward(batch.getInput()).toTensor.squeeze()
if (batchResult.size().length == 2) {
batchResult.split(1).map(outputToPrediction)
} else if (batchResult.size().length == 1) {
Array(outputToPrediction(batchResult))
} else {
throw new RuntimeException(
"unexpected batchResult dimension: " + batchResult.size().mkString(", "))
}
}
rowBatch.toIterator.zip(predictions).map { case (row, predict) =>
Row.fromSeq(row.toSeq ++ Seq(predict))
}
}
}
val resultSchema = transformSchema(dataFrame.schema)
dataFrame.sqlContext.createDataFrame(resultRDD, resultSchema)
}
protected def outputToPrediction(output: Tensor[T]): Any = {
output.clone().storage().array()
}
override def transformSchema(schema : StructType): StructType = {
ev.getType() match {
case TensorDouble =>
SchemaUtils.appendColumn(schema, $(predictionCol), ArrayType(DoubleType, false))
case TensorFloat =>
SchemaUtils.appendColumn(schema, $(predictionCol), ArrayType(FloatType, false))
case _ => throw new Exception("Only support Double and Float for now")
}
}
override def copy(extra: ParamMap): NNModel[T] = {
val copied = new NNModel[T](model.cloneModule(), uid).setParent(parent)
copyValues(copied, extra)
}
override def write: MLWriter = new NNModel.NNModelWriter[T](this)
}
object NNModel extends MLReadable[NNModel[_]] {
/**
* Construct a [[NNModel]] with default Preprocessing: SeqToTensor
*
* @param model BigDL module to be optimized
*/
def apply[T: ClassTag](
model: Module[T]
)(implicit ev: TensorNumeric[T]): NNModel[T] = {
new NNModel(model)
.setSamplePreprocessing(SeqToTensor() -> TensorToSample())
}
/**
* Construct a [[NNModel]] with a feature size.
*
* @param model BigDL module to be optimized
* @param featureSize The size (Tensor dimensions) of the feature data. e.g. an image may be with
* width * height = 28 * 28, featureSize = Array(28, 28).
*/
def apply[T: ClassTag](
model: Module[T],
featureSize: Array[Int]
)(implicit ev: TensorNumeric[T]): NNModel[T] = {
new NNModel(model)
.setSamplePreprocessing(SeqToTensor(featureSize) -> TensorToSample())
}
/**
* Construct a [[NNModel]] with a feature Preprocessing.
*
* @param model BigDL module to be optimized
* @param featurePreprocessing Preprocessing[F, Tensor[T] ].
*/
def apply[F, T: ClassTag](
model: Module[T],
featurePreprocessing: Preprocessing[F, Tensor[T]]
)(implicit ev: TensorNumeric[T]): NNModel[T] = {
new NNModel(model).setSamplePreprocessing(featurePreprocessing -> TensorToSample())
}
import scala.language.existentials
implicit val format: DefaultFormats.type = DefaultFormats
private[nnframes] class NNModelReader() extends MLReader[NNModel[_]] {
override def load(path: String): NNModel[_] = {
val (meta, model, typeTag, feaTran) = NNModel.getMetaAndModel(path, sc)
val featureSize = (meta.metadata \ "featureSize").extract[Seq[Int]].toArray
val nnModel = typeTag match {
case "TensorDouble" =>
new NNModel[Double](model.asInstanceOf[Module[Double]])
.setSamplePreprocessing(feaTran.asInstanceOf[Preprocessing[Any, Sample[Double]]])
case "TensorFloat" =>
new NNModel[Float](model.asInstanceOf[Module[Float]])
.setSamplePreprocessing(feaTran.asInstanceOf[Preprocessing[Any, Sample[Float]]])
case _ =>
throw new Exception("Only support float and double for now")
}
DefaultParamsWriterWrapper.getAndSetParams(nnModel, meta)
nnModel
}
}
private[nnframes] def getMetaAndModel(path: String, sc: SparkContext) = {
val meta = DefaultParamsWriterWrapper.loadMetadata(path, sc)
val (modulePath, weightPath) =
new Path(path, "module").toString -> new Path(path, "weight").toString
val typeTag = (meta.metadata \ "tensorDataType").extract[String]
val model = typeTag match {
case "TensorDouble" =>
ModuleLoader.loadFromFile[Double](modulePath, weightPath)
case "TensorFloat" =>
ModuleLoader.loadFromFile[Float](modulePath, weightPath)
case _ =>
throw new Exception("Only support float and double for now")
}
val ois = new ObjectInputStream(
new FileInputStream(new Path(path, "samplePreprocessing").toString))
val featurePreprocessing = try {
ois.readObject.asInstanceOf[Preprocessing[Any, Any]]
} finally {
ois.close()
}
(meta, model, typeTag, featurePreprocessing)
}
class NNModelWriter[@specialized(Float, Double) T: ClassTag](
instance: NNModel[T])(implicit ev: TensorNumeric[T]) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
NNModel.saveImpl[T](instance, instance.model,
path, sc, shouldOverwrite)
}
}
/**
* Helper method for saving a NNModel to disk.
* For compatibility with spark ml pipeline, TensorDataType is stored separately in extraMetadata.
*
* @tparam T TensorDataType
* @param instance NNModel
* @param path Path to which to save the NNModel.
* @param extraMetadata Metadata such as featureSize.
*/
private[nnframes] def saveImpl[@specialized(Float, Double) T: ClassTag](
instance: NNModel[T],
module: Module[T],
path: String,
sc: SparkContext,
isOverWrite: Boolean = false,
extraMetadata: Option[JObject] = None)(implicit ev: TensorNumeric[T]): Unit = {
val tensorDataType = ev.getType() match {
case TensorDouble => "TensorDouble"
case TensorFloat => "TensorFloat"
case _ => throw new Exception("Only support Double and Float for now")
}
val extra = extraMetadata.getOrElse(JObject()) ~ ("tensorDataType" -> tensorDataType)
// bypass the default save for samplePreprocessing
val spCache = instance.getSamplePreprocessing
instance.clear(instance.samplePreprocessing)
DefaultParamsWriterWrapper.saveMetadata(instance, path, sc, Option(extra))
instance.setSamplePreprocessing(spCache)
val (modulePath, weightPath) =
new Path(path, "module").toString -> new Path(path, "weight").toString
module.saveModule(modulePath, weightPath, isOverWrite)
val fos = new FileOutputStream(new Path(path, "samplePreprocessing").toString)
val oos = new ObjectOutputStream(fos)
try {
oos.writeObject(spCache)
} finally {
oos.close()
}
}
override def read: MLReader[NNModel[_]] = new NNModelReader
override def load(path: String): NNModel[_] = read.load(path)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy