All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.ignite.ml.nn.MLPTrainer Maven / Gradle / Ivy

Go to download

Apache Ignite® is a Distributed Database For High-Performance Computing With In-Memory Speed.

There is a newer version: 2.15.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.ignite.ml.nn;

import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.nn.initializers.RandomInitializer;
import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
import org.apache.ignite.ml.trainers.MultiLabelDatasetTrainer;
import org.apache.ignite.ml.util.Utils;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/**
 * Multilayer perceptron trainer based on partition based {@link Dataset}.
 *
 * @param 

Type of model update used in this trainer. */ public class MLPTrainer

implements MultiLabelDatasetTrainer { /** Multilayer perceptron architecture supplier that defines layers and activators. */ private final IgniteFunction, MLPArchitecture> archSupplier; /** Loss function to be minimized during the training. */ private final IgniteFunction loss; /** Update strategy that defines how to update model parameters during the training. */ private final UpdatesStrategy updatesStgy; /** Maximal number of iterations before the training will be stopped. */ private final int maxIterations; /** Batch size (per every partition). */ private final int batchSize; /** Maximal number of local iterations before synchronization. */ private final int locIterations; /** Multilayer perceptron model initializer. */ private final long seed; /** * Constructs a new instance of multilayer perceptron trainer. * * @param arch Multilayer perceptron architecture that defines layers and activators. * @param loss Loss function to be minimized during the training. * @param updatesStgy Update strategy that defines how to update model parameters during the training. * @param maxIterations Maximal number of iterations before the training will be stopped. * @param batchSize Batch size (per every partition). * @param locIterations Maximal number of local iterations before synchronization. * @param seed Random initializer seed. */ public MLPTrainer(MLPArchitecture arch, IgniteFunction loss, UpdatesStrategy updatesStgy, int maxIterations, int batchSize, int locIterations, long seed) { this(dataset -> arch, loss, updatesStgy, maxIterations, batchSize, locIterations, seed); } /** * Constructs a new instance of multilayer perceptron trainer. * * @param archSupplier Multilayer perceptron architecture supplier that defines layers and activators. * @param loss Loss function to be minimized during the training. * @param updatesStgy Update strategy that defines how to update model parameters during the training. * @param maxIterations Maximal number of iterations before the training will be stopped. * @param batchSize Batch size (per every partition). * @param locIterations Maximal number of local iterations before synchronization. * @param seed Random initializer seed. */ public MLPTrainer(IgniteFunction, MLPArchitecture> archSupplier, IgniteFunction loss, UpdatesStrategy updatesStgy, int maxIterations, int batchSize, int locIterations, long seed) { this.archSupplier = archSupplier; this.loss = loss; this.updatesStgy = updatesStgy; this.maxIterations = maxIterations; this.batchSize = batchSize; this.locIterations = locIterations; this.seed = seed; } /** {@inheritDoc} */ public MultilayerPerceptron fit(DatasetBuilder datasetBuilder, IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor) { try (Dataset dataset = datasetBuilder.build( new EmptyContextBuilder<>(), new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor) )) { MLPArchitecture arch = archSupplier.apply(dataset); MultilayerPerceptron mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed)); ParameterUpdateCalculator updater = updatesStgy.getUpdatesCalculator(); for (int i = 0; i < maxIterations; i += locIterations) { MultilayerPerceptron finalMdl = mdl; int finalI = i; List

totUp = dataset.compute( data -> { P update = updater.init(finalMdl, loss); MultilayerPerceptron mlp = Utils.copy(finalMdl); if (data.getFeatures() != null) { List

updates = new ArrayList<>(); for (int locStep = 0; locStep < locIterations; locStep++) { int[] rows = Utils.selectKDistinct( data.getRows(), Math.min(batchSize, data.getRows()), new Random(seed ^ (finalI * locStep)) ); double[] inputsBatch = batch(data.getFeatures(), rows, data.getRows()); double[] groundTruthBatch = batch(data.getLabels(), rows, data.getRows()); Matrix inputs = new DenseLocalOnHeapMatrix(inputsBatch, rows.length, 0); Matrix groundTruth = new DenseLocalOnHeapMatrix(groundTruthBatch, rows.length, 0); update = updater.calculateNewUpdate( mlp, update, locStep, inputs.transpose(), groundTruth.transpose() ); mlp = updater.update(mlp, update); updates.add(update); } List

res = new ArrayList<>(); res.add(updatesStgy.locStepUpdatesReducer().apply(updates)); return res; } return null; }, (a, b) -> { if (a == null) return b; else if (b == null) return a; else { a.addAll(b); return a; } } ); P update = updatesStgy.allUpdatesReducer().apply(totUp); mdl = updater.update(mdl, update); } return mdl; } catch (Exception e) { throw new RuntimeException(e); } } /** * Builds a batch of the data by fetching specified rows. * * @param data All data. * @param rows Rows to be fetched from the data. * @param totalRows Total number of rows in all data. * @return Batch data. */ static double[] batch(double[] data, int[] rows, int totalRows) { int cols = data.length / totalRows; double[] res = new double[cols * rows.length]; for (int i = 0; i < rows.length; i++) for (int j = 0; j < cols; j++) res[j * rows.length + i] = data[j * totalRows + rows[i]]; return res; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy