org.apache.ctakes.ytex.kernel.FoldGeneratorImpl Maven / Gradle / Ivy
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.ctakes.ytex.kernel;
import org.apache.commons.cli.*;
import org.apache.ctakes.ytex.kernel.dao.ClassifierEvaluationDao;
import org.apache.ctakes.ytex.kernel.model.CrossValidationFold;
import org.apache.ctakes.ytex.kernel.model.CrossValidationFoldInstance;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import java.io.IOException;
import java.util.*;
/**
* utility generates cv fold splits, stores in db. Takes as a command line
* parameter -prop [property file]. Also reads properties from org.apache.ctakes.ytex.properties.
* Required properties:
*
* - org.apache.ctakes.ytex.corpusName
*
- instanceClassQuery
*
*
* Optional properties:
*
* - minPerClass default 1 minimum number of instances per class/fold. if not
* enough instances of a specific class, the instance will be repeated across
* folds. E.g. if you have only one example of one class, and 2 folds, that one
* example will be duplicated in both folds.
*
- rand random number seed, defaults to current time millis
*
- org.apache.ctakes.ytex.splitName default null cv_fold.split_name
*
- folds default 2
*
- runs default 5
*
*
* @author vijay
*/
public class FoldGeneratorImpl implements FoldGenerator {
private static final Logger LOGGER = LoggerFactory.getLogger( "FoldGeneratorImpl" );
/**
* iterate through the labels, split instances into folds
*
* @param mapClassToInstanceId
* @param nFolds
* @param nMinPerClass
* @param r
* @return list with nFolds sets of instance ids corresponding to the folds
*/
private static List> createFolds(
Map> mapClassToInstanceId, int nFolds,
int nMinPerClass, Random r) {
List> folds = new ArrayList>(nFolds);
Map>> mapLabelFolds = new HashMap>>();
for (Map.Entry> classToInstanceId : mapClassToInstanceId
.entrySet()) {
List instanceIds = classToInstanceId.getValue();
Collections.shuffle(instanceIds, r);
List> classFolds = new ArrayList>(nFolds);
int blockSize = instanceIds.size() / nFolds;
for (int i = 0; i < nFolds; i++) {
Set foldInstanceIds = new HashSet(blockSize);
if (instanceIds.size() <= nMinPerClass) {
// we don't have minPerClass for the given class
// just add all of them to each fold
foldInstanceIds.addAll(instanceIds);
} else if (blockSize < nMinPerClass) {
// too few of the given class - just randomly select
// nMinPerClass
double fraction = (double) nMinPerClass
/ (double) instanceIds.size();
// iterate through the list, start somewhere in the middle
int instanceIdIndex = (int) (r.nextDouble() * instanceIds
.size());
while (foldInstanceIds.size() < nMinPerClass) {
// go back to beginning of list if we hit the end
if (instanceIdIndex >= instanceIds.size()) {
instanceIdIndex = 0;
}
// randomly select this line
if (r.nextDouble() <= fraction) {
long instanceId = instanceIds.get(instanceIdIndex);
foldInstanceIds.add(instanceId);
}
// go to next line
instanceIdIndex++;
}
} else {
int nStart = i * blockSize;
int nEnd = (i == nFolds - 1) ? instanceIds.size() : nStart
+ blockSize;
for (int instanceIdIndex = nStart; instanceIdIndex < nEnd; instanceIdIndex++) {
foldInstanceIds.add(instanceIds.get(instanceIdIndex));
}
}
classFolds.add(foldInstanceIds);
}
mapLabelFolds.put(classToInstanceId.getKey(), classFolds);
}
for (int i = 0; i < nFolds; i++) {
Set foldInstanceIds = new HashSet();
for (List> labelFold : mapLabelFolds.values()) {
foldInstanceIds.addAll(labelFold.get(i));
}
folds.add(foldInstanceIds);
}
return folds;
}
@SuppressWarnings("static-access")
public static void main(String args[]) throws ParseException, IOException {
Options options = new Options();
options.addOption(OptionBuilder
.withArgName("prop")
.hasArg()
.withDescription(
"property file with query to retrieve instance id - label - class triples")
.create("prop"));
// OptionGroup group = new OptionGroup();
// group
// .addOption(OptionBuilder
// .withArgName("query")
// .hasArg()
// .withDescription(
// "query to retrieve instance id - label - class triples")
// .create("query"));
// group
// .addOption(OptionBuilder
// .withArgName("prop")
// .hasArg()
// .withDescription(
// "property file with query to retrieve instance id - label - class triples")
// .create("prop"));
// group.isRequired();
// options.addOptionGroup(group);
// options.addOption(OptionBuilder.withArgName("name").hasArg()
// .isRequired().withDescription("name. required").create("name"));
// options.addOption(OptionBuilder.withArgName("runs").hasArg()
// .withDescription("number of runs, default 1").create("runs"));
// options.addOption(OptionBuilder.withArgName("folds").hasArg()
// .withDescription("number of folds, default 4").create("folds"));
// options.addOption(OptionBuilder.withArgName("minPerClass").hasArg()
// .withDescription("minimum instances per class, default 1")
// .create("minPerClass"));
// options.addOption(OptionBuilder.withArgName("rand").hasArg()
// .withDescription(
// "random number seed; default current time in millis")
// .create("rand"));
try {
if (args.length == 0)
printHelp(options);
else {
CommandLineParser parser = new GnuParser();
CommandLine line = parser.parse(options, args);
String propFile = line.getOptionValue("prop");
Properties props = FileUtil.loadProperties(propFile, true);
// Integer rand = line.hasOption("rand") ? Integer.parseInt(line
// .getOptionValue("rand")) : null;
// int runs = Integer.parseInt(line.getOptionValue("runs",
// "1"));
// int minPerClass = Integer.parseInt(line.getOptionValue(
// "minPerClass", "1"));
// int folds = Integer.parseInt(line.getOptionValue("folds",
// "4"));
String corpusName = props.getProperty("org.apache.ctakes.ytex.corpusName");
String splitName = props.getProperty("org.apache.ctakes.ytex.splitName");
String query = props.getProperty("instanceClassQuery");
int folds = Integer.parseInt(props.getProperty("folds", "2"));
int runs = Integer.parseInt(props.getProperty("runs", "5"));
int minPerClass = Integer.parseInt(props.getProperty(
"minPerClass", "1"));
Integer rand = props.containsKey("rand") ? Integer
.parseInt(props.getProperty("rand")) : null;
boolean argsOk = true;
if (corpusName == null) {
LOGGER.error("missing parameter: org.apache.ctakes.ytex.corpusName");
argsOk = false;
}
if (query == null) {
LOGGER.error("missing parameter: instanceClassQuery");
argsOk = false;
}
if (!argsOk) {
printHelp(options);
System.exit(1);
} else {
KernelContextHolder
.getApplicationContext()
.getBean(FoldGenerator.class)
.generateRuns(corpusName, splitName, query, folds,
minPerClass, rand, runs);
}
}
} catch (ParseException pe) {
printHelp(options);
}
}
private static void printHelp(Options options) {
HelpFormatter formatter = new HelpFormatter();
formatter
.printHelp(
"java org.apache.ctakes.ytex.kernel.FoldGeneratorImpl splits training data into mxn training/test sets for mxn-fold cross validation",
options);
}
ClassifierEvaluationDao classifierEvaluationDao;
KernelUtil kernelUtil;
/**
* generate folds for a run
*
* @param labels
* @param instances
* @param corpusName
* @param splitName
* @param run
* @param query
* @param nFolds
* @param nMinPerClass
* @param r
*/
public void generateFolds(Set labels, InstanceData instances,
String corpusName, String splitName, int run, String query,
int nFolds, int nMinPerClass, Random r) {
for (String label : instances.getLabelToInstanceMap().keySet()) {
// there should not be any runs/folds/train test split - just unpeel
// until we get to the instance - class map
SortedMap>>> runMap = instances
.getLabelToInstanceMap().get(label);
SortedMap>> foldMap = runMap
.values().iterator().next();
SortedMap> trainMap = foldMap
.values().iterator().next();
SortedMap mapInstanceIdToClass = trainMap.values()
.iterator().next();
List> folds = createFolds(nFolds, nMinPerClass, r,
mapInstanceIdToClass);
// insert the folds
insertFolds(folds, corpusName, splitName, label, run);
}
}
/**
* inver the map of instance id to class, call createFolds
*
* @param nFolds
* @param nMinPerClass
* @param r
* @param mapInstanceIdToClass
* @return
*/
private List> createFolds(int nFolds, int nMinPerClass, Random r,
SortedMap mapInstanceIdToClass) {
// invert the mapInstanceIdToClass
Map> mapClassToInstanceId = new TreeMap>();
for (Map.Entry instance : mapInstanceIdToClass.entrySet()) {
String className = instance.getValue();
long instanceId = instance.getKey();
List classInstanceIds = mapClassToInstanceId.get(className);
if (classInstanceIds == null) {
classInstanceIds = new ArrayList();
mapClassToInstanceId.put(className, classInstanceIds);
}
classInstanceIds.add(instanceId);
}
// stratified split into folds
List> folds = createFolds(mapClassToInstanceId, nFolds,
nMinPerClass, r);
return folds;
}
/*
* (non-Javadoc)
*
* @see org.apache.ctakes.ytex.kernel.FoldGenerator#generateRuns(java.lang.String,
* java.lang.String, int, int, java.lang.Integer, int)
*/
@Override
public void generateRuns(String corpusName, String splitName, String query,
int nFolds, int nMinPerClass, Integer nSeed, int nRuns) {
Random r = new Random(nSeed != null ? nSeed
: System.currentTimeMillis());
SortedSet labels = new TreeSet();
InstanceData instances = kernelUtil.loadInstances(query);
this.getClassifierEvaluationDao().deleteCrossValidationFoldByName(
corpusName, splitName);
for (int run = 1; run <= nRuns; run++) {
generateFolds(labels, instances, corpusName, splitName, run, query,
nFolds, nMinPerClass, r);
}
}
public ClassifierEvaluationDao getClassifierEvaluationDao() {
return classifierEvaluationDao;
}
public KernelUtil getKernelUtil() {
return kernelUtil;
}
/**
* insert the folds into the database
*
* @param folds
* @param corpusName
* @param run
*/
private void insertFolds(List> folds, String corpusName,
String splitName, String label, int run) {
// iterate over fold numbers
for (int foldNum = 1; foldNum <= folds.size(); foldNum++) {
Set instanceIds = new HashSet();
// iterate over instances in each fold
for (int trainFoldNum = 1; trainFoldNum <= folds.size(); trainFoldNum++) {
// add the instance, set the train flag
for (long instanceId : folds.get(trainFoldNum - 1))
instanceIds.add(new CrossValidationFoldInstance(instanceId,
trainFoldNum != foldNum));
}
classifierEvaluationDao.saveFold(new CrossValidationFold(
corpusName, splitName, label, run, foldNum, instanceIds));
// insert test set
// classifierEvaluationDao.saveFold(new CrossValidationFold(name,
// label, run, foldNum, false, folds.get(foldNum - 1)));
// insert training set
// Set trainInstances = new TreeSet();
// for (int trainFoldNum = 1; trainFoldNum <= folds.size();
// trainFoldNum++) {
// if (trainFoldNum != foldNum)
// trainInstances.addAll(folds.get(trainFoldNum - 1));
// }
// classifierEvaluationDao.saveFold(new CrossValidationFold(name,
// label, run, foldNum, true, trainInstances));
}
}
public void setClassifierEvaluationDao(
ClassifierEvaluationDao classifierEvaluationDao) {
this.classifierEvaluationDao = classifierEvaluationDao;
}
public void setKernelUtil(KernelUtil kernelUtil) {
this.kernelUtil = kernelUtil;
}
@Override
public SortedMap>>>> generateRuns(
SortedMap>>>> labelToInstanceMap,
int nFolds, int nMinPerClass, Integer nSeed, int nRuns) {
// allocate map to return
SortedMap>>>> labelToInstanceFoldMap = new TreeMap>>>>();
// initialize random seed
Random r = new Random(nSeed != null ? nSeed
: System.currentTimeMillis());
// iterate over labels
for (Map.Entry>>>> labelRun : labelToInstanceMap
.entrySet()) {
String label = labelRun.getKey();
// extract the instance id - class map
SortedMap instanceClassMap = labelRun.getValue()
.get(0).get(0).get(true);
// allocate the run to fold map
SortedMap>>> runMap = new TreeMap>>>();
labelToInstanceFoldMap.put(label, runMap);
// iterate over runs
for (int run = 1; run <= nRuns; run++) {
// generate folds for run
List> folds = createFolds(nFolds, nMinPerClass, r,
instanceClassMap);
SortedMap>> foldMap = new TreeMap>>();
// add the fold map to the run map
runMap.put(run, foldMap);
// iterate over folds
for (int trainFoldNum = 1; trainFoldNum <= folds.size(); trainFoldNum++) {
// add train/test sets for the fold
SortedMap> trainTestMap = new TreeMap>();
foldMap.put(trainFoldNum, trainTestMap);
trainTestMap.put(true, new TreeMap());
trainTestMap.put(false, new TreeMap());
// populate the train/test sets
Set testIds = folds.get(trainFoldNum - 1);
// iterate over all instances
for (Map.Entry instanceClass : instanceClassMap
.entrySet()) {
long instanceId = instanceClass.getKey();
String clazz = instanceClass.getValue();
// add the instance to the test set if it is in testIds,
// else to the train set
trainTestMap.get(!testIds.contains(instanceId)).put(
instanceId, clazz);
}
}
}
}
return labelToInstanceFoldMap;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy