
org.apache.lucene.benchmark.quality.QualityStats Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.benchmark.quality;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Locale;
/** Results of quality benchmark run for a single query or for a set of queries. */
public class QualityStats {
/** Number of points for which precision is computed. */
public static final int MAX_POINTS = 20;
private double maxGoodPoints;
private double recall;
private double[] pAt;
private double pReleventSum = 0;
private double numPoints = 0;
private double numGoodPoints = 0;
private double mrr = 0;
private long searchTime;
private long docNamesExtractTime;
/** A certain rank in which a relevant doc was found. */
public static class RecallPoint {
private int rank;
private double recall;
private RecallPoint(int rank, double recall) {
this.rank = rank;
this.recall = recall;
}
/** Returns the rank: where on the list of returned docs this relevant doc appeared. */
public int getRank() {
return rank;
}
/** Returns the recall: how many relevant docs were returned up to this point, inclusive. */
public double getRecall() {
return recall;
}
}
private ArrayList recallPoints;
/**
* Construct a QualityStats object with anticipated maximal number of relevant hits.
*
* @param maxGoodPoints maximal possible relevant hits.
*/
public QualityStats(double maxGoodPoints, long searchTime) {
this.maxGoodPoints = maxGoodPoints;
this.searchTime = searchTime;
this.recallPoints = new ArrayList<>();
pAt = new double[MAX_POINTS + 1]; // pAt[0] unused.
}
/**
* Add a (possibly relevant) doc.
*
* @param n rank of the added doc (its ordinal position within the query results).
* @param isRelevant true if the added doc is relevant, false otherwise.
*/
public void addResult(int n, boolean isRelevant, long docNameExtractTime) {
if (Math.abs(numPoints + 1 - n) > 1E-6) {
throw new IllegalArgumentException("point " + n + " illegal after " + numPoints + " points!");
}
if (isRelevant) {
numGoodPoints += 1;
recallPoints.add(new RecallPoint(n, numGoodPoints));
if (recallPoints.size() == 1 && n <= 5) { // first point, but only within 5 top scores.
mrr = 1.0 / n;
}
}
numPoints = n;
double p = numGoodPoints / numPoints;
if (isRelevant) {
pReleventSum += p;
}
if (n < pAt.length) {
pAt[n] = p;
}
recall = maxGoodPoints <= 0 ? p : numGoodPoints / maxGoodPoints;
docNamesExtractTime += docNameExtractTime;
}
/**
* Return the precision at rank n: |{relevant hits within first n
hits}| / n
*
.
*
* @param n requested precision point, must be at least 1 and at most {@link #MAX_POINTS}.
*/
public double getPrecisionAt(int n) {
if (n < 1 || n > MAX_POINTS) {
throw new IllegalArgumentException(
"n=" + n + " - but it must be in [1," + MAX_POINTS + "] range!");
}
if (n > numPoints) {
return (numPoints * pAt[(int) numPoints]) / n;
}
return pAt[n];
}
/** Return the average precision at recall points. */
public double getAvp() {
return maxGoodPoints == 0 ? 0 : pReleventSum / maxGoodPoints;
}
/** Return the recall: |{relevant hits found}| / |{relevant hits existing}|. */
public double getRecall() {
return recall;
}
/**
* Log information on this QualityStats object.
*
* @param logger Logger.
* @param prefix prefix before each log line.
*/
public void log(String title, int paddLines, PrintWriter logger, String prefix) {
for (int i = 0; i < paddLines; i++) {
logger.println();
}
if (title != null && title.trim().length() > 0) {
logger.println(title);
}
prefix = prefix == null ? "" : prefix;
NumberFormat nf = NumberFormat.getInstance(Locale.ROOT);
nf.setMaximumFractionDigits(3);
nf.setMinimumFractionDigits(3);
nf.setGroupingUsed(true);
int M = 19;
logger.println(
prefix + format("Search Seconds: ", M) + fracFormat(nf.format((double) searchTime / 1000)));
logger.println(
prefix
+ format("DocName Seconds: ", M)
+ fracFormat(nf.format((double) docNamesExtractTime / 1000)));
logger.println(prefix + format("Num Points: ", M) + fracFormat(nf.format(numPoints)));
logger.println(prefix + format("Num Good Points: ", M) + fracFormat(nf.format(numGoodPoints)));
logger.println(prefix + format("Max Good Points: ", M) + fracFormat(nf.format(maxGoodPoints)));
logger.println(prefix + format("Average Precision: ", M) + fracFormat(nf.format(getAvp())));
logger.println(prefix + format("MRR: ", M) + fracFormat(nf.format(getMRR())));
logger.println(prefix + format("Recall: ", M) + fracFormat(nf.format(getRecall())));
for (int i = 1; i < (int) numPoints && i < pAt.length; i++) {
logger.println(
prefix
+ format("Precision At " + i + ": ", M)
+ fracFormat(nf.format(getPrecisionAt(i))));
}
for (int i = 0; i < paddLines; i++) {
logger.println();
}
}
private static String padd = " ";
private String format(String s, int minLen) {
s = (s == null ? "" : s);
int n = Math.max(minLen, s.length());
return (s + padd).substring(0, n);
}
private String fracFormat(String frac) {
int k = frac.indexOf('.');
String s1 = padd + frac.substring(0, k);
int n = Math.max(k, 6);
s1 = s1.substring(s1.length() - n);
return s1 + frac.substring(k);
}
/**
* Create a QualityStats object that is the average of the input QualityStats objects.
*
* @param stats array of input stats to be averaged.
* @return an average over the input stats.
*/
public static QualityStats average(QualityStats[] stats) {
QualityStats avg = new QualityStats(0, 0);
if (stats.length == 0) {
// weired, no stats to average!
return avg;
}
int m = 0; // queries with positive judgements
// aggregate
for (int i = 0; i < stats.length; i++) {
avg.searchTime += stats[i].searchTime;
avg.docNamesExtractTime += stats[i].docNamesExtractTime;
if (stats[i].maxGoodPoints > 0) {
m++;
avg.numGoodPoints += stats[i].numGoodPoints;
avg.numPoints += stats[i].numPoints;
avg.pReleventSum += stats[i].getAvp();
avg.recall += stats[i].recall;
avg.mrr += stats[i].getMRR();
avg.maxGoodPoints += stats[i].maxGoodPoints;
for (int j = 1; j < avg.pAt.length; j++) {
avg.pAt[j] += stats[i].getPrecisionAt(j);
}
}
}
assert m > 0 : "Fishy: no \"good\" queries!";
// take average: times go by all queries, other measures go by "good" queries only.
avg.searchTime /= stats.length;
avg.docNamesExtractTime /= stats.length;
avg.numGoodPoints /= m;
avg.numPoints /= m;
avg.recall /= m;
avg.mrr /= m;
avg.maxGoodPoints /= m;
for (int j = 1; j < avg.pAt.length; j++) {
avg.pAt[j] /= m;
}
avg.pReleventSum /= m; // this is actually avgp now
avg.pReleventSum *= avg.maxGoodPoints; // so that getAvgP() would be correct
return avg;
}
/**
* Returns the time it took to extract doc names for judging the measured query, in milliseconds.
*/
public long getDocNamesExtractTime() {
return docNamesExtractTime;
}
/**
* Returns the maximal number of good points. This is the number of relevant docs known by the
* judge for the measured query.
*/
public double getMaxGoodPoints() {
return maxGoodPoints;
}
/** Returns the number of good points (only relevant points). */
public double getNumGoodPoints() {
return numGoodPoints;
}
/** Returns the number of points (both relevant and irrelevant points). */
public double getNumPoints() {
return numPoints;
}
/** Returns the recallPoints. */
public RecallPoint[] getRecallPoints() {
return recallPoints.toArray(new RecallPoint[0]);
}
/**
* Returns the Mean reciprocal rank over the queries or RR for a single query.
*
* Reciprocal rank is defined as 1/r
where r
is the rank of the first
* correct result, or 0
if there are no correct results within the top 5 results.
*
*
This follows the definition in
* Question Answering - CNLP at the TREC-10 Question Answering Track.
*/
public double getMRR() {
return mrr;
}
/** Returns the search time in milliseconds for the measured query. */
public long getSearchTime() {
return searchTime;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy