All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.modality.nlp.bert.BertFullTokenizer Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.modality.nlp.bert;

import ai.djl.modality.nlp.NlpUtils;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.preprocess.LambdaProcessor;
import ai.djl.modality.nlp.preprocess.LowerCaseConvertor;
import ai.djl.modality.nlp.preprocess.PunctuationSeparator;
import ai.djl.modality.nlp.preprocess.SimpleTokenizer;
import ai.djl.modality.nlp.preprocess.TextCleaner;
import ai.djl.modality.nlp.preprocess.TextProcessor;
import ai.djl.modality.nlp.preprocess.UnicodeNormalizer;

import java.text.Normalizer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
 * BertFullTokenizer runs end to end tokenization of input text
 *
 * 

It will run basic preprocessors to clean the input text and then run {@link * WordpieceTokenizer} to split into word pieces. * *

Reference implementation: Google Research * Bert Tokenizer */ public class BertFullTokenizer extends BertTokenizer { private Vocabulary vocabulary; private List basicBertPreprocessors; private WordpieceTokenizer wordpieceTokenizer; /** * Creates an instance of {@code BertFullTokenizer}. * * @param vocabulary the BERT vocabulary * @param lowerCase whether to convert tokens to lowercase */ public BertFullTokenizer(Vocabulary vocabulary, boolean lowerCase) { this.vocabulary = vocabulary; basicBertPreprocessors = getPreprocessors(lowerCase); wordpieceTokenizer = new WordpieceTokenizer(vocabulary, "[UNK]", 200); } /** * Returns the {@link Vocabulary} used for tokenization. * * @return the {@link Vocabulary} used for tokenization */ public Vocabulary getVocabulary() { return vocabulary; } /** {@inheritDoc} */ @Override public List tokenize(String input) { List tokens = new ArrayList<>(Collections.singletonList(input)); for (TextProcessor processor : basicBertPreprocessors) { tokens = processor.preprocess(tokens); } return wordpieceTokenizer.preprocess(tokens); } /** {@inheritDoc} */ @Override public String buildSentence(List tokens) { return String.join(" ", tokens).replace(" ##", "").trim(); } /** * Get a list of {@link TextProcessor}s to process input text for Bert models. * * @param lowerCase whether to convert input to lowercase * @return List of {@code TextProcessor}s */ public static List getPreprocessors(boolean lowerCase) { List processors = new ArrayList<>(10); processors.add(new TextCleaner(c -> c == 0 || c == 0xfffd || NlpUtils.isControl(c), '\0')); processors.add(new TextCleaner(NlpUtils::isWhiteSpace, ' ')); processors.add(new LambdaProcessor(String::trim)); processors.add(new SimpleTokenizer()); if (lowerCase) { processors.add(new LowerCaseConvertor()); } processors.add(new UnicodeNormalizer(Normalizer.Form.NFD)); processors.add( new TextCleaner(c -> Character.getType(c) == Character.NON_SPACING_MARK, '\0')); processors.add(new PunctuationSeparator()); processors.add(new LambdaProcessor(String::trim)); return processors; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy