All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.etsy.conjecture.model.UpdateableLinearModel Maven / Gradle / Ivy

There is a newer version: 0.2.3
Show newest version
package com.etsy.conjecture.model;

import static com.google.common.base.Preconditions.checkArgument;
import gnu.trove.function.TDoubleFunction;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import com.etsy.conjecture.Utilities;
import com.etsy.conjecture.data.Label;
import com.etsy.conjecture.data.LabeledInstance;
import com.etsy.conjecture.data.LazyVector;
import com.etsy.conjecture.data.StringKeyedVector;

public abstract class UpdateableLinearModel implements
        UpdateableModel>,
        Comparable>, Serializable {

    private static final long serialVersionUID = 8549108867384062857L;
    protected LazyVector param;
    protected final String modelType;

    protected long epoch;

    protected SGDOptimizer optimizer;

    // parameters for gradient truncation
    // for more info, see:
    // http://jmlr.csail.mit.edu/papers/volume10/langford09a/langford09a.pdf
    protected int period = 0;
    protected double truncationUpdate = 0.1;
    protected double truncationThreshold = 0.0;

    private String argString = "NOT SET";

    public void setArgString(String s) {
        argString = s;
    }

    public String getArgString() {
        return argString;
    }

    public double dotWithParam(StringKeyedVector x) {
        return param.dot(x);
    }

    protected UpdateableLinearModel(SGDOptimizer optimizer) {
        this.optimizer = optimizer;
        this.param = new LazyVector(100, optimizer);
        epoch = 0;
        modelType = getModelType();
    }

    protected UpdateableLinearModel(StringKeyedVector param, SGDOptimizer optimizer) {
        this.optimizer = optimizer;
        optimizer.model = this;
        this.param = new LazyVector(param, optimizer);
        epoch = 0;
        modelType = getModelType();
    }

    /**
     *  Get a StringKeyedVector holding the gradient of the loss w.r.t. every model parameter.
     */
    public abstract StringKeyedVector getGradients(LabeledInstance instance);

    /**
     *  Minibatch gradient update
     */
    public void update(Collection> instances) {
        optimizer.model = this; // avoid serialization stackoverflow
        if (epoch > 0) {
            param.incrementIteration();
        }
        StringKeyedVector updates = optimizer.getUpdates(instances);
        param.add(updates);
    }

    /**
     *  Single gradient update
     */
    public void update(LabeledInstance instance) {
        optimizer.model = this; // avoid serialization stackoverflow
        if (epoch > 0) {
            param.incrementIteration();
        }
        StringKeyedVector update = optimizer.getUpdate(instance);
        param.add(update);
        truncate(instance);
        epoch++;
    }

    public abstract L predict(StringKeyedVector instance);

    public abstract double loss(LabeledInstance instance);

    protected abstract String getModelType();

    public Iterator> decompose() {
        return param.iterator();
    }

    public void setParameter(String name, double value) {
        param.setCoordinate(name, value);
    }

    public StringKeyedVector getParam() {
        return param;
    }

    public void reScale(double scale) {
        param.mul(scale);
    }

    public void setFreezeFeatureSet(boolean freeze) {
        param.setFreezeKeySet(freeze);
    }

    public void merge(UpdateableLinearModel model, double scaling) {
        param.addScaled(model.param, scaling);
        epoch += model.epoch;
    }

    public void teardown() {
        optimizer.teardown();
    }

    /**
     *  Decide based on period and epoch whether to truncate
     */
    public void truncate(LabeledInstance instance) {
        if (period > 0 && epoch > 0 && epoch % period == 0) {
                applyTruncation(instance.getVector());
        }
    }

    public void applyTruncation(StringKeyedVector instance) {
        final double update = this.optimizer.getDecreasingLearningRate(epoch) * truncationUpdate;
        final double threshold = truncationThreshold;

        TDoubleFunction truncFn = new TDoubleFunction() {
            public double execute(double parameter) {
                if (parameter > 0 && parameter < threshold) {
                    return Math.max(0, parameter - update);
                } else if (parameter < 0 && parameter > -threshold) {
                    return Math.min(0, parameter + update);
                } else {
                    return parameter;
                }
            }
        };

        param.transformValues(truncFn);
        param.removeZeroCoordinates();
    }

    public long getEpoch() {
        return epoch;
    }

    public void setEpoch(long e) {
        epoch = e;
    }

    public UpdateableLinearModel setTruncationPeriod(int period) {
        checkArgument(period >= 0, "period must be non-negative, given: %s",
                period);
        this.period = period;
        return this;
    }

    public UpdateableLinearModel setTruncationThreshold(double threshold) {
        checkArgument(threshold >= 0, "update must be non-negative, given: %s",
                threshold);
        this.truncationThreshold = threshold;
        return this;
    }

    public UpdateableLinearModel setTruncationUpdate(double update) {
        checkArgument(update >= 0, "update must be non-negative, given: %s",
                update);
        this.truncationUpdate = update;
        return this;
    }

    @Override
    public int compareTo(UpdateableLinearModel inputModel) {
        return (int)Math.signum(inputModel.param.LPNorm(2d) - param.LPNorm(2d));
    }

    public void thresholdParameters(double t) {
        for (Iterator> it = param.iterator(); it
                .hasNext();) {
            if (Math.abs(it.next().getValue()) < t) {
                it.remove();
            }
        }
    }

    public String explainPrediction(StringKeyedVector x) {
        return explainPrediction(x, -1);
    }

    public String explainPrediction(StringKeyedVector x, int n) {
        StringBuilder out = new StringBuilder();
        Map weights = new HashMap();
        for (String dim : x.keySet()) {
            if (param.getCoordinate(dim) != 0.0) {
                weights.put(
                        dim,
                        Math.abs(x.getCoordinate(dim)
                                * param.getCoordinate(dim)));
            }
        }
        ArrayList keys = com.etsy.conjecture.Utilities
                .orderKeysByValue(weights, true);
        for (int i = 0; (n == -1 || i < n) && i < keys.size(); i++) {
            String k = keys.get(i);
            out.append(k + ":" + String.format("%.2f", x.getCoordinate(k))
                    + "->" + String.format("%.2f", param.getCoordinate(k))
                    + " ");
        }
        return out.toString();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy