com.arosbio.data.splitting.FoldedSplitter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of confai Show documentation
Show all versions of confai Show documentation
Conformal AI package, including all data IO, transformations, machine learning models and predictor classes. Without inclusion of chemistry-dependent code.
/*
* Copyright (C) Aros Bio AB.
*
* CPSign is an Open Source Software that is dual licensed to allow you to choose a license that best suits your requirements:
*
* 1) GPLv3 (GNU General Public License Version 3) with Additional Terms, including an attribution clause as well as a limitation to use the software for commercial purposes.
*
* 2) CPSign Proprietary License that allows you to use CPSign for commercial activities, such as in a revenue-generating operation or environment, or integrate CPSign in your proprietary software without worrying about disclosing the source code of your proprietary software, which is required if you choose to use the software under GPLv3 license. See arosbio.com/cpsign/commercial-license for details.
*/
package com.arosbio.data.splitting;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.arosbio.commons.CollectionUtils;
import com.arosbio.commons.GlobalConfig;
import com.arosbio.data.DataRecord;
import com.arosbio.data.DataUtils;
import com.arosbio.data.Dataset;
import com.arosbio.data.Dataset.SubSet;
import com.google.common.collect.Range;
public class FoldedSplitter implements DataSplitter {
private static final Logger LOGGER = LoggerFactory.getLogger(FoldedSplitter.class);
// State
private final Dataset dataClone;
private final boolean stratify, shuffle;
private final int numRepeats, numFolds;
private final long seed;
private Range foundRange;
// Iteration state
/** The folds generated when using the iterating */
private List> iteratorFolds;
private int currentFold = 0;
private int currentRepetition = 0;
/** Only populated in case stratify == true */
private final List> stratifiedRecs;
public static class Builder {
private int numFolds = 10;
private boolean shuffle = true;
private long seed = GlobalConfig.getInstance().getRNGSeed();
private boolean stratify = false;
private int numRepeats = 1;
private boolean findObservedLabelSpace = false;
private String name = "k-fold CV";
public int numFolds(){
return numFolds;
}
public Builder numFolds(int folds){
if (folds < 2){
throw new IllegalArgumentException("Invalid number of folds: " + folds + ", must be >=2");
}
this.numFolds = folds;
return this;
}
public boolean shuffle(){
return shuffle;
}
public Builder shuffle(boolean shuffle){
this.shuffle = shuffle;
return this;
}
public long seed(){
return seed;
}
public Builder seed(long seed){
this.seed = seed;
return this;
}
public boolean stratify(){
return stratify;
}
public Builder stratify(boolean stratify){
this.stratify = stratify;
return this;
}
public int numRepeat(){
return numRepeats;
}
public Builder numRepeat(int num){
if (num < 1){
throw new IllegalArgumentException("Num repeats must be at least 1 (i.e. performed once)");
}
this.numRepeats = num;
return this;
}
public boolean findLabelRange(){
return findObservedLabelSpace;
}
public Builder findLabelRange(boolean findRange){
this.findObservedLabelSpace = findRange;
return this;
}
public String name(){
return name;
}
public Builder name(String name) {
this.name = name;
return this;
}
public FoldedSplitter build(Dataset data){
return new FoldedSplitter(this, data);
}
}
private FoldedSplitter(
final Builder b,
final Dataset data
) {
Objects.requireNonNull(data, "data must not be null");
if (data.getDataset().size() < b.numFolds) {
throw new IllegalArgumentException("Cannot run "+b.name + " with more folds than records!");
}
if (b.numFolds < 2)
throw new IllegalArgumentException("Number of folds must be >= 2");
if (!b.shuffle && b.numRepeats>1)
throw new IllegalArgumentException("Shuffling cannot be false if number of repeated k-fold splits is larger than 1");
this.numFolds = b.numFolds;
this.stratify = b.stratify;
this.shuffle = b.shuffle;
this.numRepeats = b.numRepeats;
this.seed = b.seed;
this.dataClone = data.cloneDataOnly();
if (stratify){
// do the stratification of data once, also verifies input is not of regression type
stratifiedRecs = Collections.unmodifiableList(DataUtils.stratify(dataClone.getDataset()));
} else {
stratifiedRecs = null;
}
if (b.findObservedLabelSpace){
try {
findObservedLabelSpace();
} catch (Exception e){
LOGGER.debug("attempted to find label-space but failed: {}",e.getMessage());
}
}
}
private void findObservedLabelSpace(){
// Find the regression label space once in case we should
try {
foundRange = DataUtils.findLabelRange(dataClone);
LOGGER.debug("found label-range: {}", foundRange);
} catch (Exception e){
LOGGER.debug("failed to find the observed label-range", e);
throw new IllegalArgumentException("could not find the min and max observed values: " + e.getMessage());
}
}
public Range getObservedLabelSpace() {
return foundRange;
}
private List> getFoldsForRep(int forRep) {
long seedForRep = getSeedForRep(forRep);
LOGGER.debug("generating folds for repetition {} using stratify={}, shuffle={}, seed={}",
forRep, stratify, shuffle,seedForRep);
if (stratify) {
// Init the folds
int initSize = dataClone.getDataset().size() / numFolds + 1;
List> folds = new ArrayList<>();
for (int i=0; i(initSize));
}
// split the stratified datasets into the folds
for (List recs : stratifiedRecs) {
List tmp = recs;
// Shuffle if set to do so
if (shuffle){
tmp = new ArrayList<>(tmp); // Need copy before we shuffle
Collections.shuffle(tmp, new Random(seedForRep));
}
// Folds for each strata - note the first ones will be the largest ones (if not evenly divisible)
List> foldStrata = CollectionUtils.getDisjunctSets(tmp, numFolds, true);
if (tmp.size() % numFolds == 0){
// The folds will all have the same size, no need to worry about order
for (int i=0; i sizeSort = CollectionUtils.getSortedIndicesBySize(folds, true);
for (int i=0; i f : folds) {
Collections.shuffle(f, new Random(seedForRep));
}
return folds;
} else {
List recs = new ArrayList<>(dataClone.getDataset());
if (shuffle)
Collections.shuffle(recs, new Random(seedForRep));
return CollectionUtils.getDisjunctSets(recs, numFolds, false);
}
}
@Override
public boolean hasNext() {
// If called for the first time only - init the folds
if (iteratorFolds == null){
iteratorFolds = getFoldsForRep(currentRepetition);
}
// If more folds for the current repetition
if (currentFold < iteratorFolds.size())
return true;
// finished the current rep - start new
currentRepetition ++;
// Check if there are more reps
if (currentRepetition < numRepeats) {
iteratorFolds = getFoldsForRep(currentRepetition);
currentFold = 0; // reset the fold index for new rep
return true;
}
// No more reps
return false;
}
@Override
public DataSplit next() throws NoSuchElementException {
if (! hasNext())
throw new NoSuchElementException("No more folds!");
try {
return get(currentFold, currentRepetition);
} finally{
currentFold++;
}
}
@Override
public DataSplit get(int index) throws NoSuchElementException {
if (index < 0 || index > getMaxSplitIndex()){
throw new NoSuchElementException("Invalid index: " + index + ", it must be in the range [0,"+getMaxSplitIndex()+']');
}
int getRep = index / numFolds;
int getIndex = index % numFolds;
return get(getIndex, getRep);
}
private DataSplit get(int fold, int rep) throws NoSuchElementException {
LOGGER.debug("Generating fold {}/{} for repeat: {}",(fold+1),numFolds, rep);
List> theFolds = null;
if (rep == currentRepetition && iteratorFolds != null){
theFolds = iteratorFolds;
} else {
theFolds = getFoldsForRep(rep);
}
// Copy over the calibration and modeling exclusive data sets
Dataset first = new Dataset()
.withCalibrationExclusiveDataset(dataClone.getCalibrationExclusiveDataset().clone())
.withModelingExclusiveDataset(dataClone.getModelingExclusiveDataset().clone());
List firstDataOnly = new ArrayList<>(dataClone.getDataset().size());
List second = theFolds.get(fold);
for (int i=0; i