All Downloads are FREE. Search and download functionalities are using the official Maven repository.

eu.fbk.utils.eval.PrecisionRecall Maven / Gradle / Ivy

The newest version!
package eu.fbk.utils.eval;

import com.google.common.base.Preconditions;
import com.google.common.collect.Ordering;

import java.io.Serializable;
import java.util.Objects;

public final class PrecisionRecall implements Serializable {

    private static final long serialVersionUID = 1L;

    public static final PrecisionRecall EMPTY = new PrecisionRecall(0.0, 0.0, 0.0, 0.0,
            Double.NaN, Double.NaN, Double.NaN);

    private final double tp;

    private final double fp;

    private final double fn;

    private final double tn;

    private final double precision;

    private final double recall;

    private final double accuracy;

    private PrecisionRecall(final double tp, final double fp, final double fn, final double tn,
            final double precision, final double recall, final double accuracy) {
        this.tp = tp;
        this.fp = fp;
        this.fn = fn;
        this.tn = tn;
        this.precision = precision;
        this.recall = recall;
        this.accuracy = accuracy;
    }

    public static PrecisionRecall forCounts(final double tp, final double fp, final double fn) {
        return forCounts(tp, fp, fn, 0.0);
    }

    public static PrecisionRecall forCounts(final double tp, final double fp, final double fn,
            final double tn) {
        if (tp < 0.0 || fp < 0.0 || fn < 0.0 || tn < 0.0) {
            throw new IllegalArgumentException("Invalid contingency table values: tp=" + tp
                    + ", fp=" + fp + ", fn=" + fn + ", tn=" + tn);
        } else if (tp == 0.0 && fp == 0.0 && fn == 0.0 && tn == 0.0) {
            return EMPTY;
        } else {
            final double c = tp + fp + fn + tn;
            final double a = c == 0.0 ? Double.NaN : (tp + tn) / c;
            final double p = tp + fp == 0.0 ? Double.NaN : tp / (tp + fp);
            final double r = tp + fn == 0.0 ? Double.NaN : tp / (tp + fn);
            return new PrecisionRecall(tp, fp, fn, tn, p, r, a);
        }
    }

    public static PrecisionRecall forMeasures(final double precision, final double recall,
            final double accuracy) {
        // Do not estimate TP, FP, FN, TN
        return forMeasures(precision, recall, accuracy, Double.NaN);
    }

    /**
     * Builds a {@code PrecisionRecall} object starting from the supplied precision p, recall r,
     * accuracy a, and count value. The method tries to recover the missing TP, FP, FN and TN
     * values, assigning them the value NaN where recovery is not possible.
     *
     * @param precision the precision value
     * @param recall    the recall value
     * @param accuracy  the accuracy value
     * @param count     the count value
     * @return the corresponding {@code PrecisionRecall} object
     */
    public static PrecisionRecall forMeasures(final double precision, final double recall,
            final double accuracy, final double count) {
        Preconditions.checkArgument(Double.isNaN(precision) || precision >= 0.0
                && precision <= 1.0);
        Preconditions.checkArgument(Double.isNaN(recall) || recall >= 0.0 && recall <= 1.0);
        Preconditions.checkArgument(Double.isNaN(accuracy) || accuracy >= 0.0 && accuracy <= 1.0);
        Preconditions.checkArgument(Double.isNaN(count) || count >= 0.0);
        if (count == Double.NaN || precision == 0.0 || recall == 0.0) {
            return new PrecisionRecall(Double.NaN, Double.NaN, Double.NaN, Double.NaN, precision,
                    recall, accuracy);
        } else if (precision == 1.0 && recall == 1.0 || accuracy == 1.0) {
            return new PrecisionRecall(Double.NaN, 0.0, 0.0, Double.NaN, precision, recall,
                    accuracy);
        } else {
            final double tp = count * (1 - accuracy) / (1 / precision + 1 / recall - 2);
            final double fp = tp * (1 / precision - 1);
            final double fn = tp * (1 / recall - 1);
            final double tn = count - tp - fp - fn;
            if (tn > 0) {
                return new PrecisionRecall(tp, fp, fn, tn, precision, recall, accuracy);
            }
        }
        return new PrecisionRecall(Double.NaN, Double.NaN, Double.NaN, Double.NaN, precision,
                recall, accuracy);
    }

    public static PrecisionRecall microAverage(final Iterable prs) {
        double tp = 0.0;
        double fp = 0.0;
        double fn = 0.0;
        double tn = 0.0;
        int num = 0;
        for (final PrecisionRecall pr : prs) {
            tp += pr.tp;
            fp += pr.fp;
            fn += pr.fn;
            tn += pr.tn;
            ++num;
        }
        return num == 0 ? EMPTY : forCounts(tp / num, fp / num, fn / num, tn / num);
    }

    public static PrecisionRecall macroAverage(final Iterable prs) {
        double p = 0.0;
        double r = 0.0;
        double a = 0.0;
        double count = 0.0;
        int num = 0;
        for (final PrecisionRecall pr : prs) {
            p += pr.getPrecision();
            r += pr.getRecall();
            a += pr.getAccuracy();
            count += pr.getCount();
            ++num;
        }
        return forMeasures(p, r, a, num == 0 ? 0.0 : count / num);
    }

    public double getTP() {
        return this.tp;
    }

    public double getFP() {
        return this.fp;
    }

    public double getFN() {
        return this.fn;
    }

    public double getTN() {
        return this.tn;
    }

    public double getCount() {
        return this.tp + this.fp + this.tn + this.fn;
    }

    public double getAccuracy() {
        return this.accuracy;
    }

    public double getPrecision() {
        return this.precision;
    }

    public double getRecall() {
        return this.recall;
    }

    public double getF1() {
        return this.precision + this.recall == 0.0 ? Double.NaN : 2 * this.precision * this.recall
                / (this.precision + this.recall);
    }

    public double get(final Measure measure) {
        switch (measure) {
        case ACCURACY:
            return getAccuracy();
        case PRECISION:
            return getPrecision();
        case RECALL:
            return getRecall();
        case F1:
            return getF1();
        default:
            throw new IllegalArgumentException("Invalid measure " + measure);
        }
    }

    @Override
    public boolean equals(final Object object) {
        if (object == this) {
            return true;
        }
        if (!(object instanceof PrecisionRecall)) {
            return false;
        }
        final PrecisionRecall other = (PrecisionRecall) object;
        return this.tp == other.tp && this.fp == other.fp && this.tn == other.tn
                && this.fn == other.fn && this.precision == other.precision
                && this.recall == other.recall && this.accuracy == other.accuracy;
    }

    @Override
    public int hashCode() {
        return Objects.hash(this.tp, this.fp, this.tn, this.fn, this.precision, this.recall,
                this.accuracy);
    }

    @Override
    public String toString() {
        return String.format("p=%.3f r=%.3f f1=%.3f a=%.3f tp=%.2f fp=%.2f fn=%.2f tn=%.2f",
                this.precision, this.recall, getF1(), this.accuracy, this.tp, this.fp, this.fn,
                this.tn);
    }

    public static Ordering comparator(final Measure measure,
            final boolean higherIsBetter) {
        return new Ordering() {

            @Override
            public int compare(final PrecisionRecall left, final PrecisionRecall right) {
                final double leftValue = left.get(measure);
                final double rightValue = right.get(measure);
                final int result = Double.compare(leftValue, rightValue);
                return higherIsBetter ? -result : result;
            }

        };
    }

    public static Evaluator evaluator() {
        return new Evaluator();
    }

    public enum Measure {

        ACCURACY,

        PRECISION,

        RECALL,

        F1

    }

    public static final class Evaluator {

        private double tp;

        private double fp;

        private double fn;

        private double tn;

        private PrecisionRecall pr;

        private Evaluator() {
            this.tp = 0.0;
            this.fp = 0.0;
            this.fn = 0.0;
            this.tn = 0.0;
            this.pr = null;
        }

        private static void checkNotNegative(final double value) {
            if (value < 0.0) {
                throw new IllegalArgumentException("Illegal negative value " + value);
            }
        }

        public Evaluator add(final PrecisionRecall pr) {
            if (pr.getCount() > 0) {
                synchronized (this) {
                    this.pr = null;
                    this.tp += pr.getTP();
                    this.fp += pr.getFP();
                    this.fn += pr.getFN();
                    this.tn += pr.getTN();
                }
            }
            return this;
        }

        public Evaluator add(final PrecisionRecall.Evaluator evaluator) {
            synchronized (evaluator) {
                if (evaluator.tp > 0.0 || evaluator.fp > 0.0 || evaluator.fn > 0.0
                        || evaluator.tn > 0.0) {
                    synchronized (this) {
                        this.pr = null;
                        this.tp += evaluator.tp;
                        this.fp += evaluator.fp;
                        this.fn += evaluator.fn;
                        this.tn += evaluator.tn;
                    }
                }
            }
            return this;
        }

        public Evaluator add(final double deltaTP, final double deltaFP, final double deltaFN) {
            checkNotNegative(deltaTP);
            checkNotNegative(deltaFP);
            checkNotNegative(deltaFN);
            if (deltaTP > 0.0 || deltaFP > 0.0 || deltaFN > 0.0) {
                synchronized (this) {
                    this.pr = null;
                    this.tp += deltaTP;
                    this.fp += deltaFP;
                    this.fn += deltaFN;
                }
            }
            return this;
        }

        public Evaluator add(final double deltaTP, final double deltaFP, final double deltaFN,
                final double deltaTN) {
            checkNotNegative(deltaTP);
            checkNotNegative(deltaFP);
            checkNotNegative(deltaFN);
            checkNotNegative(deltaTN);
            if (deltaTP > 0.0 || deltaFP > 0.0 || deltaFN > 0.0 || deltaTN > 0.0) {
                synchronized (this) {
                    this.pr = null;
                    this.tp += deltaTP;
                    this.fp += deltaFP;
                    this.fn += deltaFN;
                    this.tn += deltaTN;
                }
            }
            return this;
        }

        public Evaluator addTP(final double deltaTP) {
            checkNotNegative(deltaTP);
            if (deltaTP != 0.0) {
                synchronized (this) {
                    this.pr = null;
                    this.tp += deltaTP;
                }
            }
            return this;
        }

        public Evaluator addTP() {
            return addTP(1.0);
        }

        public Evaluator addFP(final double deltaFP) {
            checkNotNegative(deltaFP);
            if (deltaFP != 0.0) {
                synchronized (this) {
                    this.pr = null;
                    this.fp += deltaFP;
                }
            }
            return this;
        }

        public Evaluator addFP() {
            return addFP(1.0);
        }

        public Evaluator addFN(final double deltaFN) {
            checkNotNegative(deltaFN);
            if (deltaFN != 0.0) {
                synchronized (this) {
                    this.pr = null;
                    this.fn += deltaFN;
                }
            }
            return this;
        }

        public Evaluator addFN() {
            return addFN(1.0);
        }

        public Evaluator addTN(final double deltaTN) {
            checkNotNegative(deltaTN);
            if (deltaTN != 0.0) {
                synchronized (this) {
                    this.pr = null;
                    this.tn += deltaTN;
                }
            }
            return this;
        }

        public Evaluator addTN() {
            return addTN(1.0);
        }

        public synchronized PrecisionRecall getResult() {
            if (this.pr == null) {
                this.pr = forCounts(this.tp, this.fp, this.fn, this.tn);
            }
            return this.pr;
        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy