All Downloads are FREE. Search and download functionalities are using the official Maven repository.

eu.project.ttc.engines.EvalEngine Maven / Gradle / Ivy

/*******************************************************************************
 * Copyright 2015 - CNRS (Centre National de Recherche Scientifique)
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 *******************************************************************************/
package eu.project.ttc.engines;

import java.io.IOException;
import java.io.PrintStream;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.TimeUnit;

import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.descriptor.ExternalResource;
import org.apache.uima.jcas.JCas;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Joiner;
import com.google.common.base.Stopwatch;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import eu.project.ttc.engines.cleaner.TermProperty;
import eu.project.ttc.models.Term;
import eu.project.ttc.models.VariationPath;
import eu.project.ttc.resources.EvalTrace;
import eu.project.ttc.resources.EvalTrace.RecPoint;
import eu.project.ttc.resources.ReferenceTermList;
import eu.project.ttc.resources.ReferenceTermList.RTLTerm;
import eu.project.ttc.resources.TermIndexResource;
import eu.project.ttc.utils.FileUtils;

/**
 * 
 * An engine for the measurement of precision and recall of 
 * an extracted termino against a reference list.
 * 
 * @author Damien Cram
 *
 */
public class EvalEngine  extends JCasAnnotator_ImplBase {
	private static final Logger LOGGER = LoggerFactory.getLogger(EvalEngine.class);

	public static final String HORIZONTAL_RULE = "------------------------------------------------------------";

	public static final String REFERENCE_LIST = "ReferenceList";
	@ExternalResource(key = REFERENCE_LIST, mandatory = true)
	private ReferenceTermList rtl;

	@ExternalResource(key = TermIndexResource.TERM_INDEX, mandatory = true)
	private TermIndexResource termIndexResource;
	
	public static final String EVAL_TRACE = "EvalTrace";
	@ExternalResource(key=EVAL_TRACE, mandatory=true)
	private EvalTrace evalTrace;
	
	public static final String OUTPUT_LOG_FILE = "OutputLogFile";
	@ConfigurationParameter(name=OUTPUT_LOG_FILE, mandatory=false)
	private String outputLogFile;

	public static final String OUTPUT_R_FILE = "OutputRFile";
	@ConfigurationParameter(name=OUTPUT_R_FILE, mandatory=false)
	private String outputRFile;

	public static final String RTL_WITH_VARIANTS = "RTLWithVariants";

	@ConfigurationParameter(name=RTL_WITH_VARIANTS, mandatory=true)
	private boolean rtlV;
	
	public static final String CUSTOM_LOG_HEADER_STRING = "CustomLogHeaderString";
	@ConfigurationParameter(name=CUSTOM_LOG_HEADER_STRING, mandatory=false, defaultValue="")
	private String customLogHeaderString;

	public static final Collection CHART_AXIS_POINTS = ImmutableSet.of(10, 50, 100, 200, 250, 500, 1000, 2000, 5000, 10000, 20000);
	
	private Set rtlTermsNotFound;
	private DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss");
	private Stopwatch sw = Stopwatch.createUnstarted();
	
	@Override
	public void process(JCas arg0) throws AnalysisEngineProcessException {
		// do nothing
	}

	
	
	@Override
	public void collectionProcessComplete()
			throws AnalysisEngineProcessException {
		// init the eval trace with the ref list size, otherwise the recall cannot be computed
		this.evalTrace.setRtlSize(rtl.getTerms().size());

		evaluate();
		writeToLogFile();
		writeToRFile();
	}

	private void writeToRFile() {
		if(outputRFile != null) {
			try {
				PrintStream stream = new PrintStream(outputRFile);
				
				writePrecisionRecallMatrix(stream);
				
				stream.flush();
				stream.close();		
			} catch (IOException e) {
				LOGGER.error("File error", e);
				LOGGER.error("Could not write R chart to file {}", outputRFile);
			}
		}
	}

	private void writePrecisionRecallMatrix(PrintStream stream) throws IOException {
		for(RecPoint p:evalTrace.getChartAxisPoints(true)) {
			stream.format("%d\t%.2f\t%.2f\n", 
				p.getRank(),
				p.getPrecision(),
				p.getRecall()
			);
		}
	}

	private void writeToLogFile() {
		if(outputLogFile != null) {
			try {
				int numVariationPaths = 0;
				for(Term lcTerm:termIndexResource.getTermIndex().getTerms())
					numVariationPaths += lcTerm.getVariationPaths(10).size();
				PrintStream stream = new PrintStream(outputLogFile);
				stream.println(HORIZONTAL_RULE);
				if(!customLogHeaderString.isEmpty())
					stream.println(customLogHeaderString);
//				stream.format("RTL path: %s\n", rtl.getPath());
//				stream.format("Variant depth: %d\n", 10);
				
				stream.format("RTL Mode: %s\n", getModeString());
				stream.format("LC term index: %s\n", termIndexResource.getTermIndex().getName());
				stream.format("Num. LC terms: %d\tIncl. variants: %d\n", termIndexResource.getTermIndex().getTerms().size(), numVariationPaths);
				
				for(int i:new int[]{10,100,1000}) {
					RecPoint p = evalTrace.getAtRank(i);
					String str = String.format("R_%d: %.2f  (p=%.2f)", i, p.getRecall()*100, p.getPrecision()*100);
					stream.println(str.replaceAll(",", "."));
				}
				stream.format("R_max: %.2f\n", evalTrace.getMaxRecall()*100);
				stream.format("Date: %s\n", dateFormat.format(new Date()));
				stream.format("Generation time: %s\n", sw.elapsed(TimeUnit.SECONDS));
				stream.println();
				stream.println(HORIZONTAL_RULE);
				stream.println("Extracted terms found in reference list");
				stream.println(HORIZONTAL_RULE);
				stream.println();

				String termLineFormat = "%-8s%-40s\n";
				stream.format(String.format(termLineFormat, "Rank", "terms"));
				stream.format(termLineFormat, "---", "---");
				for(RecPoint p:evalTrace.getChartAxisPoints(true)) {
										
					for(int i=0; i sortedRefTermsNotFound = new TreeSet(new Comparator() {
					@Override
					public int compare(RTLTerm o1, RTLTerm o2) {
						return ComparisonChain.start().compare(
								o1.getId(), o2.getId()
								).result();
					}
				});
				sortedRefTermsNotFound.addAll(rtlTermsNotFound);
				 
				for(RTLTerm refTerm:sortedRefTermsNotFound)
					stream.println(refTerm.toTSVString());

				stream.println();
				stream.println(HORIZONTAL_RULE);
				stream.println("Precision/Recall");
				stream.println(HORIZONTAL_RULE);
				stream.println();
				
				writePrecisionRecallMatrix(stream);

				stream.flush();
				stream.close();		
			} catch (IOException e) {
				LOGGER.error("File error", e);
				LOGGER.error("Could not write eval logs to file {}", outputLogFile);
			}		
		}
	}

	private String getModeString() {
		return String.format("%s\t(RTL %s variants)",
				rtlV ? "RTLv" : "RTL",
				rtlV ? "with" : "without"
		);
	}

	private void evaluate() {
		LOGGER.info("Evaluating extracted terms against file {} [RTL with variants: {}]", 
				FileUtils.getFileName(rtl.getPath()),
				rtlV);
		rtlTermsNotFound = Sets.newHashSet(rtl.asList());
		List lc = Lists.newArrayList(termIndexResource.getTermIndex().getTerms());
		Collections.sort(lc, TermProperty.WR.getComparator(termIndexResource.getTermIndex(), true));
		generateRecPointIndexes(lc.size());
		
		List rtlTermsFound = Lists.newArrayList();
		Term term;
		int tp=0;
		int rank = 0;
		while(rank < lc.size()) {
			term = lc.get(rank);
			sw.start();
			for(RTLTerm rtlTerm:getMatchingRTLTerms(rtlTermsNotFound, term)) {
				rtlTermsFound.add(rtlTerm);
				rtlTermsNotFound.remove(rtlTerm);
				tp++;
				LOGGER.debug("For term \"{}\", found reference term \"{}\" ({})", term, rtlTerm, tp);
			}
			sw.stop();
			evalTrace.trace(rank, tp, rtlTermsFound);
			rtlTermsFound = Lists.newArrayList();
			if(LOGGER.isDebugEnabled()) 
				LOGGER.debug("Top {} extracted terms against reference list {}: p={} and r={}",
					rank,
					FileUtils.getFileName(rtl.getPath()),
					String.format("%.2f", evalTrace.getLast().getPrecision()),
					String.format("%.2f", evalTrace.getLast().getRecall())
				);
			rank ++;
		}
		
		LOGGER.info("Max recall after all {} extracted terms compared against ref list {}: {} (Eval time: {})",
			evalTrace.getLast().getRank(),
			FileUtils.getFileName(rtl.getPath()),
			String.format("%.2f", evalTrace.getLast().getRecall()),
			sw.elapsed(TimeUnit.MILLISECONDS)
			);
	}
	
	
	private Set recPointIndexes;
	private void generateRecPointIndexes(int maxIndex) {
		recPointIndexes = Sets.newHashSet();
		for(int i:CHART_AXIS_POINTS)
			if(i getMatchingRTLTerms(Collection rtl,
			Term lcTerm) {
		Collection matchingRTLTerms = Sets.newHashSet();
		for(RTLTerm rtlTerm:rtl) {
			if(isMatch(rtlTerm, lcTerm))
				matchingRTLTerms.add(rtlTerm);
		}
		return matchingRTLTerms;
	}

	/**
	 * 
	 * @param rtlTerm
	 * @param lcTerm
	 * @return
	 */
	private boolean isMatch(RTLTerm rtlTerm, Term lcTerm) {
		Set rtlTerms = Sets.newHashSet(rtlTerm);
		if(rtlV)
			rtlTerms.addAll(rtlTerm.getVariants());
			
		Set lc = Sets.newHashSet();
		lc.add(lcTerm);
		for(VariationPath path:lcTerm.getVariationPaths(10))
			if(!path.isCycle())
					lc.add(path.getVariant());
		
		
		if(LOGGER.isTraceEnabled()) {
			LOGGER.trace("lc terms:  {}", Joiner.on(',').join(lc));
			LOGGER.trace("rtl terms: {}", Joiner.on(',').join(rtlTerms));
		}
		for(RTLTerm rtlTerm2:rtlTerms) {
			for(Term lcTerm2:lc) {
				if(isTermMatch(rtlTerm2, lcTerm2))
					return true;
			}
		}
		
		return false;
	}


	private boolean isTermMatch(RTLTerm refTerm, Term lcTerm) {
		
		// 1- test again lemma concatenation
		if(refTerm.getString().equalsIgnoreCase(lcTerm.getLemma()))
			return true;
		
		return false;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy