All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.sgd.GradientMachine Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.sgd;

import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Collection;
import java.util.HashSet;
import java.util.Random;

/**
 * Online gradient machine learner that tries to minimize the label ranking hinge loss.
 * Implements a gradient machine with one sigmpod hidden layer.
 * It tries to minimize the ranking loss of some given set of labels,
 * so this can be used for multi-class, multi-label
 * or auto-encoding of sparse data (e.g. text).
 */
public class GradientMachine extends AbstractVectorClassifier implements OnlineLearner, Writable {

  public static final int WRITABLE_VERSION = 1;

  // the learning rate of the algorithm
  private double learningRate = 0.1;

  // the regularization term, a positive number that controls the size of the weight vector
  private double regularization = 0.1;

  // the sparsity term, a positive number that controls the sparsity of the hidden layer. (0 - 1)
  private double sparsity = 0.1;

  // the sparsity learning rate.
  private double sparsityLearningRate = 0.1;

  // the number of features
  private int numFeatures = 10;
  // the number of hidden nodes
  private int numHidden = 100;
  // the number of output nodes
  private int numOutput = 2;

  // coefficients for the input to hidden layer.
  // There are numHidden Vectors of dimension numFeatures.
  private Vector[] hiddenWeights;

  // coefficients for the hidden to output layer.
  // There are numOuput Vectors of dimension numHidden.
  private Vector[] outputWeights;

  // hidden unit bias
  private Vector hiddenBias;

  // output unit bias
  private Vector outputBias;

  private final Random rnd;

  public GradientMachine(int numFeatures, int numHidden, int numOutput) {
    this.numFeatures = numFeatures;
    this.numHidden = numHidden;
    this.numOutput = numOutput;
    hiddenWeights = new DenseVector[numHidden];
    for (int i = 0; i < numHidden; i++) {
      hiddenWeights[i] = new DenseVector(numFeatures);
      hiddenWeights[i].assign(0);
    }
    hiddenBias = new DenseVector(numHidden);
    hiddenBias.assign(0);
    outputWeights = new DenseVector[numOutput];
    for (int i = 0; i < numOutput; i++) {
      outputWeights[i] = new DenseVector(numHidden);
      outputWeights[i].assign(0);
    }
    outputBias = new DenseVector(numOutput);
    outputBias.assign(0);
    rnd = RandomUtils.getRandom();
  }

  /**
   * Initialize weights.
   *
   * @param gen random number generator.
   */
  public void initWeights(Random gen) {
    double hiddenFanIn = 1.0 / Math.sqrt(numFeatures);
    for (int i = 0; i < numHidden; i++) {
      for (int j = 0; j < numFeatures; j++) {
        double val = (2.0 * gen.nextDouble() - 1.0) * hiddenFanIn;
        hiddenWeights[i].setQuick(j, val);
      }
    }
    double outputFanIn = 1.0 / Math.sqrt(numHidden);
    for (int i = 0; i < numOutput; i++) {
      for (int j = 0; j < numHidden; j++) {
        double val = (2.0 * gen.nextDouble() - 1.0) * outputFanIn;
        outputWeights[i].setQuick(j, val);
      }
    }
  }

  /**
   * Chainable configuration option.
   *
   * @param learningRate New value of initial learning rate.
   * @return This, so other configurations can be chained.
   */
  public GradientMachine learningRate(double learningRate) {
    this.learningRate = learningRate;
    return this;
  }

  /**
   * Chainable configuration option.
   *
   * @param regularization A positive value that controls the weight vector size.
   * @return This, so other configurations can be chained.
   */
  public GradientMachine regularization(double regularization) {
    this.regularization = regularization;
    return this;
  }

  /**
   * Chainable configuration option.
   *
   * @param sparsity A value between zero and one that controls the fraction of hidden units
   *                 that are activated on average.
   * @return This, so other configurations can be chained.
   */
  public GradientMachine sparsity(double sparsity) {
    this.sparsity = sparsity;
    return this;
  }

  /**
   * Chainable configuration option.
   *
   * @param sparsityLearningRate New value of initial learning rate for sparsity.
   * @return This, so other configurations can be chained.
   */
  public GradientMachine sparsityLearningRate(double sparsityLearningRate) {
    this.sparsityLearningRate = sparsityLearningRate;
    return this;
  }

  public void copyFrom(GradientMachine other) {
    numFeatures = other.numFeatures;
    numHidden = other.numHidden;
    numOutput = other.numOutput;
    learningRate = other.learningRate;
    regularization = other.regularization;
    sparsity = other.sparsity;
    sparsityLearningRate = other.sparsityLearningRate;
    hiddenWeights = new DenseVector[numHidden];
    for (int i = 0; i < numHidden; i++) {
      hiddenWeights[i] = other.hiddenWeights[i].clone();
    }
    hiddenBias = other.hiddenBias.clone();
    outputWeights = new DenseVector[numOutput];
    for (int i = 0; i < numOutput; i++) {
      outputWeights[i] = other.outputWeights[i].clone();
    }
    outputBias = other.outputBias.clone();
  }

  @Override
  public int numCategories() {
    return numOutput;
  }

  public int numFeatures() {
    return numFeatures;
  }

  public int numHidden() {
    return numHidden;
  }

  /**
   * Feeds forward from input to hidden unit..
   *
   * @return Hidden unit activations.
   */
  public DenseVector inputToHidden(Vector input) {
    DenseVector activations = new DenseVector(numHidden);
    for (int i = 0; i < numHidden; i++) {
      activations.setQuick(i, hiddenWeights[i].dot(input));
    }
    activations.assign(hiddenBias, Functions.PLUS);
    activations.assign(Functions.min(40.0)).assign(Functions.max(-40));
    activations.assign(Functions.SIGMOID);
    return activations;
  }

  /**
   * Feeds forward from hidden to output
   *
   * @return Output unit activations.
   */
  public DenseVector hiddenToOutput(Vector hiddenActivation) {
    DenseVector activations = new DenseVector(numOutput);
    for (int i = 0; i < numOutput; i++) {
      activations.setQuick(i, outputWeights[i].dot(hiddenActivation));
    }
    activations.assign(outputBias, Functions.PLUS);
    return activations;
  }

  /**
   * Updates using ranking loss.
   *
   * @param hiddenActivation the hidden unit's activation
   * @param goodLabels       the labels you want ranked above others.
   * @param numTrials        how many times you want to search for the highest scoring bad label.
   * @param gen              Random number generator.
   */
  public void updateRanking(Vector hiddenActivation,
                            Collection goodLabels,
                            int numTrials,
                            Random gen) {
    // All the labels are good, do nothing.
    if (goodLabels.size() >= numOutput) {
      return;
    }
    for (Integer good : goodLabels) {
      double goodScore = outputWeights[good].dot(hiddenActivation);
      int highestBad = -1;
      double highestBadScore = Double.NEGATIVE_INFINITY;
      for (int i = 0; i < numTrials; i++) {
        int bad = gen.nextInt(numOutput);
        while (goodLabels.contains(bad)) {
          bad = gen.nextInt(numOutput);
        }
        double badScore = outputWeights[bad].dot(hiddenActivation);
        if (badScore > highestBadScore) {
          highestBadScore = badScore;
          highestBad = bad;
        }
      }
      int bad = highestBad;
      double loss = 1.0 - goodScore + highestBadScore;
      if (loss < 0.0) {
        continue;
      }
      // Note from the loss above the gradient dloss/dy , y being the label is -1 for good
      // and +1 for bad.
      // dy / dw is just w since  y = x' * w + b.
      // Hence by the chain rule, dloss / dw = dloss / dy * dy / dw = -w.
      // For the regularization part, 0.5 * lambda * w' w, the gradient is lambda * w.
      // dy / db = 1.
      Vector gradGood = outputWeights[good].clone();
      gradGood.assign(Functions.NEGATE);
      Vector propHidden = gradGood.clone();
      Vector gradBad = outputWeights[bad].clone();
      propHidden.assign(gradBad, Functions.PLUS);
      gradGood.assign(Functions.mult(-learningRate * (1.0 - regularization)));
      outputWeights[good].assign(gradGood, Functions.PLUS);
      gradBad.assign(Functions.mult(-learningRate * (1.0 + regularization)));
      outputWeights[bad].assign(gradBad, Functions.PLUS);
      outputBias.setQuick(good, outputBias.get(good) + learningRate);
      outputBias.setQuick(bad, outputBias.get(bad) - learningRate);
      // Gradient of sigmoid is s * (1 -s).
      Vector gradSig = hiddenActivation.clone();
      gradSig.assign(Functions.SIGMOIDGRADIENT);
      // Multiply by the change caused by the ranking loss.
      for (int i = 0; i < numHidden; i++) {
        gradSig.setQuick(i, gradSig.get(i) * propHidden.get(i));
      }
      for (int i = 0; i < numHidden; i++) {
        for (int j = 0; j < numFeatures; j++) {
          double v = hiddenWeights[i].get(j);
          v -= learningRate * (gradSig.get(i) + regularization * v);
          hiddenWeights[i].setQuick(j, v);
        }
      }
    }
  }

  @Override
  public Vector classify(Vector instance) {
    Vector result = classifyNoLink(instance);
    // Find the max value's index.
    int max = result.maxValueIndex();
    result.assign(0);
    result.setQuick(max, 1.0);
    return result.viewPart(1, result.size() - 1);
  }

  @Override
  public Vector classifyNoLink(Vector instance) {
    DenseVector hidden = inputToHidden(instance);
    return hiddenToOutput(hidden);
  }

  @Override
  public double classifyScalar(Vector instance) {
    Vector output = classifyNoLink(instance);
    if (output.get(0) > output.get(1)) {
      return 0;
    }
    return 1;
  }

  public GradientMachine copy() {
    close();
    GradientMachine r = new GradientMachine(numFeatures(), numHidden(), numCategories());
    r.copyFrom(this);
    return r;
  }

  @Override
  public void write(DataOutput out) throws IOException {
    out.writeInt(WRITABLE_VERSION);
    out.writeDouble(learningRate);
    out.writeDouble(regularization);
    out.writeDouble(sparsity);
    out.writeDouble(sparsityLearningRate);
    out.writeInt(numFeatures);
    out.writeInt(numHidden);
    out.writeInt(numOutput);
    VectorWritable.writeVector(out, hiddenBias);
    for (int i = 0; i < numHidden; i++) {
      VectorWritable.writeVector(out, hiddenWeights[i]);
    }
    VectorWritable.writeVector(out, outputBias);
    for (int i = 0; i < numOutput; i++) {
      VectorWritable.writeVector(out, outputWeights[i]);
    }
  }

  @Override
  public void readFields(DataInput in) throws IOException {
    int version = in.readInt();
    if (version == WRITABLE_VERSION) {
      learningRate = in.readDouble();
      regularization = in.readDouble();
      sparsity = in.readDouble();
      sparsityLearningRate = in.readDouble();
      numFeatures = in.readInt();
      numHidden = in.readInt();
      numOutput = in.readInt();
      hiddenWeights = new DenseVector[numHidden];
      hiddenBias = VectorWritable.readVector(in);
      for (int i = 0; i < numHidden; i++) {
        hiddenWeights[i] = VectorWritable.readVector(in);
      }
      outputWeights = new DenseVector[numOutput];
      outputBias = VectorWritable.readVector(in);
      for (int i = 0; i < numOutput; i++) {
        outputWeights[i] = VectorWritable.readVector(in);
      }
    } else {
      throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version);
    }
  }

  @Override
  public void close() {
    // This is an online classifier, nothing to do.
  }

  @Override
  public void train(long trackingKey, String groupKey, int actual, Vector instance) {
    Vector hiddenActivation = inputToHidden(instance);
    hiddenToOutput(hiddenActivation);
    Collection goodLabels = new HashSet<>();
    goodLabels.add(actual);
    updateRanking(hiddenActivation, goodLabels, 2, rnd);
  }

  @Override
  public void train(long trackingKey, int actual, Vector instance) {
    train(trackingKey, null, actual, instance);
  }

  @Override
  public void train(int actual, Vector instance) {
    train(0, null, actual, instance);
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy