com.yahoo.language.wordpiece.WordPieceEmbedder Maven / Gradle / Ivy
The newest version!
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.language.wordpiece;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.tools.Embed;
import com.yahoo.language.Language;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.language.process.Tokenizer;
import com.yahoo.language.simple.SimpleLinguistics;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.language.wordpiece.WordPieceConfig;
import java.io.File;
import java.nio.file.Path;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* An implementation of the WordPiece embedder, usually used with BERT models,
* see https://arxiv.org/pdf/1609.08144v2.pdf
* Text is tokenized into tokens from a configured vocabulary,
* and optionally returned as a 1-d dense tensor of token ids.
*
* @author bratseth
*/
public class WordPieceEmbedder implements Embedder, Segmenter {
private final Map models;
private final Tokenizer tokenizer;
@Inject
public WordPieceEmbedder(WordPieceConfig config) {
this(new Builder(config));
}
private WordPieceEmbedder(Builder builder) {
super();
this.tokenizer = new SimpleLinguistics().getTokenizer(); // always just split on spaces etc. and lowercase
models = builder.getModels().entrySet()
.stream()
.map(e -> new Model(builder.getSubwordPrefix(), e.getKey(), e.getValue()))
.collect(Collectors.toUnmodifiableMap(m -> m.language(), m -> m));
if (models.isEmpty())
throw new IllegalArgumentException("WordPieceEmbedder requires at least one model configured");
}
/**
* Segments the given text into token segments from the WordPiece vocabulary.
*
* @param text the text to segment. The text should be of a language using space-separated words.
* @return the list of zero or more token ids resulting from segmenting the input text
*/
@Override
public List segment(String text, Language language) {
return resolveModelFrom(language).segment(text, tokenizer);
}
/**
* Segments the given text into token segments from the WordPiece vocabulary and returns the token ids.
*
* @param text the text to segment. The text should be of a language using space-separated words.
* @param context the context which specifies the language used to select a model
* @return the list of zero or more token ids resulting from segmenting the input text
*/
@Override
public List embed(String text, Context context) {
return resolveModelFrom(context.getLanguage()).embed(text, tokenizer);
}
/**
* Embeds text into a tensor.
*
* If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order
* they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small
* it will be truncated.
*
* If the tensor is any other type IllegalArgumentException is thrown.
*
* @param text the text to segment. The text should be of a language using space-separated words.
* @param context the context which specifies the language used to select a model
* @return the list of zero or more token ids resulting from segmenting the input text
*/
@Override
public Tensor embed(String text, Context context, TensorType type) {
return Embed.asTensor(text, this, context, type);
}
private Model resolveModelFrom(Language language) {
// Disregard language if there is default model
if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN);
if (models.containsKey(language)) return models.get(language);
throw new IllegalArgumentException("No WordPiece model for language " + language + " is configured");
}
public static final class Builder {
private String subwordPrefix = "##";
private final Map models = new EnumMap<>(Language.class);
public Builder() {}
public Builder(String defaultModelFile) {
addDefaultModel(new File(defaultModelFile).toPath());
}
private Builder(WordPieceConfig config) {
this.subwordPrefix = config.subwordPrefix();
for (WordPieceConfig.Model model : config.model())
addModel(Language.fromLanguageTag(model.language()), model.path());
}
public Builder setSubwordPrefix(String prefix) {
this.subwordPrefix = subwordPrefix;
return this;
}
public String getSubwordPrefix() { return subwordPrefix; }
public void addModel(Language language, Path model) {
models.put(language, model);
}
/**
* Adds the model that will be used if the language is unknown, OR only one model is specified.
* The same as addModel(Language.UNKNOWN, model).
*/
public WordPieceEmbedder.Builder addDefaultModel(Path model) {
addModel(Language.UNKNOWN, model);
return this;
}
public Map getModels() { return models; }
public WordPieceEmbedder build() {
if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied");
return new WordPieceEmbedder(this);
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy