All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.idylnlp.nlp.recognizer.DeepLearningEntityRecognizer Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright 2018 Mountain Fog, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
 * License for the specific language governing permissions and limitations under
 * the License.
 ******************************************************************************/
package ai.idylnlp.nlp.recognizer;

import java.io.File;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;

import ai.idylnlp.model.entity.Entity;
import ai.idylnlp.model.exceptions.EntityFinderException;
import ai.idylnlp.model.exceptions.ModelLoaderException;
import ai.idylnlp.model.manifest.SecondGenModelManifest;
import ai.idylnlp.model.nlp.AbstractEntityRecognizer;
import ai.idylnlp.model.nlp.SentenceSanitizer;
import ai.idylnlp.model.nlp.ner.EntityExtractionRequest;
import ai.idylnlp.model.nlp.ner.EntityExtractionResponse;
import ai.idylnlp.model.nlp.ner.EntityRecognizer;
import ai.idylnlp.nlp.sentence.sanitizers.DefaultSentenceSanitizer;
import com.neovisionaries.i18n.LanguageCode;

import ai.idylnlp.nlp.recognizer.configuration.DeepLearningEntityRecognizerConfiguration;
import ai.idylnlp.nlp.recognizer.deep.DeepLearningTokenNameFinder;
import opennlp.tools.namefind.TokenNameFinder;

/**
 * An {@link EntityRecognizer} that is powered by the
 * deeplearning4j framework. It uses a neural network
 * to perform entity extraction.
 *
 * @author Mountain Fog, Inc.
 *
 */
public class DeepLearningEntityRecognizer extends AbstractEntityRecognizer implements EntityRecognizer {

  private static final Logger LOGGER = LogManager.getLogger(DeepLearningEntityRecognizer.class);

  // Language -> (Type, [Network, Vectors])
  private Map>> loadedModels;

  public DeepLearningEntityRecognizer(DeepLearningEntityRecognizerConfiguration configuration) {

    super(configuration);

    loadedModels = new HashMap>>();

    for(String type : configuration.getEntityModels().keySet()) {

      Map> types = configuration.getEntityModels().get(type);

      for(LanguageCode language : types.keySet()) {

        for(SecondGenModelManifest modelManifest : types.get(language)) {

          if(!configuration.getBlacklistedModelIDs().contains(modelManifest.getModelId())) {

            try {

              final String modelFileName = new File(configuration.getEntityModelDirectory(), modelManifest.getModelFileName()).getAbsolutePath();

              // Load the network from the model file.
              LOGGER.info("Loading {} {} model from file: {}", language.getAlpha3().toString(), type, modelFileName);

              final File modelFile = new File(modelFileName);

              final MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(modelFile.getAbsolutePath());

              final String vectorsFileName = new File(configuration.getEntityModelDirectory(), modelManifest.getVectorsFileName()).getAbsolutePath();

              // Verify the vectors file exists.
              final File vectorsFile = new File(vectorsFileName);

              // Load the word vectors from the file.
              LOGGER.info("Loading vectors from file: {}", vectorsFileName);
              final WordVectors wordVectors = WordVectorSerializer.loadStaticModel(vectorsFile);

              final Map> t = new HashMap<>();
              t.put(type, new ImmutablePair(multiLayerNetwork, wordVectors));

              loadedModels.put(language, t);

            } catch (Exception ex) {

              LOGGER.error("Unable to load model: " + modelManifest.getModelFileName(), ex);

              getConfiguration().getBlacklistedModelIDs().add(modelManifest.getModelId());
              LOGGER.warn("Model {} is blacklisted. Loading will not be attempted until restart.", modelManifest.getModelFileName());

              // TODO: This should probably be made visible to the user somehow - maybe through the API?

            }

          } else {

            LOGGER.info("Model {} is blacklisted. Loading will not be attempted until restart.", modelManifest.getModelFileName());

          }

        }

      }

    }

  }

  @Override
  public EntityExtractionResponse extractEntities(EntityExtractionRequest request)
      throws EntityFinderException, ModelLoaderException {

    if(request.getText().length == 0) {

      throw new IllegalArgumentException("Input text cannot be empty.");

    }

    if(request.getConfidenceThreshold() < 0 || request.getConfidenceThreshold() > 100) {

      throw new IllegalArgumentException("Confidence threshold must be an integer between 0 and 100.");

    }

    SentenceSanitizer sentenceSanitizer = new DefaultSentenceSanitizer.Builder().lowerCase().removePunctuation().consolidateSpaces().build();

    // All of the extracted entities.
    Set entities = new LinkedHashSet();

    // Keep track of the extraction time.
    long startTime = System.currentTimeMillis();

    String types[] = {};

    if(!StringUtils.isEmpty(request.getType())) {
      types = request.getType().split(",");
    }

    for(String type : getConfiguration().getEntityModels().keySet()) {

      if(types.length == 0 || ArrayUtils.contains(types, type)) {

        LOGGER.trace("Processing entity class {}.", type);

        LanguageCode language = request.getLanguage();

        // The manifests of the models that will be used for this extraction.
        Set modelManifests = new HashSet();

        if(request.getLanguage() == null) {

          // TODO: Run all languages to support multilingual documents.
          Set languages = getConfiguration().getEntityModels().get(type).keySet();

          for(LanguageCode l : languages) {
            modelManifests.addAll(getConfiguration().getEntityModels().get(type).get(l));
          }

        } else {

          // We are doing a single language.

          Map> models = getConfiguration().getEntityModels().get(type);

          // If there are not any models for this entity type models will be null.
          if(models != null) {

            Set manifests = models.get(language);

            // If manifests is not null add those manifests to the set.
            if(manifests != null) {

              modelManifests.addAll(manifests);

            }

          }

        }

        if(CollectionUtils.isNotEmpty(modelManifests)) {

          for(SecondGenModelManifest modelManifest : modelManifests) {

            LOGGER.debug("{} has {} entity models.", type, modelManifests.size());

            String t = modelManifest.getType();

            // Get the network and word vectors for this language.
            LOGGER.info("Getting model for type {}, language {}", modelManifest.getLanguageCode().getAlpha3().toString(), t);

            ImmutablePair pair = loadedModels.get(modelManifest.getLanguageCode()).get(t);

            MultiLayerNetwork multiLayerNetwork = pair.getLeft();
            WordVectors wordVectors = pair.getRight();

            // Get the nameFinder for this model if it exists.
            TokenNameFinder nameFinder = nameFinders.get(modelManifest);

            if(nameFinder == null) {

              // Create a new namefinder and put it in the map.
              nameFinder = new DeepLearningTokenNameFinder(multiLayerNetwork, wordVectors,
                  modelManifest.getWindowSize(), getLabels(request.getType()));

              nameFinders.put(modelManifest, nameFinder);

            }

            Collection extractedEntities = findEntities(nameFinder, request, modelManifest, sentenceSanitizer);

            entities.addAll(extractedEntities);

          }

        } else {

          LOGGER.warn("No entity models available for language {}.", language.getAlpha3().toString());

        }

      }

    }

    long extractionTime = (System.currentTimeMillis() - startTime);

    // Create the response with the extracted entities and the time it took to extract them.
    EntityExtractionResponse response = new EntityExtractionResponse(entities, extractionTime, true);

    return response;

  }

  private String[] getLabels(String entityType) {

    return new String[] { entityType + "-start", entityType + "-cont", "other" };

  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy