All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.commons.math3.stat.correlation.KendallsCorrelation Maven / Gradle / Ivy

There is a newer version: 2.12.15
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.math3.stat.correlation;

import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair;

import java.util.Arrays;
import java.util.Comparator;

/**
 * Implementation of Kendall's Tau-b rank correlation.
 * 

* A pair of observations (x1, y1) and * (x2, y2) are considered concordant if * x1 < x2 and y1 < y2 * or x2 < x1 and y2 < y1. * The pair is discordant if x1 < x2 and * y2 < y1 or x2 < x1 and * y1 < y2. If either x1 = x2 * or y1 = y2, the pair is neither concordant nor * discordant. *

* Kendall's Tau-b is defined as: *

 * taub = (nc - nd) / sqrt((n0 - n1) * (n0 - n2))
 * 
*

* where: *

    *
  • n0 = n * (n - 1) / 2
  • *
  • nc = Number of concordant pairs
  • *
  • nd = Number of discordant pairs
  • *
  • n1 = sum of ti * (ti - 1) / 2 for all i
  • *
  • n2 = sum of uj * (uj - 1) / 2 for all j
  • *
  • ti = Number of tied values in the ith group of ties in x
  • *
  • uj = Number of tied values in the jth group of ties in y
  • *
*

* This implementation uses the O(n log n) algorithm described in * William R. Knight's 1966 paper "A Computer Method for Calculating * Kendall's Tau with Ungrouped Data" in the Journal of the American * Statistical Association. * * @see * Kendall tau rank correlation coefficient (Wikipedia) * @see A Computer * Method for Calculating Kendall's Tau with Ungrouped Data * * @since 3.3 */ public class KendallsCorrelation { /** correlation matrix */ private final RealMatrix correlationMatrix; /** * Create a KendallsCorrelation instance without data. */ public KendallsCorrelation() { correlationMatrix = null; } /** * Create a KendallsCorrelation from a rectangular array * whose columns represent values of variables to be correlated. * * @param data rectangular array with columns representing variables * @throws IllegalArgumentException if the input data array is not * rectangular with at least two rows and two columns. */ public KendallsCorrelation(double[][] data) { this(MatrixUtils.createRealMatrix(data)); } /** * Create a KendallsCorrelation from a RealMatrix whose columns * represent variables to be correlated. * * @param matrix matrix with columns representing variables to correlate */ public KendallsCorrelation(RealMatrix matrix) { correlationMatrix = computeCorrelationMatrix(matrix); } /** * Returns the correlation matrix. * * @return correlation matrix */ public RealMatrix getCorrelationMatrix() { return correlationMatrix; } /** * Computes the Kendall's Tau rank correlation matrix for the columns of * the input matrix. * * @param matrix matrix with columns representing variables to correlate * @return correlation matrix */ public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) { int nVars = matrix.getColumnDimension(); RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars); for (int i = 0; i < nVars; i++) { for (int j = 0; j < i; j++) { double corr = correlation(matrix.getColumn(i), matrix.getColumn(j)); outMatrix.setEntry(i, j, corr); outMatrix.setEntry(j, i, corr); } outMatrix.setEntry(i, i, 1d); } return outMatrix; } /** * Computes the Kendall's Tau rank correlation matrix for the columns of * the input rectangular array. The columns of the array represent values * of variables to be correlated. * * @param matrix matrix with columns representing variables to correlate * @return correlation matrix */ public RealMatrix computeCorrelationMatrix(final double[][] matrix) { return computeCorrelationMatrix(new BlockRealMatrix(matrix)); } /** * Computes the Kendall's Tau rank correlation coefficient between the two arrays. * * @param xArray first data array * @param yArray second data array * @return Returns Kendall's Tau rank correlation coefficient for the two arrays * @throws DimensionMismatchException if the arrays lengths do not match */ public double correlation(final double[] xArray, final double[] yArray) throws DimensionMismatchException { if (xArray.length != yArray.length) { throw new DimensionMismatchException(xArray.length, yArray.length); } final int n = xArray.length; final long numPairs = sum(n - 1); @SuppressWarnings("unchecked") Pair[] pairs = new Pair[n]; for (int i = 0; i < n; i++) { pairs[i] = new Pair(xArray[i], yArray[i]); } Arrays.sort(pairs, new Comparator>() { /** {@inheritDoc} */ public int compare(Pair pair1, Pair pair2) { int compareFirst = pair1.getFirst().compareTo(pair2.getFirst()); return compareFirst != 0 ? compareFirst : pair1.getSecond().compareTo(pair2.getSecond()); } }); long tiedXPairs = 0; long tiedXYPairs = 0; long consecutiveXTies = 1; long consecutiveXYTies = 1; Pair prev = pairs[0]; for (int i = 1; i < n; i++) { final Pair curr = pairs[i]; if (curr.getFirst().equals(prev.getFirst())) { consecutiveXTies++; if (curr.getSecond().equals(prev.getSecond())) { consecutiveXYTies++; } else { tiedXYPairs += sum(consecutiveXYTies - 1); consecutiveXYTies = 1; } } else { tiedXPairs += sum(consecutiveXTies - 1); consecutiveXTies = 1; tiedXYPairs += sum(consecutiveXYTies - 1); consecutiveXYTies = 1; } prev = curr; } tiedXPairs += sum(consecutiveXTies - 1); tiedXYPairs += sum(consecutiveXYTies - 1); long swaps = 0; @SuppressWarnings("unchecked") Pair[] pairsDestination = new Pair[n]; for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) { for (int offset = 0; offset < n; offset += 2 * segmentSize) { int i = offset; final int iEnd = FastMath.min(i + segmentSize, n); int j = iEnd; final int jEnd = FastMath.min(j + segmentSize, n); int copyLocation = offset; while (i < iEnd || j < jEnd) { if (i < iEnd) { if (j < jEnd) { if (pairs[i].getSecond().compareTo(pairs[j].getSecond()) <= 0) { pairsDestination[copyLocation] = pairs[i]; i++; } else { pairsDestination[copyLocation] = pairs[j]; j++; swaps += iEnd - i; } } else { pairsDestination[copyLocation] = pairs[i]; i++; } } else { pairsDestination[copyLocation] = pairs[j]; j++; } copyLocation++; } } final Pair[] pairsTemp = pairs; pairs = pairsDestination; pairsDestination = pairsTemp; } long tiedYPairs = 0; long consecutiveYTies = 1; prev = pairs[0]; for (int i = 1; i < n; i++) { final Pair curr = pairs[i]; if (curr.getSecond().equals(prev.getSecond())) { consecutiveYTies++; } else { tiedYPairs += sum(consecutiveYTies - 1); consecutiveYTies = 1; } prev = curr; } tiedYPairs += sum(consecutiveYTies - 1); final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps; final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs); return concordantMinusDiscordant / FastMath.sqrt(nonTiedPairsMultiplied); } /** * Returns the sum of the number from 1 .. n according to Gauss' summation formula: * \[ \sum\limits_{k=1}^n k = \frac{n(n + 1)}{2} \] * * @param n the summation end * @return the sum of the number from 1 to n */ private static long sum(long n) { return n * (n + 1) / 2l; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy