All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.function.vector.SquashedMatrixMultiplyVectorFunction Maven / Gradle / Ivy

/*
 * File:                SquashedMatrixMultiplyVectorFunction.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright February 28, 2007, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 */

package gov.sandia.cognition.learning.function.vector;

import gov.sandia.cognition.learning.function.scalar.AtanFunction;
import gov.sandia.cognition.math.UnivariateScalarFunction;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFunction;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.math.matrix.VectorizableVectorFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;

/**
 * A VectorizableVectorFunction that is a matrix multiply followed by a
 * VectorFunction... a no-hidden-layer neural network
 *
 * @author Kevin R. Dixon
 * @since  1.0
 *
 */
public class SquashedMatrixMultiplyVectorFunction
    extends AbstractCloneableSerializable
    implements VectorizableVectorFunction,
    VectorInputEvaluator,
    VectorOutputEvaluator
{

    /**
     * GradientDescendable that multiplies an input by the internal matrix
     */
    private MatrixMultiplyVectorFunction matrixMultiply;

    /**
     * VectorFunction that is applied to the output of the matrix multiply
     */
    private VectorFunction squashingFunction;

    /**
     * Default constructor.
     */
    public SquashedMatrixMultiplyVectorFunction()
    {
        this( 1, 1, new AtanFunction() );
    }

    /**
     * Creates a new instance of SquashedMatrixMultiplyVectorFunction
     * 
     * @param numInputs
     * Number of inputs of the function (number of matrix columns)
     * @param numOutputs
     * Number of outputs of the function (number of matrix rows)
     * @param scalarFunction 
     * Function to apply to each output
     */
    public SquashedMatrixMultiplyVectorFunction(
        int numInputs,
        int numOutputs,
        UnivariateScalarFunction scalarFunction )
    {
        this( new MatrixMultiplyVectorFunction( numInputs, numOutputs ),
            new ElementWiseVectorFunction( scalarFunction ) );
    }
    
    /**
     * Creates a new instance of SquashedMatrixMultiplyVectorFunction
     * @param matrixMultiply 
     * GradientDescendable that multiplies an input by the internal matrix
     * @param squashingFunction 
     * VectorFunction that is applied to the output of the matrix multiply
     */
    public SquashedMatrixMultiplyVectorFunction(
        MatrixMultiplyVectorFunction matrixMultiply,
        VectorFunction squashingFunction )
    {
        this.setMatrixMultiply( matrixMultiply );
        this.setSquashingFunction( squashingFunction );
    }

    /**
     * Creates a new instance of SquashedMatrixMultiplyVectorFunction
     * @param other SquashedMatrixMultiplyVectorFunction to copy
     */
    public SquashedMatrixMultiplyVectorFunction(
        SquashedMatrixMultiplyVectorFunction other )
    {
        this( other.getMatrixMultiply().clone(), other.getSquashingFunction() );
    }

    /**
     * Getter for matrixMultiply
     * @return 
     * GradientDescendable that multiplies an input by the internal matrix
     */
    public MatrixMultiplyVectorFunction getMatrixMultiply()
    {
        return this.matrixMultiply;
    }

    /**
     * Setter for matrixMultiply
     * @param matrixMultiply 
     * GradientDescendable that multiplies an input by the internal matrix
     */
    public void setMatrixMultiply(
        MatrixMultiplyVectorFunction matrixMultiply )
    {
        this.matrixMultiply = matrixMultiply;
    }

    /**
     * Getter for squashingFunction
     * @return 
     * VectorFunction that is applied to the output of the matrix multiply
     */
    public VectorFunction getSquashingFunction()
    {
        return this.squashingFunction;
    }

    /**
     * Setter for squashingFunction
     * @param squashingFunction 
     * VectorFunction that is applied to the output of the matrix multiply
     */
    public void setSquashingFunction(
        VectorFunction squashingFunction )
    {
        this.squashingFunction = squashingFunction;
    }

    public Vector convertToVector()
    {
        return this.getMatrixMultiply().convertToVector();
    }

    public void convertFromVector(
        Vector parameters )
    {
        this.getMatrixMultiply().convertFromVector( parameters );
    }

    public Vector evaluate(
        Vector input )
    {
        return this.squashingFunction.evaluate(
            this.matrixMultiply.evaluate( input ) );
    }

    @Override
    public SquashedMatrixMultiplyVectorFunction clone()
    {
        SquashedMatrixMultiplyVectorFunction clone =
            (SquashedMatrixMultiplyVectorFunction) super.clone();
        clone.setMatrixMultiply( 
            ObjectUtil.cloneSafe(this.getMatrixMultiply()) );
        return clone;
    }

    @Override
    public String toString()
    {
        String retval = "Squashing: " + this.getSquashingFunction().getClass()
            + "Weights:\n" + this.getMatrixMultiply().getInternalMatrix();
        return retval;
    }

    public int getInputDimensionality()
    {
        return this.getMatrixMultiply().getInputDimensionality();
    }

    public int getOutputDimensionality()
    {
        return this.getMatrixMultiply().getOutputDimensionality();
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy