All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.modality.nlp.preprocess.SentenceLengthNormalizer Maven / Gradle / Ivy

/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.modality.nlp.preprocess;

import java.util.ArrayList;
import java.util.List;

/**
 * {@code SentenceLengthNormalizer} normalizes the length of all the input sentences to the
 * specified number of tokens.
 *
 * 

If the number of tokens in the input sentence is higher than the fixed length, the sentence is * truncated to the fixed number. If the number of tokens in the input sentence is fewer than the * fixed sentence length, padding tokens are inserted to make the length equal to the sentence * length. */ public class SentenceLengthNormalizer implements TextProcessor { private static final int DEFAULT_SENTENCE_LENGTH = 10; private static final String DEFAULT_PADDING_TOKEN = ""; private static final String DEFAULT_EOS_TOKEN = ""; private static final String DEFAULT_BOS_TOKEN = ""; private int sentenceLength; private boolean addEosBosTokens; private String paddingToken; private String eosToken; private String bosToken; private int lastValidLength = -1; /** Creates a {@link TextProcessor} that normalizes the length of the input. */ public SentenceLengthNormalizer() { this(DEFAULT_SENTENCE_LENGTH, false); } /** * Creates a {@link TextProcessor} that normalizes the length of the input to the given sentence * length. * * @param sentenceLength the sentence length * @param addEosBosTokens whether to add Eos and Bos tokens before normalizing sentence length */ public SentenceLengthNormalizer(int sentenceLength, boolean addEosBosTokens) { this( sentenceLength, addEosBosTokens, DEFAULT_PADDING_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BOS_TOKEN); } /** * Creates a {@link TextProcessor} that normalizes the length of the input to the given sentence * length. * * @param sentenceLength the sentence length * @param addEosBosTokens whether to add Eos and Bos tokens before normalizing sentence length * @param paddingToken the padding token to be used if the number of tokens in the input is less * than sentence length * @param eosToken the end of sentence token * @param bosToken the begining of sentence token */ public SentenceLengthNormalizer( int sentenceLength, boolean addEosBosTokens, String paddingToken, String eosToken, String bosToken) { this.sentenceLength = sentenceLength; this.addEosBosTokens = addEosBosTokens; this.paddingToken = paddingToken; this.eosToken = eosToken; this.bosToken = bosToken; } /** {@inheritDoc} */ @Override public List preprocess(List tokens) { List list = new ArrayList<>(sentenceLength); if (addEosBosTokens) { list.add(bosToken); } list.addAll(tokens); if (addEosBosTokens) { list.add(eosToken); } int size = list.size(); if (sentenceLength < size) { lastValidLength = sentenceLength; if (addEosBosTokens) { list.set(sentenceLength - 1, eosToken); } return list.subList(0, sentenceLength); } lastValidLength = size; for (int i = size; i < sentenceLength; ++i) { list.add(paddingToken); } return list; } /** * Returns the valid length of the sentence that was last served as input to {@link * SentenceLengthNormalizer#preprocess(List)}. If no sentences preprocess before calling this * method, it will -1. * * @return the valid length of the sentence that was last preprocessed */ public int getLastValidLength() { return lastValidLength; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy