org.apache.ctakes.ytex.kernel.FoldGeneratorImpl Maven / Gradle / Ivy
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
package org.apache.ctakes.ytex.kernel;
import org.apache.commons.cli.*;
import org.apache.ctakes.ytex.kernel.dao.ClassifierEvaluationDao;
import org.apache.ctakes.ytex.kernel.model.CrossValidationFold;
import org.apache.ctakes.ytex.kernel.model.CrossValidationFoldInstance;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import java.util.*;
* utility generates cv fold splits, stores in db. Takes as a command line
* parameter -prop [property file]. Also reads properties from
* Required properties:
* - org.apache.ctakes.ytex.corpusName
- instanceClassQuery
* Optional properties:
* - minPerClass default 1 minimum number of instances per class/fold. if not
* enough instances of a specific class, the instance will be repeated across
* folds. E.g. if you have only one example of one class, and 2 folds, that one
* example will be duplicated in both folds.
- rand random number seed, defaults to current time millis
- org.apache.ctakes.ytex.splitName default null cv_fold.split_name
- folds default 2
- runs default 5
* @author vijay
public class FoldGeneratorImpl implements FoldGenerator {
private static final Logger LOGGER = LoggerFactory.getLogger( "FoldGeneratorImpl" );
* iterate through the labels, split instances into folds
* @param mapClassToInstanceId
* @param nFolds
* @param nMinPerClass
* @param r
* @return list with nFolds sets of instance ids corresponding to the folds
private static List> createFolds(
Map> mapClassToInstanceId, int nFolds,
int nMinPerClass, Random r) {
List> folds = new ArrayList>(nFolds);
Map>> mapLabelFolds = new HashMap>>();
for (Map.Entry> classToInstanceId : mapClassToInstanceId
.entrySet()) {
List instanceIds = classToInstanceId.getValue();
Collections.shuffle(instanceIds, r);
List> classFolds = new ArrayList>(nFolds);
int blockSize = instanceIds.size() / nFolds;
for (int i = 0; i < nFolds; i++) {
Set foldInstanceIds = new HashSet(blockSize);
if (instanceIds.size() <= nMinPerClass) {
// we don't have minPerClass for the given class
// just add all of them to each fold
} else if (blockSize < nMinPerClass) {
// too few of the given class - just randomly select
// nMinPerClass
double fraction = (double) nMinPerClass
/ (double) instanceIds.size();
// iterate through the list, start somewhere in the middle
int instanceIdIndex = (int) (r.nextDouble() * instanceIds
while (foldInstanceIds.size() < nMinPerClass) {
// go back to beginning of list if we hit the end
if (instanceIdIndex >= instanceIds.size()) {
instanceIdIndex = 0;
// randomly select this line
if (r.nextDouble() <= fraction) {
long instanceId = instanceIds.get(instanceIdIndex);
// go to next line
} else {
int nStart = i * blockSize;
int nEnd = (i == nFolds - 1) ? instanceIds.size() : nStart
+ blockSize;
for (int instanceIdIndex = nStart; instanceIdIndex < nEnd; instanceIdIndex++) {
mapLabelFolds.put(classToInstanceId.getKey(), classFolds);
for (int i = 0; i < nFolds; i++) {
Set foldInstanceIds = new HashSet();
for (List> labelFold : mapLabelFolds.values()) {
return folds;
public static void main(String args[]) throws ParseException, IOException {
Options options = new Options();
"property file with query to retrieve instance id - label - class triples")
// OptionGroup group = new OptionGroup();
// group
// .addOption(OptionBuilder
// .withArgName("query")
// .hasArg()
// .withDescription(
// "query to retrieve instance id - label - class triples")
// .create("query"));
// group
// .addOption(OptionBuilder
// .withArgName("prop")
// .hasArg()
// .withDescription(
// "property file with query to retrieve instance id - label - class triples")
// .create("prop"));
// group.isRequired();
// options.addOptionGroup(group);
// options.addOption(OptionBuilder.withArgName("name").hasArg()
// .isRequired().withDescription("name. required").create("name"));
// options.addOption(OptionBuilder.withArgName("runs").hasArg()
// .withDescription("number of runs, default 1").create("runs"));
// options.addOption(OptionBuilder.withArgName("folds").hasArg()
// .withDescription("number of folds, default 4").create("folds"));
// options.addOption(OptionBuilder.withArgName("minPerClass").hasArg()
// .withDescription("minimum instances per class, default 1")
// .create("minPerClass"));
// options.addOption(OptionBuilder.withArgName("rand").hasArg()
// .withDescription(
// "random number seed; default current time in millis")
// .create("rand"));
try {
if (args.length == 0)
else {
CommandLineParser parser = new GnuParser();
CommandLine line = parser.parse(options, args);
String propFile = line.getOptionValue("prop");
Properties props = FileUtil.loadProperties(propFile, true);
// Integer rand = line.hasOption("rand") ? Integer.parseInt(line
// .getOptionValue("rand")) : null;
// int runs = Integer.parseInt(line.getOptionValue("runs",
// "1"));
// int minPerClass = Integer.parseInt(line.getOptionValue(
// "minPerClass", "1"));
// int folds = Integer.parseInt(line.getOptionValue("folds",
// "4"));
String corpusName = props.getProperty("org.apache.ctakes.ytex.corpusName");
String splitName = props.getProperty("org.apache.ctakes.ytex.splitName");
String query = props.getProperty("instanceClassQuery");
int folds = Integer.parseInt(props.getProperty("folds", "2"));
int runs = Integer.parseInt(props.getProperty("runs", "5"));
int minPerClass = Integer.parseInt(props.getProperty(
"minPerClass", "1"));
Integer rand = props.containsKey("rand") ? Integer
.parseInt(props.getProperty("rand")) : null;
boolean argsOk = true;
if (corpusName == null) {
LOGGER.error("missing parameter: org.apache.ctakes.ytex.corpusName");
argsOk = false;
if (query == null) {
LOGGER.error("missing parameter: instanceClassQuery");
argsOk = false;
if (!argsOk) {
} else {
.generateRuns(corpusName, splitName, query, folds,
minPerClass, rand, runs);
} catch (ParseException pe) {
private static void printHelp(Options options) {
HelpFormatter formatter = new HelpFormatter();
"java org.apache.ctakes.ytex.kernel.FoldGeneratorImpl splits training data into mxn training/test sets for mxn-fold cross validation",
ClassifierEvaluationDao classifierEvaluationDao;
KernelUtil kernelUtil;
* generate folds for a run
* @param labels
* @param instances
* @param corpusName
* @param splitName
* @param run
* @param query
* @param nFolds
* @param nMinPerClass
* @param r
public void generateFolds(Set labels, InstanceData instances,
String corpusName, String splitName, int run, String query,
int nFolds, int nMinPerClass, Random r) {
for (String label : instances.getLabelToInstanceMap().keySet()) {
// there should not be any runs/folds/train test split - just unpeel
// until we get to the instance - class map
SortedMap>>> runMap = instances
SortedMap>> foldMap = runMap
SortedMap> trainMap = foldMap
SortedMap mapInstanceIdToClass = trainMap.values()
List> folds = createFolds(nFolds, nMinPerClass, r,
// insert the folds
insertFolds(folds, corpusName, splitName, label, run);
* inver the map of instance id to class, call createFolds
* @param nFolds
* @param nMinPerClass
* @param r
* @param mapInstanceIdToClass
* @return
private List> createFolds(int nFolds, int nMinPerClass, Random r,
SortedMap mapInstanceIdToClass) {
// invert the mapInstanceIdToClass
Map> mapClassToInstanceId = new TreeMap>();
for (Map.Entry instance : mapInstanceIdToClass.entrySet()) {
String className = instance.getValue();
long instanceId = instance.getKey();
List classInstanceIds = mapClassToInstanceId.get(className);
if (classInstanceIds == null) {
classInstanceIds = new ArrayList();
mapClassToInstanceId.put(className, classInstanceIds);
// stratified split into folds
List> folds = createFolds(mapClassToInstanceId, nFolds,
nMinPerClass, r);
return folds;
* (non-Javadoc)
* @see org.apache.ctakes.ytex.kernel.FoldGenerator#generateRuns(java.lang.String,
* java.lang.String, int, int, java.lang.Integer, int)
public void generateRuns(String corpusName, String splitName, String query,
int nFolds, int nMinPerClass, Integer nSeed, int nRuns) {
Random r = new Random(nSeed != null ? nSeed
: System.currentTimeMillis());
SortedSet labels = new TreeSet();
InstanceData instances = kernelUtil.loadInstances(query);
corpusName, splitName);
for (int run = 1; run <= nRuns; run++) {
generateFolds(labels, instances, corpusName, splitName, run, query,
nFolds, nMinPerClass, r);
public ClassifierEvaluationDao getClassifierEvaluationDao() {
return classifierEvaluationDao;
public KernelUtil getKernelUtil() {
return kernelUtil;
* insert the folds into the database
* @param folds
* @param corpusName
* @param run
private void insertFolds(List> folds, String corpusName,
String splitName, String label, int run) {
// iterate over fold numbers
for (int foldNum = 1; foldNum <= folds.size(); foldNum++) {
Set instanceIds = new HashSet();
// iterate over instances in each fold
for (int trainFoldNum = 1; trainFoldNum <= folds.size(); trainFoldNum++) {
// add the instance, set the train flag
for (long instanceId : folds.get(trainFoldNum - 1))
instanceIds.add(new CrossValidationFoldInstance(instanceId,
trainFoldNum != foldNum));
classifierEvaluationDao.saveFold(new CrossValidationFold(
corpusName, splitName, label, run, foldNum, instanceIds));
// insert test set
// classifierEvaluationDao.saveFold(new CrossValidationFold(name,
// label, run, foldNum, false, folds.get(foldNum - 1)));
// insert training set
// Set trainInstances = new TreeSet();
// for (int trainFoldNum = 1; trainFoldNum <= folds.size();
// trainFoldNum++) {
// if (trainFoldNum != foldNum)
// trainInstances.addAll(folds.get(trainFoldNum - 1));
// }
// classifierEvaluationDao.saveFold(new CrossValidationFold(name,
// label, run, foldNum, true, trainInstances));
public void setClassifierEvaluationDao(
ClassifierEvaluationDao classifierEvaluationDao) {
this.classifierEvaluationDao = classifierEvaluationDao;
public void setKernelUtil(KernelUtil kernelUtil) {
this.kernelUtil = kernelUtil;
public SortedMap>>>> generateRuns(
SortedMap>>>> labelToInstanceMap,
int nFolds, int nMinPerClass, Integer nSeed, int nRuns) {
// allocate map to return
SortedMap>>>> labelToInstanceFoldMap = new TreeMap>>>>();
// initialize random seed
Random r = new Random(nSeed != null ? nSeed
: System.currentTimeMillis());
// iterate over labels
for (Map.Entry>>>> labelRun : labelToInstanceMap
.entrySet()) {
String label = labelRun.getKey();
// extract the instance id - class map
SortedMap instanceClassMap = labelRun.getValue()
// allocate the run to fold map
SortedMap>>> runMap = new TreeMap>>>();
labelToInstanceFoldMap.put(label, runMap);
// iterate over runs
for (int run = 1; run <= nRuns; run++) {
// generate folds for run
List> folds = createFolds(nFolds, nMinPerClass, r,
SortedMap>> foldMap = new TreeMap>>();
// add the fold map to the run map
runMap.put(run, foldMap);
// iterate over folds
for (int trainFoldNum = 1; trainFoldNum <= folds.size(); trainFoldNum++) {
// add train/test sets for the fold
SortedMap> trainTestMap = new TreeMap>();
foldMap.put(trainFoldNum, trainTestMap);
trainTestMap.put(true, new TreeMap());
trainTestMap.put(false, new TreeMap());
// populate the train/test sets
Set testIds = folds.get(trainFoldNum - 1);
// iterate over all instances
for (Map.Entry instanceClass : instanceClassMap
.entrySet()) {
long instanceId = instanceClass.getKey();
String clazz = instanceClass.getValue();
// add the instance to the test set if it is in testIds,
// else to the train set
instanceId, clazz);
return labelToInstanceFoldMap;
© 2015 - 2025 Weber Informatics LLC | Privacy Policy