All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.isi.nlp.PartitionData Maven / Gradle / Ivy

The newest version!
package edu.isi.nlp;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;

import com.google.common.base.Charsets;
import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.io.Files;
import com.google.common.math.DoubleMath;
import edu.isi.nlp.collections.CollectionUtils;
import edu.isi.nlp.collections.ListUtils;
import edu.isi.nlp.files.FileUtils;
import edu.isi.nlp.parameters.Parameters;
import edu.isi.nlp.symbols.Symbol;
import java.io.File;
import java.io.IOException;
import java.math.RoundingMode;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Divides a data set into almost equally-sized partitions after optionally holding out a portion of
 * the data. For example, this could be used to divide into three equal partitions (33.3% of the
 * data each), to hold out 10% of the data and then split into three partitions (10% held out, 30%
 * in each partition), or do a simple train/test split (by specifying one partition).
 *
 * 

This uses the following parameters in the com.bbn.bue.common.partitionData namespace: * *

fileList or fileMap: Specify exactly of these to give the input as a file list or document id * to file map. If a list is specified, all output is in the form of file lists, etc. * *

holdOutProportion: The proportion of the data to hold out, for example .25 for 25%. May be * zero. holdOutFile: The file to write the held out file list/map to. Must be specified if * holdOutProportion is greater than zero. It is an error to specify this if holdOutProportion is * zero. * *

numPartitions: The number of partitions to create. Partitions are created after any held out * data is removed. This can be one to allow for a simple train/test split that is defined using * holdOutProportion. randomSeed: The random seed to use when shuffling the data before data is held * out and partitioned. * *

partitionOutputDir: The directory to write the file lists/maps that give each partition. * partitionListFile: The file to write the list of partition lists/maps to. partitionPrefix: The * prefix to give the filename of each partition. The output files will be of the format * %partitionOutputDir%/%partitionPrefix%.%partitionNumber%.{map,list} where the partition * number is zero-indexed. */ public final class PartitionData { private static final Logger log = LoggerFactory.getLogger(PartitionData.class); private static final String PARAM_NAMESPACE = "com.bbn.bue.common.partitionData."; // The input file list or map private static final String PARAM_FILE_LIST = PARAM_NAMESPACE + "fileList"; private static final String PARAM_FILE_MAP = PARAM_NAMESPACE + "fileMap"; // File to write out held out data to private static final String PARAM_HOLD_OUT_PATH = PARAM_NAMESPACE + "holdOutFile"; // Output directory for writing partitions private static final String PARAM_OUTPUT_DIR = PARAM_NAMESPACE + "partitionOutputDir"; // Prefix for file names of partition map/lists private static final String PARAM_PARTITION_LIST = PARAM_NAMESPACE + "partitionListFile"; // Prefix for file names of partition map/lists private static final String PARAM_PARTITION_PREFIX = PARAM_NAMESPACE + "partitionPrefix"; // The proportion of the data to hold out private static final String PARAM_HOLD_OUT = PARAM_NAMESPACE + "holdOutProportion"; // The number of partitions private static final String PARAM_PARTITIONS = PARAM_NAMESPACE + "numPartitions"; // Random seed for shuffling private static final String PARAM_RANDOM_SEED = PARAM_NAMESPACE + "randomSeed"; private PartitionData() { throw new UnsupportedOperationException(); } public static void main(String[] argv) { // we wrap the main method in this way to // ensure a non-zero return value on failure try { trueMain(argv); } catch (Exception e) { e.printStackTrace(); System.exit(1); } } private static void errorExit(final String msg) { System.err.println("Error: " + msg); System.exit(1); } private static void trueMain(String[] argv) throws IOException { if (argv.length != 1) { errorExit("Usage: PartitionData params"); } final Parameters parameters = Parameters.loadSerifStyle(new File(argv[0])); log.info("Running with parameters:\n" + parameters.dump()); // Can run on map or list, but only one of the two. parameters.assertExactlyOneDefined(PARAM_FILE_LIST, PARAM_FILE_MAP); // Configure for map/list // This will contain the file paths in the case of a file list, or the docids in the case // of a docid to file map. ImmutableList documents; // This is null if we are not using a map final ImmutableMap documentMap; if (parameters.isPresent(PARAM_FILE_LIST)) { final File fileList = parameters.getExistingFile(PARAM_FILE_LIST); log.info("Loading file list from {}", fileList); documentMap = null; documents = FluentIterable.from(FileUtils.loadFileList(fileList)) .transform(MakeCrossValidationBatches.FileToSymbolFunction.INSTANCE) .toList(); } else if (parameters.isPresent(PARAM_FILE_MAP)) { final File fileMap = parameters.getExistingFile(PARAM_FILE_MAP); log.info("Loading file map from {}", fileMap); documentMap = FileUtils.loadSymbolToFileMap(fileMap); documents = documentMap.keySet().asList(); } else { // Should be unreachable. Used instead of checkState to satisfy compiler. throw new IllegalStateException("Input is neither file list nor map"); } log.info("Loaded {} documents", documents.size()); final File outputDirectory = parameters.getCreatableDirectory(PARAM_OUTPUT_DIR); final File partitionListFile = parameters.getCreatableFile(PARAM_PARTITION_LIST); final Optional holdOutFile = parameters.getOptionalCreatableFile(PARAM_HOLD_OUT_PATH); final String partitionPrefix = parameters.getString(PARAM_PARTITION_PREFIX); final int nPartitions = parameters.getPositiveInteger(PARAM_PARTITIONS); final int randomSeed = parameters.getInteger(PARAM_RANDOM_SEED); final double holdOut = parameters.getProbability(PARAM_HOLD_OUT); // 1.0 is a valid probability but not a valid hold out value checkArgument(holdOut != 1.0, "Hold out proportion must be less than all of the data"); checkArgument( holdOutFile.isPresent() == (holdOut > 0.0), PARAM_HOLD_OUT + " must be specified if and only if hold out amount is greater than zero"); checkArgument( holdOut > 0.0 || nPartitions > 1, "Neither hold out nor more than one partition specified. Nothing to do."); // Figure out how much is held out final int nDocuments = documents.size(); final int nHeldOut = DoubleMath.roundToInt(nDocuments * holdOut, RoundingMode.HALF_UP); // Prevent requesting .99999 of 10 documents, which is all of them. checkArgument(nHeldOut < nDocuments, "Cannot hold out all documents"); // Prevent requesting .00001 of 10 documents, which is none of them. checkArgument( holdOut == 0.0 || nHeldOut > 0, "Hold out amount is non-zero but less than one document"); log.info("Holding out {} documents", nHeldOut); // Compute how much is left over after hold out final int nRemaining = nDocuments - nHeldOut; checkArgument( nRemaining >= nPartitions, "More partitions requested than number of non-held out documents"); log.info("Dividing {} documents into {} partitions", nRemaining, nPartitions); // Shuffle and replace the original to avoid incorrect references. documents = ListUtils.shuffledCopy(documents, new Random(randomSeed)); // Hold out beginning of list final ImmutableSet heldOut = ImmutableSet.copyOf(documents.subList(0, nHeldOut)); final ImmutableSet remaining = ImmutableSet.copyOf(documents.subList(nHeldOut, nDocuments)); checkState( heldOut.size() + remaining.size() == nDocuments, "Number of documents in held out and partitioned data differs from original number of documents"); int outputDocuments = 0; if (holdOut > 0.0) { // Already checked about that holdOutFile is present in this case outputDocuments += writeHoldOut(heldOut, holdOutFile.get(), documentMap); } outputDocuments += writePartitions( remaining, nPartitions, outputDirectory, partitionListFile, partitionPrefix, documentMap); checkState(nDocuments == outputDocuments, "Incorrect number of documents written"); } private static int writeHoldOut( final ImmutableSet heldOut, final File holdOutFile, final ImmutableMap documentMap) throws IOException { if (documentMap != null) { FileUtils.writeSymbolToFileMap( filterMapToKeysPreservingOrder(documentMap, heldOut), Files.asCharSink(holdOutFile, Charsets.UTF_8)); } else { FileUtils.writeFileList( Lists.transform(heldOut.asList(), SymbolToFileFunction.INSTANCE), Files.asCharSink(holdOutFile, Charsets.UTF_8)); } log.info("Wrote held out data to {}", holdOutFile); return heldOut.size(); } private static int writePartitions( final ImmutableSet remaining, final int nPartitions, final File outputDirectory, final File partitionListFile, final String partitionPrefix, final ImmutableMap documentMap) throws IOException { int outputDocuments = 0; // Partition remaining data final ImmutableList> partitions = CollectionUtils.partitionAlmostEvenly(remaining.asList(), nPartitions); // Write out each partition log.info("Writing partitions to directory {}", outputDirectory); int maxPartition = partitions.size(); final ImmutableList.Builder partitionFiles = ImmutableList.builder(); for (int partitionNum = 0; partitionNum < partitions.size(); partitionNum++) { final ImmutableList partition = partitions.get(partitionNum); // maxPartition -1 since partitions are indexed starting at 0 final String partitionFileName = partitionPrefix + '.' + StringUtils.padWithMax(partitionNum, maxPartition - 1); final File partitionFile; if (documentMap != null) { partitionFile = new File(outputDirectory, partitionFileName + ".map"); FileUtils.writeSymbolToFileMap( filterMapToKeysPreservingOrder(documentMap, partition), Files.asCharSink(partitionFile, Charsets.UTF_8)); } else { partitionFile = new File(outputDirectory, partitionFileName + ".list"); FileUtils.writeFileList( Lists.transform(partition, SymbolToFileFunction.INSTANCE), Files.asCharSink(partitionFile, Charsets.UTF_8)); } partitionFiles.add(partitionFile); outputDocuments += partition.size(); log.info("Wrote partition {} to {}", partitionNum, partitionFile); } // Write out lists of files/maps created FileUtils.writeFileList( partitionFiles.build(), Files.asCharSink(partitionListFile, Charsets.UTF_8)); log.info("Wrote partition list to {}", partitionListFile); return outputDocuments; } /** * Filters a map down to the specified keys such that the new map has the same iteration order as * the specified keys. */ private static ImmutableMap filterMapToKeysPreservingOrder( final ImmutableMap map, Iterable keys) { final ImmutableMap.Builder ret = ImmutableMap.builder(); for (final K key : keys) { final V value = map.get(key); checkArgument(value != null, "Key " + key + " not in map"); ret.put(key, value); } return ret.build(); } private enum SymbolToFileFunction implements Function { INSTANCE; @Override public File apply(final Symbol input) { return new File(input.asString()); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy