smile.regression.RBFNetwork Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2010 Haifeng Li
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package smile.regression;
import java.util.Arrays;
import smile.math.distance.Metric;
import smile.math.matrix.QRDecomposition;
import smile.math.rbf.GaussianRadialBasis;
import smile.math.rbf.RadialBasisFunction;
import smile.util.SmileUtils;
/**
* Radial basis function network. A radial basis function network is an
* artificial neural network that uses radial basis functions as activation
* functions. It is a linear combination of radial basis functions. They are
* used in function approximation, time series prediction, and control.
*
* In its basic form, radial basis function network is in the form
*
* y(x) = Σ wi φ(||x-ci||)
*
* where the approximating function y(x) is represented as a sum of N radial
* basis functions φ, each associated with a different center ci,
* and weighted by an appropriate coefficient wi. For distance,
* one usually chooses Euclidean distance. The weights wi can
* be estimated using the matrix methods of linear least squares, because
* the approximating function is linear in the weights.
*
* The points ci are often called the centers of the RBF networks,
* which can be randomly selected from training data, or learned by some clustering
* method (e.g. k-means), or learned together with weight parameters undergo
* a supervised learning processing (e.g. error-correction learning).
*
* Popular choices for φ comprise the Gaussian function and the so
* called thin plate splines. The advantage of the thin plate splines is that
* their conditioning is invariant under scalings. Gaussian, multi-quadric
* and inverse multi-quadric are infinitely smooth and and involve a scale
* or shape parameter, r0 > 0. Decreasing
* r0 tends to flatten the basis function. For a
* given function, the quality of approximation may strongly depend on this
* parameter. In particular, increasing r0 has the
* effect of better conditioning (the separation distance of the scaled points
* increases).
*
* A variant on RBF networks is normalized radial basis function (NRBF)
* networks, in which we require the sum of the basis functions to be unity.
* NRBF arises more naturally from a Bayesian statistical perspective. However,
* there is no evidence that either the NRBF method is consistently superior
* to the RBF method, or vice versa.
*
*
References
*
* - Simon Haykin. Neural Networks: A Comprehensive Foundation (2nd edition). 1999.
* - T. Poggio and F. Girosi. Networks for approximation and learning. Proc. IEEE 78(9):1484-1487, 1990.
* - Nabil Benoudjit and Michel Verleysen. On the kernel widths in radial-basis function networks. Neural Process, 2003.
*
*
* @see RadialBasisFunction
* @see SVR
*
* @author Haifeng Li
*/
public class RBFNetwork implements Regression {
/**
* The centers of RBF functions.
*/
private T[] centers;
/**
* The linear weights.
*/
private double[] w;
/**
* The distance functor.
*/
private Metric distance;
/**
* The radial basis functions.
*/
private RadialBasisFunction[] rbf;
/**
* True to fit a normalized RBF network.
*/
private boolean normalized;
/**
* Trainer for RBF networks.
*/
public static class Trainer extends RegressionTrainer {
/**
* The number of centers.
*/
private int m = 10;
/**
* The distance metric functor.
*/
private Metric distance;
/**
* The radial basis functions.
*/
private RadialBasisFunction[] rbf;
/**
* True to fit a normalized RBF network.
*/
private boolean normalized = false;
/**
* Constructor.
*
* @param distance the distance metric functor.
*/
public Trainer(Metric distance) {
this.distance = distance;
}
/**
* Sets the radial basis function.
* @param rbf the radial basis function.
* @param m the number of basis functions.
*/
public Trainer setRBF(RadialBasisFunction rbf, int m) {
this.m = m;
this.rbf = rep(rbf, m);
return this;
}
/**
* Sets the radial basis functions.
* @param rbf the radial basis functions.
*/
public Trainer setRBF(RadialBasisFunction[] rbf) {
this.m = rbf.length;
this.rbf = rbf;
return this;
}
/**
* Sets the number of centers.
* @param m the number of centers.
*/
public Trainer setNumCenters(int m) {
this.m = m;
return this;
}
/**
* Sets true to learn normalized RBF network.
* @param normalized true to learn normalized RBF network.
*/
public Trainer setNormalized(boolean normalized) {
this.normalized = normalized;
return this;
}
@Override
public RBFNetwork train(T[] x, double[] y) {
@SuppressWarnings("unchecked")
T[] centers = (T[]) java.lang.reflect.Array.newInstance(x.getClass().getComponentType(), m);
GaussianRadialBasis gaussian = SmileUtils.learnGaussianRadialBasis(x, centers, distance);
if (rbf == null) {
return new RBFNetwork(x, y, distance, gaussian, centers, normalized);
} else {
return new RBFNetwork(x, y, distance, rbf, centers, normalized);
}
}
/**
* Learns a RBF network with given centers.
*
* @param x training samples.
* @param y training labels in [0, k), where k is the number of classes.
* @param centers the centers of RBF functions.
* @return a trained RBF network
*/
public RBFNetwork train(T[] x, double[] y, T[] centers) {
return new RBFNetwork(x, y, distance, rbf, centers, normalized);
}
}
/**
* Constructor. Learn a regular RBF network without normalization.
* @param x the training dataset.
* @param y the response variable.
* @param distance the distance functor.
* @param rbf the radial basis function.
* @param centers the centers of RBF functions.
*/
public RBFNetwork(T[] x, double[] y, Metric distance, RadialBasisFunction rbf, T[] centers) {
this(x, y, distance, rbf, centers, false);
}
/**
* Constructor. Learn a regular RBF network without normalization.
* @param x the training dataset.
* @param y the response variable.
* @param distance the distance functor.
* @param rbf the radial basis functions.
* @param centers the centers of RBF functions.
*/
public RBFNetwork(T[] x, double[] y, Metric distance, RadialBasisFunction[] rbf, T[] centers) {
this(x, y, distance, rbf, centers, false);
}
/**
* Constructor.
* @param x the training dataset.
* @param y the response variable.
* @param distance the distance functor.
* @param rbf the radial basis function.
* @param centers the centers of RBF functions.
* @param normalized true for the normalized RBF network.
*/
public RBFNetwork(T[] x, double[] y, Metric distance, RadialBasisFunction rbf, T[] centers, boolean normalized) {
this(x, y, distance, rep(rbf, centers.length), centers, normalized);
}
/**
* Returns an array of radial basis functions initialized with given values.
* @param rbf the initial value of array.
* @param k the size of array.
* @return an array of radial basis functions initialized with given values
*/
private static RadialBasisFunction[] rep(RadialBasisFunction rbf, int k) {
RadialBasisFunction[] arr = new RadialBasisFunction[k];
Arrays.fill(arr, rbf);
return arr;
}
/**
* Constructor.
* @param x the training dataset.
* @param y the response variable.
* @param distance the distance functor.
* @param rbf the radial basis functions.
* @param centers the centers of RBF functions.
* @param normalized true for the normalized RBF network.
*/
public RBFNetwork(T[] x, double[] y, Metric distance, RadialBasisFunction[] rbf, T[] centers, boolean normalized) {
if (x.length != y.length) {
throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
if (rbf.length != centers.length) {
throw new IllegalArgumentException(String.format("The sizes of RBF functions and centers don't match: %d != %d", rbf.length, centers.length));
}
this.centers = centers;
this.distance = distance;
this.rbf = rbf;
this.normalized = normalized;
int n = x.length;
int m = rbf.length;
double[][] G = new double[n][m];
double[] b = new double[n];
w = new double[m];
for (int i = 0; i < n; i++) {
double sum = 0.0;
for (int j = 0; j < m; j++) {
G[i][j] = rbf[j].f(distance.d(x[i], centers[j]));
sum += G[i][j];
}
if (normalized) {
b[i] = sum * y[i];
} else {
b[i] = y[i];
}
}
QRDecomposition qr = new QRDecomposition(G, true);
qr.solve(b, w);
}
@Override
public double predict(T x) {
double sum = 0.0, sumw = 0.0;
for (int i = 0; i < rbf.length; i++) {
double f = rbf[i].f(distance.d(x, centers[i]));
sumw += w[i] * f;
sum += f;
}
return normalized ? sumw / sum : sumw;
}
}