All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.davidbracewell.apollo.ml.TrainTestSet Maven / Gradle / Ivy

The newest version!
package com.davidbracewell.apollo.ml;

import com.davidbracewell.apollo.ml.data.Dataset;
import com.davidbracewell.apollo.ml.preprocess.PreprocessorList;
import lombok.NonNull;

import java.io.Serializable;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.Spliterator;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

/**
 * 

Encapsulates a set of train/test splits.

* * @param the example type parameter * @author David B. Bracewell */ public class TrainTestSet implements Iterable>, Serializable { private static final long serialVersionUID = 1L; private final Set> splits = new HashSet<>(); /** * Adds a split to the set. * * @param trainTestSplit the train test split */ public void add(@NonNull TrainTestSplit trainTestSplit) { splits.add(trainTestSplit); } /** * Processes each split as a train/test pair * * @param consumer the consumer to run over the train & test splits */ public void forEach(@NonNull BiConsumer, Dataset> consumer) { forEach(tTrainTest -> consumer.accept(tTrainTest.getTrain(), tTrainTest.getTest())); } /** *

Evaluates the set using an Evaluation produced by the given supplier. The process begins by resetting the * learner and the trains the learner using the training portion of the split. Once the model is built it is * evaluated on the testing portion of the split.

* * @param the model type parameter * @param the evaluation type parameter * @param learner the learner to use for training * @param supplier supplies an evaluation metric. * @return the result of evaluation */ public > R evaluate(@NonNull Learner learner, @NonNull Supplier supplier) { R eval = supplier.get(); forEach(tt -> eval.merge(tt.evaluate(learner, supplier))); return eval; } @Override public Spliterator> spliterator() { return splits.spliterator(); } @Override public Iterator> iterator() { return splits.iterator(); } /** * Preprocess each of the training splits using the supplier of {@link PreprocessorList} * * @param supplier the supplier to produce a {@link PreprocessorList} * @return this train test set */ public TrainTestSet preprocess(@NonNull Supplier> supplier) { forEach((train, test) -> train.preprocess(supplier.get())); return this; } }// END OF TrainTestSet




© 2015 - 2025 Weber Informatics LLC | Privacy Policy