All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.clustering.lda.cvb.ModelTrainer Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.mahout.clustering.lda.cvb;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import org.apache.hadoop.fs.Path;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorIterable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Multithreaded LDA model trainer class, which primarily operates by running a "map/reduce"
 * operation, all in memory locally (ie not a hadoop job!) : the "map" operation is to take
 * the "read-only" {@link TopicModel} and use it to iteratively learn the p(topic|term, doc)
 * distribution for documents (this can be done in parallel across many documents, as the
 * "read-only" model is, well, read-only.  Then the outputs of this are "reduced" onto the
 * "write" model, and these updates are not parallelizable in the same way: individual
 * documents can't be added to the same entries in different threads at the same time, but
 * updates across many topics to the same term from the same document can be done in parallel,
 * so they are.
 *
 * Because computation is done asynchronously, when iteration is done, it's important to call
 * the stop() method, which blocks until work is complete.
 *
 * Setting the read model and the write model to be the same object may not quite work yet,
 * on account of parallelism badness.
 */
public class ModelTrainer {

  private static final Logger log = LoggerFactory.getLogger(ModelTrainer.class);

  private final int numTopics;
  private final int numTerms;
  private TopicModel readModel;
  private TopicModel writeModel;
  private ThreadPoolExecutor threadPool;
  private BlockingQueue workQueue;
  private final int numTrainThreads;
  private final boolean isReadWrite;

  public ModelTrainer(TopicModel initialReadModel, TopicModel initialWriteModel,
      int numTrainThreads, int numTopics, int numTerms) {
    this.readModel = initialReadModel;
    this.writeModel = initialWriteModel;
    this.numTrainThreads = numTrainThreads;
    this.numTopics = numTopics;
    this.numTerms = numTerms;
    isReadWrite = initialReadModel == initialWriteModel;
  }

  /**
   * WARNING: this constructor may not lead to good behavior.  What should be verified is that
   * the model updating process does not conflict with model reading.  It might work, but then
   * again, it might not!
   * @param model to be used for both reading (inference) and accumulating (learning)
   * @param numTrainThreads
   * @param numTopics
   * @param numTerms
   */
  public ModelTrainer(TopicModel model, int numTrainThreads, int numTopics, int numTerms) {
    this(model, model, numTrainThreads, numTopics, numTerms);
  }

  public TopicModel getReadModel() {
    return readModel;
  }

  public void start() {
    log.info("Starting training threadpool with {} threads", numTrainThreads);
    workQueue = new ArrayBlockingQueue<>(numTrainThreads * 10);
    threadPool = new ThreadPoolExecutor(numTrainThreads, numTrainThreads, 0, TimeUnit.SECONDS,
        workQueue);
    threadPool.allowCoreThreadTimeOut(false);
    threadPool.prestartAllCoreThreads();
    writeModel.reset();
  }

  public void train(VectorIterable matrix, VectorIterable docTopicCounts) {
    train(matrix, docTopicCounts, 1);
  }

  public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts) {
    return calculatePerplexity(matrix, docTopicCounts, 0);
  }

  public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts,
      double testFraction) {
    Iterator docIterator = matrix.iterator();
    Iterator docTopicIterator = docTopicCounts.iterator();
    double perplexity = 0;
    double matrixNorm = 0;
    while (docIterator.hasNext() && docTopicIterator.hasNext()) {
      MatrixSlice docSlice = docIterator.next();
      MatrixSlice topicSlice = docTopicIterator.next();
      int docId = docSlice.index();
      Vector document = docSlice.vector();
      Vector topicDist = topicSlice.vector();
      if (testFraction == 0 || docId % (1 / testFraction) == 0) {
        trainSync(document, topicDist, false, 10);
        perplexity += readModel.perplexity(document, topicDist);
        matrixNorm += document.norm(1);
      }
    }
    return perplexity / matrixNorm;
  }

  public void train(VectorIterable matrix, VectorIterable docTopicCounts, int numDocTopicIters) {
    start();
    Iterator docIterator = matrix.iterator();
    Iterator docTopicIterator = docTopicCounts.iterator();
    long startTime = System.nanoTime();
    int i = 0;
    double[] times = new double[100];
    Map batch = new HashMap<>();
    int numTokensInBatch = 0;
    long batchStart = System.nanoTime();
    while (docIterator.hasNext() && docTopicIterator.hasNext()) {
      i++;
      Vector document = docIterator.next().vector();
      Vector topicDist = docTopicIterator.next().vector();
      if (isReadWrite) {
        if (batch.size() < numTrainThreads) {
          batch.put(document, topicDist);
          if (log.isDebugEnabled()) {
            numTokensInBatch += document.getNumNondefaultElements();
          }
        } else {
          batchTrain(batch, true, numDocTopicIters);
          long time = System.nanoTime();
          log.debug("trained {} docs with {} tokens, start time {}, end time {}",
                    numTrainThreads, numTokensInBatch, batchStart, time);
          batchStart = time;
          numTokensInBatch = 0;
        }
      } else {
        long start = System.nanoTime();
        train(document, topicDist, true, numDocTopicIters);
        if (log.isDebugEnabled()) {
          times[i % times.length] =
              (System.nanoTime() - start) / (1.0e6 * document.getNumNondefaultElements());
          if (i % 100 == 0) {
            long time = System.nanoTime() - startTime;
            log.debug("trained {} documents in {}ms", i, time / 1.0e6);
            if (i % 500 == 0) {
              Arrays.sort(times);
              log.debug("training took median {}ms per token-instance", times[times.length / 2]);
            }
          }
        }
      }
    }
    stop();
  }

  public void batchTrain(Map batch, boolean update, int numDocTopicsIters) {
    while (true) {
      try {
        List runnables = new ArrayList<>();
        for (Map.Entry entry : batch.entrySet()) {
          runnables.add(new TrainerRunnable(readModel, null, entry.getKey(),
              entry.getValue(), new SparseRowMatrix(numTopics, numTerms, true),
              numDocTopicsIters));
        }
        threadPool.invokeAll(runnables);
        if (update) {
          for (TrainerRunnable runnable : runnables) {
            writeModel.update(runnable.docTopicModel);
          }
        }
        break;
      } catch (InterruptedException e) {
        log.warn("Interrupted during batch training, retrying!", e);
      }
    }
  }

  public void train(Vector document, Vector docTopicCounts, boolean update, int numDocTopicIters) {
    while (true) {
      try {
        workQueue.put(new TrainerRunnable(readModel, update
            ? writeModel
            : null, document, docTopicCounts, new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters));
        return;
      } catch (InterruptedException e) {
        log.warn("Interrupted waiting to submit document to work queue: {}", document, e);
      }
    }
  }

  public void trainSync(Vector document, Vector docTopicCounts, boolean update,
      int numDocTopicIters) {
    new TrainerRunnable(readModel, update
        ? writeModel
        : null, document, docTopicCounts, new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters).run();
  }

  public double calculatePerplexity(Vector document, Vector docTopicCounts, int numDocTopicIters) {
    TrainerRunnable runner =  new TrainerRunnable(readModel, null, document, docTopicCounts,
        new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters);
    return runner.call();
  }

  public void stop() {
    long startTime = System.nanoTime();
    log.info("Initiating stopping of training threadpool");
    try {
      threadPool.shutdown();
      if (!threadPool.awaitTermination(60, TimeUnit.SECONDS)) {
        log.warn("Threadpool timed out on await termination - jobs still running!");
      }
      long newTime = System.nanoTime();
      log.info("threadpool took: {}ms", (newTime - startTime) / 1.0e6);
      startTime = newTime;
      readModel.stop();
      newTime = System.nanoTime();
      log.info("readModel.stop() took {}ms", (newTime - startTime) / 1.0e6);
      startTime = newTime;
      writeModel.stop();
      newTime = System.nanoTime();
      log.info("writeModel.stop() took {}ms", (newTime - startTime) / 1.0e6);
      TopicModel tmpModel = writeModel;
      writeModel = readModel;
      readModel = tmpModel;
    } catch (InterruptedException e) {
      log.error("Interrupted shutting down!", e);
    }
  }

  public void persist(Path outputPath) throws IOException {
    readModel.persist(outputPath, true);
  }

  private static final class TrainerRunnable implements Runnable, Callable {
    private final TopicModel readModel;
    private final TopicModel writeModel;
    private final Vector document;
    private final Vector docTopics;
    private final Matrix docTopicModel;
    private final int numDocTopicIters;

    private TrainerRunnable(TopicModel readModel, TopicModel writeModel, Vector document,
        Vector docTopics, Matrix docTopicModel, int numDocTopicIters) {
      this.readModel = readModel;
      this.writeModel = writeModel;
      this.document = document;
      this.docTopics = docTopics;
      this.docTopicModel = docTopicModel;
      this.numDocTopicIters = numDocTopicIters;
    }

    @Override
    public void run() {
      for (int i = 0; i < numDocTopicIters; i++) {
        // synchronous read-only call:
        readModel.trainDocTopicModel(document, docTopics, docTopicModel);
      }
      if (writeModel != null) {
        // parallel call which is read-only on the docTopicModel, and write-only on the writeModel
        // this method does not return until all rows of the docTopicModel have been submitted
        // to write work queues
        writeModel.update(docTopicModel);
      }
    }

    @Override
    public Double call() {
      run();
      return readModel.perplexity(document, docTopics);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy