All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.davidbracewell.apollo.ml.classification.BaggingLearner Maven / Gradle / Ivy

package com.davidbracewell.apollo.ml.classification;

import com.davidbracewell.apollo.ml.Instance;
import com.davidbracewell.apollo.ml.data.Dataset;
import com.davidbracewell.function.SerializableSupplier;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;

import java.util.ArrayList;

/**
 * 

Learner which takes random samples (with replacement) of the data to build a number of weaker models that each * vote in one ensemble model.

* * @author David B. Bracewell */ public class BaggingLearner extends ClassifierLearner { private static final long serialVersionUID = 1L; @Getter private SerializableSupplier learnerSupplier; @Getter @Setter private int numberOfBags; @Getter @Setter private int bagSize; /** * Instantiates a new Bagging learner. */ public BaggingLearner() { this.learnerSupplier = LibLinearLearner::new; this.numberOfBags = 10; this.bagSize = -1; } /** * Instantiates a new Bagging learner. * * @param learnerSupplier Supplier for the weak learner * @param numberOfBags the number of bags, or weak learners, to generate * @param bagSize the size of each random sample */ public BaggingLearner(@NonNull SerializableSupplier learnerSupplier, int numberOfBags, int bagSize) { this.learnerSupplier = learnerSupplier; this.numberOfBags = numberOfBags; this.bagSize = bagSize; } @Override public void reset() { } /** * Sets the supplier to use to generate weak learners * * @param learnerSupplier the learner supplier */ public void setLearnerSupplier(@NonNull SerializableSupplier learnerSupplier) { this.learnerSupplier = learnerSupplier; } @Override protected Classifier trainImpl(Dataset dataset) { Ensemble model = new Ensemble(dataset.getEncoderPair(), dataset.getPreprocessors()); dataset = dataset.shuffle(); model.models = new ArrayList<>(numberOfBags); final int targetBagSize = (bagSize <= 0) ? dataset.size() : bagSize; for (int i = 0; i < numberOfBags; i++) { model.models.add(learnerSupplier.get().train(dataset.sample(true, targetBagSize))); } return model; } }// END OF BaggingLearner




© 2015 - 2025 Weber Informatics LLC | Privacy Policy