All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.vespa.embedding.SpladeEmbedder Maven / Gradle / Ivy

// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.embedding;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.SpladeEmbedderConfig;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.DirectIndexedAddress;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;

import java.nio.file.Paths;
import java.util.List;
import java.util.Map;

import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST;

/**
 * A SPLADE embedder that is embedding text to a 1-d mapped tensor. For interpretability, the tensor labels
 * are the subword strings from the wordpiece vocabulary that has a score above a threshold (default 0.0).
 *
 * @author bergum
 */
@Beta
public class SpladeEmbedder extends AbstractComponent implements Embedder {

    private final Embedder.Runtime runtime;
    private final String inputIdsName;
    private final String attentionMaskName;
    private final String tokenTypeIdsName;
    private final String outputName;
    private final double termScoreThreshold;
    private final boolean useCustomReduce;
    private final HuggingFaceTokenizer tokenizer;
    private final OnnxEvaluator evaluator;

    @Inject
    public SpladeEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, SpladeEmbedderConfig config) {
        this(onnx, runtime, config, true);
    }
    SpladeEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, SpladeEmbedderConfig config, boolean useCustomReduce) {
        this.runtime = runtime;
        inputIdsName = config.transformerInputIds();
        attentionMaskName = config.transformerAttentionMask();
        outputName = config.transformerOutput();
        tokenTypeIdsName = config.transformerTokenTypeIds();
        termScoreThreshold = config.termScoreThreshold();
        this.useCustomReduce = useCustomReduce;

        var tokenizerPath = Paths.get(config.tokenizerPath().toString());
        var builder = new HuggingFaceTokenizer.Builder()
                .addSpecialTokens(true)
                .addDefaultModel(tokenizerPath)
                .setPadding(false);
        var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath);
        if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) {
            // Force truncation
            // to max length accepted by model if tokenizer.json contains no valid truncation configuration
            int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens()
                    ? info.maxLength()
                    : config.transformerMaxTokens();
            builder.setTruncation(true).setMaxLength(maxLength);
        }
        this.tokenizer = builder.build();
        var onnxOpts = new OnnxEvaluatorOptions();

        if (config.transformerGpuDevice() >= 0)
            onnxOpts.setGpuDevice(config.transformerGpuDevice());
        onnxOpts.setExecutionMode(config.transformerExecutionMode().toString());
        onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads());
        evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
        validateModel();
    }

    public void validateModel() {
        Map inputs = evaluator.getInputInfo();
        validateName(inputs, inputIdsName, "input");
        validateName(inputs, attentionMaskName, "input");
        Map outputs = evaluator.getOutputInfo();
        validateName(outputs, outputName, "output");
    }

    /**
     * Validates that the given tensor type is a 1-d mapped tensor.
     *
     * @param target the type to validate
     * @return true if the type is a 1-d mapped tensor
     */
    protected boolean verifyTensorType(TensorType target) {
        return target.dimensions().size() == 1 && target.dimensions().get(0).isMapped();
    }

    private void validateName(Map types, String name, String type) {
        if (!types.containsKey(name)) {
            throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " +
                    "Model contains: " + String.join(",", types.keySet()));
        }
    }

    @Override
    public List embed(String text, Context context) {
        throw new UnsupportedOperationException("This embedder only supports embed with tensor type");
    }

    @Override
    public Tensor embed(String text, Context context, TensorType tensorType) {
        if (!verifyTensorType(tensorType)) {
            throw new IllegalArgumentException("Invalid splade embedder tensor destination. " +
                                               "Wanted a mapped 1-d tensor, got " + tensorType);
        }
        var start = System.nanoTime();

        var encoding = tokenizer.encode(text, context.getLanguage());
        runtime.sampleSequenceLength(encoding.ids().size(), context);

        Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1");
        Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1");
        Tensor tokenTypeIds = createTensorRepresentation(encoding.typeIds(), "d1");

        Map inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
                attentionMaskName, attentionMask.expand("d0"),
                tokenTypeIdsName, tokenTypeIds.expand("d0"));
        IndexedTensor output = (IndexedTensor) evaluator.evaluate(inputs).get(outputName);
        Tensor spladeTensor = useCustomReduce
                ? sparsifyCustomReduce(output, tensorType)
                : sparsifyReduce(output, tensorType);
        runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context);
        return spladeTensor;
    }

    /**
     * Sparsify the output tensor by applying a threshold on the log of the relu of the output.
     * This uses generic tensor reduce+map, and is slightly slower than a custom unrolled variant.
     *
     * @param modelOutput the model output tensor of shape d1,dim where d1 is the sequence length and dim is size
     *                    of the vocabulary
     * @param tensorType the type of the destination tensor
     * @return A mapped tensor with the terms from the vocab that has a score above the threshold
     */
    private Tensor sparsifyReduce(Tensor modelOutput, TensorType tensorType) {
        // Remove batch dim, batch size of 1
        Tensor output = modelOutput.reduce(Reduce.Aggregator.max, "d0", "d1");
        Tensor logOfRelu = output.map((x) -> Math.log(1 + (x > 0 ? x : 0)));
        IndexedTensor vocab = (IndexedTensor) logOfRelu;
        var builder = Tensor.Builder.of(tensorType);
        long[] tokens = new long[1];
        for (int i = 0; i < vocab.size(); i++) {
            var score = vocab.get(i);
            if (score > termScoreThreshold) {
                tokens[0] = i;
                String term = tokenizer.decode(tokens);
                builder.cell().
                        label(tensorType.dimensions().get(0).name(), term)
                        .value(score);
            }
        }
        return builder.build();
    }



    /**
     * Sparsify the model output tensor.This uses an unrolled custom reduce and is 15-20% faster than the using
     * generic tensor reduce.
     *
     * @param modelOutput the model output tensor of type tensorType
     * @param tensorType the type of the destination tensor
     * @return A mapped tensor with the terms from the vocab that has a score above the threshold
     */
    public Tensor sparsifyCustomReduce(IndexedTensor modelOutput, TensorType tensorType) {
        var builder = Tensor.Builder.of(tensorType);
        long[] shape = modelOutput.shape();
        if(shape.length != 3) {
            throw new IllegalArgumentException("The indexed tensor must be 3-dimensional");
        }
        long batch = shape[0];
        if (batch != 1) {
            throw new IllegalArgumentException("Batch size must be 1");
        }
        if (shape[1] > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("sequenceLength=" + shape[1] + " larger than an int");
        }
        if (shape[2] > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("vocabSize=" + shape[2] + " larger than an int");
        }
        int sequenceLength = (int) shape[1];
        int vocabSize = (int) shape[2];

        String dimension = tensorType.dimensions().get(0).name();
        //Iterate over the vocab dimension and find the max value for each sequence token
        long [] tokens = new long[1];
        DirectIndexedAddress directAddress = modelOutput.directAddress();
        directAddress.setIndex(0,0);
        for (int v = 0; v < vocabSize; v++) {
            double maxValue = 0.0d;
            directAddress.setIndex(2, v);
            long increment = directAddress.getStride(1);
            long directIndex = directAddress.getDirectIndex();
            for (int s = 0; s < sequenceLength; s++) {
                double value = modelOutput.get(directIndex + s * increment);
                if (value > maxValue) {
                    maxValue = value;
                }
            }
            double logOfRelu = Math.log(1 + maxValue);
            if (logOfRelu > termScoreThreshold) {
                tokens[0] = v;
                String term = tokenizer.decode(tokens);
                builder.cell()
                        .label(dimension, term)
                        .value(logOfRelu);
            }
        }
        return builder.build();
    }

    private IndexedTensor createTensorRepresentation(List input, String dimension) {
        int size = input.size();
        TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
        for (int i = 0; i < size; ++i) {
            builder.cell(input.get(i), i);
        }
        return builder.build();
    }

    @Override
    public void deconstruct() {
        evaluator.close();
        tokenizer.close();
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy