All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.classification.RBFNetwork Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.classification;

import smile.base.rbf.RBF;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.math.rbf.RadialBasisFunction;
import smile.util.IntSet;

import java.io.Serial;
import java.util.Properties;

/**
 * Radial basis function networks. A radial basis function network is an
 * artificial neural network that uses radial basis functions as activation
 * functions. It is a linear combination of radial basis functions. They are
 * used in function approximation, time series prediction, and control.
 * 

* In its basic form, radial basis function network is in the form *

* y(x) = Σ wi φ(||x-ci||) *

* where the approximating function y(x) is represented as a sum of N radial * basis functions φ, each associated with a different center ci, * and weighted by an appropriate coefficient wi. For distance, * one usually chooses Euclidean distance. The weights wi can * be estimated using the matrix methods of linear least squares, because * the approximating function is linear in the weights. *

* The centers ci can be randomly selected from training data, * or learned by some clustering method (e.g. k-means), or learned together * with weight parameters undergo a supervised learning processing * (e.g. error-correction learning). *

* The popular choices for φ comprise the Gaussian function and the * so-called thin plate splines. The advantage of the thin plate splines is that * their conditioning is invariant under scalings. Gaussian, multi-quadric * and inverse multi-quadric are infinitely smooth and involve a scale * or shape parameter, r0 {@code > 0}. Decreasing * r0 tends to flatten the basis function. For a * given function, the quality of approximation may strongly depend on this * parameter. In particular, increasing r0 has the * effect of better conditioning (the separation distance of the scaled points * increases). *

* A variant on RBF networks is normalized radial basis function (NRBF) * networks, in which we require the sum of the basis functions to be unity. * NRBF arises more naturally from a Bayesian statistical perspective. However, * there is no evidence that either the NRBF method is consistently superior * to the RBF method, or vice versa. *

* SVMs with Gaussian kernel have similar structure as RBF networks with * Gaussian radial basis functions. However, the SVM approach "automatically" * solves the network complexity problem since the size of the hidden layer * is obtained as the result of the QP procedure. Hidden neurons and * support vectors correspond to each other, so the center problems of * the RBF network is also solved, as the support vectors serve as the * basis function centers. It was reported that with similar number of support * vectors/centers, SVM shows better generalization performance than RBF * network when the training data size is relatively small. On the other hand, * RBF network gives better generalization performance than SVM on large * training data. * *

References

*
    *
  1. Simon Haykin. Neural Networks: A Comprehensive Foundation (2nd edition). 1999.
  2. *
  3. T. Poggio and F. Girosi. Networks for approximation and learning. Proc. IEEE 78(9):1484-1487, 1990.
  4. *
  5. Nabil Benoudjit and Michel Verleysen. On the kernel widths in radial-basis function networks. Neural Process, 2003.
  6. *
* * @see RadialBasisFunction * @see SVM * @see MLP * * @param the data type of model input objects. * * @author Haifeng Li */ public class RBFNetwork extends AbstractClassifier { @Serial private static final long serialVersionUID = 2L; /** * The number of classes. */ private final int k; /** * The linear weights. */ private final Matrix w; /** * The radial basis function. */ private final RBF[] rbf; /** * True to fit a normalized RBF network. */ private final boolean normalized; /** * Constructor. * @param k the number of classes. * @param rbf the radial basis functions. * @param w the weights of RBFs. * @param normalized True if this is a normalized RBF network. */ public RBFNetwork(int k, RBF[] rbf, Matrix w, boolean normalized) { this(k, rbf, w, normalized, IntSet.of(k)); } /** * Constructor. * @param k the number of classes. * @param rbf the radial basis functions. * @param w the weights of RBFs. * @param normalized True if this is a normalized RBF network. * @param labels the class label encoder. */ public RBFNetwork(int k, RBF[] rbf, Matrix w, boolean normalized, IntSet labels) { super(labels); this.k = k; this.rbf = rbf; this.w = w; this.normalized = normalized; } /** * Fits an RBF network. * * @param x training samples. * @param y training labels in [0, k), where k is the number of classes. * @param rbf the radial basis functions. * @param the data type. * @return the model. */ public static RBFNetwork fit(T[] x, int[] y, RBF[] rbf) { return fit(x, y, rbf, false); } /** * Fits an RBF network. * * @param x training samples. * @param y training labels in [0, k), where k is the number of classes. * @param rbf the radial basis functions. * @param normalized true for the normalized RBF network. * @param the data type. * @return the model. */ public static RBFNetwork fit(T[] x, int[] y, RBF[] rbf, boolean normalized) { if (x.length != y.length) { throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length)); } ClassLabels codec = ClassLabels.fit(y); int k = codec.k; int n = x.length; int m = rbf.length; Matrix G = new Matrix(n, m+1); Matrix b = new Matrix(n, k); for (int i = 0; i < n; i++) { double sum = 0.0; for (int j = 0; j < m; j++) { double r = rbf[j].f(x[i]); G.set(i, j, r); sum += r; } G.set(i, m, 1); if (normalized) { b.set(i, codec.y[i], sum); } else { b.set(i, codec.y[i], 1); } } Matrix.QR qr = G.qr(true); qr.solve(b); return new RBFNetwork<>(k, rbf, b.submatrix(0, 0, m, k-1), normalized, codec.classes); } /** * Fits an RBF network. * @param x training samples. * @param y training labels. * @param params the hyperparameters. * @return the model. */ public static RBFNetwork fit(double[][] x, int[] y, Properties params) { int neurons = Integer.parseInt(params.getProperty("smile.rbf.neurons", "30")); boolean normalize = Boolean.parseBoolean(params.getProperty("smile.rbf.normalize", "false")); return RBFNetwork.fit(x, y, RBF.fit(x, neurons), normalize); } /** * Returns true if the model is normalized. * @return true if the model is normalized. */ public boolean isNormalized() { return normalized; } @Override public int predict(T x) { int m = rbf.length; double[] f = new double[m+1]; f[m] = 1.0; for (int i = 0; i < m; i++) { f[i] = rbf[i].f(x); } double[] sumw = new double[k]; w.tv(f, sumw); return classes.valueOf(MathEx.whichMax(sumw)); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy