All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.etsy.conjecture.model.KMeans Maven / Gradle / Ivy

There is a newer version: 0.2.3
Show newest version
package com.etsy.conjecture.model;

import com.etsy.conjecture.Utilities;
import com.etsy.conjecture.data.ClusterLabel;
import com.etsy.conjecture.data.StringKeyedVector;
import com.etsy.conjecture.data.ClusterPrediction;
import com.etsy.conjecture.Utilities;

import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import com.google.common.collect.Maps;

/**
 *  Implements sparse, streaming kmeans as described here:
 *  http://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf
 */ 
public class KMeans extends ClusteringModel {
  
  private static final long serialVersionUID = 1L;
  private Map clusterCounts = new HashMap();

  public KMeans(String[] categories) {
    for(String s : categories) {
        param.put(s, new StringKeyedVector());
        clusterCounts.put(s, 0.0);
    }
  }

  private Map predefinedCenters;

  public KMeans(Map centers) {
    this.predefinedCenters = Maps.newHashMap(centers);
    for(String key : predefinedCenters.keySet()) {
      param.put(key, predefinedCenters.get(key));
      clusterCounts.put(key, 0.0);
    }
  }

  public ClusterPrediction predict(StringKeyedVector instance) {
    Map scores = new HashMap();
    for(Map.Entry e : param.entrySet()) {
      scores.put(e.getKey(), e.getValue().dot(instance));
    }
    return new ClusterPrediction(scores);
  }

  public void update(StringKeyedVector instance) {
    // Get closest center to instance
    String closest_center = predict(instance).getLabel();
    // Update the per center count
    Double current_count = clusterCounts.get(closest_center);
    clusterCounts.put(closest_center, current_count+1.0);
    // Get per center learning rate
    Double learning_rate = 1.0/clusterCounts.get(closest_center);
    // take gradient step
    StringKeyedVector center = param.get(closest_center);
    center.mul(1-learning_rate);
    instance.mul(learning_rate);
    center.add(instance);
    l1Projection(center);
    param.put(closest_center, center);
  }

  public Double getCurrent(StringKeyedVector center, Double theta) {
    Double current = 0.0;
    for (double v : center.values()) {
      current += Math.max(0, Math.abs(v)-theta);
    }
    return current;
  }

  /*
   *  Use bisection to find an approximate value of theta
   */
  public Double findTheta(StringKeyedVector center, Double norm) {
    Double upper = center.max();
    Double lower = 0.0;
    Double current = norm;
    Double theta = 0.0;
    while (current > projectionBallRadius * (1 + projectionErrorTolerance)) {
      theta = (upper + lower)/2.0;
      current = getCurrent(center, theta);
      if (current <= projectionBallRadius) {
        upper = theta;
      } else {
        lower = theta;
      }
    }
    return theta;
  }

  public void doProjection(StringKeyedVector center, Double theta) {
    Iterator it = center.iterator();
        while (it.hasNext()) {
            Map.Entry pairs = (Map.Entry)it.next();
            String key = pairs.getKey();
            double value = pairs.getValue();
            double projectedValue = Math.signum(value) * Math.max(0.0, Math.abs(value) - theta);
            center.setCoordinate(key, projectedValue);
        }
  }

  /**
   * An e-Accurate projection to the L1 ball, described here:
   * http://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf
   */
  public void l1Projection(StringKeyedVector center) {
    Double norm = center.LPNorm(1.0);
    if (norm <= projectionBallRadius + projectionErrorTolerance) {
      return;
    } else {
      Double theta = findTheta(center, norm);
      doProjection(center, theta);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy