com.arosbio.ml.sampling.FoldedSampling Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of confai Show documentation
Show all versions of confai Show documentation
Conformal AI package, including all data IO, transformations, machine learning models and predictor classes. Without inclusion of chemistry-dependent code.
/*
* Copyright (C) Aros Bio AB.
*
* CPSign is an Open Source Software that is dual licensed to allow you to choose a license that best suits your requirements:
*
* 1) GPLv3 (GNU General Public License Version 3) with Additional Terms, including an attribution clause as well as a limitation to use the software for commercial purposes.
*
* 2) CPSign Proprietary License that allows you to use CPSign for commercial activities, such as in a revenue-generating operation or environment, or integrate CPSign in your proprietary software without worrying about disclosing the source code of your proprietary software, which is required if you choose to use the software under GPLv3 license. See arosbio.com/cpsign/commercial-license for details.
*/
package com.arosbio.ml.sampling;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.arosbio.commons.CollectionUtils;
import com.arosbio.commons.GlobalConfig;
import com.arosbio.commons.TypeUtils;
import com.arosbio.commons.config.IntegerConfig;
import com.arosbio.data.Dataset;
import com.arosbio.data.splitting.FoldedSplitter;
import com.arosbio.ml.io.impl.PropertyNameSettings;
import com.arosbio.ml.sampling.impl.TrainSplitWrapper;
import com.google.common.collect.Range;
public class FoldedSampling implements MultiSampling {
public static final int ID = 3;
public static final String NAME = "Folded";
public static final int DEFAULT_NUM_SAMPLES = 10;
public static final String[] CONFIG_NUM_SAMPLES_PARAM_NAMES = new String[] {"folds", "numSamples"};
private static final String[] CONFIG_NUM_REPEATS_PARAM_NAMES = new String[] {"numRepeats", "nRep"};
private int numFolds;
private int numRepeat = 1;
public FoldedSampling() {
this(DEFAULT_NUM_SAMPLES);
}
public FoldedSampling(int numFolds) {
super();
withNumSamples(numFolds);
}
public int getID() {
return ID;
}
public String getName(){
return NAME;
}
public FoldedSampling clone(){
return new FoldedSampling(numFolds);
}
public FoldedSampling withNumSamples(int folds) {
if (folds <= 1)
throw new IllegalArgumentException("Number of samplings must be over 1 when using folded sampling");
this.numFolds = folds;
return this;
}
@Override
public int getNumSamples() {
return numFolds;
}
public FoldedSampling withNumRepeats(int num){
if (num < 1)
throw new IllegalArgumentException("number of repeats must be at least 1 (i.e. performed once)");
this.numRepeat = num;
return this;
}
public int getNumRepeats() {
return numRepeat;
}
@Override
public TrainSplitGenerator getIterator(Dataset dataset)
throws IllegalArgumentException {
return getIterator(dataset, GlobalConfig.getInstance().getRNGSeed());
}
@Override
public TrainSplitGenerator getIterator(Dataset dataset, long seed)
throws IllegalArgumentException {
return new TrainSplitWrapper(new FoldedSplitter.Builder()
.numFolds(numFolds)
.numRepeat(numRepeat)
.seed(seed)
.shuffle(true)
.stratify(false)
.findLabelRange(true)
.build(dataset));
}
@Override
public Map getProperties() {
Map props = new HashMap<>();
props.put(PropertyNameSettings.SAMPLING_STRATEGY_KEY, ID);
props.put(CONFIG_NUM_SAMPLES_PARAM_NAMES[0], numFolds);
props.put(CONFIG_NUM_REPEATS_PARAM_NAMES[0],numRepeat);
return props;
}
@Override
public boolean isFolded() {
return true;
}
@Override
public boolean isStratified() {
return false;
}
@Override
public boolean equals(Object obj){
if (! (obj instanceof FoldedSampling))
return false;
FoldedSampling other = (FoldedSampling) obj;
return this.numFolds == other.numFolds && this.numRepeat == other.numRepeat;
}
public String toString() {
return "Folded sampling with " + numFolds + " splits";
}
@Override
public List getConfigParameters() {
return Arrays.asList(
new IntegerConfig.Builder(Arrays.asList(CONFIG_NUM_SAMPLES_PARAM_NAMES), DEFAULT_NUM_SAMPLES)
.range(Range.atLeast(2))
.description("Number of folds to split the dataset into").build(),
new IntegerConfig.Builder(Arrays.asList(CONFIG_NUM_REPEATS_PARAM_NAMES),1)
.range(Range.atLeast(1))
.description("Number of times the folded sampling should be performed, the default is doing a single k-fold sampling").build()
);
}
@Override
public void setConfigParameters(Map params) throws IllegalStateException, IllegalArgumentException {
params = CollectionUtils.dropNullValues(params);
for (Map.Entry kv : params.entrySet()) {
if (CollectionUtils.containsIgnoreCase(Arrays.asList(CONFIG_NUM_SAMPLES_PARAM_NAMES), kv.getKey())) {
if (!TypeUtils.isInt(kv.getValue())) {
throw new IllegalArgumentException("Parameter " + kv.getKey() + " must be an integer number, got: " + kv.getValue());
}
int nFold = TypeUtils.asInt(kv.getValue());
if (nFold < 2)
throw new IllegalArgumentException("Parameter " + kv.getKey() + " must be >=2");
numFolds = nFold;
}
else if (CollectionUtils.containsIgnoreCase(CONFIG_NUM_REPEATS_PARAM_NAMES, kv.getKey())){
if (!TypeUtils.isInt(kv.getValue())) {
throw new IllegalArgumentException("Parameter " + kv.getKey() + " must be an integer number, got: " + kv.getValue());
}
int nRep = TypeUtils.asInt(kv.getValue());
if (nRep < 1)
throw new IllegalArgumentException("Parameter " + kv.getKey() + " must be >=1");
numRepeat = nRep;
}
}
}
}