All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.ppml.fl.utils.ProtoUtils.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.ppml.fl.utils

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity
import com.intel.analytics.bigdl.dllib.tensor.Tensor
import com.intel.analytics.bigdl.dllib.keras.models.InternalOptimizerUtil.getParametersFromModel
import com.intel.analytics.bigdl.dllib.utils.T
import com.intel.analytics.bigdl.ppml.fl.common.{FLPhase, Storage}
import com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto._
import com.intel.analytics.bigdl.dllib.utils.{Table => DllibTable}
import com.intel.analytics.bigdl.ppml.fl.fgboost.common.Split
import com.intel.analytics.bigdl.ppml.fl.generated.FGBoostServiceProto.{BoostEval, DataSplit, PredictResponse, TreePredict}
import org.apache.logging.log4j.LogManager

import scala.reflect.ClassTag
import scala.util.Random
import scala.collection.JavaConverters._
import com.intel.analytics.bigdl.dllib.utils.Log4Error
import com.intel.analytics.bigdl.ppml.fl.FLClient
import com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto

object ProtoUtils {
  private val logger = LogManager.getLogger(getClass)
  def outputTargetToTableProto(output: Activity,
                               target: Activity,
                               meta: MetaData = null): TensorMap = {
    val tensorProto = toFloatTensor(output.toTensor[Float])

    val builder = TensorMap.newBuilder
      .putTensorMap("output", tensorProto)
    if (meta != null) {
      builder.setMetaData(meta)
    }

    if (target != null) {
      val targetTensor = toFloatTensor(target.toTensor[Float])
      builder.putTensorMap("target", targetTensor)
    }
    builder.build()
  }
  def tableProtoToOutputTarget(storage: Storage[TensorMap]): (DllibTable, Tensor[Float]) = {
    val aggData = protoTableMapToTensorIterableMap(storage.clientData)
    val target = Tensor[Float]()
    if (aggData.contains("target")) {
      val t = aggData("target").head
      target.resizeAs(t).copy(t)
    }
    // TODO: multiple input
    val outputs = aggData.filter(_._1 != "target")
    Log4Error.unKnowExceptionError(outputs.size == 1,
    s"outputs size should be 1, but got ${outputs.size}")

    (T.seq(outputs.values.head.toSeq), target)
  }
  def protoTableMapToTensorIterableMap(inputMap: java.util.Map[Integer, FlBaseProto.TensorMap]):
    Map[String, Iterable[Tensor[Float]]] = {
    inputMap.asScala.mapValues(_.getTensorMapMap).values
      .flatMap(_.asScala).groupBy(_._1)
      .map{data =>
        (data._1, data._2.map {v =>
          val data = v._2.getTensorList.asScala.toArray.map(_.toFloat)
          val shape = v._2.getShapeList.asScala.toArray.map(_.toInt)
          Tensor[Float](data, shape)
        })
      }
  }

  def toFloatTensor(data: Array[Float], shape: Array[Int]): FloatTensor = {
    FloatTensor
      .newBuilder()
      .addAllTensor(data.map(float2Float).toIterable.asJava)
      .addAllShape(shape
        .map(int2Integer).toIterable.asJava)
      .build()
  }
  def toFloatTensor(t: Tensor[Float]): FloatTensor = {
    FloatTensor
      .newBuilder()
      .addAllTensor(t.toTensor[Float].contiguous().storage()
        .array().slice(t.storageOffset() - 1, t.storageOffset() - 1 + t.nElement())
        .map(float2Float).toIterable.asJava)
      .addAllShape(t.toTensor[Float].size()
        .map(int2Integer).toIterable.asJava)
      .build()
  }
  def toFloatTensor(data: Array[Float]): FloatTensor = {
    toFloatTensor(data, Array(data.length))
  }

  def getModelWeightTable(model: Module[Float], version: Int, name: String = "test"): TensorMap = {
    val weights = getParametersFromModel(model)._1
    val metadata = MetaData.newBuilder
      .setName(name).setVersion(version).build
    val tensor =
      FloatTensor.newBuilder()
        .addAllTensor(weights.storage.toList.map(v => float2Float(v)).asJava)
        .addAllShape(weights.size.toList.map(v => int2Integer(v)).asJava)
        .build()
    val metamodel = TensorMap.newBuilder
      .putTensorMap("weights", tensor)
      .setMetaData(metadata)
      .build
    metamodel
  }


  def updateModel(model: Module[Float],
                  modelData: TensorMap): Unit = {
    val weigthBias = modelData.getTensorMapMap.get("weights")
    val data = weigthBias.getTensorList.asScala.map(v => Float2float(v)).toArray
    val shape = weigthBias.getShapeList.asScala.map(v => Integer2int(v)).toArray
    val tensor = Tensor(data, shape)
    getParametersFromModel(model)._1.copy(tensor)
  }

  def getTensor(name: String, modelData: TensorMap): Tensor[Float] = {
    val dataMap = modelData.getTensorMapMap.get(name)
    val data = dataMap.getTensorList.asScala.map(Float2float).toArray
    val shape = dataMap.getShapeList.asScala.map(Integer2int).toArray
    Tensor[Float](data, shape)
  }



  def randomSplit[T: ClassTag](weight: Array[Float],
                               data: Array[T],
                               seed: Int = 1): Array[Array[T]] = {
    val random = new Random(seed = seed)
    val lens = weight.map(v => (v * data.length).toInt)
    lens(lens.length - 1) = data.length - lens.slice(0, lens.length - 1).sum
    val splits = lens.map(len => new Array[T](len))
    val counts = lens.map(_ => 0)
    data.foreach{d =>
      var indx = random.nextInt(weight.length)
      while(counts(indx) == lens(indx)) {
        indx = (indx + 1) % weight.length
      }
      splits(indx)(counts(indx)) = d
      counts(indx) += 1
    }
    splits
  }

  def toBoostEvals(localPredicts: Array[Map[String, Array[Boolean]]]): List[BoostEval] = {
    // Sorted by treeID
    localPredicts.map{predict =>
      BoostEval.newBuilder()
        .addAllEvaluates(predict.toSeq.sortBy(_._1).map(p => {
          TreePredict.newBuilder().setTreeID(p._1)
            .addAllPredicts(p._2.map(boolean2Boolean).toList.asJava)
            .build()
        }).toList.asJava)
        .build()
    }.toList
  }
  def toArrayFloat(response: PredictResponse): Array[Float] = {
    response.getData.getTensorMapMap.get("predictResult")
      .getTensorList.asScala.toArray.map(_.toFloat)
  }

  def almostEqual(v1: Float, v2: Float): Boolean = {
    if (math.abs(v1 - v2) <= 1e-1f) {
      true
    } else {
      false
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy