All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.language.huggingface.Encoding Maven / Gradle / Ivy

The newest version!
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package com.yahoo.language.huggingface;

import com.yahoo.api.annotations.Beta;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * @author bjorncs
 */
@Beta
public record Encoding(
        List ids, List typeIds, List tokens, List wordIds, List attentionMask,
        List specialTokenMask, List charTokenSpans, List overflowing) {

    public record CharSpan(int start, int end) {
        public static final CharSpan NONE = new CharSpan(-1, -1);
        static CharSpan from(ai.djl.huggingface.tokenizers.jni.CharSpan s) {
            if (s == null) return NONE;
            return new CharSpan(s.getStart(), s.getEnd());
        }
        public boolean isNone() { return this.equals(NONE); }
    }

    public Encoding {
        ids = List.copyOf(ids);
        typeIds = List.copyOf(typeIds);
        tokens = List.copyOf(tokens);
        wordIds = List.copyOf(wordIds);
        attentionMask = List.copyOf(attentionMask);
        specialTokenMask = List.copyOf(specialTokenMask);
        charTokenSpans = List.copyOf(charTokenSpans);
        overflowing = List.copyOf(overflowing);
    }

    static Encoding from(ai.djl.huggingface.tokenizers.Encoding e) {
        return new Encoding(
                toList(e.getIds()),
                toList(e.getTypeIds()),
                List.of(e.getTokens()),
                toList(e.getWordIds()),
                toList(e.getAttentionMask()),
                toList(e.getSpecialTokenMask()),
                Arrays.stream(e.getCharTokenSpans()).map(CharSpan::from).toList(),
                Arrays.stream(e.getOverflowing()).map(Encoding::from).toList());
    }

    private static List toList(long[] array) {
        if (array == null) return List.of();
        var list = new ArrayList(array.length);
        for (long e : array) list.add(e);
        return list;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy