com.intel.analytics.bigdl.ppml.fl.nn.NNServiceImpl.scala Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.bigdl.ppml.fl.nn
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity
import java.util
import java.util.Map
import com.intel.analytics.bigdl.dllib.nn.{BCECriterion, MSECriterion, Sigmoid, View}
import com.intel.analytics.bigdl.dllib.optim.Top1Accuracy
import com.intel.analytics.bigdl.ppml.fl.base.DataHolder
import com.intel.analytics.bigdl.ppml.fl.common.FLPhase._
import com.intel.analytics.bigdl.ppml.fl.generated.NNServiceGrpc
import com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto._
import com.intel.analytics.bigdl.ppml.fl.generated.NNServiceProto._
import io.grpc.stub.StreamObserver
import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.logging.log4j.LogManager
import java.util.concurrent.ConcurrentHashMap
import collection.JavaConverters._
class NNServiceImpl(clientNum: Int) extends NNServiceGrpc.NNServiceImplBase {
private val logger = LogManager.getLogger(getClass)
private var aggregatorMap: Map[String, NNAggregator] = null
initAggregatorMap()
private def initAggregatorMap(): Unit = {
aggregatorMap = new util.HashMap[String, NNAggregator]
aggregatorMap.put("vfl_logistic_regression", VFLNNAggregator(1, Sigmoid[Float](),
null, BCECriterion[Float](), Array(new Top1Accuracy())))
aggregatorMap.put("vfl_linear_regression", VFLNNAggregator(1, View[Float](),
null, MSECriterion[Float](), Array(new Top1Accuracy())))
aggregatorMap.put("hfl_logistic_regression", new HFLNNAggregator())
aggregatorMap.asScala.foreach(entry => {
entry._2.setClientNum(clientNum)
})
}
override def train(request: TrainRequest,
responseObserver: StreamObserver[TrainResponse]): Unit = {
val clientUUID = request.getClientuuid
logger.debug("Server get train request from client: " + clientUUID)
val data = request.getData
val version = data.getMetaData.getVersion
val aggregator = aggregatorMap.get(request.getAlgorithm)
try {
aggregator.putClientData(TRAIN, clientUUID, version, new DataHolder(data))
logger.debug(s"$clientUUID getting server new data to update local")
val responseData = aggregator.getStorage(TRAIN).serverData
if (responseData == null) {
val response = "Data requested doesn't exist"
responseObserver.onNext(TrainResponse.newBuilder.setResponse(response).setCode(0).build)
}
else {
val response = "Download data successfully"
responseObserver.onNext(TrainResponse.newBuilder.setResponse(response)
.setData(responseData).setCode(1).build)
}
responseObserver.onCompleted()
} catch {
case e: Exception =>
val errorMsg = ExceptionUtils.getStackTrace(e)
val response = TrainResponse.newBuilder.setResponse(errorMsg).setCode(1).build
responseObserver.onNext(response)
responseObserver.onCompleted()
} finally {
}
}
override def evaluate(request: EvaluateRequest,
responseObserver: StreamObserver[EvaluateResponse]): Unit = {
val clientUUID = request.getClientuuid
val data = request.getData
val version = data.getMetaData.getVersion
val hasReturn = request.getReturn
val aggregator = aggregatorMap.get(request.getAlgorithm)
try {
aggregator.setShouldReturn(hasReturn)
aggregator.putClientData(EVAL, clientUUID, version, new DataHolder(data))
val responseData = aggregator.getStorage(EVAL).serverData
if (responseData == null) {
val response = "Data requested doesn't exist"
responseObserver.onNext(EvaluateResponse.newBuilder.setResponse(response).setCode(0).build)
}
else if (hasReturn) {
val response = "Evaluate finishes"
// TODO: return type need to be refactor
val resultString = aggregator.getReturnMessage
responseObserver.onNext(EvaluateResponse.newBuilder
.setResponse(response)
.setMessage(resultString)
.setData(responseData).setCode(1).build)
}
else {
val response = "Evaluate batch uploaded successfully, continue to next batch"
responseObserver.onNext(EvaluateResponse.newBuilder
.setResponse(response).setData(responseData).setCode(1).build)
}
responseObserver.onCompleted()
} catch {
case e: Exception =>
val errorMsg = ExceptionUtils.getStackTrace(e)
val response = EvaluateResponse.newBuilder.setResponse(errorMsg).setCode(1).build
responseObserver.onNext(response)
responseObserver.onCompleted()
} finally {
}
}
override def predict(request: PredictRequest,
responseObserver: StreamObserver[PredictResponse]): Unit = {
val clientUUID = request.getClientuuid
val data = request.getData
val version = data.getMetaData.getVersion
val aggregator = aggregatorMap.get(request.getAlgorithm)
try {
aggregator.putClientData(PREDICT, clientUUID, version, new DataHolder(data))
val responseData = aggregator.getStorage(PREDICT).serverData
if (responseData == null) {
val response = "Data requested doesn't exist"
responseObserver.onNext(PredictResponse.newBuilder.setResponse(response).setCode(0).build)
}
else {
val response = "Download data successfully"
responseObserver.onNext(PredictResponse.newBuilder.setResponse(response)
.setData(responseData).setCode(1).build)
}
responseObserver.onCompleted()
} catch {
case e: Exception =>
val errorMsg = ExceptionUtils.getStackTrace(e)
val response = PredictResponse.newBuilder.setResponse(errorMsg).setCode(1).build
responseObserver.onNext(response)
responseObserver.onCompleted()
} finally {
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy