All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.loss.AUCLogit Maven / Gradle / Ivy

package com.expleague.ml.loss;

import com.expleague.commons.math.Func;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.util.ArrayTools;
import com.expleague.ml.TargetFunc;
import com.expleague.ml.data.set.DataSet;

import java.io.PrintStream;

/**
 * igorkuralenok on 23.05.17.
 */
@SuppressWarnings("unused")
public class AUCLogit extends Func.Stub implements TargetFunc {
  protected final Vec target;
  private final DataSet owner;
  private final int allPositive;

  public AUCLogit(final Vec target, final DataSet owner) {
    this.target = target;
    this.owner = owner;
    int positive = 0;
    for (int i = 0; i < target.dim(); i++) {
      if (target.get(i) > 0)
        positive++;
    }
    allPositive = positive;
  }

  private int[] getOrdered(Vec array) {
    int[] order = ArrayTools.sequence(0, array.dim());
    ArrayTools.parallelSort(array.toArray().clone(), order);
    return order;
  }

  @Override
  public double value(Vec x) {
    final double[] weights = new double[x.dim()];
    x.toArray(weights, 0);
    final int[] order = ArrayTools.sequence(0, x.dim());
    ArrayTools.parallelSort(weights, order);
    int trueNegative = 0;
    int falseNegative = 0;

    double sum = 0;
    int curPos = 0;

    double prevFPR = 1;
    double prevTPR = 0;
    double max_accuracy = 0;

    while (curPos < order.length) {
      if (target.get(order[curPos++]) > 0) {
        falseNegative += 1;
        continue;
      }
      else {
        trueNegative += 1;
      }
      final int allNegative = x.dim() - allPositive;
      double falsePositive = allNegative - trueNegative;
      double truePositive = allPositive - falseNegative;
      double TPR = 1.0 * truePositive / allPositive;
      double FPR = 1.0 * falsePositive / allNegative;

//      sum += (TPR + prevTPR)/2 * (prevFPR - FPR);
      sum += TPR * (prevFPR - FPR);
      prevFPR = FPR;
      prevTPR = TPR;

      double cur_accuracy = 1.0 * (trueNegative + truePositive) / (allPositive + allNegative);
      max_accuracy = Math.max(max_accuracy, cur_accuracy);
    }
    return sum;
  }

  public void printResult(Vec x, PrintStream out) {
    final double[] weights = new double[x.dim()];
    x.toArray(weights, 0);
    final int[] order = ArrayTools.sequence(0, x.dim());
    ArrayTools.parallelSort(weights, order);
    int trueNegative = 0;
    int falseNegative = 0;

    double sum = 0;
    double sumT = 0;
    int curPos = 0;

    double prevFPR = 1;
    double prevTPR = 0;
    double max_accuracy = 0;

    while (curPos < order.length) {
      if (target.get(order[curPos++]) > 0) {
        falseNegative += 1;
        continue;
      }
      else {
        trueNegative += 1;
      }
      final int allNegative = x.dim() - allPositive;
      double falsePositive = allNegative - trueNegative;
      double truePositive = allPositive - falseNegative;
      double TPR = 1.0 * truePositive / allPositive;
      double FPR = 1.0 * falsePositive / allNegative;
      out.append(String.valueOf(FPR)).append("\t")
          .append(String.valueOf(TPR)).append("\t")
          .append(String.valueOf(x.get(order[curPos - 1]))).append("\n");

      sumT += (TPR + prevTPR)/2 * (prevFPR - FPR);
      sum += TPR * (prevFPR - FPR);
      prevFPR = FPR;
      prevTPR = TPR;

      double cur_accuracy = 1.0 * (trueNegative + truePositive) / (allPositive + allNegative);
      max_accuracy = Math.max(max_accuracy, cur_accuracy);
    }
    out.append("AUC Bars: ").append(String.valueOf(sum))
        .append(" AUC Trapezium: ").append(String.valueOf(sumT))
        .append("\n");
  }

  @Override
  public int dim() {
    return target.dim();
  }

  @Override
  public DataSet owner() {
    return owner;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy