All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.word2vec.Word2VecMojoReader Maven / Gradle / Ivy

There is a newer version: 3.46.0.5
Show newest version
package hex.genmodel.algos.word2vec;

import hex.genmodel.ModelMojoReader;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Iterator;

public class Word2VecMojoReader extends ModelMojoReader {

  @Override
  public String getModelName() {
    return "Word2Vec";
  }

  @Override
  protected void readModelData() throws IOException {
    final int vocabSize = readkv("vocab_size", -1);
    final int vecSize = readkv("vec_size", -1);

    _model._vecSize = vecSize;
    _model._embeddings = new HashMap<>(vocabSize);

    byte[] rawVectors = readblob("vectors");
    if (rawVectors.length != vocabSize * vecSize * 4)
      throw new IOException("Corrupted vector representation, unexpected size: " + rawVectors.length);
    ByteBuffer bb = ByteBuffer.wrap(rawVectors);

    Iterator vocabulary = readtext("vocabulary", true).iterator();
    while (vocabulary.hasNext()) {
      float[] vec = new float[vecSize];
      for (int i = 0; i < vecSize; i++)
        vec[i] = bb.getFloat();
      _model._embeddings.put(vocabulary.next(), vec);
    }

    if (_model._embeddings.size() != vocabSize)
      throw new IOException("Corrupted model, unexpected number of words: " + _model._embeddings.size());
  }

  @Override
  protected Word2VecMojoModel makeModel(String[] columns, String[][] domains, String responseColumn) {
    return new Word2VecMojoModel(columns, domains, responseColumn);
  }

  @Override public String mojoVersion() {
    return "1.00";
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy