All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.commons.math4.analysis.interpolation.LoessInterpolator Maven / Gradle / Ivy

Go to download

Statistical sampling library for use in virtdata libraries, based on apache commons math 4

There is a newer version: 5.17.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.math4.analysis.interpolation;

import java.io.Serializable;
import java.util.Arrays;

import org.apache.commons.math4.analysis.polynomials.PolynomialSplineFunction;
import org.apache.commons.math4.exception.DimensionMismatchException;
import org.apache.commons.math4.exception.NoDataException;
import org.apache.commons.math4.exception.NonMonotonicSequenceException;
import org.apache.commons.math4.exception.NotFiniteNumberException;
import org.apache.commons.math4.exception.NotPositiveException;
import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.math4.exception.OutOfRangeException;
import org.apache.commons.math4.exception.util.LocalizedFormats;
import org.apache.commons.math4.util.FastMath;
import org.apache.commons.math4.util.MathArrays;
import org.apache.commons.math4.util.MathUtils;

/**
 * Implements the 
 * Local Regression Algorithm (also Loess, Lowess) for interpolation of
 * real univariate functions.
 * 

* For reference, see * * William S. Cleveland - Robust Locally Weighted Regression and Smoothing * Scatterplots

*

* This class implements both the loess method and serves as an interpolation * adapter to it, allowing one to build a spline on the obtained loess fit.

* * @since 2.0 */ public class LoessInterpolator implements UnivariateInterpolator, Serializable { /** Default value of the bandwidth parameter. */ public static final double DEFAULT_BANDWIDTH = 0.3; /** Default value of the number of robustness iterations. */ public static final int DEFAULT_ROBUSTNESS_ITERS = 2; /** * Default value for accuracy. * @since 2.1 */ public static final double DEFAULT_ACCURACY = 1e-12; /** serializable version identifier. */ private static final long serialVersionUID = 5204927143605193821L; /** * The bandwidth parameter: when computing the loess fit at * a particular point, this fraction of source points closest * to the current point is taken into account for computing * a least-squares regression. *

* A sensible value is usually 0.25 to 0.5.

*/ private final double bandwidth; /** * The number of robustness iterations parameter: this many * robustness iterations are done. *

* A sensible value is usually 0 (just the initial fit without any * robustness iterations) to 4.

*/ private final int robustnessIters; /** * If the median residual at a certain robustness iteration * is less than this amount, no more iterations are done. */ private final double accuracy; /** * Constructs a new {@link LoessInterpolator} * with a bandwidth of {@link #DEFAULT_BANDWIDTH}, * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations * and an accuracy of {#link #DEFAULT_ACCURACY}. * See {@link #LoessInterpolator(double, int, double)} for an explanation of * the parameters. */ public LoessInterpolator() { this.bandwidth = DEFAULT_BANDWIDTH; this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS; this.accuracy = DEFAULT_ACCURACY; } /** * Construct a new {@link LoessInterpolator} * with given bandwidth and number of robustness iterations. *

* Calling this constructor is equivalent to calling {link {@link * #LoessInterpolator(double, int, double) LoessInterpolator(bandwidth, * robustnessIters, LoessInterpolator.DEFAULT_ACCURACY)} *

* * @param bandwidth when computing the loess fit at * a particular point, this fraction of source points closest * to the current point is taken into account for computing * a least-squares regression. * A sensible value is usually 0.25 to 0.5, the default value is * {@link #DEFAULT_BANDWIDTH}. * @param robustnessIters This many robustness iterations are done. * A sensible value is usually 0 (just the initial fit without any * robustness iterations) to 4, the default value is * {@link #DEFAULT_ROBUSTNESS_ITERS}. * @see #LoessInterpolator(double, int, double) */ public LoessInterpolator(double bandwidth, int robustnessIters) { this(bandwidth, robustnessIters, DEFAULT_ACCURACY); } /** * Construct a new {@link LoessInterpolator} * with given bandwidth, number of robustness iterations and accuracy. * * @param bandwidth when computing the loess fit at * a particular point, this fraction of source points closest * to the current point is taken into account for computing * a least-squares regression. * A sensible value is usually 0.25 to 0.5, the default value is * {@link #DEFAULT_BANDWIDTH}. * @param robustnessIters This many robustness iterations are done. * A sensible value is usually 0 (just the initial fit without any * robustness iterations) to 4, the default value is * {@link #DEFAULT_ROBUSTNESS_ITERS}. * @param accuracy If the median residual at a certain robustness iteration * is less than this amount, no more iterations are done. * @throws OutOfRangeException if bandwidth does not lie in the interval [0,1]. * @throws NotPositiveException if {@code robustnessIters} is negative. * @see #LoessInterpolator(double, int) * @since 2.1 */ public LoessInterpolator(double bandwidth, int robustnessIters, double accuracy) throws OutOfRangeException, NotPositiveException { if (bandwidth < 0 || bandwidth > 1) { throw new OutOfRangeException(LocalizedFormats.BANDWIDTH, bandwidth, 0, 1); } this.bandwidth = bandwidth; if (robustnessIters < 0) { throw new NotPositiveException(LocalizedFormats.ROBUSTNESS_ITERATIONS, robustnessIters); } this.robustnessIters = robustnessIters; this.accuracy = accuracy; } /** * Compute an interpolating function by performing a loess fit * on the data at the original abscissae and then building a cubic spline * with a * {@link org.apache.commons.math4.analysis.interpolation.SplineInterpolator} * on the resulting fit. * * @param xval the arguments for the interpolation points * @param yval the values for the interpolation points * @return A cubic spline built upon a loess fit to the data at the original abscissae * @throws NonMonotonicSequenceException if {@code xval} not sorted in * strictly increasing order. * @throws DimensionMismatchException if {@code xval} and {@code yval} have * different sizes. * @throws NoDataException if {@code xval} or {@code yval} has zero size. * @throws NotFiniteNumberException if any of the arguments and values are * not finite real numbers. * @throws NumberIsTooSmallException if the bandwidth is too small to * accomodate the size of the input data (i.e. the bandwidth must be * larger than 2/n). */ @Override public final PolynomialSplineFunction interpolate(final double[] xval, final double[] yval) throws NonMonotonicSequenceException, DimensionMismatchException, NoDataException, NotFiniteNumberException, NumberIsTooSmallException { return new SplineInterpolator().interpolate(xval, smooth(xval, yval)); } /** * Compute a weighted loess fit on the data at the original abscissae. * * @param xval Arguments for the interpolation points. * @param yval Values for the interpolation points. * @param weights point weights: coefficients by which the robustness weight * of a point is multiplied. * @return the values of the loess fit at corresponding original abscissae. * @throws NonMonotonicSequenceException if {@code xval} not sorted in * strictly increasing order. * @throws DimensionMismatchException if {@code xval} and {@code yval} have * different sizes. * @throws NoDataException if {@code xval} or {@code yval} has zero size. * @throws NotFiniteNumberException if any of the arguments and values are not finite real numbers. * @throws NumberIsTooSmallException if the bandwidth is too small to * accomodate the size of the input data (i.e. the bandwidth must be * larger than 2/n). * @since 2.1 */ public final double[] smooth(final double[] xval, final double[] yval, final double[] weights) throws NonMonotonicSequenceException, DimensionMismatchException, NoDataException, NotFiniteNumberException, NumberIsTooSmallException { if (xval.length != yval.length) { throw new DimensionMismatchException(xval.length, yval.length); } final int n = xval.length; if (n == 0) { throw new NoDataException(); } checkAllFiniteReal(xval); checkAllFiniteReal(yval); checkAllFiniteReal(weights); MathArrays.checkOrder(xval); if (n == 1) { return new double[]{yval[0]}; } if (n == 2) { return new double[]{yval[0], yval[1]}; } int bandwidthInPoints = (int) (bandwidth * n); if (bandwidthInPoints < 2) { throw new NumberIsTooSmallException(LocalizedFormats.BANDWIDTH, bandwidthInPoints, 2, true); } final double[] res = new double[n]; final double[] residuals = new double[n]; final double[] sortedResiduals = new double[n]; final double[] robustnessWeights = new double[n]; // Do an initial fit and 'robustnessIters' robustness iterations. // This is equivalent to doing 'robustnessIters+1' robustness iterations // starting with all robustness weights set to 1. Arrays.fill(robustnessWeights, 1); for (int iter = 0; iter <= robustnessIters; ++iter) { final int[] bandwidthInterval = {0, bandwidthInPoints - 1}; // At each x, compute a local weighted linear regression for (int i = 0; i < n; ++i) { final double x = xval[i]; // Find out the interval of source points on which // a regression is to be made. if (i > 0) { updateBandwidthInterval(xval, weights, i, bandwidthInterval); } final int ileft = bandwidthInterval[0]; final int iright = bandwidthInterval[1]; // Compute the point of the bandwidth interval that is // farthest from x final int edge; if (xval[i] - xval[ileft] > xval[iright] - xval[i]) { edge = ileft; } else { edge = iright; } // Compute a least-squares linear fit weighted by // the product of robustness weights and the tricube // weight function. // See http://en.wikipedia.org/wiki/Linear_regression // (section "Univariate linear case") // and http://en.wikipedia.org/wiki/Weighted_least_squares // (section "Weighted least squares") double sumWeights = 0; double sumX = 0; double sumXSquared = 0; double sumY = 0; double sumXY = 0; double denom = FastMath.abs(1.0 / (xval[edge] - x)); for (int k = ileft; k <= iright; ++k) { final double xk = xval[k]; final double yk = yval[k]; final double dist = (k < i) ? x - xk : xk - x; final double w = tricube(dist * denom) * robustnessWeights[k] * weights[k]; final double xkw = xk * w; sumWeights += w; sumX += xkw; sumXSquared += xk * xkw; sumY += yk * w; sumXY += yk * xkw; } final double meanX = sumX / sumWeights; final double meanY = sumY / sumWeights; final double meanXY = sumXY / sumWeights; final double meanXSquared = sumXSquared / sumWeights; final double beta; if (FastMath.sqrt(FastMath.abs(meanXSquared - meanX * meanX)) < accuracy) { beta = 0; } else { beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX); } final double alpha = meanY - beta * meanX; res[i] = beta * x + alpha; residuals[i] = FastMath.abs(yval[i] - res[i]); } // No need to recompute the robustness weights at the last // iteration, they won't be needed anymore if (iter == robustnessIters) { break; } // Recompute the robustness weights. // Find the median residual. // An arraycopy and a sort are completely tractable here, // because the preceding loop is a lot more expensive System.arraycopy(residuals, 0, sortedResiduals, 0, n); Arrays.sort(sortedResiduals); final double medianResidual = sortedResiduals[n / 2]; if (FastMath.abs(medianResidual) < accuracy) { break; } for (int i = 0; i < n; ++i) { final double arg = residuals[i] / (6 * medianResidual); if (arg >= 1) { robustnessWeights[i] = 0; } else { final double w = 1 - arg * arg; robustnessWeights[i] = w * w; } } } return res; } /** * Compute a loess fit on the data at the original abscissae. * * @param xval the arguments for the interpolation points * @param yval the values for the interpolation points * @return values of the loess fit at corresponding original abscissae * @throws NonMonotonicSequenceException if {@code xval} not sorted in * strictly increasing order. * @throws DimensionMismatchException if {@code xval} and {@code yval} have * different sizes. * @throws NoDataException if {@code xval} or {@code yval} has zero size. * @throws NotFiniteNumberException if any of the arguments and values are * not finite real numbers. * @throws NumberIsTooSmallException if the bandwidth is too small to * accomodate the size of the input data (i.e. the bandwidth must be * larger than 2/n). */ public final double[] smooth(final double[] xval, final double[] yval) throws NonMonotonicSequenceException, DimensionMismatchException, NoDataException, NotFiniteNumberException, NumberIsTooSmallException { if (xval.length != yval.length) { throw new DimensionMismatchException(xval.length, yval.length); } final double[] unitWeights = new double[xval.length]; Arrays.fill(unitWeights, 1.0); return smooth(xval, yval, unitWeights); } /** * Given an index interval into xval that embraces a certain number of * points closest to {@code xval[i-1]}, update the interval so that it * embraces the same number of points closest to {@code xval[i]}, * ignoring zero weights. * * @param xval Arguments array. * @param weights Weights array. * @param i Index around which the new interval should be computed. * @param bandwidthInterval a two-element array {left, right} such that: * {@code (left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])} * and * {@code (right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])}. * The array will be updated. */ private static void updateBandwidthInterval(final double[] xval, final double[] weights, final int i, final int[] bandwidthInterval) { final int left = bandwidthInterval[0]; final int right = bandwidthInterval[1]; // The right edge should be adjusted if the next point to the right // is closer to xval[i] than the leftmost point of the current interval int nextRight = nextNonzero(weights, right); if (nextRight < xval.length && xval[nextRight] - xval[i] < xval[i] - xval[left]) { int nextLeft = nextNonzero(weights, bandwidthInterval[0]); bandwidthInterval[0] = nextLeft; bandwidthInterval[1] = nextRight; } } /** * Return the smallest index {@code j} such that * {@code j > i && (j == weights.length || weights[j] != 0)}. * * @param weights Weights array. * @param i Index from which to start search. * @return the smallest compliant index. */ private static int nextNonzero(final double[] weights, final int i) { int j = i + 1; while(j < weights.length && weights[j] == 0) { ++j; } return j; } /** * Compute the * tricube * weight function * * @param x Argument. * @return (1 - |x|3)3 for |x| < 1, 0 otherwise. */ private static double tricube(final double x) { final double absX = FastMath.abs(x); if (absX >= 1.0) { return 0.0; } final double tmp = 1 - absX * absX * absX; return tmp * tmp * tmp; } /** * Check that all elements of an array are finite real numbers. * * @param values Values array. * @throws NotFiniteNumberException * if one of the values is not a finite real number. */ private static void checkAllFiniteReal(final double[] values) { for (int i = 0; i < values.length; i++) { MathUtils.checkFinite(values[i]); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy