All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.clustering.cluster.MiniBatchCentroidCluster Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                MiniBatchCentroidCluster.java
 * Authors:             Jeff Piersol
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry Learning Core
 * 
 * Copyright October 20, 2016, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 */
package gov.sandia.cognition.learning.algorithm.clustering.cluster;

import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.mtj.DenseVectorFactoryMTJ;
import java.util.Collection;
import java.util.Collections;

/**
 *
 * @author Jeff Piersol
 */
public class MiniBatchCentroidCluster
    extends CentroidCluster
{

    /**
     * The number of data points that have been used to calculate this centroid.
     */
    private int numUpdates;

    /**
     *
     * @param initialPoints
     */
    public MiniBatchCentroidCluster(
        final Collection initialPoints)
    {
        if (initialPoints.size() <= 0)
        {
            throw new IllegalArgumentException(
                "You must provide at least one data point in order to create a cluster.");
        }

        // Create a centroid of the same type as input vectors
        this.centroid = initialPoints.stream().findAny().get().clone();
        this.centroid.zero();

        this.updateCluster(initialPoints);
    }

    /**
     * Updates the cluster for the given point.
     * 
     * @param dataPoint 
     *      The example to update for.
     */
    public void updateCluster(Vector dataPoint)
    {
        updateCluster(Collections.singletonList(dataPoint));
    }

    /**
     * Updates the clusters for all the given points.
     * 
     * @param dataPoints 
     *      The examples to update.
     */
    public void updateCluster(Collection dataPoints)
    {
        int initNumUpdates = numUpdates;
        this.numUpdates += dataPoints.size();
        double finalEta = 1 / (double) numUpdates;

        Vector shiftVector = DenseVectorFactoryMTJ.INSTANCE.createVector(
            centroid.getDimensionality());

        for (Vector sample : dataPoints)
        {
            shiftVector.plusEquals(sample);
        }

        // Move centroid towards data point
        centroid.scaleEquals(initNumUpdates * finalEta);
        centroid.scaledPlusEquals(finalEta, shiftVector);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy