All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.sgd.AbstractOnlineLogisticRegression Maven / Gradle / Ivy

There is a newer version: 0.13.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.sgd;

import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;

import com.google.common.base.Preconditions;

/**
 * Generic definition of a 1 of n logistic regression classifier that returns probabilities in
 * response to a feature vector.  This classifier uses 1 of n-1 coding where the 0-th category
 * is not stored explicitly.
 * 

* Provides the SGD based algorithm for learning a logistic regression, but omits all * annealing of learning rates. Any extension of this abstract class must define the overall * and per-term annealing for themselves. */ public abstract class AbstractOnlineLogisticRegression extends AbstractVectorClassifier implements OnlineLearner { // coefficients for the classification. This is a dense matrix // that is (numCategories-1) x numFeatures protected Matrix beta; // number of categories we are classifying. This should the number of rows of beta plus one. protected int numCategories; protected int step; // information about how long since coefficient rows were updated. This allows lazy regularization. protected Vector updateSteps; // information about how many updates we have had on a location. This allows per-term // annealing a la confidence weighted learning. protected Vector updateCounts; // weight of the prior on beta private double lambda = 1.0e-5; protected PriorFunction prior; // can we ignore any further regularization when doing classification? private boolean sealed; // by default we don't do any fancy training private Gradient gradient = new DefaultGradient(); /** * Chainable configuration option. * * @param lambda New value of lambda, the weighting factor for the prior distribution. * @return This, so other configurations can be chained. */ public AbstractOnlineLogisticRegression lambda(double lambda) { this.lambda = lambda; return this; } /** * Computes the inverse link function, by default the logistic link function. * * @param v The output of the linear combination in a GLM. Note that the value * of v is disturbed. * @return A version of v with the link function applied. */ public static Vector link(Vector v) { double max = v.maxValue(); if (max >= 40) { // if max > 40, we subtract the large offset first // the size of the max means that 1+sum(exp(v)) = sum(exp(v)) to within round-off v.assign(Functions.minus(max)).assign(Functions.EXP); return v.divide(v.norm(1)); } else { v.assign(Functions.EXP); return v.divide(1 + v.norm(1)); } } /** * Computes the binomial logistic inverse link function. * * @param r The value to transform. * @return The logit of r. */ public static double link(double r) { if (r < 0.0) { double s = Math.exp(r); return s / (1.0 + s); } else { double s = Math.exp(-r); return 1.0 / (1.0 + s); } } @Override public Vector classifyNoLink(Vector instance) { // apply pending regularization to whichever coefficients matter regularize(instance); return beta.times(instance); } public double classifyScalarNoLink(Vector instance) { return beta.viewRow(0).dot(instance); } /** * Returns n-1 probabilities, one for each category but the 0-th. The probability of the 0-th * category is 1 - sum(this result). * * @param instance A vector of features to be classified. * @return A vector of probabilities, one for each of the first n-1 categories. */ @Override public Vector classify(Vector instance) { return link(classifyNoLink(instance)); } /** * Returns a single scalar probability in the case where we have two categories. Using this * method avoids an extra vector allocation as opposed to calling classify() or an extra two * vector allocations relative to classifyFull(). * * @param instance The vector of features to be classified. * @return The probability of the first of two categories. * @throws IllegalArgumentException If the classifier doesn't have two categories. */ @Override public double classifyScalar(Vector instance) { Preconditions.checkArgument(numCategories() == 2, "Can only call classifyScalar with two categories"); // apply pending regularization to whichever coefficients matter regularize(instance); // result is a vector with one element so we can just use dot product return link(classifyScalarNoLink(instance)); } @Override public void train(long trackingKey, String groupKey, int actual, Vector instance) { unseal(); double learningRate = currentLearningRate(); // push coefficients back to zero based on the prior regularize(instance); // update each row of coefficients according to result Vector gradient = this.gradient.apply(groupKey, actual, instance, this); for (int i = 0; i < numCategories - 1; i++) { double gradientBase = gradient.get(i); // then we apply the gradientBase to the resulting element. for (Element updateLocation : instance.nonZeroes()) { int j = updateLocation.index(); double newValue = beta.getQuick(i, j) + gradientBase * learningRate * perTermLearningRate(j) * instance.get(j); beta.setQuick(i, j, newValue); } } // remember that these elements got updated for (Element element : instance.nonZeroes()) { int j = element.index(); updateSteps.setQuick(j, getStep()); updateCounts.incrementQuick(j, 1); } nextStep(); } @Override public void train(long trackingKey, int actual, Vector instance) { train(trackingKey, null, actual, instance); } @Override public void train(int actual, Vector instance) { train(0, null, actual, instance); } public void regularize(Vector instance) { if (updateSteps == null || isSealed()) { return; } // anneal learning rate double learningRate = currentLearningRate(); // here we lazily apply the prior to make up for our neglect for (int i = 0; i < numCategories - 1; i++) { for (Element updateLocation : instance.nonZeroes()) { int j = updateLocation.index(); double missingUpdates = getStep() - updateSteps.get(j); if (missingUpdates > 0) { double rate = getLambda() * learningRate * perTermLearningRate(j); double newValue = prior.age(beta.get(i, j), missingUpdates, rate); beta.set(i, j, newValue); updateSteps.set(j, getStep()); } } } } // these two abstract methods are how extensions can modify the basic learning behavior of this object. public abstract double perTermLearningRate(int j); public abstract double currentLearningRate(); public void setPrior(PriorFunction prior) { this.prior = prior; } public void setGradient(Gradient gradient) { this.gradient = gradient; } public PriorFunction getPrior() { return prior; } public Matrix getBeta() { close(); return beta; } public void setBeta(int i, int j, double betaIJ) { beta.set(i, j, betaIJ); } @Override public int numCategories() { return numCategories; } public int numFeatures() { return beta.numCols(); } public double getLambda() { return lambda; } public int getStep() { return step; } protected void nextStep() { step++; } public boolean isSealed() { return sealed; } protected void unseal() { sealed = false; } private void regularizeAll() { Vector all = new DenseVector(beta.numCols()); all.assign(1); regularize(all); } @Override public void close() { if (!sealed) { step++; regularizeAll(); sealed = true; } } public void copyFrom(AbstractOnlineLogisticRegression other) { // number of categories we are classifying. This should the number of rows of beta plus one. Preconditions.checkArgument(numCategories == other.numCategories, "Can't copy unless number of target categories is the same"); beta.assign(other.beta); step = other.step; updateSteps.assign(other.updateSteps); updateCounts.assign(other.updateCounts); } public boolean validModel() { double k = beta.aggregate(Functions.PLUS, new DoubleFunction() { @Override public double apply(double v) { return Double.isNaN(v) || Double.isInfinite(v) ? 1 : 0; } }); return k < 1; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy