All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.datexis.encoder.Encoder Maven / Gradle / Ivy

package de.datexis.encoder;

import com.google.common.collect.Lists;
import de.datexis.annotator.AnnotatorComponent;
import de.datexis.annotator.IComponent;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnore;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * An Encoder converts text (Span) to embedding vectors (INDArray).
 * E.g. word embedding
 * @author Sebastian Arnold 
 */
public abstract class Encoder extends AnnotatorComponent implements IEncoder, IComponent {

  protected boolean enableCache = false;
  
  public Encoder() {
    this("");
  }

  public Encoder(String id) {
    super(false);
    this.id = id;
  }
  
  @JsonIgnore
  public boolean isCachingEnabled() {
    return enableCache;
  }
  
  public void setCachingEnabled(boolean enableCache) {
    this.enableCache = enableCache;
  }
  
  /**
   * Encode a fixed-size vector from multiple Spans
   * @param spans the Spans to encode
   * @return INDArray containing all Tokens combined
   */
  public INDArray encode(Iterable spans) {
    INDArray avg = Nd4j.create(getEmbeddingVectorSize(), 1);
    INDArray vec;
    int i = 0;
    for(Span s : spans) {
      vec = encode(s.getText());
      if(vec != null) {
        avg.addi(vec);
        i++;
      }
    }
    return avg.divi(i);
  }

  /**
   * Encodes each element in the input and attaches the vectors to the element.
   * Please override this if the elements of your encoders are not independent or stateful.
   * @param input - the Document that should be encoded
   * @param elementClass - the class of sub elements in the Document, e.g. Sentence.class
   */
  public void encodeEach(Document input, Class elementClass) {
    if(elementClass == Token.class) input.streamTokens().forEach(t -> t.putVector(this.getClass(), encode(t)));
    else if(elementClass == Sentence.class) input.streamSentences().forEach(s -> s.putVector(this.getClass(), encode(s)));
    else throw new IllegalArgumentException("Cannot encode class " + elementClass.toString() + " from Document");
  }

  /**
   * Encodes each element in the input and returns these vectors as matrix.
   * Please override this if the elements of your encoders are not independent or stateful.
   *  @param input - the Document that should be encoded
   * @param timeStepClass - the class of sub elements in the Document, e.g. Sentence.class
   */
  public INDArray encodeMatrix(List input, int maxTimeSteps, Class timeStepClass) {

    INDArray encoding = EncodingHelpers.createTimeStepMatrix(input.size(), getEmbeddingVectorSize(), maxTimeSteps);
    Document example;

    for(int batchIndex = 0; batchIndex < input.size(); batchIndex++) {
      
      example = input.get(batchIndex);

      List spansToEncode = Collections.EMPTY_LIST;
      if(timeStepClass == Token.class) spansToEncode = Lists.newArrayList(example.getTokens());
      else if(timeStepClass == Sentence.class) spansToEncode = Lists.newArrayList(example.getSentences());

      for(int t = 0; t < spansToEncode.size() && t < maxTimeSteps; t++) {
        Span span = spansToEncode.get(t);
        INDArray vec;
        if(isCachingEnabled() && span.hasVector(this.getClass())) {
          // use cached vector
          vec = span.getVector(this.getClass());
        } else {
          vec = encode(span);
          if(isCachingEnabled()) span.putVector(this.getClass(), vec);
        }
        EncodingHelpers.putTimeStep(encoding, batchIndex, t, vec);
      }
      
    }
    return encoding;
  }

  /**
   * Encodes each element in the input and attaches the vectors to the element.
   * Please override this if the elements of your encoders are not independent or stateful.
   * Please override this if your encoder allows batches.
   * @param docs - the Documents that should be encoded
   * @param elementClass - the class of sub elements in the Document, e.g. Sentence.class
   */
  public void encodeEach(Collection docs, Class elementClass) {
    for(Document doc : docs) {
      encodeEach(doc, elementClass);
    }
  }
  
  /**
   * Encodes each element in the input and attaches the vectors to the element.
   * Please override this if the elements of your encoders are not independent or stateful.
   * @param input - the Sentence that should be encoded
   * @param elementClass - the class of sub elements in the Sentence, e.g. Token.class
   */
  public void encodeEach(Sentence input, Class elementClass) {
    if(elementClass == Token.class) input.streamTokens().forEach(t -> t.putVector(this.getClass(), encode(t)));
    else throw new IllegalArgumentException("Cannot encode class " + elementClass.toString() + " from Sentence");
  }
    
  public abstract void trainModel(Collection documents);
  
  public void trainModel(Stream documents) {
    trainModel(documents.collect(Collectors.toList()));
  }
  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy