All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.classification.OneVersusRest Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.classification;

import java.io.Serial;
import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.stream.IntStream;

import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.util.IntSet;

/**
 * One-vs-rest (or one-vs-all) strategy for reducing the problem of
 * multiclass classification to multiple binary classification problems.
 * It involves training a single classifier per class, with the samples
 * of that class as positive samples and all other samples as negatives.
 * This strategy requires the base classifiers to produce a real-valued
 * confidence score for its decision, rather than just a class label;
 * discrete class labels alone can lead to ambiguities, where multiple
 * classes are predicted for a single sample.
 * 

* Making decisions means applying all classifiers to an unseen sample * x and predicting the label k for which the corresponding classifier * reports the highest confidence score. *

* Although this strategy is popular, it is a heuristic that suffers * from several problems. Firstly, the scale of the confidence values * may differ between the binary classifiers. Second, even if the class * distribution is balanced in the training set, the binary classification * learners see unbalanced distributions because typically the set of * negatives they see is much larger than the set of positives. * * @param the data type of model input objects. * * @author Haifeng Li */ public class OneVersusRest extends AbstractClassifier { @Serial private static final long serialVersionUID = 2L; private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(OneVersusRest.class); /** The number of classes. */ private final int k; /** The binary classifier. */ private final Classifier[] classifiers; /** The probability estimation by Platt scaling. */ private final PlattScaling[] platt; /** * Constructor. * @param classifiers the binary classifier for each one-vs-rest case. * @param platt Platt scaling models. */ public OneVersusRest(Classifier[] classifiers, PlattScaling[] platt) { this(classifiers, platt, IntSet.of(classifiers.length)); } /** * Constructor. * @param classifiers the binary classifier for each one-vs-rest case. * @param platt Platt scaling models. * @param labels the class label encoder. */ public OneVersusRest(Classifier[] classifiers, PlattScaling[] platt, IntSet labels) { super(labels); this.classifiers = classifiers; this.platt = platt; this. k = classifiers.length; } /** * Fits a multi-class model with binary classifiers. * Use +1 and -1 as positive and negative class labels. * @param x the training samples. * @param y the training labels. * @param trainer the lambda to train binary classifiers. * @param the data type. * @return the model. */ public static OneVersusRest fit(T[] x, int[] y, BiFunction> trainer) { return fit(x, y, +1, -1, trainer); } /** * Fits a multi-class model with binary classifiers. * @param x the training samples. * @param y the training labels. * @param pos the class label for one case. * @param neg the class label for rest cases. * @param trainer the lambda to train binary classifiers. * @param the data type. * @return the model. */ @SuppressWarnings("unchecked") public static OneVersusRest fit(T[] x, int[] y, int pos, int neg, BiFunction> trainer) { if (x.length != y.length) { throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length)); } ClassLabels codec = ClassLabels.fit(y); int k = codec.k; if (k <= 2) { throw new IllegalArgumentException(String.format("Only %d classes", k)); } int n = x.length; int[] labels = codec.y; Classifier[] classifiers = new Classifier[k]; PlattScaling[] platts = new PlattScaling[k]; IntStream.range(0, k).parallel().forEach(i -> { int[] yi = new int[n]; for (int j = 0; j < n; j++) { yi[j] = labels[j] == i ? pos : neg; } classifiers[i] = trainer.apply(x, yi); try { platts[i] = PlattScaling.fit(classifiers[i], x, yi); } catch (UnsupportedOperationException ex) { logger.info("The classifier doesn't support score function. Don't fit Platt scaling."); } }); return new OneVersusRest<>(classifiers, platts[0] == null ? null : platts); } /** * Fits a multi-class model with binary data frame classifiers. * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @param trainer the lambda to train binary classifiers. * @return the model. */ public static DataFrameClassifier fit(Formula formula, DataFrame data, BiFunction trainer) { Tuple[] x = data.stream().toArray(Tuple[]::new); int[] y = formula.y(data).toIntArray(); OneVersusRest model = fit(x, y, 1, 0, (Tuple[] rows, int[] labels) -> { DataFrame df = DataFrame.of(Arrays.asList(rows)); return trainer.apply(formula, df); }); StructType schema = formula.x(data.get(0)).schema(); return new DataFrameClassifier() { @Override public int numClasses() { return model.numClasses(); } @Override public int[] classes() { return model.classes(); } @Override public int predict(Tuple x) { return model.predict(x); } @Override public Formula formula() { return formula; } @Override public StructType schema() { return schema; } }; } @Override public int predict(T x) { int y = 0; double maxf = Double.NEGATIVE_INFINITY; for (int i = 0; i < k; i++) { double f = platt[i].scale(classifiers[i].score(x)); if (f > maxf) { y = i; maxf = f; } } return classes.valueOf(y); } @Override public boolean soft() { return true; } @Override public int predict(T x, double[] posteriori) { if (platt == null) { throw new UnsupportedOperationException("Platt scaling is not available"); } for (int i = 0; i < k; i++) { posteriori[i] = platt[i].scale(classifiers[i].score(x)); } MathEx.unitize1(posteriori); return classes.valueOf(MathEx.whichMax(posteriori)); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy