org.apache.commons.math3.stat.correlation.KendallsCorrelation Maven / Gradle / Ivy
Show all versions of virtdata-lib-realer Show documentation
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.stat.correlation;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair;
import java.util.Arrays;
import java.util.Comparator;
/**
* Implementation of Kendall's Tau-b rank correlation.
*
* A pair of observations (x1, y1) and
* (x2, y2) are considered concordant if
* x1 < x2 and y1 < y2
* or x2 < x1 and y2 < y1.
* The pair is discordant if x1 < x2 and
* y2 < y1 or x2 < x1 and
* y1 < y2. If either x1 = x2
* or y1 = y2, the pair is neither concordant nor
* discordant.
*
* Kendall's Tau-b is defined as:
*
* taub = (nc - nd) / sqrt((n0 - n1) * (n0 - n2))
*
*
* where:
*
* - n0 = n * (n - 1) / 2
* - nc = Number of concordant pairs
* - nd = Number of discordant pairs
* - n1 = sum of ti * (ti - 1) / 2 for all i
* - n2 = sum of uj * (uj - 1) / 2 for all j
* - ti = Number of tied values in the ith group of ties in x
* - uj = Number of tied values in the jth group of ties in y
*
*
* This implementation uses the O(n log n) algorithm described in
* William R. Knight's 1966 paper "A Computer Method for Calculating
* Kendall's Tau with Ungrouped Data" in the Journal of the American
* Statistical Association.
*
* @see
* Kendall tau rank correlation coefficient (Wikipedia)
* @see A Computer
* Method for Calculating Kendall's Tau with Ungrouped Data
*
* @since 3.3
*/
public class KendallsCorrelation {
/** correlation matrix */
private final RealMatrix correlationMatrix;
/**
* Create a KendallsCorrelation instance without data.
*/
public KendallsCorrelation() {
correlationMatrix = null;
}
/**
* Create a KendallsCorrelation from a rectangular array
* whose columns represent values of variables to be correlated.
*
* @param data rectangular array with columns representing variables
* @throws IllegalArgumentException if the input data array is not
* rectangular with at least two rows and two columns.
*/
public KendallsCorrelation(double[][] data) {
this(MatrixUtils.createRealMatrix(data));
}
/**
* Create a KendallsCorrelation from a RealMatrix whose columns
* represent variables to be correlated.
*
* @param matrix matrix with columns representing variables to correlate
*/
public KendallsCorrelation(RealMatrix matrix) {
correlationMatrix = computeCorrelationMatrix(matrix);
}
/**
* Returns the correlation matrix.
*
* @return correlation matrix
*/
public RealMatrix getCorrelationMatrix() {
return correlationMatrix;
}
/**
* Computes the Kendall's Tau rank correlation matrix for the columns of
* the input matrix.
*
* @param matrix matrix with columns representing variables to correlate
* @return correlation matrix
*/
public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
int nVars = matrix.getColumnDimension();
RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars);
for (int i = 0; i < nVars; i++) {
for (int j = 0; j < i; j++) {
double corr = correlation(matrix.getColumn(i), matrix.getColumn(j));
outMatrix.setEntry(i, j, corr);
outMatrix.setEntry(j, i, corr);
}
outMatrix.setEntry(i, i, 1d);
}
return outMatrix;
}
/**
* Computes the Kendall's Tau rank correlation matrix for the columns of
* the input rectangular array. The columns of the array represent values
* of variables to be correlated.
*
* @param matrix matrix with columns representing variables to correlate
* @return correlation matrix
*/
public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
return computeCorrelationMatrix(new BlockRealMatrix(matrix));
}
/**
* Computes the Kendall's Tau rank correlation coefficient between the two arrays.
*
* @param xArray first data array
* @param yArray second data array
* @return Returns Kendall's Tau rank correlation coefficient for the two arrays
* @throws DimensionMismatchException if the arrays lengths do not match
*/
public double correlation(final double[] xArray, final double[] yArray)
throws DimensionMismatchException {
if (xArray.length != yArray.length) {
throw new DimensionMismatchException(xArray.length, yArray.length);
}
final int n = xArray.length;
final long numPairs = sum(n - 1);
@SuppressWarnings("unchecked")
Pair[] pairs = new Pair[n];
for (int i = 0; i < n; i++) {
pairs[i] = new Pair(xArray[i], yArray[i]);
}
Arrays.sort(pairs, new Comparator>() {
/** {@inheritDoc} */
public int compare(Pair pair1, Pair pair2) {
int compareFirst = pair1.getFirst().compareTo(pair2.getFirst());
return compareFirst != 0 ? compareFirst : pair1.getSecond().compareTo(pair2.getSecond());
}
});
long tiedXPairs = 0;
long tiedXYPairs = 0;
long consecutiveXTies = 1;
long consecutiveXYTies = 1;
Pair prev = pairs[0];
for (int i = 1; i < n; i++) {
final Pair curr = pairs[i];
if (curr.getFirst().equals(prev.getFirst())) {
consecutiveXTies++;
if (curr.getSecond().equals(prev.getSecond())) {
consecutiveXYTies++;
} else {
tiedXYPairs += sum(consecutiveXYTies - 1);
consecutiveXYTies = 1;
}
} else {
tiedXPairs += sum(consecutiveXTies - 1);
consecutiveXTies = 1;
tiedXYPairs += sum(consecutiveXYTies - 1);
consecutiveXYTies = 1;
}
prev = curr;
}
tiedXPairs += sum(consecutiveXTies - 1);
tiedXYPairs += sum(consecutiveXYTies - 1);
long swaps = 0;
@SuppressWarnings("unchecked")
Pair[] pairsDestination = new Pair[n];
for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) {
for (int offset = 0; offset < n; offset += 2 * segmentSize) {
int i = offset;
final int iEnd = FastMath.min(i + segmentSize, n);
int j = iEnd;
final int jEnd = FastMath.min(j + segmentSize, n);
int copyLocation = offset;
while (i < iEnd || j < jEnd) {
if (i < iEnd) {
if (j < jEnd) {
if (pairs[i].getSecond().compareTo(pairs[j].getSecond()) <= 0) {
pairsDestination[copyLocation] = pairs[i];
i++;
} else {
pairsDestination[copyLocation] = pairs[j];
j++;
swaps += iEnd - i;
}
} else {
pairsDestination[copyLocation] = pairs[i];
i++;
}
} else {
pairsDestination[copyLocation] = pairs[j];
j++;
}
copyLocation++;
}
}
final Pair[] pairsTemp = pairs;
pairs = pairsDestination;
pairsDestination = pairsTemp;
}
long tiedYPairs = 0;
long consecutiveYTies = 1;
prev = pairs[0];
for (int i = 1; i < n; i++) {
final Pair curr = pairs[i];
if (curr.getSecond().equals(prev.getSecond())) {
consecutiveYTies++;
} else {
tiedYPairs += sum(consecutiveYTies - 1);
consecutiveYTies = 1;
}
prev = curr;
}
tiedYPairs += sum(consecutiveYTies - 1);
final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps;
final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs);
return concordantMinusDiscordant / FastMath.sqrt(nonTiedPairsMultiplied);
}
/**
* Returns the sum of the number from 1 .. n according to Gauss' summation formula:
* \[ \sum\limits_{k=1}^n k = \frac{n(n + 1)}{2} \]
*
* @param n the summation end
* @return the sum of the number from 1 to n
*/
private static long sum(long n) {
return n * (n + 1) / 2l;
}
}