All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.base.svm.LASVM Maven / Gradle / Ivy

/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.base.svm;

import java.io.Serial;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import smile.math.MathEx;
import smile.math.kernel.MercerKernel;

/**
 * LASVM is an approximate SVM solver that uses online approximation.
 * It reaches accuracies similar to that of a real SVM after performing
 * a single sequential pass through the training examples. Further
 * benefits can be achieved using selective sampling techniques to
 * choose which example should be considered next.
 * LASVM requires considerably less memory than a regular SVM solver.
 * This becomes a considerable speed advantage for large training sets.
 *
 * @param  the data type of model input objects.
 *
 * @author Haifeng Li
 */
public class LASVM implements Serializable {
    @Serial
    private static final long serialVersionUID = 2L;
    private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(LASVM.class);

    /**
     * The default value for K_tt + K_ss - 2 * K_ts if kernel is not positive.
     */
    private static final double TAU = 1E-12;

    /**
     * The kernel function.
     */
    private final MercerKernel kernel;
    /**
     * The soft margin penalty parameter for positive samples.
     */
    private final double Cp;
    /**
     * The soft margin penalty parameter for negative samples.
     */
    private final double Cn;
    /**
     * The tolerance of convergence test.
     */
    private final double tol;
    /**
     * Support vectors.
     */
    private final ArrayList> vectors = new ArrayList<>();
    /**
     * Threshold of decision function.
     */
    private double b = 0.0;

    /**
     * True if minmax() is already called after update.
     */
    private boolean minmaxflag = false;
    /*
     * Most violating pair.
     * argmin gi of m_i < alpha_i
     * argmax gi of alpha_i < M_i
     * where m_i = min{0, y_i * C}
     * and   M_i = max{0, y_i * C}
     */
    /** The most violating pair. */
    private SupportVector svmin = null;
    /** The most violating pair. */
    private SupportVector svmax = null;
    /** The gradient of most violating pair. */
    private double gmin = Double.MAX_VALUE;
    /** The gradient of most violating pair. */
    private double gmax = -Double.MAX_VALUE;

    /**
     * The training samples.
     */
    private T[] x;
    /**
     * The kernel matrix.
     */
    private double[][] K;

    /**
     * Constructor.
     * @param kernel the kernel.
     * @param C the soft margin penalty parameter.
     * @param tol the tolerance of convergence test.
     */
    public LASVM(MercerKernel kernel, double C, double tol) {
        this(kernel, C, C, tol);
    }

    /**
     * Constructor.
     * @param kernel the kernel.
     * @param Cp the soft margin penalty parameter for positive instances.
     * @param Cn the soft margin penalty parameter for negative instances.
     * @param tol the tolerance of convergence test.
     */
    public LASVM(MercerKernel kernel, double Cp, double Cn, double tol) {
        if (Cp < 0) {
            throw new IllegalArgumentException("Invalid C: " + Cp);
        }

        if (Cn < 0) {
            throw new IllegalArgumentException("Invalid C: " + Cn);
        }

        if (tol <= 0) {
            throw new IllegalArgumentException("Invalid tol: " + tol);
        }

        this.kernel = kernel;
        this.Cp = Cp;
        this.Cn = Cn;
        this.tol = tol;
    }

    /**
     * Trains the model.
     * @param x training samples.
     * @param y training labels.
     * @param epochs the number of epochs, usually 1 or 2 is sufficient.
     * @return the model.
     */
    public KernelMachine  fit(T[] x, int[] y, int epochs) {
        this.x = x;
        this.K = new double[x.length][];

        // pick initial support vectors.
        init(x, y);

        // stochastic training
        int phase = Math.min(x.length, 1000);
        for (int epoch = 0, iter = 0; epoch < epochs; epoch++) {
            for (int i : MathEx.permutate(x.length)) {
                process(i, x[i], y[i]);

                do {
                    reprocess(tol); // at least one call to reprocess
                    minmax();
                } while (gmax - gmin > 1000);

                if (++iter % phase == 0) {
                    logger.info("{} iterations, {} support vectors", iter, vectors.size());
                }
            }
        }

        finish();

        int n = vectors.size();
        @SuppressWarnings("unchecked")
        T[] sv = (T[]) java.lang.reflect.Array.newInstance(x.getClass().getComponentType(), n);
        double[] alpha = new double[n];
        for (int i = 0; i < n; i++) {
            SupportVector v = vectors.get(i);
            sv[i] = v.x;
            alpha[i] = v.alpha;
        }

        return new KernelMachine<>(kernel, sv, alpha, b);
    }

    /**
     * Initialize the SVM with some instances as support vectors.
     */
    private void init(T[] x, int[] y) {
        int few = 5;
        int cp = 0, cn = 0;

        for (int i : MathEx.permutate(x.length)) {
            if (y[i] == 1 && cp < few) {
                if (process(i, x[i], y[i])) cp++;
            } else if (y[i] == -1 && cn < few) {
                if (process(i, x[i], y[i])) cn++;
            }

            if (cp >= few && cn >= few) break;
        }
    }

    /**
     * Finds the support vectors with smallest (of I_up) and largest (of I_down) gradients.
     */
    private void minmax() {
        if (minmaxflag) return;

        gmin = Double.MAX_VALUE;
        gmax = -Double.MAX_VALUE;

        for (SupportVector v : vectors) {
            double gi = v.g;
            double ai = v.alpha;
            if (gi < gmin && ai > v.cmin) {
                svmin = v;
                gmin = gi;
            }
            if (gi > gmax && ai < v.cmax) {
                svmax = v;
                gmax = gi;
            }
        }

        minmaxflag = true;
    }

    /**
     * Returns the cached kernel value.
     * @param i the index of support vector.
     * @param j the index of support vector.
     * @return the kernel value.
     */
    private double k(int i, int j) {
        double k = Double.NaN;
        double[] ki = K[i];
        if (ki != null) {
            k = ki[j];
        }

        if (Double.isNaN(k)) {
            k = kernel.k(x[i], x[j]);
            if (ki != null) ki[j] = k;
        }

        return k;
    }

    /**
     * Sequential minimal optimization.
     * @param v1 the first vector of working set.
     * @param v2 the second vector of working set.
     * @param epsgr the tolerance of convergence test.
     * @return true if NOT pass convergence test.
     */
    private boolean smo(SupportVector v1, SupportVector v2, double epsgr) {
        // SMO working set selection
        // Determine coordinate to process
        if (v1 == null && v2 == null) {
            minmax();

            if (gmax > -gmin) {
                v2 = svmax;
            } else {
                v1 = svmin;
            }
        }

        // kernel(v1, v2)
        double k12 = Double.NaN;

        if (v2 == null) {
            // determine imax
            assert v1 != null;
            double km = v1.k;
            double gm = v1.g;
            double best = 0.0;
            for (SupportVector v : vectors) {
                double Z = v.g - gm;
                double k = k(v1.i, v.i);
                double curv = km + v.k - 2.0 * k;
                if (curv <= 0.0) curv = TAU;
                double mu = Z / curv;
                if ((mu > 0.0 && v.alpha < v.cmax) || (mu < 0.0 && v.alpha > v.cmin)) {
                    double gain = Z * mu;
                    if (gain > best) {
                        best = gain;
                        v2 = v;
                        k12 = k;
                    }
                }
            }
        }

        if (v1 == null) {
            // determine imin
            double km = v2.k;
            double gm = v2.g;
            double best = 0.0;
            for (SupportVector v : vectors) {
                double Z = gm - v.g;
                double k = k(v2.i, v.i);
                double curv = km + v.k - 2.0 * k;
                if (curv <= 0.0) curv = TAU;

                double mu = Z / curv;
                if ((mu > 0.0 && v.alpha > v.cmin) || (mu < 0.0 && v.alpha < v.cmax)) {
                    double gain = Z * mu;
                    if (gain > best) {
                        best = gain;
                        v1 = v;
                        k12 = k;
                    }
                }
            }
        }

        if (v1 == null || v2 == null) {
            return false;
        }

        if (Double.isNaN(k12)) {
            k12 = kernel.k(v1.x, v2.x);
        }

        // Perform update
        double step = getStep(v1, v2, k12);
        v1.alpha -= step;
        v2.alpha += step;
        for (SupportVector v : vectors) {
            v.g -= step * (k(v2.i, v.i) - k(v1.i, v.i));
        }

        // optimality test
        minmaxflag = false;
        minmax();

        b = (gmax + gmin) / 2;
        return gmax - gmin > epsgr;
    }

    /**
     * Calculates the maximal search step.
     * @param v1 the first vector of working set.
     * @param v2 the second vector of working set.
     * @param k12 the kernel value k(v1, v2).
     * @return the search step.
     */
    private double getStep(SupportVector v1, SupportVector v2, double k12) {
        // Determine curvature
        double curv = v1.k + v2.k - 2 * k12;
        if (curv <= 0.0) curv = TAU;

        double step = (v2.g - v1.g) / curv;

        // Determine maximal step
        if (step >= 0.0) {
            double delta = v1.alpha - v1.cmin;
            if (delta < step) {
                step = delta;
            }
            delta = v2.cmax - v2.alpha;
            if (delta < step) {
                step = delta;
            }
        } else {
            double delta = v2.cmin - v2.alpha;
            if (delta > step) {
                step = delta;
            }
            delta = v1.alpha - v1.cmax;
            if (delta > step) {
                step = delta;
            }
        }

        return step;
    }

    /**
     * Process a new sample.
     * @return true if x is added to support vectors.
     */
    private boolean process(int i, T x, int y) {
        if (y != 1 && y != -1) {
            throw new IllegalArgumentException("Invalid label: " + y);
        }

        // Bail out if already in expansion
        for (SupportVector v : vectors) {
            if (v.x == x) return false;
        }

        double[] cache = new double[K.length];
        Arrays.fill(cache, Double.NaN);

        // Compute gradient
        double g = y;
        for (SupportVector v : vectors) {
            // Parallel stream may cause unreproducible results due to
            // different numeric round-off because of different data
            // partitions (i.e. different number of cores/threads).
            // The speedup of parallel stream is also limited as
            // the number of support vectors is often small.
            double k = kernel.k(v.x, x);
            cache[v.i] = k;
            g -= v.alpha * k;
        }

        // Decide insertion
        minmax();
        if (gmin < gmax) {
            if ((y > 0 && g < gmin) || (y < 0 && g > gmax)) {
                return false;
            }
        }

        // Insert
        SupportVector v = new SupportVector<>(i, x, y, 0.0, g, Cp, Cn, kernel.k(x, x));
        vectors.add(v);
        K[i] = cache;

        // Process
        if (y > 0) {
            smo(null, v, 0.0);
        } else {
            smo(v, null, 0.0);
        }

        minmaxflag = false;
        return true;
    }

    /**
     * Reprocess support vectors.
     * @param epsgr the tolerance of convergence test.
     * @return true if NOT pass convergence test.
     */
    private boolean reprocess(double epsgr) {
        boolean status = smo(null, null, epsgr);
        evict();
        return status;
    }

    /**
     * Call reprocess until converge.
     */
    private void finish() {
        finish(tol, vectors.size());

        int bsv = 0;
        for (SupportVector v : vectors) {
            if (v.alpha == v.cmin || v.alpha == v.cmax) {
                bsv++;
            }
        }

        logger.info("{} samples, {} support vectors, {} bounded", x.length, vectors.size(), bsv);

    }

    /**
     * Call reprocess until converge.
     * @param epsgr the tolerance of convergence test.
     * @param maxIter the maximum number of iterations.
     */
    private void finish(double epsgr, int maxIter) {
        logger.info("Finalizing the training by reprocess.");
        for (int count = 1; count <= maxIter && smo(null, null, epsgr); count++) {
            if (count % 1000 == 0) {
                logger.info("{} reprocess iterations.", count);
            }
        }
        evict();
    }

    /**
     * Removes support vectors from the kernel expansion.
     * Online kernel classifiers usually experience considerable problems
     * with noisy data sets. Each iteration is likely to cause a mistake
     * because the best achievable error rate for such problems
     * is high. The number of support vectors increases very rapidly and
     * potentially causes overfitting and poor convergence. Support vector
     * removal criteria avoid this drawback.
     */
    private void evict() {
        minmax();

        vectors.removeIf(v -> {
            if (MathEx.isZero(v.alpha, 1E-4)) {
                if ((v.g >= gmax && 0 >= v.cmax) || (v.g <= gmin && 0 <= v.cmin)) {
                    K[v.i] = null;
                    return true;
                }
            }
            return false;
        });
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy