All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.huggingface.tokenizers.HuggingFaceTokenizer Maven / Gradle / Ivy

There is a newer version: 0.31.1
Show newest version
/*
 * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.huggingface.tokenizers;

import ai.djl.huggingface.tokenizers.jni.CharSpan;
import ai.djl.huggingface.tokenizers.jni.LibUtils;
import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary;
import ai.djl.modality.nlp.preprocess.Tokenizer;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.util.Ec2Utils;
import ai.djl.util.NativeResource;
import ai.djl.util.PairList;
import ai.djl.util.Utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * {@code HuggingFaceTokenizer} is a Huggingface tokenizer implementation of the {@link Tokenizer}
 * interface that converts sentences into token.
 */
public final class HuggingFaceTokenizer extends NativeResource implements Tokenizer {

    private static final Logger logger = LoggerFactory.getLogger(HuggingFaceTokenizer.class);

    private boolean addSpecialTokens;
    private TruncationStrategy truncation;
    private PaddingStrategy padding;
    private int maxLength;
    private int stride;
    private int padToMultipleOf;
    private int modelMaxLength;

    private HuggingFaceTokenizer(long handle, Map options) {
        super(handle);
        String val = TokenizersLibrary.LIB.getTruncationStrategy(handle);
        truncation = TruncationStrategy.fromValue(val);
        val = TokenizersLibrary.LIB.getPaddingStrategy(handle);
        padding = PaddingStrategy.fromValue(val);
        maxLength = TokenizersLibrary.LIB.getMaxLength(handle);
        stride = TokenizersLibrary.LIB.getStride(handle);
        padToMultipleOf = TokenizersLibrary.LIB.getPadToMultipleOf(handle);

        if (options != null) {
            val = options.getOrDefault("addSpecialTokens", "true");
            addSpecialTokens = Boolean.parseBoolean(val);
            modelMaxLength = ArgumentsUtil.intValue(options, "modelMaxLength", 512);
            if (options.containsKey("truncation")) {
                truncation = TruncationStrategy.fromValue(options.get("truncation"));
            }
            if (options.containsKey("padding")) {
                padding = PaddingStrategy.fromValue(options.get("padding"));
            }
            maxLength = ArgumentsUtil.intValue(options, "maxLength", maxLength);
            stride = ArgumentsUtil.intValue(options, "stride", stride);
            padToMultipleOf = ArgumentsUtil.intValue(options, "padToMultipleOf", padToMultipleOf);
        } else {
            addSpecialTokens = true;
            modelMaxLength = 512;
        }

        updateTruncationAndPadding();
    }

    /**
     * Creates a pre-trained {@code HuggingFaceTokenizer} instance from huggingface hub.
     *
     * @param name the name of the huggingface tokenizer
     * @return a {@code HuggingFaceTokenizer} instance
     */
    public static HuggingFaceTokenizer newInstance(String name) {
        return newInstance(name, null);
    }

    /**
     * Create a pre-trained {@code HuggingFaceTokenizer} instance from huggingface hub.
     *
     * @param identifier the identifier of the huggingface tokenizer
     * @param options tokenizer options
     * @return a {@code HuggingFaceTokenizer} instance
     */
    public static HuggingFaceTokenizer newInstance(String identifier, Map options) {
        Ec2Utils.callHome("Huggingface");
        LibUtils.checkStatus();

        long handle = TokenizersLibrary.LIB.createTokenizer(identifier);
        return new HuggingFaceTokenizer(handle, options);
    }

    /**
     * Create a pre-trained {@code HuggingFaceTokenizer} instance from existing models.
     *
     * @param modelPath the directory or file path of the model location
     * @return a {@code HuggingFaceTokenizer} instance
     * @throws IOException when IO operation fails in loading a resource
     */
    public static HuggingFaceTokenizer newInstance(Path modelPath) throws IOException {
        return newInstance(modelPath, null);
    }

    /**
     * Create a pre-trained {@code HuggingFaceTokenizer} instance from existing models.
     *
     * @param modelPath the directory or file path of the model location
     * @param options tokenizer options
     * @return a {@code HuggingFaceTokenizer} instance
     * @throws IOException when IO operation fails in loading a resource
     */
    public static HuggingFaceTokenizer newInstance(Path modelPath, Map options)
            throws IOException {
        if (Files.isDirectory(modelPath)) {
            modelPath = modelPath.resolve("tokenizer.json");
        }
        try (InputStream is = Files.newInputStream(modelPath)) {
            return newInstance(is, options);
        }
    }

    /**
     * Create a pre-trained {@code HuggingFaceTokenizer} instance from {@code InputStream}.
     *
     * @param is {@code InputStream}
     * @param options tokenizer options
     * @return a {@code HuggingFaceTokenizer} instance
     * @throws IOException when IO operation fails in loading a resource
     */
    public static HuggingFaceTokenizer newInstance(InputStream is, Map options)
            throws IOException {
        Ec2Utils.callHome("Huggingface");
        LibUtils.checkStatus();
        String json = Utils.toString(is);

        long handle = TokenizersLibrary.LIB.createTokenizerFromString(json);
        return new HuggingFaceTokenizer(handle, options);
    }

    /** {@inheritDoc} */
    @Override
    public List tokenize(String sentence) {
        Encoding encoding = encode(sentence);
        return Arrays.asList(encoding.getTokens());
    }

    /** {@inheritDoc} */
    @Override
    public String buildSentence(List tokens) {
        // TODO:
        return String.join(" ", tokens).replace(" ##", "").trim();
    }

    /** {@inheritDoc} */
    @Override
    public void close() {
        Long pointer = handle.getAndSet(null);
        if (pointer != null) {
            TokenizersLibrary.LIB.deleteTokenizer(pointer);
        }
    }

    /**
     * Returns the {@code Encoding} of the input sentence.
     *
     * @param text the input sentence
     * @param addSpecialTokens whether to encode the sequence with special tokens relative to their
     *     model
     * @return the {@code Encoding} of the input sentence
     */
    public Encoding encode(String text, boolean addSpecialTokens) {
        long encoding = TokenizersLibrary.LIB.encode(getHandle(), text, addSpecialTokens);
        return toEncoding(encoding);
    }

    /**
     * Returns the {@code Encoding} of the input sentence.
     *
     * @param text the input sentence
     * @return the {@code Encoding} of the input sentence
     */
    public Encoding encode(String text) {
        return encode(text, addSpecialTokens);
    }

    /**
     * Returns the {@code Encoding} of the input sentence.
     *
     * @param text the input sentence
     * @param textPair the second input sentence
     * @param addSpecialTokens whether to encode the sequence with special tokens relative to their
     *     model
     * @return the {@code Encoding} of the input sentence
     */
    public Encoding encode(String text, String textPair, boolean addSpecialTokens) {
        long encoding =
                TokenizersLibrary.LIB.encodeDual(getHandle(), text, textPair, addSpecialTokens);
        return toEncoding(encoding);
    }

    /**
     * Returns the {@code Encoding} of the input sentence.
     *
     * @param text the input sentence
     * @param textPair the second input sentence
     * @return the {@code Encoding} of the input sentence
     */
    public Encoding encode(String text, String textPair) {
        return encode(text, textPair, addSpecialTokens);
    }

    /**
     * Returns the {@code Encoding} of the input sentences.
     *
     * @param inputs the input sentences
     * @param addSpecialTokens whether to encode the sequence with special tokens relative to their
     *     model
     * @return the {@code Encoding} of the input sentences
     */
    public Encoding encode(List inputs, boolean addSpecialTokens) {
        String[] array = inputs.toArray(Utils.EMPTY_ARRAY);
        return encode(array, addSpecialTokens);
    }

    /**
     * Returns the {@code Encoding} of the input sentences.
     *
     * @param inputs the input sentences
     * @return the {@code Encoding} of the input sentences
     */
    public Encoding encode(List inputs) {
        return encode(inputs, addSpecialTokens);
    }

    /**
     * Returns the {@code Encoding} of the input sentences.
     *
     * @param inputs the input sentences
     * @param addSpecialTokens whether to encode the sequence with special tokens relative to their
     *     model
     * @return the {@code Encoding} of the input sentences
     */
    public Encoding encode(String[] inputs, boolean addSpecialTokens) {
        long encoding = TokenizersLibrary.LIB.encodeList(getHandle(), inputs, addSpecialTokens);
        return toEncoding(encoding);
    }

    /**
     * Returns the {@code Encoding} of the input sentences.
     *
     * @param inputs the input sentences
     * @return the {@code Encoding} of the input sentences
     */
    public Encoding encode(String[] inputs) {
        return encode(inputs, addSpecialTokens);
    }

    /**
     * Returns the {@code Encoding} of the input sentence in batch.
     *
     * @param inputs the batch of input sentence
     * @param addSpecialTokens whether to encode the sequence with special tokens relative to their
     *     model
     * @return the {@code Encoding} of the input sentence in batch
     */
    public Encoding[] batchEncode(List inputs, boolean addSpecialTokens) {
        String[] array = inputs.toArray(Utils.EMPTY_ARRAY);
        return batchEncode(array, addSpecialTokens);
    }

    /**
     * Returns the {@code Encoding} of the input sentence in batch.
     *
     * @param inputs the batch of input sentence
     * @return the {@code Encoding} of the input sentence in batch
     */
    public Encoding[] batchEncode(List inputs) {
        return batchEncode(inputs, addSpecialTokens);
    }

    /**
     * Returns the {@code Encoding} of the input sentence in batch.
     *
     * @param inputs the batch of input sentence
     * @param addSpecialTokens whether to encode the sequence with special tokens relative to their
     *     model
     * @return the {@code Encoding} of the input sentence in batch
     */
    public Encoding[] batchEncode(String[] inputs, boolean addSpecialTokens) {
        long[] encodings = TokenizersLibrary.LIB.batchEncode(getHandle(), inputs, addSpecialTokens);
        Encoding[] ret = new Encoding[encodings.length];
        for (int i = 0; i < encodings.length; ++i) {
            ret[i] = toEncoding(encodings[i]);
        }
        return ret;
    }

    /**
     * Returns the {@code Encoding} of the input sentence in batch.
     *
     * @param inputs the batch of input sentence
     * @return the {@code Encoding} of the input sentence in batch
     */
    public Encoding[] batchEncode(String[] inputs) {
        return batchEncode(inputs, addSpecialTokens);
    }

    /**
     * Returns the {@code Encoding} of the input text pair in batch.
     *
     * @param inputs the batch of input text pair
     * @param addSpecialTokens whether to encode the sequence with special tokens relative to their
     *     model
     * @return the {@code Encoding} of the input text pair in batch
     */
    public Encoding[] batchEncode(PairList inputs, boolean addSpecialTokens) {
        String[] text = inputs.keyArray(Utils.EMPTY_ARRAY);
        String[] textPair = inputs.valueArray(Utils.EMPTY_ARRAY);
        long[] encodings =
                TokenizersLibrary.LIB.batchEncodePair(
                        getHandle(), text, textPair, addSpecialTokens);
        Encoding[] ret = new Encoding[encodings.length];
        for (int i = 0; i < encodings.length; ++i) {
            ret[i] = toEncoding(encodings[i]);
        }
        return ret;
    }

    /**
     * Returns the {@code Encoding} of the input text pair in batch.
     *
     * @param inputs the batch of input text pair
     * @return the {@code Encoding} of the input text pair in batch
     */
    public Encoding[] batchEncode(PairList inputs) {
        return batchEncode(inputs, addSpecialTokens);
    }

    /**
     * Returns the decoded String from the input ids.
     *
     * @param ids the input ids
     * @param skipSpecialTokens whether to remove special tokens in the decoding
     * @return the decoded String from the input ids
     */
    public String decode(long[] ids, boolean skipSpecialTokens) {
        return TokenizersLibrary.LIB.decode(getHandle(), ids, skipSpecialTokens);
    }

    /**
     * Returns the decoded String from the input ids.
     *
     * @param ids the input ids
     * @return the decoded String from the input ids
     */
    public String decode(long[] ids) {
        return decode(ids, !addSpecialTokens);
    }

    /**
     * Returns the decoded Strings from the input batch ids.
     *
     * @param batchIds the batch of id sequences to decode
     * @param skipSpecialTokens whether to remove special tokens in the decoding
     * @return the decoded Strings from the input batch ids
     */
    public String[] batchDecode(long[][] batchIds, boolean skipSpecialTokens) {
        return TokenizersLibrary.LIB.batchDecode(getHandle(), batchIds, skipSpecialTokens);
    }

    /**
     * Returns the decoded Strings from the input batch ids.
     *
     * @param batchIds the batch of id sequences to decode
     * @return the decoded Strings from the input batch ids
     */
    public String[] batchDecode(long[][] batchIds) {
        return batchDecode(batchIds, !addSpecialTokens);
    }

    /** Sets padding and truncation to true for batching. */
    public void enableBatch() {
        boolean changed = false;
        if (padding == PaddingStrategy.DO_NOT_PAD) {
            changed = true;
            padding = PaddingStrategy.LONGEST;
        }
        if (truncation == TruncationStrategy.DO_NOT_TRUNCATE) {
            changed = true;
            truncation = TruncationStrategy.LONGEST_FIRST;
        }
        if (changed) {
            updateTruncationAndPadding();
        }
    }

    /**
     * Creates a builder to build a {@code HuggingFaceTokenizer}.
     *
     * @return a new builder
     */
    public static Builder builder() {
        return new Builder();
    }

    /**
     * Creates a builder to build a {@code HuggingFaceTokenizer}.
     *
     * @param arguments the models' arguments
     * @return a new builder
     */
    public static Builder builder(Map arguments) {
        Builder builder = builder();
        builder.configure(arguments);

        return builder;
    }

    /*
     * See: https://huggingface.co/docs/transformers/pad_truncation
     */
    private void updateTruncationAndPadding() {
        boolean isTruncate = truncation != TruncationStrategy.DO_NOT_TRUNCATE;
        if (padding == PaddingStrategy.MAX_LENGTH || isTruncate) {
            if (maxLength == -1) {
                logger.warn(
                        "maxLength is not explicitly specified, use modelMaxLength: "
                                + modelMaxLength);
                maxLength = modelMaxLength;
            } else if (maxLength > modelMaxLength) {
                logger.warn(
                        "maxLength is greater then modelMaxLength, change to: " + modelMaxLength);
                maxLength = modelMaxLength;
            }

            if (padding == PaddingStrategy.MAX_LENGTH && isTruncate && padToMultipleOf != 0) {
                int remainder = maxLength % padToMultipleOf;
                if (remainder != 0) {
                    int newMaxLength = maxLength + padToMultipleOf - maxLength % padToMultipleOf;
                    if (newMaxLength > modelMaxLength) {
                        newMaxLength -= padToMultipleOf;
                    }
                    logger.warn(
                            "maxLength ("
                                    + maxLength
                                    + ") is not a multiple of padToMultipleOf ("
                                    + padToMultipleOf
                                    + "), change to: "
                                    + newMaxLength);
                    maxLength = newMaxLength;
                }
            }
        }

        if (isTruncate) {
            TokenizersLibrary.LIB.setTruncation(getHandle(), maxLength, truncation.name(), stride);
        } else {
            TokenizersLibrary.LIB.disableTruncation(getHandle());
        }

        if (padding == PaddingStrategy.DO_NOT_PAD) {
            TokenizersLibrary.LIB.disablePadding(getHandle());
        } else {
            TokenizersLibrary.LIB.setPadding(
                    getHandle(), maxLength, padding.name(), padToMultipleOf);
        }
    }

    private Encoding toEncoding(long encoding) {
        long[] ids = TokenizersLibrary.LIB.getTokenIds(encoding);
        long[] typeIds = TokenizersLibrary.LIB.getTypeIds(encoding);
        String[] tokens = TokenizersLibrary.LIB.getTokens(encoding);
        long[] wordIds = TokenizersLibrary.LIB.getWordIds(encoding);
        long[] attentionMask = TokenizersLibrary.LIB.getAttentionMask(encoding);
        long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding);
        CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding);
        long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);

        Encoding[] overflowing = new Encoding[overflowingHandles.length];
        for (int i = 0; i < overflowingHandles.length; ++i) {
            overflowing[i] = toEncoding(overflowingHandles[i]);
        }

        TokenizersLibrary.LIB.deleteEncoding(encoding);
        return new Encoding(
                ids,
                typeIds,
                tokens,
                wordIds,
                attentionMask,
                specialTokenMask,
                charSpans,
                overflowing);
    }

    /** {@inheritDoc} */
    @SuppressWarnings("deprecation")
    @Override
    protected void finalize() throws Throwable {
        close();
        super.finalize();
    }

    /** An enum to represent the different available truncation strategies. */
    private enum TruncationStrategy {
        LONGEST_FIRST,
        ONLY_FIRST,
        ONLY_SECOND,
        DO_NOT_TRUNCATE;

        /**
         * Converts the String to the matching TruncationStrategy type.
         *
         * @param value the String to convert
         * @return the matching PaddingStrategy type
         * @throws IllegalArgumentException if the value does not match any TruncationStrategy type
         */
        static TruncationStrategy fromValue(String value) {
            if ("true".equals(value)) {
                return TruncationStrategy.LONGEST_FIRST;
            } else if ("false".equals(value)) {
                return TruncationStrategy.DO_NOT_TRUNCATE;
            }
            for (TruncationStrategy strategy : TruncationStrategy.values()) {
                if (strategy.name().equalsIgnoreCase(value)) {
                    return strategy;
                }
            }
            throw new IllegalArgumentException("Invalid TruncationStrategy: " + value);
        }
    }

    /** An enum to represent the different available padding strategies. */
    private enum PaddingStrategy {
        LONGEST,
        MAX_LENGTH,
        DO_NOT_PAD;

        /**
         * Converts the String to the matching PaddingStrategy type.
         *
         * @param value the String to convert
         * @return the matching PaddingStrategy type
         * @throws IllegalArgumentException if the value does not match any PaddingStrategy type
         */
        static PaddingStrategy fromValue(String value) {
            if ("true".equals(value)) {
                return PaddingStrategy.LONGEST;
            } else if ("false".equals(value)) {
                return PaddingStrategy.DO_NOT_PAD;
            }
            for (PaddingStrategy strategy : PaddingStrategy.values()) {
                if (strategy.name().equalsIgnoreCase(value)) {
                    return strategy;
                }
            }
            throw new IllegalArgumentException("Invalid PaddingStrategy: " + value);
        }
    }

    /** The builder for creating huggingface tokenizer. */
    public static final class Builder {

        private Path tokenizerPath;
        private NDManager manager;
        private Map options;

        Builder() {
            options = new ConcurrentHashMap<>();
            options.put("addSpecialTokens", "true");
        }

        /**
         * Sets the optional manager used to manage the lifecycle of the tokenizer.
         *
         * @param manager the {@link NDManager}
         * @return this builder
         */
        public Builder optManager(NDManager manager) {
            this.manager = manager;
            return this;
        }

        /**
         * Sets the name of the tokenizer.
         *
         * @param tokenizerName the name of the tokenizer
         * @return this builder
         */
        public Builder optTokenizerName(String tokenizerName) {
            options.put("tokenizer", tokenizerName);
            return this;
        }

        /**
         * Sets the file path of the tokenizer.
         *
         * @param tokenizerPath the path of the tokenizer
         * @return this builder
         */
        public Builder optTokenizerPath(Path tokenizerPath) {
            this.tokenizerPath = tokenizerPath;
            return this;
        }

        /**
         * Sets if add special tokens.
         *
         * @param addSpecialTokens true to add special tokens
         * @return this builder
         */
        public Builder optAddSpecialTokens(boolean addSpecialTokens) {
            options.put("addSpecialTokens", String.valueOf(addSpecialTokens));
            return this;
        }

        /**
         * Enables or Disables default truncation behavior for the tokenizer.
         *
         * @param enabled whether to enable default truncation behavior
         * @return this builder
         */
        public Builder optTruncation(boolean enabled) {
            options.put("truncation", String.valueOf(enabled));
            return this;
        }

        /**
         * Enables truncation to only truncate the first item.
         *
         * @return this builder
         */
        public Builder optTruncateFirstOnly() {
            options.put("truncation", TruncationStrategy.ONLY_FIRST.name());
            return this;
        }

        /**
         * Enables truncation to only truncate the second item.
         *
         * @return this builder
         */
        public Builder optTruncateSecondOnly() {
            options.put("truncation", TruncationStrategy.ONLY_SECOND.name());
            return this;
        }

        /**
         * Enables or Disables default padding behavior for the tokenizer.
         *
         * @param enabled whether to enable default padding behavior
         * @return this builder
         */
        public Builder optPadding(boolean enabled) {
            options.put("padding", String.valueOf(enabled));
            return this;
        }

        /**
         * Enables padding to pad sequences to previously specified maxLength, or modelMaxLength if
         * not specified.
         *
         * @return this builder
         */
        public Builder optPadToMaxLength() {
            options.put("padding", PaddingStrategy.MAX_LENGTH.name());
            return this;
        }

        /**
         * Sets maxLength for padding and truncation.
         *
         * @param maxLength the length to truncate and/or pad sequences to
         * @return this builder
         */
        public Builder optMaxLength(int maxLength) {
            options.put("maxLength", String.valueOf(maxLength));
            return this;
        }

        /**
         * Sets padToMultipleOf for padding.
         *
         * @param padToMultipleOf the multiple of sequences should be padded to
         * @return this builder
         */
        public Builder optPadToMultipleOf(int padToMultipleOf) {
            options.put("padToMultipleOf", String.valueOf(padToMultipleOf));
            return this;
        }

        /**
         * Sets the stride to use in overflow overlap when truncating sequences longer than the
         * model supports.
         *
         * @param stride the number of tokens to overlap when truncating long sequences
         * @return this builder
         */
        public Builder optStride(int stride) {
            options.put("stride", String.valueOf(stride));
            return this;
        }

        /**
         * Configures the builder with the arguments.
         *
         * @param arguments the arguments
         */
        public void configure(Map arguments) {
            for (Map.Entry entry : arguments.entrySet()) {
                options.put(entry.getKey(), entry.getValue().toString());
            }
        }

        /**
         * Utility to make a tokenizer managed by the builder manager (if one is specified).
         *
         * @param tokenizer the tokenizer to manage
         * @return the updated tokenizer
         */
        private HuggingFaceTokenizer managed(HuggingFaceTokenizer tokenizer) {
            if (manager != null) {
                manager.attachInternal(tokenizer.getUid(), tokenizer);
            }
            return tokenizer;
        }

        /**
         * Builds the translator.
         *
         * @return the new translator
         * @throws IOException when IO operation fails in loading a resource
         */
        public HuggingFaceTokenizer build() throws IOException {
            String tokenizerName = options.get("tokenizer");
            if (tokenizerName != null) {
                return managed(HuggingFaceTokenizer.newInstance(tokenizerName, options));
            }
            if (tokenizerPath == null) {
                throw new IllegalArgumentException("Missing tokenizer path.");
            }
            return managed(HuggingFaceTokenizer.newInstance(tokenizerPath, options));
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy