All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer Maven / Gradle / Ivy

Go to download

Apache Ignite® is a Distributed Database For High-Performance Computing With In-Memory Speed.

There is a newer version: 2.15.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.ignite.ml.svm;

import java.util.concurrent.ThreadLocalRandom;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.structures.LabeledVector;
import org.jetbrains.annotations.NotNull;

/**
 * Base class for a soft-margin SVM linear classification trainer based on the communication-efficient distributed dual
 * coordinate ascent algorithm (CoCoA) with hinge-loss function. 

This trainer takes input as Labeled Dataset with -1 * and +1 labels for two classes and makes binary classification.

The paper about this algorithm could be found * here https://arxiv.org/abs/1409.1458. */ public class SVMLinearBinaryClassificationTrainer implements SingleLabelDatasetTrainer { /** Amount of outer SDCA algorithm iterations. */ private int amountOfIterations = 200; /** Amount of local SDCA algorithm iterations. */ private int amountOfLocIterations = 100; /** Regularization parameter. */ private double lambda = 0.4; /** * Trains model based on the specified data. * * @param datasetBuilder Dataset builder. * @param featureExtractor Feature extractor. * @param lbExtractor Label extractor. * @return Model. */ @Override public SVMLinearBinaryClassificationModel fit(DatasetBuilder datasetBuilder, IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor) { assert datasetBuilder != null; PartitionDataBuilder> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>( featureExtractor, lbExtractor ); Vector weights; try(Dataset> dataset = datasetBuilder.build( (upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { final int cols = dataset.compute(data -> data.colSize(), (a, b) -> a == null ? b : a); final int weightVectorSizeWithIntercept = cols + 1; weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept); for (int i = 0; i < this.getAmountOfIterations(); i++) { Vector deltaWeights = calculateUpdates(weights, dataset); weights = weights.plus(deltaWeights); // creates new vector } } catch (Exception e) { throw new RuntimeException(e); } return new SVMLinearBinaryClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0)); } /** */ @NotNull private Vector initializeWeightsWithZeros(int vectorSize) { return new DenseLocalOnHeapVector(vectorSize); } /** */ private Vector calculateUpdates(Vector weights, Dataset> dataset) { return dataset.compute(data -> { Vector copiedWeights = weights.copy(); Vector deltaWeights = initializeWeightsWithZeros(weights.size()); final int amountOfObservation = data.rowSize(); Vector tmpAlphas = initializeWeightsWithZeros(amountOfObservation); Vector deltaAlphas = initializeWeightsWithZeros(amountOfObservation); for (int i = 0; i < this.getAmountOfLocIterations(); i++) { int randomIdx = ThreadLocalRandom.current().nextInt(amountOfObservation); Deltas deltas = getDeltas(data, copiedWeights, amountOfObservation, tmpAlphas, randomIdx); copiedWeights = copiedWeights.plus(deltas.deltaWeights); // creates new vector deltaWeights = deltaWeights.plus(deltas.deltaWeights); // creates new vector tmpAlphas.set(randomIdx, tmpAlphas.get(randomIdx) + deltas.deltaAlpha); deltaAlphas.set(randomIdx, deltaAlphas.get(randomIdx) + deltas.deltaAlpha); } return deltaWeights; }, (a, b) -> a == null ? b : a.plus(b)); } /** */ private Deltas getDeltas(LabeledDataset data, Vector copiedWeights, int amountOfObservation, Vector tmpAlphas, int randomIdx) { LabeledVector row = (LabeledVector)data.getRow(randomIdx); Double lb = (Double)row.label(); Vector v = makeVectorWithInterceptElement(row); double alpha = tmpAlphas.get(randomIdx); return maximize(lb, v, alpha, copiedWeights, amountOfObservation); } /** */ private Vector makeVectorWithInterceptElement(LabeledVector row) { Vector vec = row.features().like(row.features().size() + 1); vec.set(0, 1); // set intercept element for (int j = 0; j < row.features().size(); j++) vec.set(j + 1, row.features().get(j)); return vec; } /** */ private Deltas maximize(double lb, Vector v, double alpha, Vector weights, int amountOfObservation) { double gradient = calcGradient(lb, v, weights, amountOfObservation); double prjGrad = calculateProjectionGradient(alpha, gradient); return calcDeltas(lb, v, alpha, prjGrad, weights.size(), amountOfObservation); } /** */ private Deltas calcDeltas(double lb, Vector v, double alpha, double gradient, int vectorSize, int amountOfObservation) { if (gradient != 0.0) { double qii = v.dot(v); double newAlpha = calcNewAlpha(alpha, gradient, qii); Vector deltaWeights = v.times(lb * (newAlpha - alpha) / (this.lambda() * amountOfObservation)); return new Deltas(newAlpha - alpha, deltaWeights); } else return new Deltas(0.0, initializeWeightsWithZeros(vectorSize)); } /** */ private double calcNewAlpha(double alpha, double gradient, double qii) { if (qii != 0.0) return Math.min(Math.max(alpha - (gradient / qii), 0.0), 1.0); else return 1.0; } /** */ private double calcGradient(double lb, Vector v, Vector weights, int amountOfObservation) { double dotProduct = v.dot(weights); return (lb * dotProduct - 1.0) * (this.lambda() * amountOfObservation); } /** */ private double calculateProjectionGradient(double alpha, double gradient) { if (alpha <= 0.0) return Math.min(gradient, 0.0); else if (alpha >= 1.0) return Math.max(gradient, 0.0); else return gradient; } /** * Set up the regularization parameter. * @param lambda The regularization parameter. Should be more than 0.0. * @return Trainer with new lambda parameter value. */ public SVMLinearBinaryClassificationTrainer withLambda(double lambda) { assert lambda > 0.0; this.lambda = lambda; return this; } /** * Gets the regularization lambda. * @return The parameter value. */ public double lambda() { return lambda; } /** * Gets the amount of outer iterations of SCDA algorithm. * @return The parameter value. */ public int getAmountOfIterations() { return amountOfIterations; } /** * Set up the amount of outer iterations of SCDA algorithm. * @param amountOfIterations The parameter value. * @return Trainer with new amountOfIterations parameter value. */ public SVMLinearBinaryClassificationTrainer withAmountOfIterations(int amountOfIterations) { this.amountOfIterations = amountOfIterations; return this; } /** * Gets the amount of local iterations of SCDA algorithm. * @return The parameter value. */ public int getAmountOfLocIterations() { return amountOfLocIterations; } /** * Set up the amount of local iterations of SCDA algorithm. * @param amountOfLocIterations The parameter value. * @return Trainer with new amountOfLocIterations parameter value. */ public SVMLinearBinaryClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) { this.amountOfLocIterations = amountOfLocIterations; return this; } } /** This is a helper class to handle pair results which are returned from the calculation method. */ class Deltas { /** */ public double deltaAlpha; /** */ public Vector deltaWeights; /** */ public Deltas(double deltaAlpha, Vector deltaWeights) { this.deltaAlpha = deltaAlpha; this.deltaWeights = deltaWeights; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy