de.datexis.sector.encoder.ClassEncoder Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of texoo-sector Show documentation
Show all versions of texoo-sector Show documentation
TeXoo module for Topic Segmentation and Classification (SECTOR)
package de.datexis.sector.encoder;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Document;
import de.datexis.model.Span;
import de.datexis.preprocess.LowercasePreprocessor;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map.Entry;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyWord;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.LoggerFactory;
/**
*
* @author Sebastian Arnold
*/
public class ClassEncoder extends LookupCacheEncoder {
private static final TokenPreProcess preprocessor = new LowercasePreprocessor();
public static final String ID = "CLS";
public ClassEncoder() {
super(ID);
log = LoggerFactory.getLogger(ClassEncoder.class);
}
public ClassEncoder(String id) {
super(id);
log = LoggerFactory.getLogger(ClassEncoder.class);
}
@Override
public String getName() {
return "Classification Encoder";
}
@Override
public INDArray encode(Span classLabel) {
return encode(classLabel.getText());
}
@Override
public long getEmbeddingVectorSize() {
return vocab.numWords();
}
/*@Override
public String getWord(int index) {
String word = super.getWord(index);
return (word != null ? word : "unknown");
}*/
@Override
public INDArray encode(String classLabel) {
return oneHot(classLabel);
}
public int getIndex(String word) {
String w = preprocessor.preProcess(word);
return vocab.indexOf(w);
}
public INDArray oneHot(String word) {
INDArray vector = Nd4j.zeros(getEmbeddingVectorSize(), 1);
int i = getIndex(word);
if(i>=0) vector.put(i, 0, 1.0);
else log.warn("could not encode class '{}'. is it contained in training set?", word);
return vector;
}
public boolean isUnknown(String classLabel) {
String w = preprocessor.preProcess(classLabel);
return !vocab.containsWord(w);
}
@Override
public void trainModel(Collection documents) {
throw new UnsupportedOperationException("cannot train classification on Documents");
}
public void trainModelUsingHead(Iterable classes) {
trainModel(classes, 0);
// stop after head of distribution (mean value reached)
double val = 0;
for(VocabularyWord word : vocab.words()) {
val += word.getCount();
}
vocab.truncateVocabulary((int)(val / vocab.numWords()));
vocab.updateHuffmanCodes();
appendTrainLog("truncated to " + vocab.numWords() + " classes");
}
public void trainModel(Iterable classes, int minClassFrequency) {
appendTrainLog("Training " + getName() + " model...");
setModel(null);
timer.start();
String w;
totalWords = 0;
for(String s : classes) {
w = preprocessor.preProcess(s);
totalWords++;
if(w.isEmpty()) continue;
if(!vocab.containsWord(w)) {
vocab.addWord(w);
} else {
vocab.incrementWordCounter(w);
}
}
int total = vocab.numWords();
vocab.truncateVocabulary(minClassFrequency);
vocab.updateHuffmanCodes();
timer.stop();
appendTrainLog("trained " + vocab.numWords() + " classes (" + total + " total)", timer.getLong());
setModelAvailable(true);
}
@Override
public String getNearestNeighbour(INDArray v) {
Collection knn = getNearestNeighbours(v, 1);
if(knn.isEmpty()) return null;
else return knn.iterator().next();
}
@Override
public Collection getNearestNeighbours(INDArray v, int k) {
// create copy
final Double[] data = new Double[(int) v.length()];
for(int j=0; j result = new ArrayList<>(k);
for(int i=0; i max) {
index = j;
max = data[j];
data[j] = Double.MIN_VALUE;
}
}
result.add(getWord(index));
}
return result;
}
public Collection> getNearestNeighbourEntries(INDArray v, int k) {
// create copy
final Double[] data = new Double[(int) v.length()];
for(int j=0; j> result = new ArrayList<>(k);
for(int i=0; i max) {
index = j;
max = data[j];
data[j] = Double.MIN_VALUE;
}
}
result.add(new AbstractMap.SimpleEntry<>(getWord(index),v.getDouble(index)));
}
return result;
}
}