All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.datexis.retrieval.encoder.LSTMSentenceAnnotator Maven / Gradle / Ivy

package de.datexis.retrieval.encoder;

import de.datexis.annotator.Annotator;
import de.datexis.annotator.AnnotatorComponent;
import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.IEncoder;
import de.datexis.retrieval.tagger.LSTMSentenceTagger;
import de.datexis.retrieval.tagger.ModelBuilder;
import de.datexis.tagger.Tagger;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * This Annotator capsules a Sentence Embedding
 * @author Sebastian Arnold 
 */
public class LSTMSentenceAnnotator extends Annotator {
  
  protected final Logger log = LoggerFactory.getLogger(getClass());
  
  /** used by XML deserialization */
  public LSTMSentenceAnnotator() {
    super();
  }
  
  public LSTMSentenceAnnotator(Tagger root) {
    super(root);
  }
  
  protected LSTMSentenceAnnotator(AnnotatorComponent comp) {
    super(comp);
  }
  
  @Override
  public LSTMSentenceTagger getTagger() {
    return (LSTMSentenceTagger) super.getTagger();
  }
  
  public void trainModel(Resource trainingSentences) {
    getTagger().trainModel(trainingSentences);
  }
  
  public LSTMSentenceEncoder asEncoder() {
    return new LSTMSentenceEncoder(getTagger());
  }
  
  public static class Builder {
    
    LSTMSentenceAnnotator ann;
    LSTMSentenceTagger tagger;
    
    protected ILossFunction lossFunc = LossFunctions.LossFunction.MCXENT.getILossFunction();
    protected Activation activation = Activation.SOFTMAX;
    IEncoder inputEncoder, targetEncoder;
    
    private int examplesPerEpoch = -1;
    private int maxTimeSeriesLength = -1;
    private int lstmLayerSize = 256;
    private int embeddingLayerSize = 128;
    private double learningRate = 0.01;
    private double dropOut = 0.5;
    private int iterations = 1;
    private int batchSize = 16; // number of Examples until Sample/Test
    private int numEpochs = 1;
    
    private boolean enabletrainingUI = false;
    
    public Builder() {
      tagger = new LSTMSentenceTagger();
      ann = new LSTMSentenceAnnotator(tagger);
    }
    
    public Builder withId(String id) {
      this.tagger.setId(id);
      return this;
    }
  
    public Builder withDataset(String datasetName, WordHelpers.Language lang) {
      ann.getProvenance().setDataset(datasetName);
      ann.getProvenance().setLanguage(lang.toString().toLowerCase());
      return this;
    }
    
    public Builder withLossFunction(LossFunctions.LossFunction lossFunc, Activation activation) {
      this.lossFunc = lossFunc.getILossFunction();
      this.activation = activation;
      return this;
    }
    
    public Builder withLossFunction(ILossFunction lossFunc, Activation activation) {
      this.lossFunc = lossFunc;
      this.activation = activation;
      return this;
    }
    
    public Builder withModelParams(int lstmLayerSize, int embeddingLayerSize) {
      this.lstmLayerSize = lstmLayerSize;
      this.embeddingLayerSize = embeddingLayerSize;
      return this;
    }
    
    public Builder withTrainingParams(double learningRate, double dropOut, int examplesPerEpoch, int batchSize, int numEpochs) {
      this.learningRate = learningRate;
      this.dropOut = dropOut;
      this.examplesPerEpoch = examplesPerEpoch;
      this.batchSize = batchSize;
      this.numEpochs = numEpochs;
      return this;
    }
    
    public Builder withTrainingParams(double learningRate, double dropOut, int examplesPerEpoch, int maxTimeSeriesLength, int batchSize, int numEpochs) {
      this.learningRate = learningRate;
      this.dropOut = dropOut;
      this.examplesPerEpoch = examplesPerEpoch;
      this.batchSize = batchSize;
      this.maxTimeSeriesLength = maxTimeSeriesLength;
      this.numEpochs = numEpochs;
      return this;
    }
    
    public Builder withInputEncoders(String desc, Encoder inputEncoder) {
      this.inputEncoder = inputEncoder;
      tagger.setInputEncoders(inputEncoder);
      ann.getProvenance().setFeatures(desc);
      ann.addComponent(inputEncoder);
      return this;
    }
    
    public Builder withTargetEncoder(Encoder targetEncoder) {
      this.targetEncoder = targetEncoder;
      tagger.setTargetEncoder(targetEncoder);
      ann.addComponent(targetEncoder);
      return this;
    }
    
    public Builder enableTrainingUI(boolean enable) {
      this.enabletrainingUI = enable;
      return this;
    }
    
    public LSTMSentenceAnnotator build() {
      tagger.initializeNetwork(ModelBuilder.buildLSTMSentenceTagger(
        inputEncoder.getEmbeddingVectorSize(),
        lstmLayerSize,
        embeddingLayerSize,
        targetEncoder.getEmbeddingVectorSize(),
        iterations, learningRate, dropOut, lossFunc, activation)
      );
      if(enabletrainingUI) tagger.enableTrainingUI();
      tagger.setEmbeddingLayerSize(embeddingLayerSize);
      tagger.setTrainingParams(examplesPerEpoch, maxTimeSeriesLength, batchSize, numEpochs, true);
      ann.getProvenance().setTask(tagger.getId());
      tagger.setName(ann.getProvenance().toString());
      tagger.appendTrainLog(printParams());
      return ann;
    }
    
    private String printParams() {
      StringBuilder line = new StringBuilder();
      line.append("TRAINING PARAMS: ").append(tagger.getName()).append("\n");
      line.append("\nEncoders:\n");
      for(Encoder e : tagger.getEncoders()) {
        line.append(e.getId()).append("\t").append(e.getClass().getSimpleName()).append("\t").append(e.getEmbeddingVectorSize()).append("\n");
      }
      line.append("\nNetwork Params:\n");
      line.append("BLSTM").append("\t").append(lstmLayerSize).append("\n");
      line.append("EMB").append("\t").append(embeddingLayerSize).append("\n");
      line.append("\nTraining Params:\n");
      line.append("examples per epoch").append("\t").append(examplesPerEpoch).append("\n");
      line.append("max time series length").append("\t").append(maxTimeSeriesLength).append("\n");
      line.append("epochs").append("\t").append(numEpochs).append("\n");
      line.append("iterations").append("\t").append(iterations).append("\n");
      line.append("batch size").append("\t").append(batchSize).append("\n");
      line.append("learning rate").append("\t").append(learningRate).append("\n");
      line.append("dropout").append("\t").append(dropOut).append("\n");
      line.append("loss").append("\t").append(lossFunc.toString()).append("\n");
      line.append("\n");
      return line.toString();
    }
    
  }
  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy