All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.davidbracewell.apollo.ml.sequence.WindowedLearner Maven / Gradle / Ivy

package com.davidbracewell.apollo.ml.sequence;

import com.davidbracewell.apollo.ml.Instance;
import com.davidbracewell.apollo.ml.classification.ClassifierLearner;
import com.davidbracewell.apollo.ml.data.Dataset;
import com.davidbracewell.io.QuietIO;
import lombok.NonNull;

import java.util.Map;

/**
 * 

Greedy learner that wraps a {@link ClassifierLearner}.

* * @author David B. Bracewell */ public class WindowedLearner extends SequenceLabelerLearner { private static final long serialVersionUID = 3783930856969307606L; private final ClassifierLearner learner; /** * Instantiates a new Windowed learner. * * @param learner the learner */ public WindowedLearner(ClassifierLearner learner) { this.learner = learner; } @Override protected SequenceLabeler trainImpl(Dataset dataset) { WindowedLabeler wl = new WindowedLabeler(dataset.getLabelEncoder(), dataset.getFeatureEncoder(), dataset.getPreprocessors(), getTransitionFeatures(), getValidator()); Dataset nd = Dataset.classification() .source(dataset.stream() .flatMap(sequence -> getTransitionFeatures().toInstances(sequence) .stream())); QuietIO.closeQuietly(dataset); wl.classifier = learner.train(nd); wl.encoderPair = wl.classifier.getEncoderPair(); return wl; } @Override public void reset() { learner.reset(); } @Override public Map getParameters() { return learner.getParameters(); } @Override public void setParameters(@NonNull Map parameters) { learner.setParameters(parameters); } @Override public void setParameter(String name, Object value) { learner.setParameter(name, value); } @Override public Object getParameter(String name) { return learner.getParameter(name); } }// END OF WindowedLearner




© 2015 - 2025 Weber Informatics LLC | Privacy Policy