All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.datexis.encoder.bert.BertRESTEncoder Maven / Gradle / Ivy

There is a newer version: 1.3.3
Show newest version
package de.datexis.encoder.bert;

import com.google.gson.Gson;
import de.datexis.encoder.AbstractRESTEncoder;
import de.datexis.encoder.EncodingHelpers;
import de.datexis.encoder.RESTAdapter;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.*;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.stream.Collectors;

public class BertRESTEncoder extends AbstractRESTEncoder {

  protected BertRESTEncoder() {
    super("BERT");
  };
  
  public BertRESTEncoder(RESTAdapter restAdapter) {
    super("BERT", restAdapter);
  }

  public static BertRESTEncoder create(String domain, int port, int embeddingDimension) {
    return new BertRESTEncoder(new BertRESTAdapter(domain, port, embeddingDimension));
  }
  
  @Override
  public INDArray encode(String word) {
    throw new UnsupportedOperationException("BERT cannotbe used to encode single words");
  }
  
  @Override
  public INDArray encode(Span span) {
    throw new UnsupportedOperationException("please use encodeMatrix()");
  }
  
  @Override
  public INDArray encodeMatrix(List input, int maxTimeSteps, Class timeStepClass) {
    // check if we have all the sentence vectors already
    if(isCachingEnabled() && timeStepClass.equals(Sentence.class) && input.stream()
      .flatMap(doc -> doc.streamSentences()
        .limit(maxTimeSteps)
        .map(sent -> sent.hasVector(this.getClass())))
      .allMatch(b -> b == true)) {
      // use the implementation that returns cached vectors
      return super.encodeMatrix(input, maxTimeSteps, timeStepClass);
    }
    return encodeDocumentsParallelNoTokenization(input, maxTimeSteps);
  }

  private String documentToRequest(Document d, int id, int maxSequenceLength) {
    List sentences = d.getSentences();
    sentences = sentences.stream().filter(s -> !s.isEmpty()).collect(Collectors.toList());
    String[][] sequences = new String[d.getSentences().size()][];
    for (int i = 0; i < sentences.size(); ++i) {
      String[] sequence = new String[Math.min(sentences.get(i).getTokens().size(), maxSequenceLength)];
      for (int j = 0; j < sentences.get(i).getTokens().size() && j < maxSequenceLength; ++j) {
        sequence[j] = sentences.get(i).getToken(j).getText();
      }
      sequences[i] = sequence;
    }
    Gson gson = new Gson();
    BaaSRequest req = new BaaSRequest();
    req.id = id;
    req.texts = sequences;
    req.is_tokenized = true;
    return gson.toJson(req);
  }

  public INDArray encodeDocumentsParallelNoTokenization(List documents, int maxSequenceLength) {
    INDArray encoding = EncodingHelpers.createTimeStepMatrix((long) documents.size(), this.getEmbeddingVectorSize(), (long) maxSequenceLength);
    List responses = documents.parallelStream()
      .map(d -> {
        try {
          if(d != null && d.getSentences().size() > 0)
            return ((BertRESTAdapter)this.restAdapter).simpleRequestNonTokenized(d, maxSequenceLength);
        } catch (IOException e) {
          e.printStackTrace();
          System.out.println("Error at document: " + d.getId());
        }
        return null;
      }).collect(Collectors.toList());

    int docIndex = 0;
    for (BertNonTokenizedResponse resp : responses) {
      if(resp != null) {
        for(int t = 0; t < resp.result.length && t < maxSequenceLength; ++t) {
          INDArray vec = Transforms.unitVec(Nd4j.create(resp.result[t], new long[]{getEmbeddingVectorSize(), 1}));
          EncodingHelpers.putTimeStep(encoding, (long) docIndex, (long) t, vec);
          if(isCachingEnabled()) documents.get(docIndex).getSentence(t).putVector(this.getClass(), vec);
        }
      }
      docIndex++;
    }
    return encoding;

  }


  public ArrayList encodeDocumentsParallel(Collection documents, int maxSequenceLength, INDArray toFill) throws InterruptedException {
    ArrayList results = new ArrayList<>();
    LinkedBlockingQueue resps = new LinkedBlockingQueue<>();
    ArrayList requests = new ArrayList<>();

    Instant beforeRequestGen = Instant.now();
    int id = 0;
    for (Document doc : documents) {
      requests.add(documentToRequest(doc, id, maxSequenceLength));
      id++;
    }
    Instant afterRequestGen = Instant.now();
    long requestGenDur = Duration.between(beforeRequestGen, afterRequestGen).toMillis();

    Instant beforeRequest = Instant.now();
    List responses = requests.parallelStream().map(req -> {
      try {
        return ((BertRESTAdapter)this.restAdapter).simpleRequest(req);
      } catch (IOException e) {
        e.printStackTrace();
      }
      return null;
    }).sorted((Comparator.comparingInt(resp -> resp.id))).collect(Collectors.toList());

    Instant afterRequest = Instant.now();
    long requestDur = Duration.between(beforeRequest, afterRequest).toMillis();

    // remove first and last element from sequence arrays
    Instant beforeArrayGen = Instant.now();
    int docId = 0;
    for (BertResponse respons : responses) {
      double[][][] result = new double[respons.result.length][][];
      // [sentence][token][certain EmbeddingValue]
      for (int i = 0; i < respons.result.length && i < maxSequenceLength; ++i) {
        result[i] = Arrays.copyOfRange(respons.result[i], 1, respons.result[i].length - 1);
      }
      results.add(result);
      docId++;
    }
    Instant afterArrayGen = Instant.now();
    long arrayGenDur = Duration.between(beforeArrayGen, afterArrayGen).toMillis();
    System.out.println("Request generation: " + requestGenDur + "\n" + "Requests: " + requestDur + "\n" + "Array generation: " + arrayGenDur);
    return results;

  }
  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy