All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.validation.metric.Recall Maven / Gradle / Ivy

/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.validation.metric;

import java.io.Serial;
import smile.math.MathEx;

/**
 * In information retrieval area, sensitivity is called recall.
 *
 * @see Sensitivity
 *
 * @author Haifeng Li
 */
public class Recall implements ClassificationMetric {
    @Serial
    private static final long serialVersionUID = 2L;
    /** Default instance. */
    public static final Recall instance = new Recall();
    /** The aggregating strategy for multi-classes. */
    private final Averaging strategy;

    /**
     * Constructor.
     */
    public Recall() {
        this(null);
    }

    /**
     * Constructor.
     * @param strategy The aggregating strategy for multi-classes.
     */
    public Recall(Averaging strategy) {
        this.strategy = strategy;
    }

    @Override
    public double score(int[] truth, int[] prediction) {
        return of(truth, prediction, strategy);
    }

    @Override
    public String toString() {
        return strategy == null ? "Recall" : strategy + "-Recall";
    }

    /**
     * Calculates the recall/sensitivity of binary classification.
     * @param truth the ground truth.
     * @param prediction the prediction.
     * @return the metric.
     */
    public static double of(int[] truth, int[] prediction) {
        for (int t : truth) {
            if (t != 0 && t != 1) {
                throw new IllegalArgumentException("Recall can only be applied to binary classification: " + t);
            }
        }

        for (int p : prediction) {
            if (p != 0 && p != 1) {
                throw new IllegalArgumentException("Recall can only be applied to binary classification: " + p);
            }
        }

        return of(truth, prediction, null);
    }

    /**
     * Calculates the recall/sensitivity.
     * @param truth the ground truth.
     * @param prediction the prediction.
     * @param strategy The aggregating strategy for multi-classes.
     * @return the metric.
     */
    public static double of(int[] truth, int[] prediction, Averaging strategy) {
        if (truth.length != prediction.length) {
            throw new IllegalArgumentException(String.format("The vector sizes don't match: %d != %d.", truth.length, prediction.length));
        }

        int numClasses = Math.max(MathEx.max(truth), MathEx.max(prediction)) + 1;
        if (numClasses > 2 && strategy == null) {
            throw new IllegalArgumentException("Averaging strategy is null for multi-class");
        }

        int length = strategy == Averaging.Macro || strategy == Averaging.Weighted ? numClasses : 1;
        int[] tp = new int[length];
        int[] size = new int[numClasses];

        int n = truth.length;
        for (var target : truth) {
            ++size[target];
        }

        if (strategy == null) {
            for (int i = 0; i < n; i++) {
                if (prediction[i] == 1 && truth[i] == 1) {
                    tp[0]++;
                }
            }
        } else if (strategy == Averaging.Micro) {
            for (int i = 0; i < n; i++) {
                tp[0] += truth[i] == prediction[i] ?  1 : 0;
            }
        } else {
            for (int i = 0; i < n; i++) {
                tp[truth[i]] += truth[i] == prediction[i] ?  1 : 0;
            }
        }

        double[] recall = new double[tp.length];
        if (tp.length == 1) {
            recall[0] = (double) tp[0] / (strategy == null ? size[1] : n);
        } else {
            for (int i = 0; i < tp.length; i++) {
                recall[i] = (double) tp[i] / size[i];
            }
        }

        if (strategy == Averaging.Macro) {
            return MathEx.mean(recall);
        } else if (strategy == Averaging.Weighted) {
            double weighted = 0.0;
            for (int i = 0; i < numClasses; i++) {
                weighted += recall[i] * size[i];
            }
            return weighted / n;
        }
        return recall[0];
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy