All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.etsy.conjecture.model.LogisticRegression Maven / Gradle / Ivy

There is a newer version: 0.2.3
Show newest version
package com.etsy.conjecture.model;

import com.etsy.conjecture.Utilities;
import com.etsy.conjecture.data.BinaryLabel;
import com.etsy.conjecture.data.LabeledInstance;
import com.etsy.conjecture.data.StringKeyedVector;

/**
 *  Logistic regression loss for binary classification with y in {-1, 1}.
 */
public class LogisticRegression extends UpdateableLinearModel {

    private static final long serialVersionUID = 1L;

    public LogisticRegression(SGDOptimizer optimizer) {
        super(optimizer);
    }

    public LogisticRegression(StringKeyedVector param, SGDOptimizer optimizer) {
        super(param, optimizer);
    }

    @Override
    public BinaryLabel predict(StringKeyedVector instance) {
        return new BinaryLabel(Utilities.logistic(instance.dot(param)));
    }

    @Override
    public double loss(LabeledInstance instance) {
        double inner = instance.getVector().dot(param);
        double label = instance.getLabel().getAsPlusMinus();
        return Math.log(1.0 + Math.exp(-label * inner));
    }

    @Override
    public StringKeyedVector getGradients(LabeledInstance instance) {
        StringKeyedVector gradients = instance.getVector().copy();
        double label = instance.getLabel().getAsPlusMinus();
        double inner = instance.getVector().dot(param);
        double gradient = -label / (Math.exp(label * inner) + 1.0);
        gradients.mul(gradient);
        return gradients;
    }

    protected String getModelType() {
        return "logistic_regression";
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy