All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.language.wordpiece.WordPieceEmbedder Maven / Gradle / Ivy

The newest version!
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.language.wordpiece;

import com.yahoo.component.annotation.Inject;
import com.yahoo.language.tools.Embed;
import com.yahoo.language.Language;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.language.process.Tokenizer;
import com.yahoo.language.simple.SimpleLinguistics;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.language.wordpiece.WordPieceConfig;

import java.io.File;
import java.nio.file.Path;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * An implementation of the WordPiece embedder, usually used with BERT models,
 * see https://arxiv.org/pdf/1609.08144v2.pdf
 * Text is tokenized into tokens from a configured vocabulary,
 * and optionally returned as a 1-d dense tensor of token ids.
 *
 * @author bratseth
 */
public class WordPieceEmbedder implements Embedder, Segmenter {

    private final Map models;

    private final Tokenizer tokenizer;

    @Inject
    public WordPieceEmbedder(WordPieceConfig config) {
        this(new Builder(config));
    }

    private WordPieceEmbedder(Builder builder) {
        super();
        this.tokenizer = new SimpleLinguistics().getTokenizer(); // always just split on spaces etc. and lowercase
        models = builder.getModels().entrySet()
                        .stream()
                        .map(e -> new Model(builder.getSubwordPrefix(), e.getKey(), e.getValue()))
                        .collect(Collectors.toUnmodifiableMap(m -> m.language(), m -> m));
        if (models.isEmpty())
            throw new IllegalArgumentException("WordPieceEmbedder requires at least one model configured");
    }

    /**
     * Segments the given text into token segments from the WordPiece vocabulary.
     *
     * @param text the text to segment. The text should be of a language using space-separated words.
     * @return the list of zero or more token ids resulting from segmenting the input text
     */
    @Override
    public List segment(String text, Language language) {
        return resolveModelFrom(language).segment(text, tokenizer);
    }

    /**
     * Segments the given text into token segments from the WordPiece vocabulary and returns the token ids.
     *
     * @param text the text to segment. The text should be of a language using space-separated words.
     * @param context the context which specifies the language used to select a model
     * @return the list of zero or more token ids resulting from segmenting the input text
     */
    @Override
    public List embed(String text, Context context) {
        return resolveModelFrom(context.getLanguage()).embed(text, tokenizer);
    }

    /**
     * 

Embeds text into a tensor.

* *

If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small * it will be truncated.

* *

If the tensor is any other type IllegalArgumentException is thrown.

* * @param text the text to segment. The text should be of a language using space-separated words. * @param context the context which specifies the language used to select a model * @return the list of zero or more token ids resulting from segmenting the input text */ @Override public Tensor embed(String text, Context context, TensorType type) { return Embed.asTensor(text, this, context, type); } private Model resolveModelFrom(Language language) { // Disregard language if there is default model if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); if (models.containsKey(language)) return models.get(language); throw new IllegalArgumentException("No WordPiece model for language " + language + " is configured"); } public static final class Builder { private String subwordPrefix = "##"; private final Map models = new EnumMap<>(Language.class); public Builder() {} public Builder(String defaultModelFile) { addDefaultModel(new File(defaultModelFile).toPath()); } private Builder(WordPieceConfig config) { this.subwordPrefix = config.subwordPrefix(); for (WordPieceConfig.Model model : config.model()) addModel(Language.fromLanguageTag(model.language()), model.path()); } public Builder setSubwordPrefix(String prefix) { this.subwordPrefix = subwordPrefix; return this; } public String getSubwordPrefix() { return subwordPrefix; } public void addModel(Language language, Path model) { models.put(language, model); } /** * Adds the model that will be used if the language is unknown, OR only one model is specified. * The same as addModel(Language.UNKNOWN, model). */ public WordPieceEmbedder.Builder addDefaultModel(Path model) { addModel(Language.UNKNOWN, model); return this; } public Map getModels() { return models; } public WordPieceEmbedder build() { if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied"); return new WordPieceEmbedder(this); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy