gov.sandia.cognition.learning.function.vector.DifferentiableGeneralizedLinearModel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: DifferentiableGeneralizedLinearModel.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright February 28, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.function.vector;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.function.scalar.IdentityScalarFunction;
import gov.sandia.cognition.math.DifferentiableUnivariateScalarFunction;
import gov.sandia.cognition.math.matrix.DifferentiableVectorFunction;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
/**
* A GradientDescenable version of a GeneralizedLinearModel, in
* other words, a GeneralizedLinearModel where the squashing
* function is differentiable
*
* @author Kevin R. Dixon
* @since 1.0
*
*/
public class DifferentiableGeneralizedLinearModel
extends GeneralizedLinearModel
implements GradientDescendable, DifferentiableVectorFunction
{
/**
* Default Constructor. Creates a 1x1 model with a identity function for
* the output.
*/
public DifferentiableGeneralizedLinearModel()
{
this( 1, 1, new IdentityScalarFunction() );
}
/**
* Creates a new instance of GeneralizedLinearModel
*
* @param numInputs
* Number of inputs of the function (number of matrix columns)
* @param numOutputs
* Number of outputs of the function (number of matrix rows)
* @param scalarFunction
* Function to apply to each output
*/
public DifferentiableGeneralizedLinearModel(
int numInputs,
int numOutputs,
DifferentiableUnivariateScalarFunction scalarFunction )
{
this( new MultivariateDiscriminant( numInputs, numOutputs ),
new ElementWiseDifferentiableVectorFunction( scalarFunction ) );
}
/**
* Creates a new instance of DifferentiableGeneralizedLinearModel
* @param matrixMultiply
* GradientDescendable that multiplies an input by the internal matrix
* @param squashingFunction
* VectorFunction that is applied to the output of the matrix multiply
*/
public DifferentiableGeneralizedLinearModel(
MultivariateDiscriminant matrixMultiply,
DifferentiableVectorFunction squashingFunction )
{
super( matrixMultiply, squashingFunction );
}
/**
* Creates a new instance of DifferentiableGeneralizedLinearModel
* @param matrixMultiply
* GradientDescendable that multiplies an input by the internal matrix
* @param scalarSquashingFunction
* scalar function that is applied to the output of the matrix multiply
*/
public DifferentiableGeneralizedLinearModel(
MultivariateDiscriminant matrixMultiply,
DifferentiableUnivariateScalarFunction scalarSquashingFunction )
{
this( matrixMultiply, new ElementWiseDifferentiableVectorFunction(
scalarSquashingFunction ) );
}
/**
* Creates a new instance of DifferentiableGeneralizedLinearModel
*
* @param other DifferentiableGeneralizedLinearModel to copy
*/
public DifferentiableGeneralizedLinearModel(
DifferentiableGeneralizedLinearModel other )
{
super( other );
}
@Override
public DifferentiableVectorFunction getSquashingFunction()
{
return (DifferentiableVectorFunction) super.getSquashingFunction();
}
public Matrix computeParameterGradient(
Vector input )
{
Matrix gradient =
this.getDiscriminant().computeParameterGradient( input );
Vector y = this.getDiscriminant().evaluate( input );
Matrix derivative = this.getSquashingFunction().differentiate( y );
return derivative.times( gradient );
}
@Override
public DifferentiableGeneralizedLinearModel clone()
{
return (DifferentiableGeneralizedLinearModel) super.clone();
}
public Matrix differentiate(
Vector input )
{
Matrix dudx = this.getDiscriminant().differentiate( input );
Vector u = this.getDiscriminant().evaluate( input );
Matrix dydu = this.getSquashingFunction().differentiate( u );
return dydu.times( dudx );
}
}