org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer Maven / Gradle / Ivy
Show all versions of ignite-ml Show documentation
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.regressions.linear;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.nn.Activators;
import org.apache.ignite.ml.nn.MLPTrainer;
import org.apache.ignite.ml.nn.MultilayerPerceptron;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.optimization.LossFunctions;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import java.io.Serializable;
import java.util.Arrays;
/**
* Trainer of the linear regression model based on stochastic gradient descent algorithm.
*/
public class LinearRegressionSGDTrainer implements SingleLabelDatasetTrainer {
/** Update strategy. */
private final UpdatesStrategy super MultilayerPerceptron, P> updatesStgy;
/** Max number of iteration. */
private final int maxIterations;
/** Batch size. */
private final int batchSize;
/** Number of local iterations. */
private final int locIterations;
/** Seed for random generator. */
private final long seed;
/**
* Constructs a new instance of linear regression SGD trainer.
*
* @param updatesStgy Update strategy.
* @param maxIterations Max number of iteration.
* @param batchSize Batch size.
* @param locIterations Number of local iterations.
* @param seed Seed for random generator.
*/
public LinearRegressionSGDTrainer(UpdatesStrategy super MultilayerPerceptron, P> updatesStgy, int maxIterations,
int batchSize, int locIterations, long seed) {
this.updatesStgy = updatesStgy;
this.maxIterations = maxIterations;
this.batchSize = batchSize;
this.locIterations = locIterations;
this.seed = seed;
}
/** {@inheritDoc} */
@Override public LinearRegressionModel fit(DatasetBuilder datasetBuilder,
IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor) {
IgniteFunction, MLPArchitecture> archSupplier = dataset -> {
int cols = dataset.compute(data -> {
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> a == null ? b : a);
MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.LINEAR);
return architecture;
};
MLPTrainer> trainer = new MLPTrainer<>(
archSupplier,
LossFunctions.MSE,
updatesStgy,
maxIterations,
batchSize,
locIterations,
seed
);
IgniteBiFunction lbE = new IgniteBiFunction() {
@Override public double[] apply(K k, V v) {
return new double[]{lbExtractor.apply(k, v)};
}
};
MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, lbE);
double[] p = mlp.parameters().getStorage().data();
return new LinearRegressionModel(new DenseLocalOnHeapVector(
Arrays.copyOf(p, p.length - 1)),
p[p.length - 1]
);
}
}