com.davidbracewell.apollo.ml.sequence.WindowedLearner Maven / Gradle / Ivy
package com.davidbracewell.apollo.ml.sequence;
import com.davidbracewell.apollo.ml.Instance;
import com.davidbracewell.apollo.ml.classification.ClassifierLearner;
import com.davidbracewell.apollo.ml.data.Dataset;
import com.davidbracewell.io.QuietIO;
import lombok.NonNull;
import java.util.Map;
/**
* Greedy learner that wraps a {@link ClassifierLearner}.
*
* @author David B. Bracewell
*/
public class WindowedLearner extends SequenceLabelerLearner {
private static final long serialVersionUID = 3783930856969307606L;
private final ClassifierLearner learner;
/**
* Instantiates a new Windowed learner.
*
* @param learner the learner
*/
public WindowedLearner(ClassifierLearner learner) {
this.learner = learner;
}
@Override
protected SequenceLabeler trainImpl(Dataset dataset) {
WindowedLabeler wl = new WindowedLabeler(dataset.getLabelEncoder(),
dataset.getFeatureEncoder(),
dataset.getPreprocessors(),
getTransitionFeatures(),
getValidator());
Dataset nd = Dataset.classification()
.source(dataset.stream()
.flatMap(sequence -> getTransitionFeatures().toInstances(sequence)
.stream()));
QuietIO.closeQuietly(dataset);
wl.classifier = learner.train(nd);
wl.encoderPair = wl.classifier.getEncoderPair();
return wl;
}
@Override
public void reset() {
learner.reset();
}
@Override
public Map getParameters() {
return learner.getParameters();
}
@Override
public void setParameters(@NonNull Map parameters) {
learner.setParameters(parameters);
}
@Override
public void setParameter(String name, Object value) {
learner.setParameter(name, value);
}
@Override
public Object getParameter(String name) {
return learner.getParameter(name);
}
}// END OF WindowedLearner
© 2015 - 2025 Weber Informatics LLC | Privacy Policy