smile.feature.selection.SumSquaresRatio Maven / Gradle / Ivy
The newest version!
/*
* Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
*
* Smile is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Smile is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Smile. If not, see .
*/
package smile.feature.selection;
import java.util.Arrays;
import java.util.stream.IntStream;
import smile.classification.ClassLabels;
import smile.data.DataFrame;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.math.MathEx;
/**
* The ratio of between-groups to within-groups sum of squares is a univariate
* feature ranking metric, which can be used as a feature selection criterion
* for multi-class classification problems. For each variable j, this ratio is
* BSS(j) / WSS(j) = ΣI(yi = k)(xkj - x·j)2 / ΣI(yi = k)(xij - xkj)2;
* where x·j denotes the average of variable j across all
* samples, xkj denotes the average of variable j across samples
* belonging to class k, and xij is the value of variable j of sample i.
* Clearly, features with larger sum squares ratios are better for classification.
*
* References
*
* - S. Dudoit, J. Fridlyand and T. Speed. Comparison of discrimination methods for the classification of tumors using gene expression data. J Am Stat Assoc, 97:77-87, 2002.
*
*
* @author Haifeng Li
*/
public class SumSquaresRatio implements Comparable {
/** The feature name. */
public final String feature;
/** Sum squares ratio. */
public final double ssr;
/**
* Constructor.
* @param feature The feature name.
* @param ssr Sum squares ratio.
*/
public SumSquaresRatio(String feature, double ssr) {
this.feature = feature;
this.ssr = ssr;
}
@Override
public int compareTo(SumSquaresRatio other) {
return Double.compare(ssr, other.ssr);
}
@Override
public String toString() {
return String.format("SumSquaresRatio(%s, %.4f)", feature, ssr);
}
/**
* Calculates the sum squares ratio of numeric variables.
*
* @param data the data frame of the explanatory and response variables.
* @param clazz the column name of class labels.
* @return the sum squares ratio.
*/
public static SumSquaresRatio[] fit(DataFrame data, String clazz) {
BaseVector, ?, ?> y = data.column(clazz);
ClassLabels codec = ClassLabels.fit(y);
if (codec.k < 2) {
throw new UnsupportedOperationException("Invalid number of classes: " + codec.k);
}
int n = data.nrow();
int k = codec.k;
int[] nc = new int[k];
double[] condmu = new double[k];
for (int i = 0; i < n; i++) {
int yi = codec.y[i];
nc[yi]++;
}
StructType schema = data.schema();
return IntStream.range(0, schema.length()).mapToObj(j -> {
StructField field = schema.field(j);
if (field.isNumeric()) {
BaseVector, ?, ?> xj = data.column(j);
double mu = 0.0;
Arrays.fill(condmu, 0.0);
for (int i = 0; i < n; i++) {
int yi = codec.y[i];
double xij = xj.getDouble(i);
mu += xij;
condmu[yi] += xij;
}
mu /= n;
for (int i = 0; i < k; i++) {
condmu[i] /= nc[i];
}
double wss = 0.0;
double bss = 0.0;
for (int i = 0; i < n; i++) {
int yi = codec.y[i];
double xij = xj.getDouble(i);
bss += MathEx.pow2(condmu[yi] - mu);
wss += MathEx.pow2(xij - condmu[yi]);
}
return new SumSquaresRatio(field.name, bss / wss);
} else {
return null;
}
}).filter(s2n -> s2n != null && !s2n.feature.equals(clazz)).toArray(SumSquaresRatio[]::new);
}
}