All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
de.datexis.encoder.impl.BagOfWordsEncoder Maven / Gradle / Ivy
package de.datexis.encoder.impl;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import de.datexis.preprocess.MinimalLowercaseNewlinePreprocessor;
import org.apache.commons.math3.util.Pair;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.LoggerFactory;
import java.util.*;
/**
* A Bag-Of-Words N-Hot Encoder with stopword and minFreq training
* @author sarnold
*/
public class BagOfWordsEncoder extends LookupCacheEncoder {
protected TokenPreProcess preprocessor = new MinimalLowercaseNewlinePreprocessor();
protected WordHelpers wordHelpers;
protected WordHelpers.Language language;
public BagOfWordsEncoder() {
this("BOW");
}
public BagOfWordsEncoder(String id) {
super(id);
log = LoggerFactory.getLogger(BagOfWordsEncoder.class);
vocab = new VocabularyHolder.Builder().build();
}
public Class getPreprocessorClass() {
return preprocessor.getClass();
}
public void setPreprocessorClass(String preprocessor) throws ClassNotFoundException, IllegalAccessException, InstantiationException {
Class> clazz = Class.forName(preprocessor);
this.preprocessor = (TokenPreProcess) clazz.newInstance();
}
@JsonIgnore
public TokenPreProcess getPreprocessor() {
return preprocessor;
}
public void setPreprocessor(TokenPreProcess preprocessor) {
this.preprocessor = preprocessor;
}
@Override
public String getName() {
return "Bag-of-words Encoder";
}
@Override
public void trainModel(Collection documents) {
trainModel(documents, 1, WordHelpers.Language.EN);
}
public void trainModel(Collection documents, int minWordFrequency, WordHelpers.Language language) {
appendTrainLog("Training " + getName() + " model...");
setModel(null);
totalWords = 0;
timer.start();
setLanguage(language);
for(Document doc : documents) {
for(Token t : doc.getTokens()) {
String w = preprocessor.preProcess(t.getText());
if(!w.isEmpty()) {
totalWords++;
if(!wordHelpers.isStopWord(w)) {
if(!vocab.containsWord(w)) vocab.addWord(w);
else vocab.incrementWordCounter(w);
}
}
}
}
int total = vocab.numWords();
vocab.truncateVocabulary(minWordFrequency);
vocab.updateHuffmanCodes();
timer.stop();
appendTrainLog("trained " + vocab.numWords() + " words (" + total + " total)", timer.getLong());
setModelAvailable(true);
}
public void trainModel(Iterable sentences, int minWordFrequency, int minWordLength, WordHelpers.Language language) {
appendTrainLog("Training " + getName() + " model...");
setModel(null);
totalWords = 0;
timer.start();
setLanguage(language);
for(String s : sentences) {
for(String t : WordHelpers.splitSpaces(s)) {
String w = preprocessor.preProcess(t);
if(!w.isEmpty()) {
totalWords++;
if(!wordHelpers.isStopWord(w) && w.length() >= minWordLength) {
if(!vocab.containsWord(w)) vocab.addWord(w);
else vocab.incrementWordCounter(w);
}
}
}
}
int total = vocab.numWords();
vocab.truncateVocabulary(minWordFrequency);
vocab.updateHuffmanCodes();
timer.stop();
appendTrainLog("trained " + vocab.numWords() + " words (" + total + " total)", timer.getLong());
setModelAvailable(true);
}
@Override
public boolean isUnknown(String word) {
return super.isUnknown(preprocessor.preProcess(word));
}
@Override
public int getIndex(String word) {
return super.getIndex(preprocessor.preProcess(word));
}
@Override
public int getFrequency(String word) {
return super.getFrequency(preprocessor.preProcess(word));
}
@Override
public double getProbability(String word) {
return super.getProbability(preprocessor.preProcess(word));
}
public WordHelpers.Language getLanguage() {
return language;
}
public void setLanguage(WordHelpers.Language language) {
this.language = language;
wordHelpers = new WordHelpers(language);
}
/**
* Encode a list of Tokens into an n-hot vector
* @param spans
* @return
*/
@Override
public INDArray encode(Iterable extends Span> spans) {
INDArray vector = Nd4j.zeros(getEmbeddingVectorSize(), 1);
int i;
// best results were seen with no normalization and 1.0 instead of word frequency
for(Span s : spans) {
i = getIndex(s.getText());
if(i>=0) vector.put(i, 0, 1.0);
}
return vector;
}
/**
* Encode a list of Strings into an n-hot vector
*/
protected INDArray encode(String[] words) {
INDArray vector = Nd4j.zeros(getEmbeddingVectorSize(), 1);
int i;
// best results were seen with no normalization and 1.0 instead of word frequency
for(String w : words) {
i = getIndex(w);
if(i>=0) vector.put(i, 0, 1.0);
}
return vector;
}
@Override
public INDArray encode(Span span) {
if(span instanceof Token) return encode(Arrays.asList(span));
else if(span instanceof Sentence) return encode(((Sentence) span).getTokens());
else return encode(span.getText());
}
/**
* Encode a phrase, splitting at spaces.
* @param phrase
* @return
*/
@Override
public INDArray encode(String phrase) {
return encode(WordHelpers.splitSpaces(phrase));
}
/**
* Tokenizes the String and encodes one word out of it with given distribution.
*/
public INDArray encodeSubsampled(String phrase) {
INDArray vector = Nd4j.zeros(getEmbeddingVectorSize(), 1);
String[] tokens = WordHelpers.splitSpaces(phrase);
if(tokens.length == 1) return encode(tokens[0]);
List> itemWeights = new ArrayList<>(5);
double completeWeight = 0.0;
String w;
for(String t : tokens) {
w = preprocessor.preProcess(t);
if(!w.isEmpty() && !wordHelpers.isStopWord(w)) {
final double weight = samplingRate(super.getProbability(w));
if(weight == 1.) continue; // word not in vocab
completeWeight += weight;
itemWeights.add(new Pair(w, weight));
}
}
double r = Math.random() * completeWeight;
double countWeight = 0.0;
for(Pair item : itemWeights) {
countWeight += item.getValue();
if(countWeight >= r) {
int i = getIndex(item.getKey());
if(i>=0) vector.put(i, 0, 1.0);
return vector;
}
}
return vector; // return zeroes
}
public double getConfidence(INDArray v, int i) {
return v.getDouble(i);
}
public double getMaxConfidence(INDArray v) {
return v.max(0).sumNumber().doubleValue();
}
public Set asString(Iterable tokens) {
Set result = new HashSet<>();
for(Token t : tokens) {
if(!isUnknown(t.getText())) result.add(preprocessor.preProcess(t.getText()));
}
return result;
}
@Override
public String getNearestNeighbour(INDArray v) {
Collection knn = getNearestNeighbours(v, 1);
if(knn.isEmpty()) return null;
else return knn.iterator().next();
}
@Override
public Collection getNearestNeighbours(INDArray v, int n) {
// find maximum entries
INDArray[] sorted = Nd4j.sortWithIndices(v.dup(), 0, false); // index,value
if(sorted[0].sumNumber().doubleValue() == 0.) // TODO: sortWithIndices could be run on -1 / 0 / 1 ?
log.warn("NearestNeighbour on zero vector - please check vector alignment!");
INDArray idx = sorted[0]; // ranked indexes
// get top n
ArrayList result = new ArrayList<>(n);
for(int i=0; i 0.) result.add(getWord(idx.getInt(i)));
}
return result;
}
public boolean keepWord(String word) {
return(Math.random() < samplingRate(word));
}
/**
* Sets words in a given target to 0 based on probabilities.
* http://mccormickml.com/2017/01/11/word2vec-tutorial-part-2-negative-sampling/
*/
public INDArray subsample(INDArray target) {
INDArray result = target.dup();
for(int i=0; i 0) {
if(!keepWord(getWord(i))) result.putScalar(i, 0.);
}
}
return result;
}
protected double samplingRate(String word) {
double p = getProbability(word);
return (Math.sqrt(p / 0.001) + 1) * (0.001 / p);
}
protected double samplingRate(double p) {
//return (Math.sqrt(p / 0.001) + 1) * (0.001 / p);
return 0.001 / (0.001 + p);
}
@JsonIgnore
public INDArray subsampleWeights() {
INDArray vector = Nd4j.zeros(getEmbeddingVectorSize(), 1);
String w;
for(int i=0; i