All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.isi.nlp.MakeCrossValidationBatches Maven / Gradle / Ivy

The newest version!
package edu.isi.nlp;

import static com.google.common.base.Predicates.in;

import com.google.common.base.Charsets;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.io.Files;
import edu.isi.nlp.collections.CollectionUtils;
import edu.isi.nlp.collections.ListUtils;
import edu.isi.nlp.files.FileUtils;
import edu.isi.nlp.parameters.Parameters;
import edu.isi.nlp.symbols.Symbol;
import edu.isi.nlp.symbols.SymbolUtils;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import java.util.Set;
import java.util.SortedMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Given a list of files and a number of splits, creates training/test file lists for
 * cross-validation. When the files cannot be evenly divided across all splits, extra files are
 * distributed as evenly as possible, starting with the first folds. For example, dividing 11 items
 * into three folds will result in folds of size (4, 4, 3).
 *
 * 

Normally, cross-validation batches are made such that as much data as possible is used in * training. However, if the {@code singleFoldTraining} parameter is set to true, a single fold is * used in training. Thus, across folds, each document appears exactly once in the training data and * exactly once in the testing data. In a standard cross-validation setup, each data point appears * in the training data {@code k-1} times for {@code k} folds and exactly once in the testing data. * *

For example, in a normal cross-validation setup with three folds (A, B, C) you have the * following test: train pairs: * *

 * A: (B, C)
 * B: (A, C)
 * C: (A, B)
 * 
* * In single fold training you have: * *
 * A: B
 * B: C
 * C: A
 * 
* *

Note that in the single fold case, each fold is tested by training on the "next" fold. This is * arbitrary, but as documents are shuffled before being placed in folds this is not problematic. */ public final class MakeCrossValidationBatches { private static final Logger log = LoggerFactory.getLogger(MakeCrossValidationBatches.class); private static final String PARAM_NAMESPACE = "com.bbn.bue.common.crossValidation."; private static final String PARAM_FILE_LIST = PARAM_NAMESPACE + "fileList"; private static final String PARAM_FILE_MAP = PARAM_NAMESPACE + "fileMap"; private static final String PARAM_NUM_BATCHES = PARAM_NAMESPACE + "numBatches"; private static final String PARAM_RANDOM_SEED = PARAM_NAMESPACE + "randomSeed"; private static final String PARAM_OUTPUT_DIR = PARAM_NAMESPACE + "outputDir"; private static final String PARAM_OUTPUT_NAME = PARAM_NAMESPACE + "outputName"; private static final String PARAM_SINGLE_FOLD_TRAINING = PARAM_NAMESPACE + "singleFoldTraining"; private MakeCrossValidationBatches() { throw new UnsupportedOperationException(); } public static void main(String[] argv) { // we wrap the main method in this way to // ensure a non-zero return value on failure try { trueMain(argv); } catch (Exception e) { e.printStackTrace(); System.exit(1); } } private static void errorExit(final String msg) { System.err.println("Error: " + msg); System.exit(1); } private static void trueMain(String[] argv) throws IOException { if (argv.length != 1) { errorExit("Usage: MakeCrossValidationBatches params"); } final Parameters parameters = Parameters.loadSerifStyle(new File(argv[0])); // Can run on map or list, but only one of the two. parameters.assertExactlyOneDefined(PARAM_FILE_LIST, PARAM_FILE_MAP); // Configure for map/list boolean useFileMap = false; final File sourceFiles; if (parameters.isPresent(PARAM_FILE_LIST)) { sourceFiles = parameters.getExistingFile(PARAM_FILE_LIST); } else if (parameters.isPresent(PARAM_FILE_MAP)) { useFileMap = true; sourceFiles = parameters.getExistingFile(PARAM_FILE_MAP); } else { throw new IllegalArgumentException("Impossible state reached"); } // Configure for single fold training. boolean singleFoldTraining = parameters.getOptionalBoolean(PARAM_SINGLE_FOLD_TRAINING).or(false); final File outputDirectory = parameters.getCreatableDirectory(PARAM_OUTPUT_DIR); final String outputName = parameters.getString(PARAM_OUTPUT_NAME); final int numBatches = parameters.getPositiveInteger(PARAM_NUM_BATCHES); final int randomSeed = parameters.getInteger(PARAM_RANDOM_SEED); if (numBatches < 1) { errorExit("Bad numBatches value: Need one or more batches to divide data into"); } final int maxBatch = numBatches - 1; final ImmutableMap docIdMap; if (useFileMap) { docIdMap = FileUtils.loadSymbolToFileMap(Files.asCharSource(sourceFiles, Charsets.UTF_8)); } else { // We load a file list but coerce it into a map final ImmutableList inputFiles = FileUtils.loadFileList(sourceFiles); docIdMap = Maps.uniqueIndex(inputFiles, FileToSymbolFunction.INSTANCE); // Check that nothing was lost in the conversion if (docIdMap.size() != inputFiles.size()) { errorExit("Input file list contains duplicate entries"); } } // Get the list of docids and shuffle them. In the case of using a file list, these are just // paths, not document ids, but they serve the same purpose. final ImmutableList docIds = ListUtils.shuffledCopy(docIdMap.keySet().asList(), new Random(randomSeed)); if (numBatches > docIds.size()) { errorExit("Bad numBatches value: Cannot create more batches than there are input files"); } // Divide into folds final ImmutableList> testFolds = CollectionUtils.partitionAlmostEvenly(docIds, numBatches); // Write out training/test data for each fold final ImmutableList.Builder foldLists = ImmutableList.builder(); final ImmutableList.Builder foldMaps = ImmutableList.builder(); int batchNum = 0; int totalDocs = 0; // Set up train folds final ImmutableList> trainFolds = createTrainFolds(testFolds, docIds, singleFoldTraining); // Loop over train/test folds Preconditions.checkState(trainFolds.size() == testFolds.size()); for (int i = 0; i < testFolds.size(); i++) { final ImmutableList testDocIds = testFolds.get(i); final ImmutableList trainDocIds = trainFolds.get(i); // Track the total test documents across folds to make sure nothing is lost. totalDocs += testDocIds.size(); // Create maps for training and test. These are sorted to avoid arbitrary ordering. final SortedMap trainDocIdMap = ImmutableSortedMap.copyOf( Maps.filterKeys(docIdMap, in(trainDocIds)), SymbolUtils.byStringOrdering()); final SortedMap testDocIdMap = ImmutableSortedMap.copyOf( Maps.filterKeys(docIdMap, in(testDocIds)), SymbolUtils.byStringOrdering()); // Don't write out the maps for file lists as the keys are not actually document IDs if (useFileMap) { final File trainingMapOutputFile = new File( outputDirectory, outputName + "." + StringUtils.padWithMax(batchNum, maxBatch) + ".training.docIDToFileMap"); FileUtils.writeSymbolToFileMap( trainDocIdMap, Files.asCharSink(trainingMapOutputFile, Charsets.UTF_8)); final File testMapOutputFile = new File( outputDirectory, outputName + "." + StringUtils.padWithMax(batchNum, maxBatch) + ".test.docIDToFileMap"); FileUtils.writeSymbolToFileMap( testDocIdMap, Files.asCharSink(testMapOutputFile, Charsets.UTF_8)); foldMaps.add(testMapOutputFile); } // Write out file lists final ImmutableList trainingFilesForBatch = ImmutableList.copyOf(trainDocIdMap.values()); final ImmutableList testFilesForBatch = ImmutableList.copyOf(testDocIdMap.values()); final File trainingOutputFile = new File( outputDirectory, outputName + "." + StringUtils.padWithMax(batchNum, maxBatch) + ".training.list"); FileUtils.writeFileList( trainingFilesForBatch, Files.asCharSink(trainingOutputFile, Charsets.UTF_8)); final File testOutputFile = new File( outputDirectory, outputName + "." + StringUtils.padWithMax(batchNum, maxBatch) + ".test.list"); FileUtils.writeFileList(testFilesForBatch, Files.asCharSink(testOutputFile, Charsets.UTF_8)); foldLists.add(testOutputFile); ++batchNum; } if (totalDocs != docIdMap.size()) { errorExit("Test files created are not the same length as the input"); } // Write out lists of files/maps created FileUtils.writeFileList( foldLists.build(), Files.asCharSink(new File(outputDirectory, "folds.list"), Charsets.UTF_8)); if (useFileMap) { FileUtils.writeFileList( foldMaps.build(), Files.asCharSink(new File(outputDirectory, "folds.maplist"), Charsets.UTF_8)); } log.info( "Wrote {} cross validation batches from {} to directory {}", numBatches, sourceFiles.getAbsoluteFile(), outputDirectory.getAbsolutePath()); } private static ImmutableList shuffledDocIds( final int randomSeed, final ImmutableMap docIdMap) { final ArrayList docIds = Lists.newArrayList(docIdMap.keySet()); Collections.shuffle(docIds, new Random(randomSeed)); return ImmutableList.copyOf(docIds); } private static ImmutableList> createTrainFolds( final ImmutableList> testFolds, final ImmutableList docIds, final boolean singleFoldTraining) { final ImmutableList.Builder> ret = ImmutableList.builder(); for (int i = 0; i < testFolds.size(); i++) { final Set testDocIds = ImmutableSet.copyOf(testFolds.get(i)); final ImmutableList trainDocIds = singleFoldTraining // In the single fold training case, use the "next" fold as the training data. We use // the modulus to wrap around the list. ? testFolds.get((i + 1) % testFolds.size()) // In the normal case, just use all the remaining data for training. : Sets.difference(ImmutableSet.copyOf(docIds), testDocIds).immutableCopy().asList(); ret.add(trainDocIds); } return ret.build(); } enum FileToSymbolFunction implements Function { INSTANCE; @Override public Symbol apply(final File input) { return Symbol.from(input.getPath()); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy