All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.clustering.lda.cvb.TopicModel Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.mahout.clustering.lda.cvb;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.DistributedRowMatrixWriter;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.stats.Sampler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Thin wrapper around a {@link Matrix} of counts of occurrences of (topic, term) pairs.  Dividing
 * {code topicTermCount.viewRow(topic).get(term)} by the sum over the values for all terms in that
 * row yields p(term | topic).  Instead dividing it by all topic columns for that term yields
 * p(topic | term).
 *
 * Multithreading is enabled for the {@code update(Matrix)} method: this method is async, and
 * merely submits the matrix to a work queue.  When all work has been submitted,
 * {@code awaitTermination()} should be called, which will block until updates have been
 * accumulated.
 */
public class TopicModel implements Configurable, Iterable {
  
  private static final Logger log = LoggerFactory.getLogger(TopicModel.class);
  
  private final String[] dictionary;
  private final Matrix topicTermCounts;
  private final Vector topicSums;
  private final int numTopics;
  private final int numTerms;
  private final double eta;
  private final double alpha;

  private Configuration conf;

  private final Sampler sampler;
  private final int numThreads;
  private ThreadPoolExecutor threadPool;
  private Updater[] updaters;

  public int getNumTerms() {
    return numTerms;
  }

  public int getNumTopics() {
    return numTopics;
  }

  public TopicModel(int numTopics, int numTerms, double eta, double alpha, String[] dictionary,
      double modelWeight) {
    this(numTopics, numTerms, eta, alpha, null, dictionary, 1, modelWeight);
  }

  public TopicModel(Configuration conf, double eta, double alpha,
      String[] dictionary, int numThreads, double modelWeight, Path... modelpath) throws IOException {
    this(loadModel(conf, modelpath), eta, alpha, dictionary, numThreads, modelWeight);
  }

  public TopicModel(int numTopics, int numTerms, double eta, double alpha, String[] dictionary,
      int numThreads, double modelWeight) {
    this(new DenseMatrix(numTopics, numTerms), new DenseVector(numTopics), eta, alpha, dictionary,
        numThreads, modelWeight);
  }

  public TopicModel(int numTopics, int numTerms, double eta, double alpha, Random random,
      String[] dictionary, int numThreads, double modelWeight) {
    this(randomMatrix(numTopics, numTerms, random), eta, alpha, dictionary, numThreads, modelWeight);
  }

  private TopicModel(Pair model, double eta, double alpha, String[] dict,
      int numThreads, double modelWeight) {
    this(model.getFirst(), model.getSecond(), eta, alpha, dict, numThreads, modelWeight);
  }

  public TopicModel(Matrix topicTermCounts, Vector topicSums, double eta, double alpha,
    String[] dictionary, double modelWeight) {
    this(topicTermCounts, topicSums, eta, alpha, dictionary, 1, modelWeight);
  }

  public TopicModel(Matrix topicTermCounts, double eta, double alpha, String[] dictionary,
      int numThreads, double modelWeight) {
    this(topicTermCounts, viewRowSums(topicTermCounts),
        eta, alpha, dictionary, numThreads, modelWeight);
  }

  public TopicModel(Matrix topicTermCounts, Vector topicSums, double eta, double alpha,
    String[] dictionary, int numThreads, double modelWeight) {
    this.dictionary = dictionary;
    this.topicTermCounts = topicTermCounts;
    this.topicSums = topicSums;
    this.numTopics = topicSums.size();
    this.numTerms = topicTermCounts.numCols();
    this.eta = eta;
    this.alpha = alpha;
    this.sampler = new Sampler(RandomUtils.getRandom());
    this.numThreads = numThreads;
    if (modelWeight != 1) {
      topicSums.assign(Functions.mult(modelWeight));
      for (int x = 0; x < numTopics; x++) {
        topicTermCounts.viewRow(x).assign(Functions.mult(modelWeight));
      }
    }
    initializeThreadPool();
  }

  private static Vector viewRowSums(Matrix m) {
    Vector v = new DenseVector(m.numRows());
    for (MatrixSlice slice : m) {
      v.set(slice.index(), slice.vector().norm(1));
    }
    return v;
  }

  private synchronized void initializeThreadPool() {
    if (threadPool != null) {
      threadPool.shutdown();
      try {
        threadPool.awaitTermination(100, TimeUnit.SECONDS);
      } catch (InterruptedException e) {
        log.error("Could not terminate all threads for TopicModel in time.", e);
      }
    }
    threadPool = new ThreadPoolExecutor(numThreads, numThreads, 0, TimeUnit.SECONDS,
                                                           new ArrayBlockingQueue(numThreads * 10));
    threadPool.allowCoreThreadTimeOut(false);
    updaters = new Updater[numThreads];
    for (int i = 0; i < numThreads; i++) {
      updaters[i] = new Updater();
      threadPool.submit(updaters[i]);
    }
  }

  Matrix topicTermCounts() {
    return topicTermCounts;
  }

  @Override
  public Iterator iterator() {
    return topicTermCounts.iterateAll();
  }

  public Vector topicSums() {
    return topicSums;
  }

  private static Pair randomMatrix(int numTopics, int numTerms, Random random) {
    Matrix topicTermCounts = new DenseMatrix(numTopics, numTerms);
    Vector topicSums = new DenseVector(numTopics);
    if (random != null) {
      for (int x = 0; x < numTopics; x++) {
        for (int term = 0; term < numTerms; term++) {
          topicTermCounts.viewRow(x).set(term, random.nextDouble());
        }
      }
    }
    for (int x = 0; x < numTopics; x++) {
      topicSums.set(x, random == null ? 1.0 : topicTermCounts.viewRow(x).norm(1));
    }
    return Pair.of(topicTermCounts, topicSums);
  }

  public static Pair loadModel(Configuration conf, Path... modelPaths)
    throws IOException {
    int numTopics = -1;
    int numTerms = -1;
    List> rows = new ArrayList<>();
    for (Path modelPath : modelPaths) {
      for (Pair row
          : new SequenceFileIterable(modelPath, true, conf)) {
        rows.add(Pair.of(row.getFirst().get(), row.getSecond().get()));
        numTopics = Math.max(numTopics, row.getFirst().get());
        if (numTerms < 0) {
          numTerms = row.getSecond().get().size();
        }
      }
    }
    if (rows.isEmpty()) {
      throw new IOException(Arrays.toString(modelPaths) + " have no vectors in it");
    }
    numTopics++;
    Matrix model = new DenseMatrix(numTopics, numTerms);
    Vector topicSums = new DenseVector(numTopics);
    for (Pair pair : rows) {
      model.viewRow(pair.getFirst()).assign(pair.getSecond());
      topicSums.set(pair.getFirst(), pair.getSecond().norm(1));
    }
    return Pair.of(model, topicSums);
  }

  // NOTE: this is purely for debug purposes.  It is not performant to "toString()" a real model
  @Override
  public String toString() {
    StringBuilder buf = new StringBuilder();
    for (int x = 0; x < numTopics; x++) {
      String v = dictionary != null
          ? vectorToSortedString(topicTermCounts.viewRow(x).normalize(1), dictionary)
          : topicTermCounts.viewRow(x).asFormatString();
      buf.append(v).append('\n');
    }
    return buf.toString();
  }

  public int sampleTerm(Vector topicDistribution) {
    return sampler.sample(topicTermCounts.viewRow(sampler.sample(topicDistribution)));
  }

  public int sampleTerm(int topic) {
    return sampler.sample(topicTermCounts.viewRow(topic));
  }

  public synchronized void reset() {
    for (int x = 0; x < numTopics; x++) {
      topicTermCounts.assignRow(x, new SequentialAccessSparseVector(numTerms));
    }
    topicSums.assign(1.0);
    if (threadPool.isTerminated()) {
      initializeThreadPool();
    }
  }

  public synchronized void stop() {
    for (Updater updater : updaters) {
      updater.shutdown();
    }
    threadPool.shutdown();
    try {
      if (!threadPool.awaitTermination(60, TimeUnit.SECONDS)) {
        log.warn("Threadpool timed out on await termination - jobs still running!");
      }
    } catch (InterruptedException e) {
      log.error("Interrupted shutting down!", e);
    }
  }

  public void renormalize() {
    for (int x = 0; x < numTopics; x++) {
      topicTermCounts.assignRow(x, topicTermCounts.viewRow(x).normalize(1));
      topicSums.assign(1.0);
    }
  }

  public void trainDocTopicModel(Vector original, Vector topics, Matrix docTopicModel) {
    // first calculate p(topic|term,document) for all terms in original, and all topics,
    // using p(term|topic) and p(topic|doc)
    pTopicGivenTerm(original, topics, docTopicModel);
    normalizeByTopic(docTopicModel);
    // now multiply, term-by-term, by the document, to get the weighted distribution of
    // term-topic pairs from this document.
    for (Element e : original.nonZeroes()) {
      for (int x = 0; x < numTopics; x++) {
        Vector docTopicModelRow = docTopicModel.viewRow(x);
        docTopicModelRow.setQuick(e.index(), docTopicModelRow.getQuick(e.index()) * e.get());
      }
    }
    // now recalculate \(p(topic|doc)\) by summing contributions from all of pTopicGivenTerm
    topics.assign(0.0);
    for (int x = 0; x < numTopics; x++) {
      topics.set(x, docTopicModel.viewRow(x).norm(1));
    }
    // now renormalize so that \(sum_x(p(x|doc))\) = 1
    topics.assign(Functions.mult(1 / topics.norm(1)));
  }

  public Vector infer(Vector original, Vector docTopics) {
    Vector pTerm = original.like();
    for (Element e : original.nonZeroes()) {
      int term = e.index();
      // p(a) = sum_x (p(a|x) * p(x|i))
      double pA = 0;
      for (int x = 0; x < numTopics; x++) {
        pA += (topicTermCounts.viewRow(x).get(term) / topicSums.get(x)) * docTopics.get(x);
      }
      pTerm.set(term, pA);
    }
    return pTerm;
  }

  public void update(Matrix docTopicCounts) {
    for (int x = 0; x < numTopics; x++) {
      updaters[x % updaters.length].update(x, docTopicCounts.viewRow(x));
    }
  }

  public void updateTopic(int topic, Vector docTopicCounts) {
    topicTermCounts.viewRow(topic).assign(docTopicCounts, Functions.PLUS);
    topicSums.set(topic, topicSums.get(topic) + docTopicCounts.norm(1));
  }

  public void update(int termId, Vector topicCounts) {
    for (int x = 0; x < numTopics; x++) {
      Vector v = topicTermCounts.viewRow(x);
      v.set(termId, v.get(termId) + topicCounts.get(x));
    }
    topicSums.assign(topicCounts, Functions.PLUS);
  }

  public void persist(Path outputDir, boolean overwrite) throws IOException {
    FileSystem fs = outputDir.getFileSystem(conf);
    if (overwrite) {
      fs.delete(outputDir, true); // CHECK second arg
    }
    DistributedRowMatrixWriter.write(outputDir, conf, topicTermCounts);
  }

  /**
   * Computes {@code \(p(topic x | term a, document i)\)} distributions given input document {@code i}.
   * {@code \(pTGT[x][a]\)} is the (un-normalized) {@code \(p(x|a,i)\)}, or if docTopics is {@code null},
   * {@code \(p(a|x)\)} (also un-normalized).
   *
   * @param document doc-term vector encoding {@code \(w(term a|document i)\)}.
   * @param docTopics {@code docTopics[x]} is the overall weight of topic {@code x} in given
   *          document. If {@code null}, a topic weight of {@code 1.0} is used for all topics.
   * @param termTopicDist storage for output {@code \(p(x|a,i)\)} distributions.
   */
  private void pTopicGivenTerm(Vector document, Vector docTopics, Matrix termTopicDist) {
    // for each topic x
    for (int x = 0; x < numTopics; x++) {
      // get p(topic x | document i), or 1.0 if docTopics is null
      double topicWeight = docTopics == null ? 1.0 : docTopics.get(x);
      // get w(term a | topic x)
      Vector topicTermRow = topicTermCounts.viewRow(x);
      // get \sum_a w(term a | topic x)
      double topicSum = topicSums.get(x);
      // get p(topic x | term a) distribution to update
      Vector termTopicRow = termTopicDist.viewRow(x);

      // for each term a in document i with non-zero weight
      for (Element e : document.nonZeroes()) {
        int termIndex = e.index();

        // calc un-normalized p(topic x | term a, document i)
        double termTopicLikelihood = (topicTermRow.get(termIndex) + eta) * (topicWeight + alpha)
            / (topicSum + eta * numTerms);
        termTopicRow.set(termIndex, termTopicLikelihood);
      }
    }
  }

  /**
   * \(sum_x sum_a (c_ai * log(p(x|i) * p(a|x)))\)
   */
  public double perplexity(Vector document, Vector docTopics) {
    double perplexity = 0;
    double norm = docTopics.norm(1) + (docTopics.size() * alpha);
    for (Element e : document.nonZeroes()) {
      int term = e.index();
      double prob = 0;
      for (int x = 0; x < numTopics; x++) {
        double d = (docTopics.get(x) + alpha) / norm;
        double p = d * (topicTermCounts.viewRow(x).get(term) + eta)
                   / (topicSums.get(x) + eta * numTerms);
        prob += p;
      }
      perplexity += e.get() * Math.log(prob);
    }
    return -perplexity;
  }

  private void normalizeByTopic(Matrix perTopicSparseDistributions) {
    // then make sure that each of these is properly normalized by topic: sum_x(p(x|t,d)) = 1
    for (Element e : perTopicSparseDistributions.viewRow(0).nonZeroes()) {
      int a = e.index();
      double sum = 0;
      for (int x = 0; x < numTopics; x++) {
        sum += perTopicSparseDistributions.viewRow(x).get(a);
      }
      for (int x = 0; x < numTopics; x++) {
        perTopicSparseDistributions.viewRow(x).set(a,
            perTopicSparseDistributions.viewRow(x).get(a) / sum);
      }
    }
  }

  public static String vectorToSortedString(Vector vector, String[] dictionary) {
    List> vectorValues = new ArrayList<>(vector.getNumNondefaultElements());
    for (Element e : vector.nonZeroes()) {
      vectorValues.add(Pair.of(dictionary != null ? dictionary[e.index()] : String.valueOf(e.index()),
                               e.get()));
    }
    Collections.sort(vectorValues, new Comparator>() {
      @Override public int compare(Pair x, Pair y) {
        return y.getSecond().compareTo(x.getSecond());
      }
    });
    Iterator> listIt = vectorValues.iterator();
    StringBuilder bldr = new StringBuilder(2048);
    bldr.append('{');
    int i = 0;
    while (listIt.hasNext() && i < 25) {
      i++;
      Pair p = listIt.next();
      bldr.append(p.getFirst());
      bldr.append(':');
      bldr.append(p.getSecond());
      bldr.append(',');
    }
    if (bldr.length() > 1) {
      bldr.setCharAt(bldr.length() - 1, '}');
    }
    return bldr.toString();
  }

  @Override
  public void setConf(Configuration configuration) {
    this.conf = configuration;
  }

  @Override
  public Configuration getConf() {
    return conf;
  }

  private final class Updater implements Runnable {
    private final ArrayBlockingQueue> queue =
        new ArrayBlockingQueue<>(100);
    private boolean shutdown = false;
    private boolean shutdownComplete = false;

    public void shutdown() {
      try {
        synchronized (this) {
          while (!shutdownComplete) {
            shutdown = true;
            wait(10000L); // Arbitrarily, wait 10 seconds rather than forever for this
          }
        }
      } catch (InterruptedException e) {
        log.warn("Interrupted waiting to shutdown() : ", e);
      }
    }

    public boolean update(int topic, Vector v) {
      if (shutdown) { // maybe don't do this?
        throw new IllegalStateException("In SHUTDOWN state: cannot submit tasks");
      }
      while (true) { // keep trying if interrupted
        try {
          // start async operation by submitting to the queue
          queue.put(Pair.of(topic, v));
          // return once you got access to the queue
          return true;
        } catch (InterruptedException e) {
          log.warn("Interrupted trying to queue update:", e);
        }
      }
    }

    @Override
    public void run() {
      while (!shutdown) {
        try {
          Pair pair = queue.poll(1, TimeUnit.SECONDS);
          if (pair != null) {
            updateTopic(pair.getFirst(), pair.getSecond());
          }
        } catch (InterruptedException e) {
          log.warn("Interrupted waiting to poll for update", e);
        }
      }
      // in shutdown mode, finish remaining tasks!
      for (Pair pair : queue) {
        updateTopic(pair.getFirst(), pair.getSecond());
      }
      synchronized (this) {
        shutdownComplete = true;
        notifyAll();
      }
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy