
org.datacleaner.components.machinelearning.impl.NeuralNetTrainer Maven / Gradle / Ivy
/**
* DataCleaner (community edition)
* Copyright (C) 2014 Free Software Foundation, Inc.
*
* This copyrighted material is made available to anyone wishing to use, modify,
* copy, or redistribute it subject to the terms and conditions of the GNU
* Lesser General Public License, as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
* for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this distribution; if not, write to:
* Free Software Foundation, Inc.
* 51 Franklin Street, Fifth Floor
* Boston, MA 02110-1301 USA
*/
package org.datacleaner.components.machinelearning.impl;
import java.util.List;
import org.datacleaner.components.machinelearning.api.MLClassificationMetadata;
import org.datacleaner.components.machinelearning.api.MLClassificationRecord;
import org.datacleaner.components.machinelearning.api.MLClassificationTrainer;
import org.datacleaner.components.machinelearning.api.MLTrainerCallback;
import org.datacleaner.components.machinelearning.api.MLTrainingOptions;
import org.datacleaner.components.machinelearning.api.MLClassifier;
import org.datacleaner.components.machinelearning.api.MLFeatureModifier;
import smile.classification.NeuralNetwork;
import smile.classification.NeuralNetwork.ActivationFunction;
import smile.classification.NeuralNetwork.ErrorFunction;
public class NeuralNetTrainer implements MLClassificationTrainer {
private final MLTrainingOptions trainingOptions;
private final int epochs;
private final ErrorFunction errorFunction;
private final ActivationFunction activationFunction;
private final int[] hiddenNeuronPerLayer;
private final double learningRate;
private final double momentum;
public NeuralNetTrainer(MLTrainingOptions trainingOptions, int epochs, ErrorFunction errorFunction,
ActivationFunction activationFunction, int[] hiddenNeuronPerLayer, double learningRate, double momentum) {
this.trainingOptions = trainingOptions;
this.epochs = epochs;
this.errorFunction = errorFunction;
this.activationFunction = activationFunction;
this.hiddenNeuronPerLayer = hiddenNeuronPerLayer;
this.learningRate = learningRate;
this.momentum = momentum;
}
@Override
public MLClassifier train(Iterable data, List featureModifiers,
MLTrainerCallback callback) {
final List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy