All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.function.vector.MultivariateDiscriminant Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                MultivariateDiscriminant.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright April 26, 2006, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 */

package gov.sandia.cognition.learning.function.vector;

import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.CodeReviewResponse;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.math.matrix.VectorizableDifferentiableVectorFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;

/**
 * Allows learning algorithms (vectorizing, differentiating) on a matrix*vector
 * multiply.
 *
 * @author Kevin R. Dixon
 * @since 1.0
 */
@CodeReview(
    reviewer="Justin Basilico",
    date="2006-10-06",
    changesNeeded=true,
    comments={
        "Can you just add a comment for why the differentiation code is correct?",
        "Otherwise, class looks fine."
    },
    response=@CodeReviewResponse(
        respondent="Kevin R. Dixon",
        date="2006-10-06",
        moreChangesNeeded=false,
        comments="Added in-code comment describing the derivation of the differentiation formulae."
    )
)
public class MultivariateDiscriminant
    extends AbstractCloneableSerializable
    implements VectorizableDifferentiableVectorFunction,
    VectorInputEvaluator,
    VectorOutputEvaluator,
    GradientDescendable
{

    /** Internal matrix to premultiply input vectors by. */
    private Matrix discriminant;

    /**
     * Default constructor.
     */
    public MultivariateDiscriminant()
    {
        this( 1, 1 );
    }

    /**
     * Creates a new MultivariateDiscriminant
     * @param numInputs
     * Number of inputs of the function (number of matrix columns)
     * @param numOutputs
     * Number of outputs of the function (number of matrix rows)
     */
    public MultivariateDiscriminant(
        final int numInputs,
        final int numOutputs )
    {
        this( MatrixFactory.getDefault().createIdentity(numOutputs, numInputs) );
    }
    
    /**
     * Creates a new instance of MatrixVectorMultiplyFunction.
     *
     * @param discriminant internal matrix to premultiply input vectors by.
     */
    public MultivariateDiscriminant(
        final Matrix discriminant )
    {
        this.setDiscriminant( discriminant );
    }

    /**
     * Copy constructor
     * @param other MultivariateDiscriminant to copy
     */
    public MultivariateDiscriminant(
        final MultivariateDiscriminant other )
    {
        this( other.getDiscriminant().clone() );
    }

    @Override
    public MultivariateDiscriminant clone()
    {
        MultivariateDiscriminant clone =
            (MultivariateDiscriminant) super.clone();
        clone.setDiscriminant( this.getDiscriminant().clone() );
        return clone;
    }

    /**
     * Getter for discriminant.
     *
     * @return internal matrix to premultiply input vectors by
     */
    public Matrix getDiscriminant()
    {
        return this.discriminant;
    }

    /**
     * Setter for discriminant
     *
     * @param discriminant internal matrix to premultiply input vectors by
     */
    protected void setDiscriminant(
        final Matrix discriminant )
    {
        this.discriminant = discriminant;
    }

    /**
     * Creates a row-stacked version of the discriminant.
     *
     * @return row-stacked Vector representing the discriminant
     */
    @Override
    public Vector convertToVector()
    {
        return this.discriminant.convertToVector();
    }

    /**
     * Uploads a matrix from a row-stacked vector of parameters.
     *
     * @param parameters row-stacked version of discriminant
     */
    @Override
    public void convertFromVector(
        Vector parameters )
    {
        this.discriminant.convertFromVector( parameters );
    }

    @Override
    public Vector evaluate(
        final Vector input )
    {
        return this.discriminant.times( input );
    }

    @Override
    public Matrix differentiate(
        final Vector input )
    {
        return this.getDiscriminant();
    }

    @Override
    public Matrix computeParameterGradient(
        final Vector input )
    {
        return computeParameterGradient(this.discriminant, input);
    }

    /**
     * Computes the parameter gradient of the given matrix post-multiplied
     * by the input Vector
     * @param matrix
     * Matrix to pre-multiply the input Vector.
     * @param input
     * Vector to post-multiply the Matrix.
     * @return
     * Derivative of the matrix elements with respect to the inputs.
     */
    public static Matrix computeParameterGradient(
        final Matrix matrix,
        final Vector input )
    {
        int M = matrix.getNumRows();
        int N = matrix.getNumColumns();

//        Matrix gradient = MatrixFactory.getDefault().createMatrix( M, M * N );
        Matrix gradient = MatrixFactory.getSparseDefault().createMatrix( M, M * N );

        // Derivation of this gradient assumes that the parameters are
        // column-stacked from the underlying matrices...
        // This makes a gradient which looks like:
        // [ x0 0   0   x1  0   0]
        // [ 0  x0  0   0   x1  0]
        // [ 0  0   x0  0   0   x1]
        // And so forth...
        //
        int columnIndex = 0;
        for (int j = 0; j < N; j++)
        {
            double inputValue = input.getElement( j );
            for (int i = 0; i < M; i++)
            {
                gradient.setElement( i, columnIndex, inputValue );
                columnIndex++;
            }
        }

        return gradient;

    }


    @Override
    public String toString()
    {
        return this.getDiscriminant().toString();
    }

    @Override
    public int getInputDimensionality()
    {
        return this.getDiscriminant().getNumColumns();
    }

    @Override
    public int getOutputDimensionality()
    {
        return this.getDiscriminant().getNumRows();
    }
    
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy