All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.datexis.encoder.EncoderSet Maven / Gradle / Ivy

package de.datexis.encoder;

import de.datexis.model.Span;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.LoggerFactory;

/**
 * A set of Encoders for vectors that will be concatenated as input
 * @author sarnold
 */
public class EncoderSet implements Iterable, IEncoder {
  
  protected static final org.slf4j.Logger log = LoggerFactory.getLogger(EncoderSet.class);
  
  protected List encoders;
  protected int size;
  
  public EncoderSet(Encoder... encoders) {
    this.encoders = new ArrayList<>(encoders.length);
    this.size = 0;
    for(Encoder enc : encoders) {
      addEncoder(enc);
    }
  }
  
  public final void addEncoder(Encoder e) {
    encoders.add(e);
    if(e.getEmbeddingVectorSize() == 0) log.warn("Adding uninitialized Encoder " + e.getName());
    this.size += e.getEmbeddingVectorSize();
  }
  
  /**
   * Recalculates vector size in case one Encoder has changed
   */
  public void updateVectorSize() {
    this.size = 0;
    for(Encoder enc : this.encoders) {
      this.size += enc.getEmbeddingVectorSize();
    }
  }
  
  @Override
  public long getEmbeddingVectorSize() {
    return size;
  }
  
  public Iterable iterable() {
    return encoders;
  }

  @Override
  public Iterator iterator() {
    return encoders.iterator();
  }
  
  /**
   * Encodes a given String using all Encoders. Does not save the intermediate results to the Tokens.
   */
  public INDArray encode(String word) {
    INDArray result = Nd4j.create(getEmbeddingVectorSize());
    int i = 0;
    for(Encoder enc : encoders) {
      final INDArray vec = enc.encode(word);
      result.get(NDArrayIndex.interval(i, i + enc.getEmbeddingVectorSize())).assign(vec);
      i += enc.getEmbeddingVectorSize();
    }
    return result;
  }
  
  public INDArray encode(Iterable spans) {
    INDArray result = Nd4j.create(getEmbeddingVectorSize());
    int i = 0;
    for(Encoder enc : encoders) {
      final INDArray vec = enc.encode(spans);
      result.get(NDArrayIndex.interval(i, i + enc.getEmbeddingVectorSize())).assign(vec);
      i += enc.getEmbeddingVectorSize();
    }
    return result;
  }
  
  @Override
  public INDArray encode(Span span) {
    return encode(span.toString());
  }
  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy