![JAR search and dependency download from the Maven repository](/logo.png)
eu.project.ttc.engines.EvalEngine Maven / Gradle / Ivy
/*******************************************************************************
* Copyright 2015 - CNRS (Centre National de Recherche Scientifique)
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*******************************************************************************/
package eu.project.ttc.engines;
import java.io.IOException;
import java.io.PrintStream;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.TimeUnit;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.descriptor.ExternalResource;
import org.apache.uima.jcas.JCas;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Joiner;
import com.google.common.base.Stopwatch;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import eu.project.ttc.engines.cleaner.TermProperty;
import eu.project.ttc.models.Term;
import eu.project.ttc.models.VariationPath;
import eu.project.ttc.resources.EvalTrace;
import eu.project.ttc.resources.EvalTrace.RecPoint;
import eu.project.ttc.resources.ReferenceTermList;
import eu.project.ttc.resources.ReferenceTermList.RTLTerm;
import eu.project.ttc.resources.TermIndexResource;
import eu.project.ttc.utils.FileUtils;
/**
*
* An engine for the measurement of precision and recall of
* an extracted termino against a reference list.
*
* @author Damien Cram
*
*/
public class EvalEngine extends JCasAnnotator_ImplBase {
private static final Logger LOGGER = LoggerFactory.getLogger(EvalEngine.class);
public static final String HORIZONTAL_RULE = "------------------------------------------------------------";
public static final String REFERENCE_LIST = "ReferenceList";
@ExternalResource(key = REFERENCE_LIST, mandatory = true)
private ReferenceTermList rtl;
@ExternalResource(key = TermIndexResource.TERM_INDEX, mandatory = true)
private TermIndexResource termIndexResource;
public static final String EVAL_TRACE = "EvalTrace";
@ExternalResource(key=EVAL_TRACE, mandatory=true)
private EvalTrace evalTrace;
public static final String OUTPUT_LOG_FILE = "OutputLogFile";
@ConfigurationParameter(name=OUTPUT_LOG_FILE, mandatory=false)
private String outputLogFile;
public static final String OUTPUT_R_FILE = "OutputRFile";
@ConfigurationParameter(name=OUTPUT_R_FILE, mandatory=false)
private String outputRFile;
public static final String RTL_WITH_VARIANTS = "RTLWithVariants";
@ConfigurationParameter(name=RTL_WITH_VARIANTS, mandatory=true)
private boolean rtlV;
public static final String CUSTOM_LOG_HEADER_STRING = "CustomLogHeaderString";
@ConfigurationParameter(name=CUSTOM_LOG_HEADER_STRING, mandatory=false, defaultValue="")
private String customLogHeaderString;
public static final Collection CHART_AXIS_POINTS = ImmutableSet.of(10, 50, 100, 200, 250, 500, 1000, 2000, 5000, 10000, 20000);
private Set rtlTermsNotFound;
private DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss");
private Stopwatch sw = Stopwatch.createUnstarted();
@Override
public void process(JCas arg0) throws AnalysisEngineProcessException {
// do nothing
}
@Override
public void collectionProcessComplete()
throws AnalysisEngineProcessException {
// init the eval trace with the ref list size, otherwise the recall cannot be computed
this.evalTrace.setRtlSize(rtl.getTerms().size());
evaluate();
writeToLogFile();
writeToRFile();
}
private void writeToRFile() {
if(outputRFile != null) {
try {
PrintStream stream = new PrintStream(outputRFile);
writePrecisionRecallMatrix(stream);
stream.flush();
stream.close();
} catch (IOException e) {
LOGGER.error("File error", e);
LOGGER.error("Could not write R chart to file {}", outputRFile);
}
}
}
private void writePrecisionRecallMatrix(PrintStream stream) throws IOException {
for(RecPoint p:evalTrace.getChartAxisPoints(true)) {
stream.format("%d\t%.2f\t%.2f\n",
p.getRank(),
p.getPrecision(),
p.getRecall()
);
}
}
private void writeToLogFile() {
if(outputLogFile != null) {
try {
int numVariationPaths = 0;
for(Term lcTerm:termIndexResource.getTermIndex().getTerms())
numVariationPaths += lcTerm.getVariationPaths(10).size();
PrintStream stream = new PrintStream(outputLogFile);
stream.println(HORIZONTAL_RULE);
if(!customLogHeaderString.isEmpty())
stream.println(customLogHeaderString);
// stream.format("RTL path: %s\n", rtl.getPath());
// stream.format("Variant depth: %d\n", 10);
stream.format("RTL Mode: %s\n", getModeString());
stream.format("LC term index: %s\n", termIndexResource.getTermIndex().getName());
stream.format("Num. LC terms: %d\tIncl. variants: %d\n", termIndexResource.getTermIndex().getTerms().size(), numVariationPaths);
for(int i:new int[]{10,100,1000}) {
RecPoint p = evalTrace.getAtRank(i);
String str = String.format("R_%d: %.2f (p=%.2f)", i, p.getRecall()*100, p.getPrecision()*100);
stream.println(str.replaceAll(",", "."));
}
stream.format("R_max: %.2f\n", evalTrace.getMaxRecall()*100);
stream.format("Date: %s\n", dateFormat.format(new Date()));
stream.format("Generation time: %s\n", sw.elapsed(TimeUnit.SECONDS));
stream.println();
stream.println(HORIZONTAL_RULE);
stream.println("Extracted terms found in reference list");
stream.println(HORIZONTAL_RULE);
stream.println();
String termLineFormat = "%-8s%-40s\n";
stream.format(String.format(termLineFormat, "Rank", "terms"));
stream.format(termLineFormat, "---", "---");
for(RecPoint p:evalTrace.getChartAxisPoints(true)) {
for(int i=0; i sortedRefTermsNotFound = new TreeSet(new Comparator() {
@Override
public int compare(RTLTerm o1, RTLTerm o2) {
return ComparisonChain.start().compare(
o1.getId(), o2.getId()
).result();
}
});
sortedRefTermsNotFound.addAll(rtlTermsNotFound);
for(RTLTerm refTerm:sortedRefTermsNotFound)
stream.println(refTerm.toTSVString());
stream.println();
stream.println(HORIZONTAL_RULE);
stream.println("Precision/Recall");
stream.println(HORIZONTAL_RULE);
stream.println();
writePrecisionRecallMatrix(stream);
stream.flush();
stream.close();
} catch (IOException e) {
LOGGER.error("File error", e);
LOGGER.error("Could not write eval logs to file {}", outputLogFile);
}
}
}
private String getModeString() {
return String.format("%s\t(RTL %s variants)",
rtlV ? "RTLv" : "RTL",
rtlV ? "with" : "without"
);
}
private void evaluate() {
LOGGER.info("Evaluating extracted terms against file {} [RTL with variants: {}]",
FileUtils.getFileName(rtl.getPath()),
rtlV);
rtlTermsNotFound = Sets.newHashSet(rtl.asList());
List lc = Lists.newArrayList(termIndexResource.getTermIndex().getTerms());
Collections.sort(lc, TermProperty.WR.getComparator(termIndexResource.getTermIndex(), true));
generateRecPointIndexes(lc.size());
List rtlTermsFound = Lists.newArrayList();
Term term;
int tp=0;
int rank = 0;
while(rank < lc.size()) {
term = lc.get(rank);
sw.start();
for(RTLTerm rtlTerm:getMatchingRTLTerms(rtlTermsNotFound, term)) {
rtlTermsFound.add(rtlTerm);
rtlTermsNotFound.remove(rtlTerm);
tp++;
LOGGER.debug("For term \"{}\", found reference term \"{}\" ({})", term, rtlTerm, tp);
}
sw.stop();
evalTrace.trace(rank, tp, rtlTermsFound);
rtlTermsFound = Lists.newArrayList();
if(LOGGER.isDebugEnabled())
LOGGER.debug("Top {} extracted terms against reference list {}: p={} and r={}",
rank,
FileUtils.getFileName(rtl.getPath()),
String.format("%.2f", evalTrace.getLast().getPrecision()),
String.format("%.2f", evalTrace.getLast().getRecall())
);
rank ++;
}
LOGGER.info("Max recall after all {} extracted terms compared against ref list {}: {} (Eval time: {})",
evalTrace.getLast().getRank(),
FileUtils.getFileName(rtl.getPath()),
String.format("%.2f", evalTrace.getLast().getRecall()),
sw.elapsed(TimeUnit.MILLISECONDS)
);
}
private Set recPointIndexes;
private void generateRecPointIndexes(int maxIndex) {
recPointIndexes = Sets.newHashSet();
for(int i:CHART_AXIS_POINTS)
if(i getMatchingRTLTerms(Collection rtl,
Term lcTerm) {
Collection matchingRTLTerms = Sets.newHashSet();
for(RTLTerm rtlTerm:rtl) {
if(isMatch(rtlTerm, lcTerm))
matchingRTLTerms.add(rtlTerm);
}
return matchingRTLTerms;
}
/**
*
* @param rtlTerm
* @param lcTerm
* @return
*/
private boolean isMatch(RTLTerm rtlTerm, Term lcTerm) {
Set rtlTerms = Sets.newHashSet(rtlTerm);
if(rtlV)
rtlTerms.addAll(rtlTerm.getVariants());
Set lc = Sets.newHashSet();
lc.add(lcTerm);
for(VariationPath path:lcTerm.getVariationPaths(10))
if(!path.isCycle())
lc.add(path.getVariant());
if(LOGGER.isTraceEnabled()) {
LOGGER.trace("lc terms: {}", Joiner.on(',').join(lc));
LOGGER.trace("rtl terms: {}", Joiner.on(',').join(rtlTerms));
}
for(RTLTerm rtlTerm2:rtlTerms) {
for(Term lcTerm2:lc) {
if(isTermMatch(rtlTerm2, lcTerm2))
return true;
}
}
return false;
}
private boolean isTermMatch(RTLTerm refTerm, Term lcTerm) {
// 1- test again lemma concatenation
if(refTerm.getString().equalsIgnoreCase(lcTerm.getLemma()))
return true;
return false;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy