All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.datexis.encoder.impl.LetterNGramEncoder Maven / Gradle / Ivy

package de.datexis.encoder.impl;

import de.datexis.model.Document;
import de.datexis.model.Token;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Span;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.LoggerFactory;

/**
 * A character n-gram encoder, mostly used for Letter-Trigrams
 * @author Sebastian Arnold 
 */
public class LetterNGramEncoder extends LookupCacheEncoder {
  
  /** the number of chars per gram */
  protected int n;
  
  public LetterNGramEncoder() {
    super("TRI");
    log = LoggerFactory.getLogger(LetterNGramEncoder.class);
  }
  
  public LetterNGramEncoder(String id) {
    super(id);
    log = LoggerFactory.getLogger(LetterNGramEncoder.class);
  }
  
  public LetterNGramEncoder(int n) {
    super("TRI");
    log = LoggerFactory.getLogger(LetterNGramEncoder.class);
    this.n = n;
  }
  
  @Override
  public String getName() {
    return Integer.toString(n)+"-gram Encoder";
  }

  public int getN() {
    return n;
  }
  
  public LetterNGramEncoder setN(int n) {
    this.n = n;
    return this;
  }

  @Override
  public INDArray encode(Span span) {
    return encode(span.getText());
  }

  @Override
  public INDArray encode(String phrase) {
    INDArray vector = Nd4j.zeros(getEmbeddingVectorSize(), 1);
    List ngrams = generateNGrams(phrase);
    int i;
    for(String ngram : ngrams) {
       i = getIndex(ngram);
       // best results were seen with no normalization and 1.0 instead of word frequency
       if(i>=0) vector.put(i, 0, 1.0);//cache.wordFrequency(ngram));
    }
    return vector;//.div(vector.norm1(0));
  }

  @Override
  public boolean isUnknown(String word) {
    List ngrams = generateNGrams(word);
    for(String ngram : ngrams) {
      if(!vocab.containsWord(ngram)) return true;
    }
    return false;
  }

  /**
   * Reduce the input to keep only neccesary characters. Output is lowercase!
   * (printable ISO-8859-1 and Windows-1252 chars)
   * @param input
   * @return 
   */
  public String keepOnlyPrintableChars(String input) {
    // Letters, Numbers, Punctuation
    // http://www.regular-expressions.info/unicode.html#category
    return input.replaceAll("[^\\p{L}\\p{N}\\p{P}\\p{Sm}\\p{Sc}]", "").toLowerCase();
  }
  
  /**
   * Generates a list of n-grams for a token, surrounded by #.
   * Example token = "cat", n = 3: [#ca,cat,at#]
   * @param token The token to split, as String
   * @param n The number of characters to use
   * @return A list of n-grams
   */
  public List generateNGrams(String token, int n) {
    String word = "#" + keepOnlyPrintableChars(token) + "#";
    List arr = new ArrayList<>();
    for (int start = 0; start <= word.length() - n; start++) {
      arr.add(word.substring(start, start + n));
    }
    return arr;
  }
  
  public List generateNGrams(String token) {
    return generateNGrams(token, n);
  }
  
  public List getTrigramsFromProbabilityVector(INDArray vec) {
    ArrayList result = new ArrayList<>();
    for(int i=0; i 0.5) result.add(getWord(i));
    }
    return result;
  }
  
  @Override
  public void trainModel(Collection documents) {
    trainModel(documents, 1);
  }
  
  public void trainModel(Collection documents, int minWordFrequency) {
    appendTrainLog("Training " + getName() + " model...");
    setModel(null);
    timer.start();
    List ngrams;
    totalWords = 0;
    for(Document doc : documents) {
      for(Token t : doc.getTokens()) {
        ngrams = generateNGrams(t.getText());
        for(String w : ngrams) {
          totalWords++;
          //System.out.println(x + ": " + cache.containsWord(x));
          if(!vocab.containsWord(w)) {
            vocab.addWord(w);
          } else {
            vocab.incrementWordCounter(w);
          }
        }
      }
    }
    int total = vocab.numWords();
    vocab.truncateVocabulary(minWordFrequency); // TODO: truncate
    vocab.updateHuffmanCodes();
    timer.stop();
    appendTrainLog("trained " + vocab.numWords() + " " + n + "-grams (" +  total + " total)", timer.getLong());
    setModelAvailable(true);
  }
  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy