com.davidbracewell.apollo.ml.TrainTestSet Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of apollo Show documentation
Show all versions of apollo Show documentation
A machine learning library for Java.
The newest version!
package com.davidbracewell.apollo.ml;
import com.davidbracewell.apollo.ml.data.Dataset;
import com.davidbracewell.apollo.ml.preprocess.PreprocessorList;
import lombok.NonNull;
import java.io.Serializable;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.Spliterator;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
/**
* Encapsulates a set of train/test splits.
*
* @param the example type parameter
* @author David B. Bracewell
*/
public class TrainTestSet implements Iterable>, Serializable {
private static final long serialVersionUID = 1L;
private final Set> splits = new HashSet<>();
/**
* Adds a split to the set.
*
* @param trainTestSplit the train test split
*/
public void add(@NonNull TrainTestSplit trainTestSplit) {
splits.add(trainTestSplit);
}
/**
* Processes each split as a train/test pair
*
* @param consumer the consumer to run over the train & test splits
*/
public void forEach(@NonNull BiConsumer, Dataset> consumer) {
forEach(tTrainTest -> consumer.accept(tTrainTest.getTrain(), tTrainTest.getTest()));
}
/**
* Evaluates the set using an Evaluation produced by the given supplier. The process begins by resetting the
* learner and the trains the learner using the training portion of the split. Once the model is built it is
* evaluated on the testing portion of the split.
*
* @param the model type parameter
* @param the evaluation type parameter
* @param learner the learner to use for training
* @param supplier supplies an evaluation metric.
* @return the result of evaluation
*/
public > R evaluate(@NonNull Learner learner, @NonNull Supplier supplier) {
R eval = supplier.get();
forEach(tt -> eval.merge(tt.evaluate(learner, supplier)));
return eval;
}
@Override
public Spliterator> spliterator() {
return splits.spliterator();
}
@Override
public Iterator> iterator() {
return splits.iterator();
}
/**
* Preprocess each of the training splits using the supplier of {@link PreprocessorList}
*
* @param supplier the supplier to produce a {@link PreprocessorList}
* @return this train test set
*/
public TrainTestSet preprocess(@NonNull Supplier> supplier) {
forEach((train, test) -> train.preprocess(supplier.get()));
return this;
}
}// END OF TrainTestSet
© 2015 - 2025 Weber Informatics LLC | Privacy Policy