All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.datexis.encoder.LookupCacheEncoder Maven / Gradle / Ivy

package de.datexis.encoder;

import de.datexis.common.Resource;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyWord;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.LoggerFactory;

/**
 * Outline for a simple cache-based 1-hot encoder
 * @author Sebastian Arnold 
 */
public abstract class LookupCacheEncoder extends Encoder {

  /** a cache of all existing n-grams */
  protected VocabularyHolder vocab;
  
  protected int totalWords = 0;
  
  public LookupCacheEncoder() {
    this("");
  }
  
  public LookupCacheEncoder(String id) {
    super(id);
    log = LoggerFactory.getLogger(LookupCacheEncoder.class);
    vocab = new VocabularyHolder.Builder().build();
  }

  public int getTotalWords() {
    return totalWords;
  }

  public void setTotalWords(int totalWords) {
    this.totalWords = totalWords;
  }
  
  @Override
  //@JsonIgnore
  public long getEmbeddingVectorSize() {
    return vocab.numWords();
  }
  
  /**
   * Return the index of a word in the vocabulary.
   */
  public int getIndex(String word) {
    return vocab.indexOf(word);
  }
  
  /**
   * @return target vector for Vocabulary word or null vector if word dies not exist
   */
  public INDArray oneHot(String word) {
    INDArray vector = Nd4j.zeros(getEmbeddingVectorSize(), 1);
    int i = getIndex(word);
    if(i>=0) vector.put(i, 0, 1.0);
    else log.warn("could not encode class '{}'. is it contained in training set?", word);
    return vector;
  }
  
  /**
   * Return the frequency of a word in the vocabulary.
   */
  public int getFrequency(String word) {
    VocabularyWord w = vocab.getVocabularyWordByString(word);
    return w != null ? w.getCount() : 0;
  }
  
  public double getProbability(String word) {
    return getFrequency(word) / (double) totalWords;
  }
  
  public double getConfidence(INDArray v, int i) {
    return v.getDouble(i);
  }
  
  public double getMaxConfidence(INDArray v) {
    return v.max(0).sumNumber().doubleValue();
  }
  
  /**
   * Return word from the vocabulary at a given index.
   */
  public String getWord(int index) {
    VocabularyWord w = vocab.getVocabularyWordByIdx(index);
    return w != null ? w.getWord() : null;
  }
  
  public boolean isUnknown(String word) {
    return !vocab.containsWord(word);
  }
  
  /**
   * Saves the model to .tsv.gz
   * @param modelPath
   * @param name 
   */
  @Override
  public void saveModel(Resource modelPath, String name) {
    Resource modelFile = modelPath.resolve(name + ".tsv.gz");
    try(OutputStreamWriter out = new OutputStreamWriter(modelFile.getGZIPOutputStream())) {
      // FIXME: simply serialize vocab!
      int i = 0;
      for(VocabularyWord w : vocab.getVocabulary()) {
        i++;
        out.write(w.getHuffmanNode().getIdx() + "\t" + w.getWord() + "\t" + w.getCount() + "\n");
      }
      setModel(modelFile);
      log.info("saved " + i + " words");
    } catch(IOException ex) {
      log.error(ex.toString());
    }
    Resource streamFile = modelPath.resolve(name + ".bin");
    try(ObjectOutputStream oos = new ObjectOutputStream(streamFile.getOutputStream())) {
      oos.writeObject(vocab);
    } catch(IOException ex) {
      log.error(ex.toString());
    }
  }
  
  @Override
  public void loadModel(Resource modelFile) throws IOException {
    /* unfortunately, VocabularyHolder is in fact not serializable yet
      if(modelFile.getFileName().endsWith(".bin")) {
      vocab = SerializationUtils.readObject(modelFile.getInputStream());
      vocab.updateHuffmanCodes();
      setModel(modelFile);
      setModelAvailable(true);
      log.info("loaded " + vocab.numWords() + " words from binary " + modelFile.toString());
    } else {*/
    try(BufferedReader fr = new BufferedReader(new InputStreamReader(modelFile.getInputStream(), "UTF-8"))) {
      String line;
      int i = 1000000000; // FIXME: we are abusing this VocabCache here. But it works (TM)
                          // No, seriously, this should only be used fpr legacy models in the future. Let's save as binary files.
      while((line = fr.readLine()) != null) {
        String[] tsv = line.split("\\t");
        VocabularyWord w = new VocabularyWord(tsv[1]);
        //w.setCount(Integer.parseInt(tsv[2]));
        int huffmanIdx = Integer.parseInt(tsv[0]);
        w.setCount(i-huffmanIdx); // dirty fix to override resort of huffman tree for same counts
        vocab.addWord(w);
      }
      vocab.updateHuffmanCodes();
      setModel(modelFile);
      setModelAvailable(true);
      log.info("loaded " + vocab.numWords() + " words from " + modelFile.toString());
    }
  }
  
  @JsonIgnore
  public List getWords() {
    return vocab.words().stream().map( a -> a.getWord()).collect(Collectors.toList());
  }
  
  /**
   * Return K nearest neighbours
   */
  public Collection getNearestNeighbours(String word, int k) {
    throw new UnsupportedOperationException("No nearest words in LookupCache.");
  }

  public String getNearestNeighbour(INDArray v) {
    throw new UnsupportedOperationException("No nearest words in LookupCache.");
  }

  public Collection getNearestNeighbours(INDArray v, int k) {
    throw new UnsupportedOperationException("No nearest words in LookupCache.");
  }
  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy