All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.datexis.encoder.impl.Word2VecEncoder Maven / Gradle / Ivy

package de.datexis.encoder.impl;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import de.datexis.common.ObjectSerializer;
import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import de.datexis.preprocess.IdentityPreprocessor;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Counter;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.util.*;
import java.util.stream.Collectors;

import static org.deeplearning4j.models.embeddings.loader.WordVectorSerializer.fromPair;

/**
 * A Word2Vec model from http://deeplearning4j.org/word2vec.html
 * @author Sebastian Arnold 
 */
@JsonIgnoreProperties(ignoreUnknown = true)
//@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "class")
public class Word2VecEncoder extends Encoder {

	private static final Logger log = LoggerFactory.getLogger(Word2VecEncoder.class);

  public static enum ModelType { TEXT, BINARY, DL4J, GOOGLE };
  
  private final static Collection FILENAMES_TEXT = Arrays.asList(".txt", ".txt.gz", ".vec");
  private final static Collection FILENAMES_BINARY = Arrays.asList(".bin", ".bin.gz");
  private final static Collection FILENAMES_DL4J = Arrays.asList(".zip");
  private final static Collection FILENAMES_GOOGLE = Arrays.asList(".zip");
  
	private WordVectors vec;
	private long length;
  private String modelName;
  private boolean saveModelReference = false;
  private TokenPreProcess preprocessor = new IdentityPreprocessor();
  
	public Word2VecEncoder() {
    super("EMB");
  }
  
  public Word2VecEncoder(String id) {
    super(id);
  }

  public static Word2VecEncoder load(Resource path) throws IOException {
    Word2VecEncoder vec = new Word2VecEncoder();
    vec.loadModel(path);
    return vec;
  }
  
  /**
   * Load a dummy encoder that returns only zeros.
   */
  public static Word2VecEncoder loadDummyEncoder() {
    Word2VecEncoder vec = new Word2VecEncoder();
    Resource txt = Resource.fromJAR("encoder/word2vec.txt");
    try {
      vec.loadModel(txt);
    } catch(IOException e) {
      log.error("could not load dummy encoder!");
    }
    return vec;
  }
  
  public void loadModelAsReference(Resource modelFile) throws IOException {
    loadModel(modelFile);
    saveModelReference = true;
  }
  
  @Override
  public void loadModel(Resource modelFile) throws IOException {
    log.info("Loading Word2Vec model: {} with preprocessor {}", modelFile.getFileName(), getPreprocessorClass());
    switch(getModelType(modelFile.getFileName())) {
      default:
      case TEXT: vec = WordVectorSerializer.loadStaticModel(modelFile.toFile()); break;
      case BINARY: vec = Word2VecEncoder.loadBinaryModel(modelFile.getInputStream()); break;
      case DL4J: vec = WordVectorSerializer.loadStaticModel(modelFile.toFile()); break;
      case GOOGLE: vec = WordVectorSerializer.loadStaticModel(modelFile.toFile()); break;
    }
    int size = vec.vocab().numWords();
    INDArray example = vec.getWordVectorMatrix(vec.vocab().wordAtIndex(0));
    length = example.length();
    setModel(modelFile);
    setModelAvailable(true);
    log.info("Loaded Word2Vec model '" +  modelFile.getFileName() + "' with " + size + " vectors of size " + length );
	}
  
  @Override
  public void saveModel(Resource modelPath, String name) throws IOException {
    saveModel(modelPath, name, ModelType.BINARY);
  }
  
  public void saveModel(Resource modelPath, String name, ModelType type) throws IOException {
    // rely on AnnotatorFactory to find the model in the search path
    if(saveModelReference) return;
    // TODO: we also need to save the input Token Preprocessor!
    Resource modelFile;
    ObjectSerializer.writeJSON(this, modelPath.resolve("config.json"));
    switch(type) {
      default:
      case BINARY: {
        modelFile = modelPath.resolve(name + ".bin");
        Word2VecEncoder.writeBinaryModel(vec, modelFile.getOutputStream());
      } break;
      case TEXT: {
        modelFile = modelPath.resolve(name + ".txt.gz");
        WordVectorSerializer.writeWordVectors((Word2Vec) vec, modelFile.getGZIPOutputStream());
      } break;
      case DL4J: {
        modelFile = modelPath.resolve(name+".zip");
        WordVectorSerializer.writeWord2VecModel((Word2Vec) vec, modelFile.getOutputStream());
      } break;
      case GOOGLE: {
        modelFile = null;
        log.error("Cannot write Google Model");
      } break;
    }
    setModel(modelFile);
  }

  public void setPreprocessor(TokenPreProcess preprocessor) {
    this.preprocessor = preprocessor;
  }
  
  @Override
  public void trainModel(Collection documents) {
    int batchSize = 16;
    int windowSize = 10;
    int minWordFrequency = 3;
    int layerSize = 256;
    int iterations = 5;
    int epochs = 1;
    trainModel(documents.stream().flatMap(d -> d.streamSentences()).collect(Collectors.toList()),
      batchSize, windowSize, minWordFrequency, layerSize, iterations, epochs, new ArrayList<>());
  }
  
  /*public void trainModel(Iterable sentences) {
    int batchSize = 1000;
    int windowSize = 5;
    int minWordFrequency = 2;
    int layerSize = 150;
    int iterations = 1;
    int epochs = 1;
    trainModel(sentences, batchSize, windowSize, minWordFrequency, layerSize, iterations, epochs, new ArrayList());
  }*/
  
  public void trainModel(Iterable sentences, int batchSize, int windowSize, int minWordFrequency, int layerSize, int iterations, int epochs, List stopWords) {
		SentenceIterator iter = new SentenceStringIterator(sentences);
    trainModel(iter, batchSize, windowSize, minWordFrequency, layerSize, iterations, epochs, stopWords);
	}
  
  public void trainModel(SentenceIterator iter, int batchSize, int windowSize, int minWordFrequency, int layerSize, int iterations, int epochs, List stopWords) {
  
    //CudaEnvironment.getInstance().getConfiguration().allowMultiGPU(true).setMemoryModel(MemoryModel.DELAYED);
    // PLEASE NOTE: For CUDA FP16 precision support is available
//    DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF);

    // temp workaround for backend initialization
    Nd4j.create(1);

    /*CudaEnvironment.getInstance().getConfiguration()
        // key option enabled
        .allowMultiGPU(true)

        // we're allowing larger memory caches
        .setMaximumDeviceCache(2L * 1024L * 1024L * 1024L)

        // cross-device access is used for faster model averaging over pcie
        .allowCrossDeviceAccess(true);
    */
    
    TokenizerFactory t = new DefaultTokenizerFactory();
		t.setTokenPreProcessor(preprocessor);
    
		log.info("Building model....");
    vec = new org.deeplearning4j.models.word2vec.Word2Vec.Builder()
            .batchSize(batchSize) //# words per minibatch.
            .windowSize(windowSize)
            .minWordFrequency(minWordFrequency) // 
            .useAdaGrad(false) //
            .layerSize(layerSize) // word feature vector size
            .seed(42)
            .iterations(iterations) // # iterations to train
            .epochs(epochs)
            .stopWords(stopWords)
            .learningRate(0.025) // 
            .minLearningRate(0.001) // learning rate decays wrt # words. floor learning
            .negativeSample(10) // sample size 10 words
            .iterate(iter) //
            .tokenizerFactory(t)
            .build();

			log.info("Fitting Word2Vec model....");
      ((org.deeplearning4j.models.word2vec.Word2Vec) vec).fit();
			
	}

  public static ModelType getModelType(String filename) {
    String name = filename.toLowerCase();
    if(FILENAMES_TEXT.stream().anyMatch(ext -> name.endsWith(ext))) return ModelType.TEXT;
    else if(FILENAMES_BINARY.stream().anyMatch(ext -> name.endsWith(ext))) return ModelType.BINARY;
    else if(FILENAMES_DL4J.stream().anyMatch(ext -> name.endsWith(ext))) return ModelType.DL4J;
    else if(FILENAMES_GOOGLE.stream().anyMatch(ext -> name.endsWith(ext))) return ModelType.GOOGLE;
    else return ModelType.TEXT;
  }
  
  public Class getPreprocessorClass() {
    return preprocessor.getClass();
  }
  
  public void setPreprocessorClass(String preprocessor) throws ClassNotFoundException, IllegalAccessException, InstantiationException {
    Class clazz = Class.forName(preprocessor);
    this.preprocessor = (TokenPreProcess) clazz.newInstance();
  }
  
	@Override
	public String getName() {
		return modelName;
	}
  
	/**
	 * Use this function to access word vectors
	 * @param word
	 * @return
	 */
	private INDArray getWordVector(String word) {
		return vec.getWordVectorMatrix(preprocessor.preProcess(word));
	}

	public boolean isUnknown(String word) {
		return !vec.hasWord(preprocessor.preProcess(word));
	}

	@Override
	public INDArray encode(Span span) {
    if(span instanceof Token) return encode(preprocessor.preProcess(span.getText()));
    else return encode(span.getText());
	}

	@Override
	public long getEmbeddingVectorSize() {
		return length;
	}

  /**
   * Encodes the word. Returns nullvector if word was not found.
   * @param word
   * @return 
   */
	@Override
	public INDArray encode(String word) {
    INDArray sum = Nd4j.zeros(getEmbeddingVectorSize(), 1);
    int len = 0;
    for(String w : WordHelpers.splitSpaces(word)) {
      if(w.trim().isEmpty()) continue;
      INDArray arr = vec.getWordVectorMatrix(preprocessor.preProcess(w));
      if(arr != null) sum.addi(arr.transpose());
      len++;
    }
    return len == 0 ? sum : sum.div(len);
	}

	public Collection getNearestNeighbours(String word, int k) {
		return vec.wordsNearest(preprocessor.preProcess(word), k);
	}

	public Collection getNearestNeighbours(INDArray v, int k) {
		Counter distances = new Counter<>();
    for(Object s : vec.vocab().words()) {
			String word = (String) s;
			INDArray otherVec = encode(word);
			double sim = Transforms.cosineSim(v, otherVec);
			distances.incrementCount(word, sim);
		}
		distances.keepTopNElements(k);
		return distances.keySetSorted();
	}

	public String getNearestNeighbour(INDArray v) {
		Collection result = getNearestNeighbours(v, 1);
    if(result.isEmpty()) return "_";
    else return result.iterator().next();
  }

  /**
   * Writes the model to DATEXIS binary format
   * @param vec
   * @param outputStream 
   */
  private static void writeBinaryModel(WordVectors vec, OutputStream outputStream) throws IOException {
    
    int words = 0;
    
    try(BufferedOutputStream buf = new BufferedOutputStream(outputStream);
         DataOutputStream writer = new DataOutputStream(buf)) {
      for(Object word : vec.vocab().words()) {
        if(word == null) continue;
        INDArray wordVector = vec.getWordVectorMatrix((String) word);
        log.trace("Write: " + word + " (size " + wordVector.length() + ")");
        writer.writeUTF((String) word);
        Nd4j.write(wordVector, writer);
        words++;
      }
      writer.flush();
    }
    
    log.info("Wrote " + words + " words with size " + vec.vectorSize());
    
  }
  
  /**
   * Loads the model from DATEXIS bindary format
   * @param stream
   * @return 
   */
  private static WordVectors loadBinaryModel(InputStream stream) throws IOException {
    
    AbstractCache cache = new AbstractCache.Builder().build();
    List arrays = new ArrayList<>();
    int words = 0;
    
    try(BufferedInputStream buf = new BufferedInputStream(stream); 
         DataInputStream reader = new DataInputStream(buf)) {
      //for(String word = reader.readUTF(); !word.equals("_aZ92_EOF");) {
      while(reader.available() > 0) {
        String word = reader.readUTF();
        INDArray row = Nd4j.read(reader);
        VocabWord word1 = new VocabWord(1.0, word);
        word1.setIndex(cache.numWords());
        cache.addToken(word1);
        cache.addWordToIndex(word1.getIndex(), word);
        cache.putVocabWord(word);
        arrays.add(row);
        words++;
      }
      
    }

    InMemoryLookupTable lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder()
            .vectorLength(arrays.get(0).columns())
            .cache(cache)
            .build();

    INDArray syn = Nd4j.vstack(arrays);

    Nd4j.clearNans(syn);
    lookupTable.setSyn0(syn);

    return fromPair(Pair.makePair((InMemoryLookupTable) lookupTable, (VocabCache) cache));
    
  }
  
 /**
  * A Simple String Iterator used for Word2Vec Training
  * @author sarnold
  */
 public class SentenceStringIterator implements SentenceIterator {

   private Iterator it;
   Iterable sentences;
   private SentencePreProcessor spp;

   public SentenceStringIterator(Iterable sentences) {
     this.sentences = sentences;
     reset();
   }

   @Override
   public String nextSentence() {
     return it.next().getText();
   }

   @Override
   public boolean hasNext() {
     return it.hasNext();
   }

   @Override
   public void reset() {
     it = sentences.iterator();
   }

   @Override
   public void finish() {
     it.remove();
   }

   @Override
   public SentencePreProcessor getPreProcessor() {
     return this.spp;
   }

   @Override
   public void setPreProcessor(SentencePreProcessor spp) {
     this.spp = spp;
   }

 }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy