All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.modality.nlp.Vocabulary Maven / Gradle / Ivy

/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.modality.nlp;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
 * {@code Vocabulary} is a collection of tokens. The primary purpose of a vocabulary is the map a
 * token to an index.
 */
public class Vocabulary {

    private Map tokens = new ConcurrentHashMap<>();
    private List indexToToken = new ArrayList<>();
    private Set reservedTokens;
    private int minFrequency;
    private String unknownToken;

    /**
     * Create a {@code Vocabulary} object with ta {@link VocabularyBuilder}.
     *
     * @param builder the {@link VocabularyBuilder} to build the vocabulary with
     */
    public Vocabulary(VocabularyBuilder builder) {
        reservedTokens = builder.reservedTokens;
        minFrequency = builder.minFrequency;
        unknownToken = builder.unknownToken;
        reservedTokens.add(unknownToken);
        for (String token : reservedTokens) {
            tokens.put(token, new TokenInfo(Integer.MAX_VALUE));
        }
        for (List sentence : builder.sentences) {
            addAllTokens(sentence);
        }
    }

    private void addToken(String token) {
        if (reservedTokens.contains(token)) {
            return;
        }
        TokenInfo tokenInfo = tokens.getOrDefault(token, new TokenInfo());
        if (++tokenInfo.frequency >= minFrequency) {
            tokenInfo.index = indexToToken.size();
            indexToToken.add(token);
        }
        tokens.put(token, tokenInfo);
    }

    private void addAllTokens(Collection tokens) {
        for (String token : tokens) {
            addToken(token);
        }
    }

    /**
     * Returns whether the given token is a known word.
     *
     * @param token the token
     * @return whether the given token is a known word
     */
    public boolean isKnownToken(String token) {
        return tokens.containsKey(token);
    }

    /**
     * Return a {@link String} used for unseen or rarely-seen tokens.
     *
     * @return the {@link String} used for unseen or rarely-seen tokens
     */
    public String getUnknownToken() {
        return unknownToken;
    }

    /**
     * Returns the token corresponding to the given index.
     *
     * @param index the index
     * @return the token corresponding to the given index
     */
    public String getToken(int index) {
        if (index < 0 || index >= indexToToken.size()) {
            return unknownToken;
        }
        return indexToToken.get(index);
    }

    /**
     * Returns the index of the given token.
     *
     * @param token the token
     * @return the index of the given token
     */
    public long getIndex(String token) {
        if (tokens.containsKey(token)) {
            return tokens.get(token).index;
        }
        return 0;
    }

    /**
     * Returns the size of the {@code Vocabulary}.
     *
     * @return the size of the {@code Vocabulary}
     */
    public int size() {
        return tokens.size();
    }

    /** Builder class that is used to build the {@link Vocabulary}. */
    public static class VocabularyBuilder {
        protected List> sentences = new LinkedList<>();
        protected Set reservedTokens = new HashSet<>();
        protected int minFrequency = 10;
        protected String unknownToken = "";

        /**
         * Sets the optional parameter that specifies the minimum frequency to consider a token to
         * be part of the {@link Vocabulary}. Defaults to 10.
         *
         * @param minFrequency the minimum frequency to consider a token to be part of the {@link
         *     Vocabulary}
         * @return this {@code VocabularyBuilder}
         */
        public VocabularyBuilder optMinFrequency(int minFrequency) {
            this.minFrequency = minFrequency;
            return this;
        }

        /**
         * Sets the optional parameter that specifies the unknown token's string value.
         *
         * @param unknownToken the string value of the unknown token
         * @return this {@code VocabularyBuilder}
         */
        public VocabularyBuilder optUnknownToken(String unknownToken) {
            this.unknownToken = unknownToken;
            return this;
        }

        /**
         * Sets the optional parameter that sets the list of reserved tokens.
         *
         * @param reservedTokens the list of reserved tokens
         * @return this {@code VocabularyBuilder}
         */
        public VocabularyBuilder optReservedTokens(Collection reservedTokens) {
            this.reservedTokens.addAll(reservedTokens);
            return this;
        }

        /**
         * Adds the given sentence to the {@link Vocabulary}.
         *
         * @param sentence the sentence to be added
         * @return this {@code VocabularyBuilder}
         */
        public VocabularyBuilder add(List sentence) {
            this.sentences.add(sentence);
            return this;
        }

        /**
         * Adds the given list of sentences to the {@link Vocabulary}.
         *
         * @param sentences the list of sentences to be added
         * @return this {@code VocabularyBuilder}
         */
        public VocabularyBuilder addAll(List> sentences) {
            this.sentences.addAll(sentences);
            return this;
        }

        /**
         * Builds the {@link Vocabulary} object with the set arguments.
         *
         * @return the {@link Vocabulary} object built
         */
        public Vocabulary build() {
            return new Vocabulary(this);
        }
    }

    /**
     * {@code TokenInfo} represents the information stored in the {@link Vocabulary} about a given
     * token.
     */
    private static final class TokenInfo {
        int frequency;
        long index = -1;

        public TokenInfo() {}

        public TokenInfo(int index) {
            this.index = index;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy