All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.function.categorization.FisherLinearDiscriminantBinaryCategorizer Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                FisherLinearDiscriminantBinaryCategorizer.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright October 9, 2007, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 */

package gov.sandia.cognition.learning.function.categorization;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.scalar.LinearDiscriminant;
import gov.sandia.cognition.math.MultivariateStatisticsUtil;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.method.ReceiverOperatingCharacteristic;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;

/**
 * A Fisher Linear Discriminant classifier, which creates an optimal linear
 * separating plane between two Gaussian classes of different covariances.
 *
 * @author Kevin R. Dixon
 * @since  2.0
 */
@PublicationReference(
    author="Wikipedia",
    title="Linear discriminant analysis",
    type=PublicationType.WebPage,
    year=2009,
    url="http://en.wikipedia.org/wiki/Linear_discriminant_analysis#Fisher.27s_linear_discriminant"
)
public class FisherLinearDiscriminantBinaryCategorizer
    extends ScalarFunctionToBinaryCategorizerAdapter
{

    /**
     * Default constructor
     */
    public FisherLinearDiscriminantBinaryCategorizer()
    {
        this( (Vector) null, DEFAULT_THRESHOLD );
    }

    /** 
     * Creates a new of {@code FisherLinearDiscriminantBinaryCategorizer}.
     * 
     * @param   weightVector The weight vector.
     * @param   threshold The threshold.
     */
    public FisherLinearDiscriminantBinaryCategorizer(
        final Vector weightVector,
        final double threshold)
    {
        this(new LinearDiscriminant(weightVector),threshold);
    }

    /**
     * Creates a new of {@code FisherLinearDiscriminantBinaryCategorizer}.
     * 
     * @param   discriminant The linear discriminant to use.
     * @param   threshold The threshold.
     */
    public FisherLinearDiscriminantBinaryCategorizer(
        final LinearDiscriminant discriminant,
        final double threshold )
    {
        super(discriminant, threshold);
    }

    @Override
    public FisherLinearDiscriminantBinaryCategorizer clone()
    {
        return (FisherLinearDiscriminantBinaryCategorizer) super.clone();
    }

    /**
     * This class implements a closed form solver for the Fisher linear
     * discriminant binary categorizer.
     */
    public static class ClosedFormSolver
        extends AbstractCloneableSerializable
        implements SupervisedBatchLearner
    {
        /** The default covariance. */
        private double defaultCovariance;

        /**
         * Default constructor.
         */
        public ClosedFormSolver()
        {
            this( MultivariateGaussian.MaximumLikelihoodEstimator.DEFAULT_COVARIANCE );
        }

        /**
         * Creates a new {@code ClosedFormSolver}.
         * 
         * @param   defaultCovariance The default covariance.
         */
        public ClosedFormSolver(
            double defaultCovariance)
        {
            this.defaultCovariance = defaultCovariance;
        }

        public FisherLinearDiscriminantBinaryCategorizer learn(
            Collection> data)
        {
            return ClosedFormSolver.learn(data, this.defaultCovariance);
        }
        
        /**
         * Closed-form learning algorithm for a Fisher Linear Discriminant.
         * 
         * @param   data The data to learn the discriminant categorizer from.
         * @param   defaultCovariance The default covariance.
         * @return  A discriminant categorizer learned from the data.
         */
        public static FisherLinearDiscriminantBinaryCategorizer learn(
            Collection> data,
            final double defaultCovariance)
        {
            // Split the data into two classes based on their
            DefaultPair, LinkedList> pair =
                DatasetUtil.splitDatasets(data);
            LinkedList d1 = pair.getFirst();
            LinkedList d0 = pair.getSecond();

            // This is faster than estimating a MultivariateGaussian as
            // the Gaussian will automatically invert the covariance matrix
            // and cache that
            Pair r1 =
                MultivariateStatisticsUtil.computeMeanAndCovariance(d1);
            Vector m1 = r1.getFirst();
            Matrix c1 = r1.getSecond();

            Pair r0 =
                MultivariateStatisticsUtil.computeMeanAndCovariance(d0);
            Vector m0 = r0.getFirst();
            Matrix c0 = r0.getSecond();

            Matrix cinverse;
            if (defaultCovariance != 0.0)
            {
                int M = m0.getDimensionality();
                Matrix ci = MatrixFactory.getDefault().createIdentity(M, M).scale(defaultCovariance);
                cinverse = c0.plus(c1.plus(ci)).inverse();
            }
            else
            {
                cinverse = c0.plus(c1).inverse();
            }

            Vector weightVector = cinverse.times(m1.minus(m0));

            // Technically, the threshold is supposed to be zero, but we might
            // try to do better
            LinearDiscriminant discriminant =
                new LinearDiscriminant(weightVector);
            ArrayList> doubleData =
                new ArrayList>(data.size());
            for (InputOutputPair sample : data)
            {
                Double value = discriminant.evaluate(sample.getInput());
                doubleData.add(new DefaultInputOutputPair(
                    value, sample.getOutput()));
            }


            ReceiverOperatingCharacteristic roc =
                ReceiverOperatingCharacteristic.create(doubleData);

            ReceiverOperatingCharacteristic.Statistic stats =
                roc.computeStatistics();
            return new FisherLinearDiscriminantBinaryCategorizer(
                discriminant, stats.getOptimalThreshold().getClassifier().getThreshold() );
        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy