All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.metarank.ltrlib.metric.NDCG.scala Maven / Gradle / Ivy

There is a newer version: 0.2.6
Show newest version
package io.github.metarank.ltrlib.metric

import io.github.metarank.cfor._

case class NDCG(cutoff: Int) extends Metric {
  override def eval(y: Array[Array[Double]], yhat: Array[Array[Double]]): Double = {
    var ndcg = 0.0
    cfor(y.indices) { group =>
      {
        var dcg = 0.0
        if (y(group).exists(_ != 0.0)) {
          val zipped             = y(group).zip(yhat(group))
          val sortedByPrediction = zipped.sortBy(-_._2).map(_._1)
          cfor(0 until math.min(y(group).length, cutoff)) { doc =>
            dcg += (math.pow(2.0, sortedByPrediction(doc)) - 1.0) / (math.log10(doc + 2) / math.log10(2))
          }
          var idcg         = 0.0
          val sortedByReal = y(group).sorted.reverse
          cfor(0 until math.min(y(group).length, cutoff)) { doc =>
            idcg += (math.pow(2.0, sortedByReal(doc)) - 1.0) / (math.log10(doc + 2) / math.log10(2))
          }
          ndcg += dcg / idcg
        }
      }
    }
    ndcg / y.length
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy