gov.sandia.cognition.learning.function.vector.GaussianContextRecognizer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: GaussianContextRecognizer.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright July 3, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.function.vector;
import gov.sandia.cognition.algorithm.AnytimeAlgorithmWrapper;
import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.learning.algorithm.AnytimeBatchLearner;
import gov.sandia.cognition.learning.algorithm.clustering.KMeansClusterer;
import gov.sandia.cognition.learning.algorithm.clustering.ParallelizedKMeansClusterer;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.GaussianCluster;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.GaussianClusterCreator;
import gov.sandia.cognition.learning.algorithm.clustering.divergence.GaussianClusterDivergenceFunction;
import gov.sandia.cognition.learning.algorithm.clustering.initializer.NeighborhoodGaussianClusterInitializer;
import gov.sandia.cognition.statistics.distribution.MixtureOfGaussians;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFunction;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.NamedValue;
import java.util.ArrayList;
import java.util.Collection;
/**
* Uses a MixtureOfGaussians to compute the probability of the different
* constituent MultivariateGaussians (that is, the contexts)
*
* @author Kevin R. Dixon
* @since 1.0
*
*/
public class GaussianContextRecognizer
extends AbstractCloneableSerializable
implements VectorFunction,
VectorInputEvaluator,
VectorOutputEvaluator
{
/**
* Underlying MixtureOfGaussians that computes context probabilities
*/
private MixtureOfGaussians.PDF gaussianMixture;
/** Creates a new instance of GaussianContextRecognizer */
public GaussianContextRecognizer()
{
this( (MixtureOfGaussians.PDF) null );
}
/**
* Creates a new instance of GaussianContextRecognizer
* @param gaussianMixture
* Underlying MixtureOfGaussians that computes context probabilities
*/
public GaussianContextRecognizer(
MixtureOfGaussians.PDF gaussianMixture )
{
this.setGaussianMixture( gaussianMixture );
}
/**
* Creates a new instance of GaussianContextRecognizer
* @param clusters
*/
public GaussianContextRecognizer(
Collection clusters )
{
this();
this.consumeClusters( clusters );
}
/**
* Copy constructor
* @param other
* GaussianContextRecognizer to clone
*/
public GaussianContextRecognizer(
GaussianContextRecognizer other )
{
this( other.getGaussianMixture().clone() );
}
@Override
public GaussianContextRecognizer clone()
{
GaussianContextRecognizer clone = (GaussianContextRecognizer) super.clone();
clone.setGaussianMixture( this.getGaussianMixture().clone() );
return clone;
}
@Override
public Vector evaluate(
Vector input )
{
return VectorFactory.getDefault().copyArray(
this.getGaussianMixture().computeRandomVariableProbabilities( input ) );
}
/**
* Uses the given clusters to populate the internal clusters of this
* @param clusters Clusters from which to create the data for this
*/
public void consumeClusters(
Collection clusters )
{
ArrayList gaussians =
new ArrayList( clusters.size() );
double[] prior = new double[ clusters.size() ];
int index = 0;
for (GaussianCluster cluster : clusters)
{
gaussians.add( new MultivariateGaussian.PDF( cluster.getGaussian() ) );
prior[index] = cluster.getMembers().size();
index++;
}
this.setGaussianMixture(
new MixtureOfGaussians.PDF( gaussians, prior ) );
}
/**
* Getter for gaussianMixture
* @return
* Underlying MixtureOfGaussians that computes context probabilities
*/
public MixtureOfGaussians.PDF getGaussianMixture()
{
return this.gaussianMixture;
}
/**
* Setter for gaussianMixture
* @param gaussianMixture
* Underlying MixtureOfGaussians that computes context probabilities
*/
public void setGaussianMixture(
MixtureOfGaussians.PDF gaussianMixture )
{
this.gaussianMixture = gaussianMixture;
}
@Override
public int getInputDimensionality()
{
return this.getGaussianMixture().getDimensionality();
}
@Override
public int getOutputDimensionality()
{
return this.getGaussianMixture().getDistributionCount();
}
/**
* Creates a GaussianContextRecognizer from a Dataset using
* a BatchClusterer
*/
public static class Learner
extends AnytimeAlgorithmWrapper, Collection>>
implements AnytimeBatchLearner,GaussianContextRecognizer>,
MeasurablePerformanceAlgorithm
{
/**
* Creates a new {@code Learner}.
*/
public Learner()
{
super();
// By default use a k-means cluster creator.
final GaussianClusterCreator clusterCreator = new GaussianClusterCreator();
final NeighborhoodGaussianClusterInitializer initializer =
new NeighborhoodGaussianClusterInitializer();
final ParallelizedKMeansClusterer kMeans =
new ParallelizedKMeansClusterer();
kMeans.setInitializer(initializer);
kMeans.setDivergenceFunction(new GaussianClusterDivergenceFunction());
kMeans.setCreator(clusterCreator);
this.setAlgorithm(kMeans);
}
/**
* Creates a new instance of Learner
* @param algorithm
* Clustering algorithm to create the GaussianContextRecognizer from
*/
public Learner(
KMeansClusterer algorithm )
{
super( algorithm );
}
// SteeringWheelAngle(LW), BrakeForce(MBRE_ESP), AcceleratorPedalPct(PW), DistronicAccelerationRelative(DTR_ObjIntrstDist), DistronicDistance(DTR_ObjIntrstDist), DistronicPositionRelative(DTR_ObjIntrstPosn), DisctronicImpactTime, DistronicSpeedRelative(DTR_ObjIntrstRelSpd), GPSElevation(GPS_HEIGHT), Gear(GPS_SOG), LateralAccelerationDerived, WheelSpeedDerived
@Override
public GaussianContextRecognizer learn(
Collection extends Vector> data )
{
this.getAlgorithm().learn( data );
return this.getResult();
}
@Override
public GaussianContextRecognizer getResult()
{
if( (this.getAlgorithm() == null) ||
(this.getAlgorithm().getResult() == null) )
{
return null;
}
else
{
return new GaussianContextRecognizer( this.getAlgorithm().getResult() );
}
}
@Override
public NamedValue extends Number> getPerformance()
{
if( this.getAlgorithm() instanceof MeasurablePerformanceAlgorithm )
{
return ((MeasurablePerformanceAlgorithm) this.getAlgorithm()).getPerformance();
}
else
{
return null;
}
}
@Override
public boolean getKeepGoing()
{
return this.getAlgorithm().getKeepGoing();
}
@Override
public Collection extends Vector> getData()
{
return this.getAlgorithm().getData();
}
}
}