All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.language.sentencepiece.SentencePieceEmbedder Maven / Gradle / Ivy

The newest version!
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.language.sentencepiece;

import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.tools.Embed;
import com.yahoo.language.Language;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;

import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * A native Java implementation of SentencePiece - see https://github.com/google/sentencepiece
 *
 * SentencePiece is a language-agnostic segmenter and embedder for neural nets.
 *
 * @author bratseth
 */
@Beta
public class SentencePieceEmbedder implements Segmenter, Embedder {

    private final Map models;

    private final SentencePieceAlgorithm algorithm;

    @Inject
    public SentencePieceEmbedder(SentencePieceConfig config) {
        this(new Builder(config));
    }

    public SentencePieceEmbedder(Builder builder) {
        algorithm = new SentencePieceAlgorithm(builder.getCollapseUnknowns(), builder.getScoring());

        models = builder.getModels().entrySet()
                        .stream()
                        .map(e -> new Model(e.getKey(), e.getValue()))
                        .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m));
        if (models.isEmpty())
            throw new IllegalArgumentException("SentencePieceEmbedder requires at least one model configured");
    }

    /**
     * Segments the given text into token segments using the SentencePiece algorithm
     *
     * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported.
     * @param language the model to use, or Language.UNKNOWN to use the default model if any
     * @return the list of zero or more tokens resulting from segmenting the input text
     */
    @Override
    public List segment(String rawInput, Language language) {
        String input = normalize(rawInput);
        var resultBuilder = new ResultBuilder>(new ArrayList<>()) {
            public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
                result().add(input.substring(segmentStart, segmentEnd));
            }
        };
        segment(input, language, resultBuilder);
        Collections.reverse(resultBuilder.result());
        return resultBuilder.result();
    }

    /**
     * Segments the given text into token segments using the SentencePiece algorithm and returns the segment ids.
     *
     * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported.
     * @param context the context which specifies the language used to select a model
     * @return the list of zero or more token ids resulting from segmenting the input text
     */
    @Override
    public List embed(String rawInput, Embedder.Context context) {
        var resultBuilder = new ResultBuilder>(new ArrayList<>()) {
            public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
                result().add(segmentEnds[segmentEnd].id);
            }
        };
        segment(normalize(rawInput), context.getLanguage(), resultBuilder);
        Collections.reverse(resultBuilder.result());
        return resultBuilder.result();
    }

    /**
     * Converts the list of token id's into a text. The opposite operation of embed.
     *
     * @param tokens the list of tokens to decode to a string
     * @param context the context which specifies the language used to select a model
     * @return the string formed by decoding the tokens back to their string representation
     */
    @Override
    public String decode(List tokens, Embedder.Context context) {
        return decode(tokens, context, false);
    }

    public String decode(List tokens, Embedder.Context context, boolean skipControl) {
        Model model = resolveModelFrom(context.getLanguage());
        StringBuilder sb = new StringBuilder();
        for (var tokenId : tokens) {
            var token = model.tokenId2Token.get(tokenId);
            var skip = skipControl && token.type() == TokenType.control;
            if ( ! skip) {
                sb.append(token.text());
            }
        }
        return denormalize(sb.toString());
    }

    /**
     * 

Embeds text into a tensor.

* *

If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small * it will be truncated.

* *

If the tensor is any other type IllegalArgumentException is thrown.

* * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. * @param context the context which specifies the language used to select a model * @return the list of zero or more token ids resulting from segmenting the input text */ @Override public Tensor embed(String rawInput, Embedder.Context context, TensorType type) { return Embed.asTensor(rawInput, this, context, type); } private void segment(String input, Language language, ResultBuilder resultBuilder) { algorithm.segment(input, resultBuilder, resolveModelFrom(language)); } private Model resolveModelFrom(Language language) { // Disregard language if there is default model if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); if (models.containsKey(language)) return models.get(language); throw new IllegalArgumentException("No SentencePiece model for language " + language + " is configured"); } public String normalize(String s) { StringBuilder b = new StringBuilder(s.length() + 1); boolean queuedSpace = true; // Always start by one space for (int i = 0; i < s.length(); i++) { char c = s.charAt(i); if (s.charAt(i) == ' ') { queuedSpace = true; } else { if (queuedSpace) { b.append(SentencePieceAlgorithm.spaceSymbol); queuedSpace = false; } b.append(c); } } return b.toString(); } public String denormalize(String s) { String result = s.replace(SentencePieceAlgorithm.spaceSymbol, ' '); return result.charAt(0) == ' ' ? result.substring(1) : result; // Skip first space } public static final class Builder { private final Map models = new EnumMap<>(Language.class); private boolean collapseUnknowns = true; private Scoring scoring = Scoring.fewestSegments; public Builder() {} public Builder(String defaultModelFile) { addDefaultModel(new File(defaultModelFile).toPath()); } private Builder(SentencePieceConfig config) { collapseUnknowns = config.collapseUnknowns(); scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments : Scoring.highestScore; for (SentencePieceConfig.Model model : config.model()) addModel(Language.fromLanguageTag(model.language()), model.path()); } public void addModel(Language language, Path model) { models.put(language, model); } /** * Adds the model that will be used if the language is unknown, OR only one model is specified. * The same as addModel(Language.UNKNOWN, model). */ public Builder addDefaultModel(Path model) { addModel(Language.UNKNOWN, model); return this; } public Map getModels() { return models; } /** * Sets whether consecutive unknown character should be collapsed into one large unknown token (default) * or be returned as single character tokens. */ public Builder setCollapseUnknowns(boolean collapseUnknowns) { this.collapseUnknowns = collapseUnknowns; return this; } public boolean getCollapseUnknowns() { return collapseUnknowns; } /** Sets the scoring strategy to use when picking a segmentation. Default: fewestSegments. */ public Builder setScoring(Scoring scoring) { this.scoring = scoring; return this; } public Scoring getScoring() { return scoring; } public SentencePieceEmbedder build() { if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied"); return new SentencePieceEmbedder(this); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy