gov.sandia.cognition.learning.algorithm.clustering.cluster.MiniBatchCentroidCluster Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: MiniBatchCentroidCluster.java
* Authors: Jeff Piersol
* Company: Sandia National Laboratories
* Project: Cognitive Foundry Learning Core
*
* Copyright October 20, 2016, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
*/
package gov.sandia.cognition.learning.algorithm.clustering.cluster;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.mtj.DenseVectorFactoryMTJ;
import java.util.Collection;
import java.util.Collections;
/**
*
* @author Jeff Piersol
*/
public class MiniBatchCentroidCluster
extends CentroidCluster
{
/**
* The number of data points that have been used to calculate this centroid.
*/
private int numUpdates;
/**
*
* @param initialPoints
*/
public MiniBatchCentroidCluster(
final Collection extends Vector> initialPoints)
{
if (initialPoints.size() <= 0)
{
throw new IllegalArgumentException(
"You must provide at least one data point in order to create a cluster.");
}
// Create a centroid of the same type as input vectors
this.centroid = initialPoints.stream().findAny().get().clone();
this.centroid.zero();
this.updateCluster(initialPoints);
}
/**
* Updates the cluster for the given point.
*
* @param dataPoint
* The example to update for.
*/
public void updateCluster(Vector dataPoint)
{
updateCluster(Collections.singletonList(dataPoint));
}
/**
* Updates the clusters for all the given points.
*
* @param dataPoints
* The examples to update.
*/
public void updateCluster(Collection extends Vector> dataPoints)
{
int initNumUpdates = numUpdates;
this.numUpdates += dataPoints.size();
double finalEta = 1 / (double) numUpdates;
Vector shiftVector = DenseVectorFactoryMTJ.INSTANCE.createVector(
centroid.getDimensionality());
for (Vector sample : dataPoints)
{
shiftVector.plusEquals(sample);
}
// Move centroid towards data point
centroid.scaleEquals(initNumUpdates * finalEta);
centroid.scaledPlusEquals(finalEta, shiftVector);
}
}