All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.medallia.word2vec.Word2VecTrainerBuilder Maven / Gradle / Ivy

There is a newer version: 0.10.3
Show newest version
package com.medallia.word2vec;

import com.google.common.base.MoreObjects;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.Multiset;
import com.medallia.word2vec.util.AutoLog;
import org.apache.commons.logging.Log;
import com.medallia.word2vec.neuralnetwork.NeuralNetworkConfig;
import com.medallia.word2vec.neuralnetwork.NeuralNetworkType;

import java.util.List;
import java.util.Map;

/**
 * Builder pattern for training a new {@link Word2VecModel}
 * 

* This is a port of the open source C implementation of word2vec *

* Note that this isn't a completely faithful rewrite, specifically: *

    *
  • When building the vocabulary from the training file: *
      *
    • The original version does a reduction step when learning the vocabulary from the file * when the vocab size hits 21 million words, removing any words that do not meet the * minimum frequency threshold. This Java port has no such reduction step. *
    • The original version injects a </s> token into the vocabulary (with a word count of 0) * as a substitute for newlines in the input file. This Java port's vocabulary excludes the token. *
    *
  • In partitioning the file for processing *
      *
    • The original version assumes that sentences are delimited by newline characters and injects a sentence * boundary per 1000 non-filtered tokens, i.e. valid token by the vocabulary and not removed by the randomized * sampling process. Java port mimics this behavior for now ... *
    • When the original version encounters an empty line in the input file, it re-processes the first word of the * last non-empty line with a sentence length of 0 and updates the random value. Java port omits this behavior. *
    *
  • In the sampling function *
      *
    • The original C documentation indicates that the range should be between 0 and 1e-5, but the default value is 1e-3. * This Java port retains that confusing information. *
    • The random value generated for comparison to determine if a token should be filtered uses a float. * This Java port uses double precision for twice the fun *
    *
  • In the distance function to find the nearest matches to a target query *
      *
    • The original version includes an unnecessary normalization of the vector for the input query which * may lead to tiny inaccuracies. This Java port foregoes this superfluous operation. *
    • The original version has an O(n * k) algorithm for finding top matches and is hardcoded to 40 matches. * This Java port uses Google's lovely {@link com.google.common.collect.Ordering#greatestOf(java.util.Iterator, int)} * which is O(n + k log k) and takes in arbitrary k. *
    *
  • The k-means clustering option is excluded in the Java port *
* *

*

* Please do not hesitate to peek at the source code. *
* It should be readable, concise, and correct. *
* I ask you to reach out if it is not. */ public class Word2VecTrainerBuilder { private static final Log LOG = AutoLog.getLog(); private Integer layerSize; private Integer windowSize; private Integer numThreads; private NeuralNetworkType type; private int negativeSamples; private boolean useHierarchicalSoftmax; private Multiset vocab; private Integer minFrequency; private Double initialLearningRate; private Double downSampleRate; private Integer iterations; private TrainingProgressListener listener; Word2VecTrainerBuilder() { } /** * Size of the layers in the neural network model *

* Defaults to 100 */ public Word2VecTrainerBuilder setLayerSize(int layerSize) { Preconditions.checkArgument(layerSize > 0, "Value must be positive"); this.layerSize = layerSize; return this; } /** * Size of the window to consider *

* Default window size is 5 tokens */ public Word2VecTrainerBuilder setWindowSize(int windowSize) { Preconditions.checkArgument(windowSize > 0, "Value must be positive"); this.windowSize = windowSize; return this; } /** * Specify number of threads to use for parallelization *

* Defaults to {@link Runtime#availableProcessors()} */ public Word2VecTrainerBuilder useNumThreads(int numThreads) { Preconditions.checkArgument(numThreads > 0, "Value must be positive"); this.numThreads = numThreads; return this; } /** * @see {@link NeuralNetworkType} *

* By default, word2vec uses the {@link NeuralNetworkType#SKIP_GRAM} */ public Word2VecTrainerBuilder type(NeuralNetworkType type) { this.type = Preconditions.checkNotNull(type); return this; } /** * Specify to use hierarchical softmax *

* By default, word2vec does not use hierarchical softmax */ public Word2VecTrainerBuilder useHierarchicalSoftmax() { this.useHierarchicalSoftmax = true; return this; } /** * Number of negative samples to use * Common values are between 5 and 10 *

* Defaults to 0 */ public Word2VecTrainerBuilder useNegativeSamples(int negativeSamples) { Preconditions.checkArgument(negativeSamples >= 0, "Value must be non-negative"); this.negativeSamples = negativeSamples; return this; } /** * Use a pre-built vocabulary *

* If this is not specified, word2vec will attempt to learn a vocabulary from the training data * @param vocab {@link Map} from token to frequency */ public Word2VecTrainerBuilder useVocab(Multiset vocab) { this.vocab = Preconditions.checkNotNull(vocab); return this; } /** * Specify the minimum frequency for a valid token to be considered * part of the vocabulary *

* Defaults to 5 */ public Word2VecTrainerBuilder setMinVocabFrequency(int minFrequency) { Preconditions.checkArgument(minFrequency >= 0, "Value must be non-negative"); this.minFrequency = minFrequency; return this; } /** * Set the starting learning rate *

* Default is 0.025 for skip-gram and 0.05 for CBOW */ public Word2VecTrainerBuilder setInitialLearningRate(double initialLearningRate) { Preconditions.checkArgument(initialLearningRate >= 0, "Value must be non-negative"); this.initialLearningRate = initialLearningRate; return this; } /** * Set threshold for occurrence of words. Those that appear with higher frequency in the training data, * e.g. stopwords, will be randomly removed *

* Default is 1e-3, useful range is (0, 1e-5) */ public Word2VecTrainerBuilder setDownSamplingRate(double downSampleRate) { Preconditions.checkArgument(downSampleRate >= 0, "Value must be non-negative"); this.downSampleRate = downSampleRate; return this; } /** Set the number of iterations */ public Word2VecTrainerBuilder setNumIterations(int iterations) { Preconditions.checkArgument(iterations > 0, "Value must be positive"); this.iterations = iterations; return this; } /** Set a progress listener */ public Word2VecTrainerBuilder setListener(TrainingProgressListener listener) { this.listener = listener; return this; } /** Train the model */ public Word2VecModel train(Iterable> sentences) throws InterruptedException { this.type = MoreObjects.firstNonNull(type, NeuralNetworkType.CBOW); this.initialLearningRate = MoreObjects.firstNonNull(initialLearningRate, type.getDefaultInitialLearningRate()); if (this.numThreads == null) this.numThreads = Runtime.getRuntime().availableProcessors(); this.iterations = MoreObjects.firstNonNull(iterations, 5); this.layerSize = MoreObjects.firstNonNull(layerSize, 100); this.windowSize = MoreObjects.firstNonNull(windowSize, 5); this.downSampleRate = MoreObjects.firstNonNull(downSampleRate, 0.001); this.minFrequency = MoreObjects.firstNonNull(minFrequency, 5); this.listener = MoreObjects.firstNonNull(listener, new TrainingProgressListener() { @Override public void update(Stage stage, double progress) { System.out.println(String.format("Stage %s, progress %s%%", stage, progress)); } }); Optional> vocab = this.vocab == null ? Optional.>absent() : Optional.of(this.vocab); return new Word2VecTrainer( minFrequency, vocab, new NeuralNetworkConfig( type, numThreads, iterations, layerSize, windowSize, negativeSamples, downSampleRate, initialLearningRate, useHierarchicalSoftmax ) ).train(LOG, listener, sentences); } /** Listener for model training progress */ public interface TrainingProgressListener { /** Sequential stages of processing */ enum Stage { ACQUIRE_VOCAB, FILTER_SORT_VOCAB, CREATE_HUFFMAN_ENCODING, TRAIN_NEURAL_NETWORK, } /** * Called during word2vec training *

* Note that this is called in a separate thread from the processing thread * @param stage Current {@link Stage} of processing * @param progress Progress of the current stage as a double value between 0 and 1 */ void update(Stage stage, double progress); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy