
org.dkpro.tc.crfsuite.CRFSuiteBatchCrossValidationReport Maven / Gradle / Ivy
/*******************************************************************************
* Copyright 2016
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/
package org.dkpro.tc.crfsuite;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.text.StrTokenizer;
import org.apache.commons.math.stat.descriptive.moment.Mean;
import org.apache.commons.math.stat.descriptive.moment.StandardDeviation;
import org.apache.commons.math.stat.descriptive.summary.Sum;
import org.dkpro.lab.reporting.BatchReportBase;
import org.dkpro.lab.reporting.FlexTable;
import org.dkpro.lab.storage.StorageService;
import org.dkpro.lab.storage.impl.PropertiesAdapter;
import org.dkpro.lab.task.Task;
import org.dkpro.lab.task.TaskContextMetadata;
import org.dkpro.tc.core.Constants;
import org.dkpro.tc.core.util.ReportConstants;
import org.dkpro.tc.core.util.ReportUtils;
import org.dkpro.tc.ml.ExperimentCrossValidation;
/**
* Collects the final evaluation results in a cross validation setting.
*/
public class CRFSuiteBatchCrossValidationReport
extends BatchReportBase
implements Constants
{
private static final String foldAveraged = " (average over all folds)";
private static final String foldSum = " (sum over all folds)";
private static final List discriminatorsToExclude = Arrays.asList(new String[] {
"files_validation", "files_training" });
private static final List nonAveragedResultsMeasures = Arrays.asList(new String[] {
ReportConstants.CORRECT, ReportConstants.INCORRECT, ReportConstants.NUMBER_EXAMPLES,
ReportConstants.NUMBER_LABELS });
@Override
public void execute()
throws Exception
{
StorageService store = getContext().getStorageService();
FlexTable table = FlexTable.forClass(String.class);
Map> key2resultValues = new HashMap>();
for (TaskContextMetadata subcontext : getSubtasks()) {
String name = ExperimentCrossValidation.class.getSimpleName();
// one CV batch (which internally ran numFolds times)
if (subcontext.getLabel().startsWith(name)) {
Map discriminatorsMap = store.retrieveBinary(subcontext.getId(),
Task.DISCRIMINATORS_KEY, new PropertiesAdapter()).getMap();
File eval = store.locateKey(subcontext.getId(), EVAL_FILE_NAME + SUFFIX_CSV);
Map resultMap = new HashMap();
String[][] evalMatrix = null;
int i = 0;
for (String line : FileUtils.readLines(eval)) {
String[] tokenizedLine = StrTokenizer.getCSVInstance(line).getTokenArray();
if (evalMatrix == null) {
evalMatrix = new String[FileUtils.readLines(eval).size()][tokenizedLine.length];
}
evalMatrix[i] = tokenizedLine;
i++;
}
// columns
for (int j = 0; j < evalMatrix[0].length; j++) {
String header = evalMatrix[0][j];
String[] vals = new String[evalMatrix.length - 1];
// rows
for (int k = 1; k < evalMatrix.length; k++) {
if (evalMatrix[k][j].equals("null")) {
vals[k - 1] = String.valueOf(0.);
}
else {
vals[k - 1] = evalMatrix[k][j];
}
}
Mean mean = new Mean();
Sum sum = new Sum();
StandardDeviation std = new StandardDeviation();
double[] dVals = new double[vals.length];
Set sVals = new HashSet();
for (int k = 0; k < vals.length; k++) {
try {
dVals[k] = Double.parseDouble(vals[k]);
sVals = null;
}
catch (NumberFormatException e) {
dVals = null;
sVals.add(vals[k]);
}
}
if (dVals != null) {
if (nonAveragedResultsMeasures.contains(header)) {
resultMap.put(header + foldSum, String.valueOf(sum.evaluate(dVals)));
}
else {
resultMap.put(
header + foldAveraged,
String.valueOf(mean.evaluate(dVals) + "\u00B1"
+ String.valueOf(std.evaluate(dVals))));
}
}
else {
if (sVals.size() > 1) {
resultMap.put(header, "---");
}
else {
resultMap.put(header, vals[0]);
}
}
}
String key = getKey(discriminatorsMap);
List results;
if (key2resultValues.get(key) == null) {
results = new ArrayList();
}
else {
results = key2resultValues.get(key);
}
key2resultValues.put(key, results);
Map values = new HashMap();
Map cleanedDiscriminatorsMap = new HashMap();
for (String disc : discriminatorsMap.keySet()) {
if (!ReportUtils.containsExcludePattern(disc, discriminatorsToExclude)) {
cleanedDiscriminatorsMap.put(disc, discriminatorsMap.get(disc));
}
}
values.putAll(cleanedDiscriminatorsMap);
values.putAll(resultMap);
table.addRow(subcontext.getLabel(), values);
}
}
getContext().getLoggingService().message(getContextLabel(),
ReportUtils.getPerformanceOverview(table));
// Excel cannot cope with more than 255 columns
if (table.getColumnIds().length <= 255) {
getContext().storeBinary(EVAL_FILE_NAME + "_compact" + SUFFIX_EXCEL,
table.getExcelWriter());
}
getContext().storeBinary(EVAL_FILE_NAME + "_compact" + SUFFIX_CSV, table.getCsvWriter());
table.setCompact(false);
// Excel cannot cope with more than 255 columns
if (table.getColumnIds().length <= 255) {
getContext().storeBinary(EVAL_FILE_NAME + SUFFIX_EXCEL, table.getExcelWriter());
}
getContext().storeBinary(EVAL_FILE_NAME + SUFFIX_CSV, table.getCsvWriter());
// output the location of the batch evaluation folder
// otherwise it might be hard for novice users to locate this
File dummyFolder = store.locateKey(getContext().getId(), "dummy");
// TODO can we also do this without creating and deleting the dummy folder?
getContext().getLoggingService().message(getContextLabel(),
"Storing detailed results in:\n" + dummyFolder.getParent() + "\n");
dummyFolder.delete();
}
private String getKey(Map discriminatorsMap)
{
Set sortedDiscriminators = new TreeSet(discriminatorsMap.keySet());
List values = new ArrayList();
for (String discriminator : sortedDiscriminators) {
values.add(discriminatorsMap.get(discriminator));
}
return StringUtils.join(values, "_");
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy