All Downloads are FREE. Search and download functionalities are using the official Maven repository.

opennlp.dl.namefinder.NameFinderDL Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package opennlp.dl.namefinder;

import java.io.File;
import java.io.IOException;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;

import opennlp.dl.AbstractDL;
import opennlp.dl.InferenceOptions;
import opennlp.dl.SpanEnd;
import opennlp.dl.Tokens;
import opennlp.tools.namefind.TokenNameFinder;
import opennlp.tools.sentdetect.SentenceDetector;
import opennlp.tools.tokenize.WordpieceTokenizer;
import opennlp.tools.util.Span;

/**
 * An implementation of {@link TokenNameFinder} that uses ONNX models.
 *
 * @see TokenNameFinder
 * @see InferenceOptions
 */
public class NameFinderDL extends AbstractDL implements TokenNameFinder {

  public static final String I_PER = "I-PER";
  public static final String B_PER = "B-PER";
  public static final String SEPARATOR = "[SEP]";

  private static final String CHARS_TO_REPLACE = "##";

  private final SentenceDetector sentenceDetector;
  private final Map ids2Labels;
  private final InferenceOptions inferenceOptions;

  /**
   * Instantiates a {@link TokenNameFinder name finder} using ONNX models.
   * 
   * @param model The ONNX model file.
   * @param vocabulary The model file's vocabulary file.
   * @param ids2Labels The mapping of ids to labels.
   * @param sentenceDetector The {@link SentenceDetector} to be used.
   *
   * @throws OrtException Thrown if the {@code model} cannot be loaded.
   * @throws IOException Thrown if errors occurred loading the {@code model} or {@code vocabulary}.
   */
  public NameFinderDL(File model, File vocabulary, Map ids2Labels,
                      SentenceDetector sentenceDetector) throws IOException, OrtException {

    this(model, vocabulary, ids2Labels, new InferenceOptions(), sentenceDetector);

  }

  /**
   * Instantiates a {@link TokenNameFinder name finder} using ONNX models.
   *
   * @param model The ONNX model file.
   * @param vocabulary The model file's vocabulary file.
   * @param ids2Labels The mapping of ids to labels.
   * @param inferenceOptions {@link InferenceOptions} to control the inference.
   * @param sentenceDetector The {@link SentenceDetector} to be used.
   *
   * @throws OrtException Thrown if the {@code model} cannot be loaded.
   * @throws IOException Thrown if errors occurred loading the {@code model} or {@code vocabulary}.
   */
  public NameFinderDL(File model, File vocabulary, Map ids2Labels,
                      InferenceOptions inferenceOptions,
                      SentenceDetector sentenceDetector) throws IOException, OrtException {

    this.env = OrtEnvironment.getEnvironment();

    final OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
    if (inferenceOptions.isGpu()) {
      sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
    }

    this.session = env.createSession(model.getPath(), sessionOptions);
    this.ids2Labels = ids2Labels;
    this.vocab = loadVocab(vocabulary);
    this.tokenizer = new WordpieceTokenizer(vocab.keySet());
    this.inferenceOptions = inferenceOptions;
    this.sentenceDetector = sentenceDetector;

  }

  @Override
  public Span[] find(String[] input) {

    final List spans = new LinkedList<>();

    // Join the tokens here because they will be tokenized using Wordpiece during inference.
    final String text = String.join(" ", input);

    final String[] sentences = sentenceDetector.sentDetect(text);

    for (String sentence : sentences) {

      // The WordPiece tokenized text. This changes the spacing in the text.
      final List wordpieceTokens = tokenize(sentence);

      for (final Tokens tokens : wordpieceTokens) {

        try {

          // The inputs to the ONNX model.
          final Map inputs = new HashMap<>();
          inputs.put(INPUT_IDS, OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.ids()),
              new long[] {1, tokens.ids().length}));

          if (inferenceOptions.isIncludeAttentionMask()) {
            inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
                LongBuffer.wrap(tokens.mask()), new long[] {1, tokens.mask().length}));
          }

          if (inferenceOptions.isIncludeTokenTypeIds()) {
            inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
                LongBuffer.wrap(tokens.types()), new long[] {1, tokens.types().length}));
          }

          // The outputs from the model.
          final float[][][] v = (float[][][]) session.run(inputs).get(0).getValue();

          // Find consecutive B-PER and I-PER labels and combine the spans where necessary.
          // There are also B-LOC and I-LOC tags for locations that might be useful at some point.

          // Keep track of where the last span was so when there are multiple/duplicate
          // spans we can get the next one instead of the first one each time.
          int characterStart = 0;

          final String[] toks = tokens.tokens();

          // We are looping over the vector for each word,
          // finding the index of the array that has the maximum value,
          // and then finding the token classification that corresponds to that index.
          for (int x = 0; x < v[0].length; x++) {

            final float[] arr = v[0][x];
            final int maxIndex = maxIndex(arr);
            final String label = ids2Labels.get(maxIndex);

            // TODO: Need to make sure this value is between 0 and 1?
            // Can we do thresholding without it between 0 and 1?
            final double confidence = arr[maxIndex]; // / 10;

            // Is this is the start of a person entity.
            if (B_PER.equals(label)) {

              String spanText;

              // Find the end index of the span in the array (where the label is not I-PER).
              final SpanEnd spanEnd = findSpanEnd(v, x, ids2Labels, toks);

              // If the end is -1 it means this is a single-span token.
              // If the end is != -1 it means this is a multi-span token.
              if (spanEnd.index() != -1) {

                final StringBuilder sb = new StringBuilder();

                // We have to concatenate the tokens.
                // Add each token in the array and separate them with a space.
                // We'll separate each with a single space because later we'll find the original span
                // in the text and ignore spacing between individual tokens in findByRegex().
                int end = spanEnd.index();
                for (int i = x; i <= end; i++) {

                  // If the next token starts with ##, combine it with this token.
                  if (toks[i + 1].startsWith(CHARS_TO_REPLACE)) {

                    sb.append(toks[i]).append(toks[i + 1].replace(CHARS_TO_REPLACE, ""));

                    // Append a space unless the next (next) token starts with ##.
                    if (!toks[i + 2].startsWith(CHARS_TO_REPLACE)) {
                      sb.append(" ");
                    }

                    // Skip the next token since we just included it in this iteration.
                    i++;

                  } else {

                    sb.append(toks[i].replace(CHARS_TO_REPLACE, ""));

                    // Append a space unless the next token is a period.
                    if (!".".equals(toks[i + 1])) {
                      sb.append(" ");
                    }

                  }

                }

                // This is the text of the span. We use the whole original input text and not one
                // of the splits. This gives us accurate character positions.
                spanText = findByRegex(text, sb.toString().trim()).trim();

              } else {

                // This is a single-token span so there is nothing else to do except grab the token.
                spanText = toks[x];

              }

              if (!SEPARATOR.equals(spanText)) {

                spanText = spanText.replace(CHARS_TO_REPLACE, "");

                // This ignores other potential matches in the same sentence
                // by only taking the first occurrence.
                characterStart = text.indexOf(spanText, characterStart);

                // TODO: This check should not be needed because the span was found.
                // If we aren't finding it now it's because there's a whitespace difference.
                if (characterStart != -1) {

                  final int characterEnd = characterStart + spanText.length();

                  spans.add(new Span(characterStart, characterEnd, spanText, confidence));

                  // OP-1: Only increment characterStart by one.
                  characterStart++;

                }

              }

            }

          }

        } catch (OrtException ex) {
          throw new RuntimeException("Error performing namefinder inference: " + ex.getMessage(), ex);
        }

      }

    }

    return spans.toArray(new Span[0]);

  }

  @Override
  public void clearAdaptiveData() {
    // No use in this implementation.
  }

  private SpanEnd findSpanEnd(float[][][] v, int startIndex, Map id2Labels,
                              String[] tokens) {

    // -1 means there is no follow-up token, so it is a single-token span.
    int index = -1;
    int characterEnd = 0;

    // Starts at the span start in the vector.
    // Looks at the next token to see if it is an I-PER.
    // Go until the next token is something other than I-PER.
    // When the next token is not I-PER, return the previous index.

    for (int x = startIndex + 1; x < v[0].length; x++) {

      // Get the next item.
      final float[] arr = v[0][x];

      // See if the next token has an I-PER label.
      final String nextTokenClassification = id2Labels.get(maxIndex(arr));

      if (!I_PER.equals(nextTokenClassification)) {
        index = x - 1;
        break;
      }

    }

    // Find where the span ends based on the tokens.
    for (int x = 1; x <= index && x < tokens.length; x++) {
      characterEnd += tokens[x].length();
    }

    // Account for the number of spaces (that is the number of tokens).
    // (One space per token.)
    characterEnd += index - 1;

    return new SpanEnd(index, characterEnd);

  }

  private int maxIndex(float[] arr) {

    double max = Float.NEGATIVE_INFINITY;
    int index = -1;

    for (int x = 0; x < arr.length; x++) {
      if (arr[x] > max) {
        index = x;
        max = arr[x];
      }
    }

    return index;

  }

  private static String findByRegex(String text, String span) {

    final String regex = span
        .replaceAll(" ", "\\\\s+")
        .replaceAll("\\)", "\\\\)")
        .replaceAll("\\(", "\\\\(");

    final Pattern pattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE);
    final Matcher matcher = pattern.matcher(text);

    if (matcher.find()) {
      return matcher.group(0);
    }

    // For some reason the regex match wasn't found. Just return the original span.
    return span;

  }

  private List tokenize(final String text) {

    final List t = new LinkedList<>();

    // In this article as the paper suggests, we are going to segment the input into smaller text and feed
    // each of them into BERT, it means for each row, we will split the text in order to have some
    // smaller text (200 words long each)
    // https://medium.com/analytics-vidhya/text-classification-with-bert-using-transformers-for-long-text-inputs-f54833994dfd

    // Split the input text into 200 word chunks with 50 overlapping between chunks.
    final String[] whitespaceTokenized = text.split("\\s+");

    for (int start = 0; start < whitespaceTokenized.length;
         start = start + inferenceOptions.getDocumentSplitSize()) {

      // 200 word length chunk
      // Check the end do don't go past and get a StringIndexOutOfBoundsException
      int end = start + inferenceOptions.getDocumentSplitSize();
      if (end > whitespaceTokenized.length) {
        end = whitespaceTokenized.length;
      }

      // The group is that subsection of string.
      final String group = String.join(" ", Arrays.copyOfRange(whitespaceTokenized, start, end));

      // We want to overlap each chunk by 50 words so scoot back 50 words for the next iteration.
      start = start - inferenceOptions.getSplitOverlapSize();

      // Now we can tokenize the group and continue.
      final String[] tokens = tokenizer.tokenize(group);

      final int[] ids = new int[tokens.length];

      for (int x = 0; x < tokens.length; x++) {
        ids[x] = vocab.get(tokens[x]);
      }

      final long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();

      final long[] mask = new long[ids.length];
      Arrays.fill(mask, 1);

      final long[] types = new long[ids.length];
      Arrays.fill(types, 0);

      t.add(new Tokens(tokens, lids, mask, types));

    }

    return t;

  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy