All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.function.vector.GeneralizedLinearModel Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                GeneralizedLinearModel.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright February 28, 2007, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 */

package gov.sandia.cognition.learning.function.vector;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.function.scalar.IdentityScalarFunction;
import gov.sandia.cognition.math.UnivariateScalarFunction;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFunction;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.math.matrix.VectorizableVectorFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;

/**
 * A VectorizableVectorFunction that is a matrix multiply followed by a
 * VectorFunction... a no-hidden-layer neural network
 *
 * @author Kevin R. Dixon
 * @since  1.0
 *
 */
@PublicationReference(
    author="Wikipedia",
    title="Generalized linear model",
    type=PublicationType.WebPage,
    year=2011,
    url="http://en.wikipedia.org/wiki/Generalized_linear_model"
)
public class GeneralizedLinearModel
    extends AbstractCloneableSerializable
    implements VectorizableVectorFunction,
    VectorInputEvaluator,
    VectorOutputEvaluator
{

    /**
     * GradientDescendable that multiplies an input by the internal matrix
     */
    private MultivariateDiscriminant discriminant;

    /**
     * VectorFunction that is applied to the output of the matrix multiply
     */
    private VectorFunction squashingFunction;

    /**
     * Default constructor. Creates a 1x1 model with an identity function for
     * the output.
     */
    public GeneralizedLinearModel()
    {
        this(1, 1, new IdentityScalarFunction());
    }

    /**
     * Creates a new instance of GeneralizedLinearModel
     * 
     * @param numInputs
     * Number of inputs of the function (number of matrix columns)
     * @param numOutputs
     * Number of outputs of the function (number of matrix rows)
     * @param scalarFunction 
     * Function to apply to each output
     */
    public GeneralizedLinearModel(
        int numInputs,
        int numOutputs,
        UnivariateScalarFunction scalarFunction )
    {
        this( new MultivariateDiscriminant( numInputs, numOutputs ),
            new ElementWiseVectorFunction( scalarFunction ) );
    }
    
    /**
     * Creates a new instance of GeneralizedLinearModel
     * @param discriminant
     * GradientDescendable that multiplies an input by the internal matrix
     * @param squashingFunction 
     * VectorFunction that is applied to the output of the matrix multiply
     */
    public GeneralizedLinearModel(
        MultivariateDiscriminant discriminant,
        VectorFunction squashingFunction )
    {
        super();
        
        this.setDiscriminant( discriminant );
        this.setSquashingFunction( squashingFunction );
    }

    /**
     * Creates a new instance of GeneralizedLinearModel
     * @param other GeneralizedLinearModel to copy
     */
    public GeneralizedLinearModel(
        GeneralizedLinearModel other )
    {
        this( other.getDiscriminant().clone(), other.getSquashingFunction() );
    }

    /**
     * Creates a new instance of GeneralizedLinearModel
     * @param matrixMultiply
     * GradientDescendable that multiplies an input by the internal matrix
     * @param scalarSquashingFunction
     * scalar function that is applied to the output of the matrix multiply
     */
    public GeneralizedLinearModel(
        MultivariateDiscriminant matrixMultiply,
        UnivariateScalarFunction scalarSquashingFunction )
    {
        this( matrixMultiply, new ElementWiseVectorFunction(
            scalarSquashingFunction ) );
    }

    /**
     * Getter for discriminant
     * @return 
     * GradientDescendable that multiplies an input by the internal matrix
     */
    public MultivariateDiscriminant getDiscriminant()
    {
        return this.discriminant;
    }

    /**
     * Setter for discriminant
     * @param discriminant
     * GradientDescendable that multiplies an input by the internal matrix
     */
    public void setDiscriminant(
        MultivariateDiscriminant discriminant )
    {
        this.discriminant = discriminant;
    }

    /**
     * Getter for squashingFunction
     * @return 
     * VectorFunction that is applied to the output of the matrix multiply
     */
    public VectorFunction getSquashingFunction()
    {
        return this.squashingFunction;
    }

    /**
     * Setter for squashingFunction
     * @param squashingFunction 
     * VectorFunction that is applied to the output of the matrix multiply
     */
    public void setSquashingFunction(
        VectorFunction squashingFunction )
    {
        this.squashingFunction = squashingFunction;
    }

    public Vector convertToVector()
    {
        return this.getDiscriminant().convertToVector();
    }

    public void convertFromVector(
        Vector parameters )
    {
        this.getDiscriminant().convertFromVector( parameters );
    }

    public Vector evaluate(
        Vector input )
    {
        return this.squashingFunction.evaluate(
            this.discriminant.evaluate( input ) );
    }

    @Override
    public GeneralizedLinearModel clone()
    {
        GeneralizedLinearModel clone =
            (GeneralizedLinearModel) super.clone();
        clone.setDiscriminant(
            ObjectUtil.cloneSafe(this.getDiscriminant()) );
        return clone;
    }

    @Override
    public String toString()
    {
        String retval = "Squashing: " + this.getSquashingFunction().getClass()
            + "Weights:\n" + this.getDiscriminant().getDiscriminant();
        return retval;
    }

    public int getInputDimensionality()
    {
        return this.getDiscriminant().getInputDimensionality();
    }

    public int getOutputDimensionality()
    {
        return this.getDiscriminant().getOutputDimensionality();
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy