smile.classification.OneVersusRest Maven / Gradle / Ivy
/*
* Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
*
* Smile is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Smile is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Smile. If not, see .
*/
package smile.classification;
import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.stream.IntStream;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.util.IntSet;
/**
* One-vs-rest (or one-vs-all) strategy for reducing the problem of
* multiclass classification to multiple binary classification problems.
* It involves training a single classifier per class, with the samples
* of that class as positive samples and all other samples as negatives.
* This strategy requires the base classifiers to produce a real-valued
* confidence score for its decision, rather than just a class label;
* discrete class labels alone can lead to ambiguities, where multiple
* classes are predicted for a single sample.
*
* Making decisions means applying all classifiers to an unseen sample
* x and predicting the label k for which the corresponding classifier
* reports the highest confidence score.
*
* Although this strategy is popular, it is a heuristic that suffers
* from several problems. Firstly, the scale of the confidence values
* may differ between the binary classifiers. Second, even if the class
* distribution is balanced in the training set, the binary classification
* learners see unbalanced distributions because typically the set of
* negatives they see is much larger than the set of positives.
*
* @author Haifeng Li
*/
public class OneVersusRest extends AbstractClassifier {
private static final long serialVersionUID = 2L;
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(OneVersusRest.class);
/** The number of classes. */
private final int k;
/** The binary classifier. */
private final Classifier[] classifiers;
/** The probability estimation by Platt scaling. */
private final PlattScaling[] platt;
/**
* Constructor.
* @param classifiers the binary classifier for each one-vs-rest case.
* @param platt Platt scaling models.
*/
public OneVersusRest(Classifier[] classifiers, PlattScaling[] platt) {
this(classifiers, platt, IntSet.of(classifiers.length));
}
/**
* Constructor.
* @param classifiers the binary classifier for each one-vs-rest case.
* @param platt Platt scaling models.
* @param labels the class label encoder.
*/
public OneVersusRest(Classifier[] classifiers, PlattScaling[] platt, IntSet labels) {
super(labels);
this.classifiers = classifiers;
this.platt = platt;
this. k = classifiers.length;
}
/**
* Fits a multi-class model with binary classifiers.
* Use +1 and -1 as positive and negative class labels.
* @param x the training samples.
* @param y the training labels.
* @param trainer the lambda to train binary classifiers.
* @param the data type.
* @return the model.
*/
public static OneVersusRest fit(T[] x, int[] y, BiFunction> trainer) {
return fit(x, y, +1, -1, trainer);
}
/**
* Fits a multi-class model with binary classifiers.
* @param x the training samples.
* @param y the training labels.
* @param pos the class label for one case.
* @param neg the class label for rest cases.
* @param trainer the lambda to train binary classifiers.
* @param the data type.
* @return the model.
*/
@SuppressWarnings("unchecked")
public static OneVersusRest fit(T[] x, int[] y, int pos, int neg, BiFunction> trainer) {
if (x.length != y.length) {
throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
ClassLabels codec = ClassLabels.fit(y);
int k = codec.k;
if (k <= 2) {
throw new IllegalArgumentException(String.format("Only %d classes", k));
}
int n = x.length;
int[] labels = codec.y;
Classifier[] classifiers = new Classifier[k];
PlattScaling[] platts = new PlattScaling[k];
IntStream.range(0, k).parallel().forEach(i -> {
int[] yi = new int[n];
for (int j = 0; j < n; j++) {
yi[j] = labels[j] == i ? pos : neg;
}
classifiers[i] = trainer.apply(x, yi);
try {
platts[i] = PlattScaling.fit(classifiers[i], x, yi);
} catch (UnsupportedOperationException ex) {
logger.info("The classifier doesn't support score function. Don't fit Platt scaling.");
}
});
return new OneVersusRest<>(classifiers, platts[0] == null ? null : platts);
}
/**
* Fits a multi-class model with binary data frame classifiers.
* @param formula a symbolic description of the model to be fitted.
* @param data the data frame of the explanatory and response variables.
* @param trainer the lambda to train binary classifiers.
* @return the model.
*/
public static DataFrameClassifier fit(Formula formula, DataFrame data, BiFunction trainer) {
Tuple[] x = data.stream().toArray(Tuple[]::new);
int[] y = formula.y(data).toIntArray();
OneVersusRest model = fit(x, y, 1, 0, (Tuple[] rows, int[] labels) -> {
DataFrame df = DataFrame.of(Arrays.asList(rows));
return trainer.apply(formula, df);
});
StructType schema = formula.x(data.get(0)).schema();
return new DataFrameClassifier() {
@Override
public int numClasses() {
return model.numClasses();
}
@Override
public int[] classes() {
return model.classes();
}
@Override
public int predict(Tuple x) {
return model.predict(x);
}
@Override
public Formula formula() {
return formula;
}
@Override
public StructType schema() {
return schema;
}
};
}
@Override
public int predict(T x) {
int y = 0;
double maxf = Double.NEGATIVE_INFINITY;
for (int i = 0; i < k; i++) {
double f = platt[i].scale(classifiers[i].score(x));
if (f > maxf) {
y = i;
maxf = f;
}
}
return classes.valueOf(y);
}
@Override
public boolean soft() {
return true;
}
@Override
public int predict(T x, double[] posteriori) {
if (platt == null) {
throw new UnsupportedOperationException("Platt scaling is not available");
}
for (int i = 0; i < k; i++) {
posteriori[i] = platt[i].scale(classifiers[i].score(x));
}
MathEx.unitize1(posteriori);
return classes.valueOf(MathEx.whichMax(posteriori));
}
}