All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.util.FeatureStats.scala Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */

package com.tencent.angel.sona.ml.util
import java.util

import com.tencent.angel.client.AngelPSClient
import com.tencent.angel.ml.matrix.MatrixContext
import com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult
import com.tencent.angel.mlcore.utils.MLException
import com.tencent.angel.ml.math2.storage.{IntDoubleSparseVectorStorage, LongDoubleSparseVectorStorage}
import com.tencent.angel.ml.math2.utils.RowType
import com.tencent.angel.ml.math2.vector.{IntDoubleVector, LongDoubleVector}
import com.tencent.angel.psagent.PSAgent
import com.tencent.angel.sona.core.ExecutorContext
import com.tencent.angel.sona.psf.StatsFunc
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap
import org.apache.spark.TaskContext
import org.apache.spark.linalg.{IntSparseVector, LongSparseVector}
import org.apache.spark.util.Example
import org.apache.spark.broadcast.Broadcast


class FeatureStats(val uid: String, val modelTypeString: String, val exeCtx: Broadcast[ExecutorContext]) extends Serializable {
  private val statsMatName = s"${uid}_stats"
  @transient private var numValidateFeatures: Long = -1

  private lazy val modelType: RowType = {
    modelTypeString match {
      case s if s.equalsIgnoreCase(RowType.T_DOUBLE_DENSE.toString) => RowType.T_DOUBLE_DENSE
      case s if s.equalsIgnoreCase(RowType.T_DOUBLE_SPARSE.toString) => RowType.T_DOUBLE_SPARSE
      case s if s.equalsIgnoreCase(RowType.T_DOUBLE_SPARSE_LONGKEY.toString) => RowType.T_DOUBLE_SPARSE_LONGKEY
      case _ => throw MLException("ModelType is not supported!")
    }
  }

  def createPSMat(psClient: AngelPSClient, numFeature: Long): Unit = {
    // name: String, rowNum: Int, colNum: Long, validIndexNum: Long, maxRowNumInBlock: Int, maxColNumInBlock: Long, rowType: RowType
    val matCtx: MatrixContext = new MatrixContext(statsMatName, 1, numFeature, -1,
      -1, -1, modelType)
    val list = new util.ArrayList[MatrixContext]()
    list.add(matCtx)
    psClient.createMatrices(list)
  }

  def getNumValidateFeatures(psAgent: PSAgent): Long = {
    if (numValidateFeatures == -1) {
      val statsMat = psAgent.getMatrixClient(statsMatName)
      val statsFunc = new StatsFunc(statsMat.getMatrixId)
      val getResult = statsMat.get(statsFunc)
      numValidateFeatures = getResult.asInstanceOf[ScalarAggrResult].getResult.toLong
    }

    numValidateFeatures
  }

  def partitionStatsWithPS(iter: Iterator[Example]): Iterator[util.HashMap[Int, Long]] = {
    val partitionId = TaskContext.getPartitionId()
    val partitionStat = new util.HashMap[Int, Long]()

    var count: Long = 0L
    val statsBatchSize = 10000
    var hashMap: AnyRef = null

    val headSample = if (iter.hasNext) {
      val Example(_, _, features) = iter.next()
      count += 1

      features match {
        case feat: IntSparseVector =>
          val openHashMap = new Int2DoubleOpenHashMap(statsBatchSize)
          feat.foreachActive { case (idx: Long, value: Double) =>
            openHashMap.put(idx.toInt, value)
          }
          hashMap = openHashMap
        case feat: LongSparseVector =>
          val openHashMap = new Long2DoubleOpenHashMap(statsBatchSize)
          feat.foreachActive { case (idx: Long, value: Double) =>
            openHashMap.put(idx, value)
          }
          hashMap = openHashMap
        case _ => throw MLException("Dense Vector is not supported!")
      }

      features
    } else {
      partitionStat.put(partitionId, 0L)
      return Iterator.single[util.HashMap[Int, Long]](partitionStat)
    }

    val psAgent = ExecutorContext.getPSAgent(exeCtx.value)
    val statsMat = psAgent.getMatrixClient(statsMatName)

    while (iter.hasNext) {
      val Example(_, _, features) = iter.next()
      features match {
        case feat: IntSparseVector =>
          val openHashMap = hashMap.asInstanceOf[Int2DoubleOpenHashMap]
          feat.foreachActive { case (idx: Long, value: Double) =>
            openHashMap.put(idx.toInt, value)
          }
        case feat: LongSparseVector =>
          val openHashMap = hashMap.asInstanceOf[Long2DoubleOpenHashMap]
          feat.foreachActive { case (idx: Long, value: Double) =>
            openHashMap.put(idx, value)
          }
        case _ => throw MLException("Dense Vector is not supported!")
      }
      count += 1

      if (count % statsBatchSize == 0 && count != 0) {
        features match {
          case feat: IntSparseVector =>
            val openHashMap = hashMap.asInstanceOf[Int2DoubleOpenHashMap]
            val storage = new IntDoubleSparseVectorStorage(feat.size.toInt, openHashMap)
            val vector = new IntDoubleVector(0, 0, 0, feat.size.toInt, storage)
            statsMat.increment(0, vector, true)
            openHashMap.clear()
          case feat: LongSparseVector =>
            val openHashMap = hashMap.asInstanceOf[Long2DoubleOpenHashMap]
            val storage = new LongDoubleSparseVectorStorage(feat.size, openHashMap)
            val vector = new LongDoubleVector(0, 0, 0, feat.size, storage)
            statsMat.increment(0, vector, true)
            openHashMap.clear()
          case _ => throw MLException("Dense Vector is not supported!")
        }
      }
    }

    headSample match {
      case feat: IntSparseVector =>
        val openHashMap = hashMap.asInstanceOf[Int2DoubleOpenHashMap]
        val storage = new IntDoubleSparseVectorStorage(feat.size.toInt, openHashMap)
        val vector = new IntDoubleVector(0, 0, 0, feat.size.toInt, storage)
        statsMat.increment(0, vector, true)
        openHashMap.clear()
      case feat: LongSparseVector =>
        val openHashMap = hashMap.asInstanceOf[Long2DoubleOpenHashMap]
        val storage = new LongDoubleSparseVectorStorage(feat.size, openHashMap)
        val vector = new LongDoubleVector(0, 0, 0, feat.size, storage)
        statsMat.increment(0, vector, true)
        openHashMap.clear()
      case _ => throw MLException("Dense Vector is not supported!")
    }

    partitionStat.put(partitionId, count)
    Iterator.single[util.HashMap[Int, Long]](partitionStat)
  }

  def partitionStats(iter: Iterator[Example]): Iterator[util.HashMap[Int, Long]] = {
    val partitionId = TaskContext.getPartitionId()
    val partitionStat = new util.HashMap[Int, Long]()

    partitionStat.put(partitionId, iter.length.toLong)
    Iterator.single[util.HashMap[Int, Long]](partitionStat)
  }

  def mergeMap(first: util.HashMap[Int, Long], second: util.HashMap[Int, Long]): util.HashMap[Int, Long] = {
    first.putAll(second)

    first
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy