All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.common.AngelSaverLoader.scala Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */

package com.tencent.angel.sona.ml.common
import com.tencent.angel.client.AngelPSClient
import com.tencent.angel.mlcore.conf.SharedConf
import com.tencent.angel.mlcore.utils.JsonUtils
import com.tencent.angel.sona.core.DriverContext
import org.apache.hadoop.fs.{FileSystem, Path}
import com.tencent.angel.sona.ml.util.{DefaultParamsReader, DefaultParamsWriter, MLReader, MLUtils, MLWriter}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row

import scala.reflect.ClassTag

object AngelSaverLoader {

  case class ModelData(sharedConfStr: String, angelModelName: String) {
    def this(conf: SharedConf, angelModelName: String) {
      this(conf.toString(), angelModelName)
    }

    lazy val conf: SharedConf = SharedConf.fromString(sharedConfStr)
  }

  private[angel] class AngelModelWriter[Model <: AngelSparkModel](instance: Model)
    extends MLWriter with Serializable with Logging {
    override protected def saveImpl(path: String): Unit = {
      // 1. Save metadata and Params
      DefaultParamsWriter.saveMetadata(instance, path, sc)

      // 2. save angel model to s"angel_${instance.modelName}"
      val angelModelPath = new Path(path, "angel")
      instance.angelModel.saveModel(DriverContext.get().sparkMasterContext,
        MLUtils.getHDFSPath(angelModelPath.toString))

      try {
        val fs = angelModelPath.getFileSystem(DriverContext.get().getAngelClient.getConf)
        val gjson = fs.create(new Path(angelModelPath, "graph.json"), true)
        val jsonStr = JsonUtils.toJsonConfStr(instance.sharedConf, instance.angelModel.graph)
        gjson.writeBytes(jsonStr)
        gjson.flush()
        gjson.close()
      } catch {
        case e: Exception => logWarning(e.getMessage)
      }

      // for park2.1
      // cancel saving other information for the moment due to the error below:
      // ERROR ApplicationMaster: User class threw exception: java.util.ServiceConfigurationError:
      // org.apache.spark.sql.sources.DataSourceRegister: Provider org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
      // could not be instantiated
      //java.util.ServiceConfigurationError: org.apache.spark.sql.sources.DataSourceRegister:
      // Provider org.apache.spark.sql.execution.datasources.csv.CSVFileFormat could not be instantiated

      // 3. prepare other information to save
      val modelData = ModelData(instance.sharedConf.toString(), instance.angelModelName)

      // 4. save other info to parquet file
      sparkSession.createDataFrame(Seq(modelData)).repartition(1)
        .write.parquet(new Path(path, "data").toString)
    }
  }

  private[angel] class AngelModelReader[Model <: AngelSparkModel : ClassTag](implicit psClient: AngelPSClient)
    extends MLReader[Model] with Serializable with Logging {
    private val clz = implicitly[ClassTag[Model]].runtimeClass

    override def load(path: String): Model = {
      // 1. load metadata
      val metadata = DefaultParamsReader.loadMetadata(path, sc, clz.getName)

      // 2. read data from parquet in "path/data"
      val dataPath = new Path(path, "data").toString
      val data = sparkSession.read.format("parquet").load(dataPath)
      val Row(confString: String, modelName: String) = data.select("sharedConfStr", "angelModelName").head()

      val cstr = clz.getConstructor(classOf[String], classOf[String])
      val model = cstr.newInstance(metadata.uid, modelName).asInstanceOf[Model]

      val sharedConf = SharedConf.fromString(confString)
      val angelModelPath = new Path(path, "angel").toString
      val sparkEnvContext = model.sparkEnvContext
      model.angelModel.updateConf(sharedConf)
      model.angelModel
        .buildNetwork()
        .createMatrices(sparkEnvContext)
        .loadModel(sparkEnvContext, MLUtils.getHDFSPath(angelModelPath), null)

      // 3. set Params from metadata
      metadata.getAndSetParams(model)
      model
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy