All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.datexis.encoder.impl.OneHotEncoder Maven / Gradle / Ivy

package de.datexis.encoder.impl;

import de.datexis.model.Document;
import de.datexis.model.Token;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Span;
import de.datexis.preprocess.MinimalLowercasePreprocessor;
import java.util.ArrayList;
import java.util.Collection;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.LoggerFactory;

/**
 * A one-hot encoder
 * @author sarnold
 */
public class OneHotEncoder extends LookupCacheEncoder {
  
  private static final TokenPreProcess preprocessor = new MinimalLowercasePreprocessor();
  
  public OneHotEncoder() {
    super("1H");
    log = LoggerFactory.getLogger(OneHotEncoder.class);
  }
  
  public OneHotEncoder(String id) {
    super(id);
    log = LoggerFactory.getLogger(OneHotEncoder.class);
  }
  
  @Override
  public String getName() {
    return "1-hot Encoder";
  }

  @Override
  public INDArray encode(Span span) {
    return encode(span.getText());
  }

  @Override
  public INDArray encode(String word) {
    INDArray vector = Nd4j.zeros(getEmbeddingVectorSize(), 1);
    String w = preprocessor.preProcess(word);
    int i = vocab.indexOf(w);
    if(i>=0) vector.put(i, 0, 1.0);
    return vector;
  }

  public boolean isUnknown(String word) {
    String w = preprocessor.preProcess(word);
    return !vocab.containsWord(w);
  }
    
  @Override
  public void trainModel(Collection documents) {
    trainModel(documents, 1);
  }

  /**
   * Trains the model for every word in the document
   * @param documents
   * @param minWordFrequency 
   */
  public void trainModel(Collection documents, int minWordFrequency) {
    appendTrainLog("Training " + getName() + " model...");
    setModel(null);
    timer.start();
    String w;
    totalWords = 0;
    for(Document doc : documents) {
      for(Token t : doc.getTokens()) {
        w = preprocessor.preProcess(t.getText());
        totalWords++;
        if(w.isEmpty()) continue;
        if(!vocab.containsWord(w)) {
          vocab.addWord(w);
        } else {
          vocab.incrementWordCounter(w);
        }
      }
    }
    int total = vocab.numWords();
    vocab.truncateVocabulary(minWordFrequency);
    vocab.updateHuffmanCodes();
    timer.stop();
    appendTrainLog("trained " + vocab.numWords() + " words (" +  total + " total)", timer.getLong());
    setModelAvailable(true);
  }
  
  @Override
  public Collection getNearestNeighbours(INDArray v, int n) {
    // create copy
    final Double[] data = new Double[(int)v.length()];
    for(int j=0; j result = new ArrayList<>(n);
    for(int i=0; i max) {
          index = j;
          max = data[j];
          data[j] = Double.MIN_VALUE;
        }
      }
      result.add(getWord(index));
    }
    return result;
	}
  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy