All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.datexis.encoder.impl.BagOfWordsEncoder Maven / Gradle / Ivy

package de.datexis.encoder.impl;

import de.datexis.common.WordHelpers;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import de.datexis.preprocess.MinimalLowercaseNewlinePreprocessor;
import org.apache.commons.math3.util.Pair;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.LoggerFactory;

import java.util.*;

/**
 * A Bag-Of-Words N-Hot Encoder with stopword and minFreq training
 * @author sarnold
 */
public class BagOfWordsEncoder extends LookupCacheEncoder {
  
  protected TokenPreProcess preprocessor = new MinimalLowercaseNewlinePreprocessor();
  protected WordHelpers wordHelpers;
  protected WordHelpers.Language language;
  
  public BagOfWordsEncoder() {
    this("BOW");
  }
  
  public BagOfWordsEncoder(String id) {
    super(id);
    log = LoggerFactory.getLogger(BagOfWordsEncoder.class);
    vocab = new VocabularyHolder.Builder().build();
  }
  
  public Class getPreprocessorClass() {
    return preprocessor.getClass();
  }
  
  public void setPreprocessorClass(String preprocessor) throws ClassNotFoundException, IllegalAccessException, InstantiationException {
    Class clazz = Class.forName(preprocessor);
    this.preprocessor = (TokenPreProcess) clazz.newInstance();
  }
  
  @JsonIgnore
  public TokenPreProcess getPreprocessor() {
    return preprocessor;
  }
  
  public void setPreprocessor(TokenPreProcess preprocessor) {
    this.preprocessor = preprocessor;
  }
  
  @Override
  public String getName() {
    return "Bag-of-words Encoder";
  }

  @Override
  public void trainModel(Collection documents) {
    trainModel(documents, 1, WordHelpers.Language.EN);
  }
  
  public void trainModel(Collection documents, int minWordFrequency, WordHelpers.Language language) {
    appendTrainLog("Training " + getName() + " model...");
    setModel(null);
    totalWords = 0;
    timer.start();
    setLanguage(language);
    for(Document doc : documents) {
      for(Token t : doc.getTokens()) {
        String w = preprocessor.preProcess(t.getText());
        if(!w.isEmpty()) {
          totalWords++;
          if(!wordHelpers.isStopWord(w)) {
            if(!vocab.containsWord(w)) vocab.addWord(w);
            else vocab.incrementWordCounter(w);
          }
        }
      }
    }
    int total = vocab.numWords();
    vocab.truncateVocabulary(minWordFrequency);
    vocab.updateHuffmanCodes();
    timer.stop();
    appendTrainLog("trained " + vocab.numWords() + " words (" +  total + " total)", timer.getLong());
    setModelAvailable(true);
  }
  
  public void trainModel(Iterable sentences, int minWordFrequency, int minWordLength, WordHelpers.Language language) {
    appendTrainLog("Training " + getName() + " model...");
    setModel(null);
    totalWords = 0;
    timer.start();
    setLanguage(language);
    for(String s : sentences) {
      for(String t : WordHelpers.splitSpaces(s)) {
        String w = preprocessor.preProcess(t);
        if(!w.isEmpty()) {
          totalWords++;
          if(!wordHelpers.isStopWord(w) && w.length() >= minWordLength) {
            if(!vocab.containsWord(w)) vocab.addWord(w);
            else vocab.incrementWordCounter(w);
          }
        }
      }
    }
    int total = vocab.numWords();
    vocab.truncateVocabulary(minWordFrequency);
    vocab.updateHuffmanCodes();
    timer.stop();
    appendTrainLog("trained " + vocab.numWords() + " words (" +  total + " total)", timer.getLong());
    setModelAvailable(true);
  }
  
  @Override
  public boolean isUnknown(String word) {
    return super.isUnknown(preprocessor.preProcess(word));
  }
  
  @Override
  public int getIndex(String word) {
    return super.getIndex(preprocessor.preProcess(word));
  }
  
  @Override
  public int getFrequency(String word) {
    return super.getFrequency(preprocessor.preProcess(word));
  }
  
  @Override
  public double getProbability(String word) {
    return super.getProbability(preprocessor.preProcess(word));
  }

  public WordHelpers.Language getLanguage() {
    return language;
  }

  public void setLanguage(WordHelpers.Language language) {
    this.language = language;
    wordHelpers = new WordHelpers(language);
  }
  
  /**
   * Encode a list of Tokens into an n-hot vector
   * @param spans
   * @return 
   */
  @Override
  public INDArray encode(Iterable spans) {
    INDArray vector = Nd4j.zeros(getEmbeddingVectorSize(), 1);
    int i;
    // best results were seen with no normalization and 1.0 instead of word frequency
    for(Span s : spans) {
      i = getIndex(s.getText());
      if(i>=0) vector.put(i, 0, 1.0);
    }
    return vector;
  }
  
  /**
   * Encode a list of Strings into an n-hot vector
   */
  protected INDArray encode(String[] words) {
    INDArray vector = Nd4j.zeros(getEmbeddingVectorSize(), 1);
    int i;
    // best results were seen with no normalization and 1.0 instead of word frequency
    for(String w : words) {
      i = getIndex(w);
      if(i>=0) vector.put(i, 0, 1.0);
    }
    return vector;
  }
  
  @Override
  public INDArray encode(Span span) {
    if(span instanceof Token) return encode(Arrays.asList(span));
    else if(span instanceof Sentence) return encode(((Sentence) span).getTokens());
    else return encode(span.getText());
  }

  /**
   * Encode a phrase, splitting at spaces.
   * @param phrase
   * @return 
   */
  @Override
  public INDArray encode(String phrase) {
    return encode(WordHelpers.splitSpaces(phrase));
  }
  
  /**
   * Tokenizes the String and encodes one word out of it with given distribution.
   */
  public INDArray encodeSubsampled(String phrase) {
    INDArray vector = Nd4j.zeros(getEmbeddingVectorSize(), 1);
    String[] tokens = WordHelpers.splitSpaces(phrase);
    if(tokens.length == 1) return encode(tokens[0]);
    List> itemWeights = new ArrayList<>(5);
    double completeWeight = 0.0;
    String w;
    for(String t : tokens) {
      w = preprocessor.preProcess(t);
      if(!w.isEmpty() && !wordHelpers.isStopWord(w)) {
        final double weight = samplingRate(super.getProbability(w));
        if(weight == 1.) continue; // word not in vocab
        completeWeight += weight;
        itemWeights.add(new Pair(w, weight));
      }
    }
    double r = Math.random() * completeWeight;
    double countWeight = 0.0;
    for(Pair item : itemWeights) {
        countWeight += item.getValue();
        if(countWeight >= r) {
          int i = getIndex(item.getKey());
          if(i>=0) vector.put(i, 0, 1.0);
          return vector;
        }
    }
    return vector; // return zeroes
  }
  
  public double getConfidence(INDArray v, int i) {
    return v.getDouble(i);
  }
  
  public double getMaxConfidence(INDArray v) {
    return v.max(0).sumNumber().doubleValue();
  }

  public Set asString(Iterable tokens) {
    Set result = new HashSet<>();
    for(Token t : tokens) {
      if(!isUnknown(t.getText())) result.add(preprocessor.preProcess(t.getText()));
    }
    return result;
  }
  
  @Override
  public String getNearestNeighbour(INDArray v) {
    Collection knn = getNearestNeighbours(v, 1);
    if(knn.isEmpty()) return null;
    else return knn.iterator().next();
  }
  
  @Override
  public Collection getNearestNeighbours(INDArray v, int n) {
    // find maximum entries
    INDArray[] sorted = Nd4j.sortWithIndices(v.dup(), 0, false); // index,value
    if(sorted[0].sumNumber().doubleValue() == 0.) // TODO: sortWithIndices could be run on -1 / 0 / 1 ?
      log.warn("NearestNeighbour on zero vector - please check vector alignment!");
    INDArray idx = sorted[0]; // ranked indexes
    // get top n
    ArrayList result = new ArrayList<>(n);
    for(int i=0; i 0.) result.add(getWord(idx.getInt(i)));
    }
    return result;
	}

  public boolean keepWord(String word) {
    return(Math.random() < samplingRate(word));
  }

  /**
   * Sets words in a given target to 0 based on probabilities.
   * http://mccormickml.com/2017/01/11/word2vec-tutorial-part-2-negative-sampling/
   */
  public INDArray subsample(INDArray target) {
    INDArray result = target.dup();
    for(int i=0; i 0) {
        if(!keepWord(getWord(i))) result.putScalar(i, 0.);
      }
    }
    return result;
  }
  
  protected double samplingRate(String word) {
    double p = getProbability(word);
    return (Math.sqrt(p / 0.001) + 1) * (0.001 / p);
  }
  
  protected double samplingRate(double p) {
    //return (Math.sqrt(p / 0.001) + 1) * (0.001 / p);
    return 0.001 / (0.001 + p);
  }

  @JsonIgnore
  public INDArray subsampleWeights() {
    INDArray vector = Nd4j.zeros(getEmbeddingVectorSize(), 1);
    String w;
    for(int i=0; i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy