![JAR search and dependency download from the Maven repository](/logo.png)
gov.sandia.cognition.learning.function.vector.GeneralizedLinearModel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: GeneralizedLinearModel.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright February 28, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.function.vector;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.function.scalar.IdentityScalarFunction;
import gov.sandia.cognition.math.UnivariateScalarFunction;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFunction;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.math.matrix.VectorizableVectorFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
/**
* A VectorizableVectorFunction that is a matrix multiply followed by a
* VectorFunction... a no-hidden-layer neural network
*
* @author Kevin R. Dixon
* @since 1.0
*
*/
@PublicationReference(
author="Wikipedia",
title="Generalized linear model",
type=PublicationType.WebPage,
year=2011,
url="http://en.wikipedia.org/wiki/Generalized_linear_model"
)
public class GeneralizedLinearModel
extends AbstractCloneableSerializable
implements VectorizableVectorFunction,
VectorInputEvaluator,
VectorOutputEvaluator
{
/**
* GradientDescendable that multiplies an input by the internal matrix
*/
private MultivariateDiscriminant discriminant;
/**
* VectorFunction that is applied to the output of the matrix multiply
*/
private VectorFunction squashingFunction;
/**
* Default constructor. Creates a 1x1 model with an identity function for
* the output.
*/
public GeneralizedLinearModel()
{
this(1, 1, new IdentityScalarFunction());
}
/**
* Creates a new instance of GeneralizedLinearModel
*
* @param numInputs
* Number of inputs of the function (number of matrix columns)
* @param numOutputs
* Number of outputs of the function (number of matrix rows)
* @param scalarFunction
* Function to apply to each output
*/
public GeneralizedLinearModel(
int numInputs,
int numOutputs,
UnivariateScalarFunction scalarFunction )
{
this( new MultivariateDiscriminant( numInputs, numOutputs ),
new ElementWiseVectorFunction( scalarFunction ) );
}
/**
* Creates a new instance of GeneralizedLinearModel
* @param discriminant
* GradientDescendable that multiplies an input by the internal matrix
* @param squashingFunction
* VectorFunction that is applied to the output of the matrix multiply
*/
public GeneralizedLinearModel(
MultivariateDiscriminant discriminant,
VectorFunction squashingFunction )
{
super();
this.setDiscriminant( discriminant );
this.setSquashingFunction( squashingFunction );
}
/**
* Creates a new instance of GeneralizedLinearModel
* @param other GeneralizedLinearModel to copy
*/
public GeneralizedLinearModel(
GeneralizedLinearModel other )
{
this( other.getDiscriminant().clone(), other.getSquashingFunction() );
}
/**
* Creates a new instance of GeneralizedLinearModel
* @param matrixMultiply
* GradientDescendable that multiplies an input by the internal matrix
* @param scalarSquashingFunction
* scalar function that is applied to the output of the matrix multiply
*/
public GeneralizedLinearModel(
MultivariateDiscriminant matrixMultiply,
UnivariateScalarFunction scalarSquashingFunction )
{
this( matrixMultiply, new ElementWiseVectorFunction(
scalarSquashingFunction ) );
}
/**
* Getter for discriminant
* @return
* GradientDescendable that multiplies an input by the internal matrix
*/
public MultivariateDiscriminant getDiscriminant()
{
return this.discriminant;
}
/**
* Setter for discriminant
* @param discriminant
* GradientDescendable that multiplies an input by the internal matrix
*/
public void setDiscriminant(
MultivariateDiscriminant discriminant )
{
this.discriminant = discriminant;
}
/**
* Getter for squashingFunction
* @return
* VectorFunction that is applied to the output of the matrix multiply
*/
public VectorFunction getSquashingFunction()
{
return this.squashingFunction;
}
/**
* Setter for squashingFunction
* @param squashingFunction
* VectorFunction that is applied to the output of the matrix multiply
*/
public void setSquashingFunction(
VectorFunction squashingFunction )
{
this.squashingFunction = squashingFunction;
}
public Vector convertToVector()
{
return this.getDiscriminant().convertToVector();
}
public void convertFromVector(
Vector parameters )
{
this.getDiscriminant().convertFromVector( parameters );
}
public Vector evaluate(
Vector input )
{
return this.squashingFunction.evaluate(
this.discriminant.evaluate( input ) );
}
@Override
public GeneralizedLinearModel clone()
{
GeneralizedLinearModel clone =
(GeneralizedLinearModel) super.clone();
clone.setDiscriminant(
ObjectUtil.cloneSafe(this.getDiscriminant()) );
return clone;
}
@Override
public String toString()
{
String retval = "Squashing: " + this.getSquashingFunction().getClass()
+ "Weights:\n" + this.getDiscriminant().getDiscriminant();
return retval;
}
public int getInputDimensionality()
{
return this.getDiscriminant().getInputDimensionality();
}
public int getOutputDimensionality()
{
return this.getDiscriminant().getOutputDimensionality();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy