All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.ppml.fl.python.PythonPPML.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.ppml.fl.python


import com.intel.analytics.bigdl.dllib.optim.ValidationResult
import com.intel.analytics.bigdl.dllib.tensor.Tensor
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.dllib.utils.python.api.{JTensor, PythonBigDL}
import com.intel.analytics.bigdl.ppml.fl.algorithms.{FGBoostRegression, PSI}
import com.intel.analytics.bigdl.ppml.fl.fgboost.FGBoostModel
import com.intel.analytics.bigdl.ppml.fl.utils.{FLClientClosable, TimingSupportive}

import java.util.{List => JList}
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import com.intel.analytics.bigdl.dllib.utils.Log4Error
import com.intel.analytics.bigdl.ppml.fl.{FLClient, FLContext, FLServer, NNModel}

object PythonPPML {

  def ofFloat(): PythonPPML[Float] = new PythonPPML[Float]()

  def ofDouble(): PythonPPML[Double] = new PythonPPML[Double]()
}
class PythonPPML[T: ClassTag](implicit ev: TensorNumeric[T])
  extends PythonBigDL with TimingSupportive {
  def initFLContext(id: Int, target: String): Unit = {
    FLContext.initFLContext(id, target)
  }
  def setPsiSalt(psiSalt: String): Unit = {
    FLContext.setPsiSalt(psiSalt)
  }
  def createFLServer(): FLServer = {
    new FLServer()
  }
  def createHflNN(): Unit = {

  }
  def createHflLogisticRegression(): Unit = {

  }
  def createHflLinearRegression(): Unit = {

  }
  def createVflLogisticRegression(): Unit = {

  }
  def createVflLinearRegression(): Unit = {

  }

  def createFGBoostRegression(learningRate: Double,
                              maxDepth: Int,
                              minChildSize: Int,
                              serverModelPath: String): FGBoostRegression = {
    new FGBoostRegression(learningRate.toFloat, maxDepth, minChildSize, serverModelPath)
  }
  def fgBoostLoadServerModel(fgBoost: FGBoostModel, modelPath: String): Unit = {
    fgBoost.loadServerModel(modelPath)
  }
  def createFGBoostClassification(): Unit = {

  }
  def flServerBuild(flServer: FLServer): Unit = {
    flServer.build()
  }
  def flServerStart(flServer: FLServer): Unit = {
    flServer.start()
  }
  def flServerStop(flServer: FLServer): Unit = {
    flServer.stop()
  }

  def flServerSetClientNum(flServer: FLServer, clientNum: Int): Unit = {
    flServer.setClientNum(clientNum)
  }
  def flServerSetPort(flServer: FLServer, port: Int): Unit = {
    flServer.setPort(port)
  }
  def flServerBlockUntilShutdown(flServer: FLServer): Unit = {
    flServer.blockUntilShutdown()
  }

  /**
   * FlClient is not exposed to users API, the Python API for this only locates in test
   * @param target the FlClient target Url
   * @return
   */
  def createFLClient(target: String): FLClient = {
    val flClient = new FLClient()
    if (target != null) flClient.setTarget(target)
    flClient
  }
  def flClientClosableSetFLClient(flClientClosable: FLClientClosable,
                                  flClient: FLClient): FLClientClosable = {
    flClientClosable.setFlClient(flClient)
  }
  def createPSI(): PSI = {
    new PSI()
  }
  def psiGetSalt(psi: PSI, secureCode: String = ""): String = {
    psi.getSalt(secureCode)
  }
  def psiUploadSet(psi: PSI, ids: JList[String], salt: String): Unit = {
    psi.uploadSet(ids, salt)
  }
  def psiDownloadIntersection(psi: PSI,
                              maxtry: Int = 100, retry: Int = 3000): java.util.List[String] = {
    psi.downloadIntersection(maxtry, retry)
  }
  def psiGetIntersection(psi: PSI, ids: JList[String],
                         maxtry: Int = 100, retry: Int = 3000): java.util.List[String] = {
    psi.getIntersection(ids, maxtry, retry)
  }
  def jTensorToTensorArray(jTensor: JTensor): Array[Tensor[Float]] = {
    Log4Error.invalidOperationError(jTensor.shape.length == 2,
      s"FGBoost only support 2D input, get dimension: ${jTensor.shape.length}")
    val featureNum = jTensor.shape(1)
    jTensor.storage.grouped(featureNum).map(array => {
      Tensor[Float](array, Array(array.length))
    }).toArray
  }
  def fgBoostFitAdd(model: FGBoostModel, xTrain: JTensor): ArrayBuffer[Tensor[Float]] = {
    val tensorArray = timing("JVM JTensor to Array Tensor") {
      jTensorToTensorArray(xTrain)
    }
    timing("Add training batch to model") {
      model.fitAdd(tensorArray)
    }


  }
  def fgBoostFitCall(model: FGBoostModel, yTrain: JTensor, boostRound: Int): Unit = {
    logger.info(s"start call fit")
    val labelArray = if (yTrain != null) yTrain.storage else null
    timing("Call fit method") {
      model.fitCall(labelArray, boostRound)
    }

  }

  def fgBoostFit(model: FGBoostModel, feature: JTensor, label: JTensor, boostRound: Int): Unit = {
    val tensorArray = jTensorToTensorArray(feature)
    val labelArray = if (label != null) label.storage else null
    model.fit(tensorArray, labelArray, boostRound)
  }
  def fgBoostEvaluate(model: FGBoostModel,
                      feature: JTensor, label: JTensor): Array[ValidationResult] = {
    val tensorArray = jTensorToTensorArray(feature)
    val labelArray = if (label != null) label.storage else null
    model.evaluate(tensorArray, labelArray)
  }
  def fgBoostPredict(model: FGBoostModel, feature: JTensor): JTensor = {
    val tensorArray = jTensorToTensorArray(feature)
    val result = model.predict(tensorArray).map(_.storage().array())
    JTensor(result.flatten, Array(result.length, result(0).length), bigdlType = "float")
  }
  def fgBoostRegressionSave(model: FGBoostRegression, dest: String): Unit = {
    model.saveModel(dest)
  }
  def fgBoostRegressionLoad(src: String): FGBoostRegression = {
    FGBoostRegression.loadModel(src)
  }
  def nnFit(model: NNModel): Unit = {

  }
  def nnEvaluate(model: NNModel): Unit = {

  }
  def nnPredict(model: NNModel): Unit = {

  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy