All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.metric.AUC.scala Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */


package org.apache.spark.angel.ml.metric

import com.tencent.angel.ml.math2.matrix.Matrix
import com.tencent.angel.ml.math2.vector.{IntDoubleVector, IntFloatVector}
import org.apache.spark.rdd.RDD

import scala.collection.mutable.ArrayBuffer

class AUC extends Metric {

  override def cal(scores: RDD[(Matrix, Matrix)]): Double = {
    // format to (Double, Double) RDD
    val pairs = scores.mapPartitions {
      case iter =>
        val pairs = new ArrayBuffer[(Double, Double)]
        iter.foreach { case (results, labels) =>
          val predict = results.getCol(1)
          val label = labels.getCol(0).asInstanceOf[IntFloatVector]
          val size = label.size()
          predict match {
            case double: IntDoubleVector =>
              for (i <- 0 until size)
                pairs.append((label.get(i), double.get(i)))
            case float: IntFloatVector =>
              for (i <- 0 until size)
                pairs.append((label.get(i), float.get(i)))
          }
        }
        pairs.iterator
    }

    calculate(pairs)
  }

  def calculate(pairs: RDD[(Double, Double)]): Double = {
    // sort by predict
    val sorted = pairs.sortBy(f => f._2)
    sorted.cache()

    val numTotal = sorted.count()
    val numPositive = sorted.filter(f => f._1 > 0).count()
    val numNegetive = numTotal - numPositive

    // calculate the summation of ranks for positive samples
    val sumRanks = sorted.zipWithIndex().filter(f => f._1._1.toInt == 1).map(f => f._2 + 1).reduce(_ + _)
    val auc = sumRanks * 1.0 / numPositive / numNegetive - (numPositive + 1.0) / 2.0 / numNegetive

    sorted.unpersist()
    auc
  }

  override def toString: String = "AUC"
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy