hex.word2vec.Word2VecModel Maven / Gradle / Ivy
package hex.word2vec;
import hex.ModelCategory;
import hex.ModelMetrics;
import water.Key;
import water.H2O;
import water.Futures;
import water.DKV;
import water.Iced;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.AppendableVec;
import water.fvec.Vec;
import water.nbhm.NonBlockingHashMap;
import water.parser.ValueString;
import water.util.ArrayUtils;
import water.util.Log;
import hex.Model;
import hex.word2vec.Word2VecModel.*;
import water.util.RandomUtils;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Random;
public class Word2VecModel extends Model {
private volatile Word2VecModelInfo _modelInfo;
void setModelInfo(Word2VecModelInfo mi) { _modelInfo = mi; }
final public Word2VecModelInfo getModelInfo() { return _modelInfo; }
private Key _w2vKey;
public Word2VecModel(Key selfKey, Word2VecParameters params, Word2VecOutput output) {
super(selfKey, params, output);
_modelInfo = new Word2VecModelInfo(params);
assert(Arrays.equals(_key._kb, selfKey._kb));
}
@Override
public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
throw H2O.unimpl("No Model Metrics for Word2Vec.");
}
@Override public double[] score0(Chunk[] cs, int foo, double data[/*ncols*/], double preds[/*nclasses+1*/]) {
throw H2O.unimpl();
}
@Override protected double[] score0(double data[/*ncols*/], double preds[/*nclasses+1*/]) {
throw H2O.unimpl();
}
/**
* Takes an input string can return the word vector for that word.
*
* @param target - String of desired word
* @return float array containing the word vector values or null if
* the word isn't present in the vocabulary.
*/
public float[] transform(String target) {
NonBlockingHashMap vocabHM = buildVocabHashMap();
Vec[] vs = ((Frame) _w2vKey.get()).vecs();
ValueString tmp = new ValueString(target);
return transform(tmp, vocabHM, vs);
}
private float[] transform(ValueString tmp, NonBlockingHashMap vocabHM, Vec[] vs) {
final int vecSize = vs.length-1;
float[] vec = new float[vecSize];
if (!vocabHM.containsKey(tmp)) {
Log.warn("Target word " + tmp + " isn't in vocabulary.");
return null;
}
int row = vocabHM.get(tmp);
for(int i=0; i < vecSize; i++) vec[i] = (float) vs[i+1].at(row);
return vec;
}
/**
* Find synonyms (i.e. wordvectors with the
* highest cosine similarity) of the supplied
* String and print them to stdout.
*
* @param target String of desired word
* @param cnt Number of synonyms to find
*/
public HashMap findSynonyms(String target, int cnt) {
if (cnt > 0) {
NonBlockingHashMap vocabHM = buildVocabHashMap();
Vec[] vs = ((Frame) _w2vKey.get()).vecs();
ValueString tmp = new ValueString(target);
float[] tarVec = transform(tmp, vocabHM, vs);
return findSynonyms(tarVec, cnt, vs);
} else {
Log.err("Synonym count must be greater than 0.");
return null;
}
}
/**
* Find synonyms (i.e. wordvectors with the
* highest cosine similarity) of the word vector
* for a word.
*
* @param tarVec word vector of a word
* @param cnt number of synonyms to find
*
*/
public void findSynonyms(float[] tarVec, int cnt) {
if (cnt > 0) {
Vec[] vs = ((Frame) _w2vKey.get()).vecs();
findSynonyms(tarVec, cnt, vs);
} else Log.err("Synonym count must be greater than 0.");
}
private HashMap findSynonyms(float[] tarVec, int cnt, Vec[] vs) {
final int vecSize= vs.length - 1, vocabSize = (int) vs[0].length();
int[] matches = new int[cnt];
float [] scores = new float[cnt];
float[] curVec = new float[vecSize];
HashMap res = new HashMap<>();
if (tarVec.length != vs.length-1) {
Log.warn("Target vector length differs from the vocab's vector length.");
return null;
}
for (int i=0; i < vocabSize; i++) {
for(int j=0; j < vecSize; j++) curVec[j] = (float) vs[j+1].at(i);
float score = cosineSimilarity(tarVec, curVec);
for (int j = 0; j < cnt; j++) {
if (score > scores[j] && score < 0.999999) {
for (int k = cnt - 1; k > j; k--) {
scores[k] = scores[k - 1];
matches[k] = matches[k-1];
}
scores[j] = score;
matches[j] = i;
break;
}
}
}
for (int i=0; i < cnt; i++)
res.put(vs[0].atStr(new ValueString(), matches[i]).toString(), scores[i]);
return res;
}
/**
* Basic calculation of cosine similarity
* @param target - a word vector
* @param current - a word vector
* @return cosine similarity between the two word vectors
*/
public float cosineSimilarity(float[] target, float[] current) {
float dotProd = 0, tsqr = 0, csqr = 0;
for(int i=0; i< target.length; i++) {
dotProd += target[i] * current[i];
tsqr += Math.pow(target[i],2);
csqr += Math.pow(current[i],2);
}
return (float) (dotProd / (Math.sqrt(tsqr)*Math.sqrt(csqr)));
}
/**
* Hashmap for quick lookup of a word's row number.
*/
private NonBlockingHashMap buildVocabHashMap() {
NonBlockingHashMap vocabHM;
Vec word = ((Frame) _w2vKey.get()).vec(0);
final int vocabSize = (int) ((Frame) _w2vKey.get()).numRows();
vocabHM = new NonBlockingHashMap<>(vocabSize);
for(int i=0; i < vocabSize; i++) vocabHM.put(word.atStr(new ValueString(),i),i);
return vocabHM;
}
public void buildModelOutput() {
final int vecSize = _parms._vecSize;
Futures fs = new Futures();
String[] colNames = new String[vecSize];
Vec[] vecs = new Vec[vecSize];
Key keys[] = Vec.VectorGroup.VG_LEN1.addVecs(vecs.length);
//allocate
NewChunk cs[] = new NewChunk[vecs.length];
AppendableVec avs[] = new AppendableVec[vecs.length];
for (int i = 0; i < vecs.length; i++) {
avs[i] = new AppendableVec(keys[i]);
cs[i] = new NewChunk(avs[i], 0);
}
//fill in vector values
for( int i = 0; i < _modelInfo._vocabSize; i++ ) {
for (int j=0; j < vecSize; j++) {
cs[j].addNum(_modelInfo._syn0[i * vecSize + j]);
}
}
//finalize vectors
for (int i = 0; i < vecs.length; i++) {
colNames[i] = new String("V"+i);
cs[i].close(0, fs);
vecs[i] = avs[i].close(fs);
}
fs.blockForPending();
Frame fr = new Frame(_w2vKey = Key.make("w2v"));
//FIXME this ties the word count frame to this one which means
//FIXME one can't be deleted without destroying the other
fr.add("Word", (_parms._vocabKey.get()).vec(0));
fr.add(colNames, vecs);
DKV.put(_w2vKey, fr);
}
@Override public void delete() {
_parms._vocabKey.remove();
_w2vKey.remove();
remove();
super.delete();
}
public static class Word2VecParameters extends Model.Parameters {
static final int MAX_VEC_SIZE = 10000;
public Word2Vec.WordModel _wordModel = Word2Vec.WordModel.SkipGram;
public Word2Vec.NormModel _normModel = Word2Vec.NormModel.HSM;
public Key _vocabKey;
public int _minWordFreq = 5;
public int _vecSize = 100;
public int _windowSize = 5;
public int _epochs = 5;
public int _negSampleCnt = 5;
public float _initLearningRate = 0.05f;
public float _sentSampleRate = 1e-3f;
}
public static class Word2VecOutput extends Model.Output{
public Word2Vec.WordModel _wordModel;
public Word2Vec.NormModel _normModel;
public int _minWordFreq, _vecSize, _windowSize, _epochs, _negSampleCnt;
public float _initLearningRate, _sentSampleRate;
public Word2VecOutput(Word2Vec b) { super(b);}
@Override public ModelCategory getModelCategory() {
return ModelCategory.Unknown;
}
}
public static class Word2VecModelInfo extends Iced {
static final int UNIGRAM_TABLE_SIZE = 10000000;
static final float UNIGRAM_POWER = 0.75F;
static final int MAX_CODE_LENGTH = 40;
long _trainFrameSize;
int _vocabSize;
float _curLearningRate;
float[] _syn0, _syn1;
int[] _uniTable = null;
int[][] _HBWTCode = null;
int[][] _HBWTPoint = null;
private Word2VecParameters _parameters;
public final Word2VecParameters getParams() { return _parameters; }
public Word2VecModelInfo() {}
public Word2VecModelInfo(final Word2VecParameters params) {
_parameters = params;
if(_parameters._vocabKey == null) {
_parameters._vocabKey = (new WordCountTask(_parameters._minWordFreq)).doAll(_parameters.train())._wordCountKey;
}
_vocabSize = (int) (_parameters._vocabKey.get()).numRows();
_trainFrameSize = getTrainFrameSize(_parameters.train());
//initialize weights to random values
Random rand = RandomUtils.getRNG(0xDECAF, 0xDA7A);
_syn1 = new float[_parameters._vecSize * _vocabSize];
_syn0 = new float[_parameters._vecSize * _vocabSize];
for (int i = 0; i < _parameters._vecSize * _vocabSize; i++) _syn0[i] = (rand.nextFloat() - 0.5f) / _parameters._vecSize;
if(_parameters._normModel == Word2Vec.NormModel.HSM)
buildHuffmanBinaryWordTree();
else // NegSampling
buildUnigramTable();
}
/**
* Set of functions to accumulate counts of how many
* words were processed so far.
*/
private static int _localWordCnt=0, _globalWordCnt=0;
public synchronized void addLocallyProcessed(long p) { _localWordCnt += p; }
public synchronized long getLocallyProcessed() { return _localWordCnt; }
public synchronized void setLocallyProcessed(int p) { _localWordCnt = p; }
public synchronized void addGloballyProcessed(long p) { _globalWordCnt += p; }
public synchronized long getGloballyProcessed() { return _globalWordCnt; }
public synchronized long getTotalProcessed() { return _globalWordCnt + _localWordCnt; }
/**
* Used to add together the weight vectors between
* two map instances.
*
* @param other - parameters object from other map method
*/
protected void add(Word2VecModelInfo other) {
ArrayUtils.add(_syn0, other._syn0);
ArrayUtils.add(_syn1, other._syn1);
addLocallyProcessed(other.getLocallyProcessed());
}
/**
* Used to reduce the summations from map methods
* to an average across map/reduce threads.
*
* @param N - number of map/reduce threads to divide by
*/
protected void div(float N) {
if (N > 1) {
ArrayUtils.div(_syn0, N);
ArrayUtils.div(_syn1, N);
}
}
/**
* Calculates a new global learning rate for the next round
* of map/reduce calls.
* The learning rate is a coefficient that controls the amount that
* newly learned information affects current learned information.
*/
public void updateLearningRate() {
_curLearningRate = _parameters._initLearningRate * (1 - getTotalProcessed() / (float) (_parameters._epochs * _trainFrameSize + 1));
if (_curLearningRate < _parameters._initLearningRate * 0.0001F) _curLearningRate = _parameters._initLearningRate * 0.0001F;
}
/**
* Generates a unigram table from the [word, count] vocab frame.
* The unigram table is needed for normalizing through negative sampling.
*
* This design consumes memory for speed and simplicity. It also breaks for
* smaller vocabularies. Alternates should be explored.
*/
private void buildUnigramTable() {
float d = 0;
long vocabWordsPow = 0;
_uniTable = new int[UNIGRAM_TABLE_SIZE];
Vec wCount = (_parameters._vocabKey.get()).vec(1);
for (int i=0; i < wCount.length(); i++) vocabWordsPow += Math.pow(wCount.at8(i), UNIGRAM_POWER);
for (int i = 0, j =0; i < UNIGRAM_TABLE_SIZE; i++) {
_uniTable[i] = j;
if (j >= _vocabSize-1) j = 0;
if (i / (float) UNIGRAM_TABLE_SIZE > d)
d += Math.pow(wCount.at8(j++), UNIGRAM_POWER) / (float) vocabWordsPow;
}
}
/* Explored packing the unigram table into chunks for the benefit of
compression. The random access nature ended up increasing the run
time of a negative sampling run by ~50%.
private Key buildUnigramTable() {
Futures fs = new Futures();
Vec wCount, uniTblVec;
AppendableVec utAV = new AppendableVec(Vec.newKey());
NewChunk utNC = null;
long vocabWordsPow = 0;
float d = 0;
int chkIdx = 0;
wCount = ((Frame)_vocabKey.get()).vec(1);
for (int i=0; i < wCount.length(); i++) vocabWordsPow += Math.pow(wCount.at8(i), UNIGRAM_POWER);
for (int i = 0, j =0; i < UNIGRAM_TABLE_SIZE; i++) {
//allocate as needed
if ((i % Vec.CHUNK_SZ) == 0){
if (utNC != null) utNC.close(chkIdx++, fs);
utNC = new NewChunk(utAV, chkIdx);
}
utNC.addNum(j, 0);
if (i / (float) UNIGRAM_TABLE_SIZE > d) {
d += Math.pow(wCount.at8(++j), UNIGRAM_POWER) / (float) vocabWordsPow;
}
if (j >= _vocabSize) j = _vocabSize - 1;
}
//finalize vectors
utNC.close(chkIdx, fs);
uniTblVec = utAV.close(fs);
fs.blockForPending();
return uniTblVec._key;
} */
/**
* Generates the values for a Huffman binary tree
* from the [word, count] vocab frame.
*/
private void buildHuffmanBinaryWordTree() {
int min1i, min2i, pos1, pos2;
int[] point = new int[MAX_CODE_LENGTH];
int[] code = new int[MAX_CODE_LENGTH];
long[] count = new long[_vocabSize * 2 - 1];
int[] binary = new int[_vocabSize * 2 - 1];
int[] parent_node = new int[_vocabSize * 2 - 1];
Vec wCount = (_parameters._vocabKey.get()).vec(1);
_HBWTCode = new int[_vocabSize][];
_HBWTPoint = new int[_vocabSize][];
assert (_vocabSize == wCount.length());
for (int i = 0; i < _vocabSize; i++) count[i] = wCount.at8(i);
for (int i = _vocabSize; i < _vocabSize * 2 - 1; i++) count[i] = (long) 1e15;
pos1 = _vocabSize - 1;
pos2 = _vocabSize;
// Following algorithm constructs the Huffman tree by adding one node at a time
for (int i = 0; i < _vocabSize - 1; i++) {
// First, find two smallest nodes 'min1, min2'
if (pos1 >= 0) {
if (count[pos1] < count[pos2]) {
min1i = pos1;
pos1--;
} else {
min1i = pos2;
pos2++;
}
} else {
min1i = pos2;
pos2++;
}
if (pos1 >= 0) {
if (count[pos1] < count[pos2]) {
min2i = pos1;
pos1--;
} else {
min2i = pos2;
pos2++;
}
} else {
min2i = pos2;
pos2++;
}
count[_vocabSize + i] = count[min1i] + count[min2i];
parent_node[min1i] = _vocabSize + i;
parent_node[min2i] = _vocabSize + i;
binary[min2i] = 1;
}
// Now assign binary code to each vocabulary word
for (int j = 0; j < _vocabSize; j++) {
int k = j;
int m = 0;
while (true) {
int val = binary[k];
code[m] = val;
point[m] = k;
m++;
k = parent_node[k];
if (k == 0) break;
}
_HBWTCode[j] = new int[m];
_HBWTPoint[j] = new int[m + 1];
_HBWTPoint[j][0] = _vocabSize - 2;
for (int l = 0; l < m; l++) {
_HBWTCode[j][m - l - 1] = code[l];
_HBWTPoint[j][m - l] = point[l] - _vocabSize;
}
}
}
/**
* Calculates the number of words that Word2Vec will train on.
* This is a needed parameter for correct trimming of the learning
* rate in the algo. Rather that require the user to calculate it,
* this finds it and adds it to the parameters object.
*
* @param tf - frame containing words to train on
* @return count - total words in training frame
*/
private long getTrainFrameSize(Frame tf) {
long count=0;
for (Vec v: tf.vecs()) if(v.isString()) count += v.length();
return count;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy