com.etsy.conjecture.model.ClusteringModel Maven / Gradle / Ivy
package com.etsy.conjecture.model;
import com.etsy.conjecture.data.LabeledInstance;
import com.etsy.conjecture.data.ClusterLabel;
import com.etsy.conjecture.data.MulticlassPrediction;
import com.etsy.conjecture.data.StringKeyedVector;
import com.etsy.conjecture.data.LazyVector;
import com.etsy.conjecture.Utilities;
import com.etsy.conjecture.data.Label;
import com.etsy.conjecture.data.LabeledInstance;
import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import static com.google.common.base.Preconditions.checkArgument;
public abstract class ClusteringModel implements UpdateableModel>, Serializable {
static final long serialVersionUID = 666L;
protected double projectionErrorTolerance = 0.01;
protected double projectionBallRadius = 1.0;
protected int numClusters = 100;
protected Map param = new HashMap();
public void update(LabeledInstance instance) {
update(instance.getVector());
}
public void update(Collection> instances) {
for(LabeledInstance instance : instances) {
update(instance.getVector());
}
}
public abstract void update(StringKeyedVector instance);
public abstract ClusterLabel predict(StringKeyedVector instance);
protected ClusteringModel() {
Map init_param = new HashMap();
for (int i = 0; i < numClusters; i++) {
init_param.put(Integer.toString(i), new StringKeyedVector());
}
this.param = init_param;
}
protected ClusteringModel(HashMap param) {
Map init_param = new HashMap();
Iterator it = param.entrySet().iterator();
while (it.hasNext()) {
Map.Entry pairs = (Map.Entry)it.next();
init_param.put(pairs.getKey(), pairs.getValue());
it.remove();
}
this.param = init_param;
}
public void setFreezeFeatureSet(boolean freeze) {
for(Map.Entry e : param.entrySet()) {
e.getValue().setFreezeKeySet(freeze);
}
}
public void reScale(double scale) {
for(String cat : param.keySet()) {
param.get(cat).mul(scale);
}
}
public void merge(ClusteringModel model, double scale) {
for(String cat : param.keySet()) {
param.get(cat).addScaled(model.param.get(cat), scale);
}
}
public ClusteringModel setNumClusters(int k) {
checkArgument(k >= 0, "number of clusters must be non-negative, given: %s", k);
this.numClusters = k;
return this;
}
public ClusteringModel setL1ProjectionErrorTolerance(double e) {
checkArgument(e >= 0, "error tolerance must be non-negative, given: %s", e);
this.projectionErrorTolerance = e;
return this;
}
public ClusteringModel setL1ProjectionBallRadius(double r) {
checkArgument(r >= 0, "radius must be non-negative, given: %s", r);
this.projectionBallRadius = r;
return this;
}
public Iterator> decompose() {
throw new UnsupportedOperationException("not done yet");
}
public void setParameter(String name, double value){
throw new UnsupportedOperationException("not done yet");
}
public long getEpoch() {
return 0;
}
public void setEpoch(long epoch) {
// this class doesnt care about epoch.
}
}