com.etsy.conjecture.model.UpdateableMulticlassLinearModel Maven / Gradle / Ivy
package com.etsy.conjecture.model;
import static com.google.common.base.Preconditions.checkArgument;
import gnu.trove.function.TDoubleFunction;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import com.etsy.conjecture.Utilities;
import com.etsy.conjecture.data.MulticlassLabel;
import com.etsy.conjecture.data.LabeledInstance;
import com.etsy.conjecture.data.BinaryLabeledInstance;
import com.etsy.conjecture.data.MulticlassLabeledInstance;
import com.etsy.conjecture.data.MulticlassPrediction;
import com.etsy.conjecture.data.LazyVector;
import com.etsy.conjecture.data.StringKeyedVector;
import com.etsy.conjecture.data.RealValuedLabel;
import com.etsy.conjecture.data.BinaryLabel;
public class UpdateableMulticlassLinearModel implements
UpdateableModel,
Comparable, Serializable {
private static final long serialVersionUID = 8549108867384062857L;
protected String modelType;
private String argString = "NOT SET";
protected long epoch;
protected Map> param = new HashMap>();
public UpdateableMulticlassLinearModel(Map> param) {
this.param = param;
this.epoch = 0;
this.modelType = this.getModelType();
}
public void setArgString(String s) {
argString = s;
}
public String getArgString() {
return argString;
}
public void setModelType(String modelType) {
this.modelType = modelType;
}
public String getModelType() {
return modelType;
}
public Iterator> decompose() {
throw new UnsupportedOperationException("not done yet");
}
public void setParameter(String name, double value) {
throw new UnsupportedOperationException("not done yet");
}
public void reScale(double scale) {
for (String cat : param.keySet()) {
param.get(cat).param.mul(scale);
}
}
public void setFreezeFeatureSet(boolean freeze) {
for (Map.Entry> e : param.entrySet()) {
e.getValue().param.setFreezeKeySet(freeze);
}
}
/**
* Minibatch gradient update
*/
public void update(Collection> instances) {
for (LabeledInstance instance : instances) {
update(instance);
}
}
/**
* Single gradient update.
*/
public void update(LabeledInstance instance) {
for (Map.Entry> e : param.entrySet()) {
String category = e.getKey();
UpdateableLinearModel model = e.getValue();
double label = e.getKey().equals(instance.getLabel().getLabel()) ? 1.0 : 0.0;
BinaryLabeledInstance blInstance = new BinaryLabeledInstance(label, instance.getVector());
model.update(blInstance);
}
epoch++;
}
@Override
public MulticlassPrediction predict(StringKeyedVector instance) {
Map scores = new HashMap();
double normalization = 0;
for (Map.Entry> e : param.entrySet()) {
double prediction = ((RealValuedLabel)e.getValue().predict(instance)).getValue();
scores.put(e.getKey(), prediction);
normalization += prediction;
}
for (Map.Entry e : scores.entrySet()) {
scores.put(e.getKey(), e.getValue() / normalization);
}
return new MulticlassPrediction(scores);
}
public void merge(UpdateableMulticlassLinearModel model, double scale) {
for (String cat : param.keySet()) {
param.get(cat).param.addScaled(model.param.get(cat).param, scale);
}
epoch += model.epoch;
}
public void teardown() {
for (Map.Entry> e : param.entrySet()) {
e.getValue().teardown();
}
}
public long getEpoch() {
return epoch;
}
public void setEpoch(long e) {
epoch = e;
}
// what to do here?
@Override
public int compareTo(UpdateableMulticlassLinearModel inputModel) {
return (int)Math.signum(inputModel.getEpoch() - getEpoch());
}
public void thresholdParameters(double t) {
for (UpdateableLinearModel m : param.values()) {
for (Iterator> it = m.param.iterator(); it
.hasNext();) {
if (Math.abs(it.next().getValue()) < t) {
it.remove();
}
}
}
}
public String explainPrediction(StringKeyedVector x) {
return explainPrediction(x, -1);
}
public String explainPrediction(StringKeyedVector x, int n) {
throw new UnsupportedOperationException("not done yet");
}
}