All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.text.tokenization.tokenizer.preprocessor.BertWordPiecePreProcessor Maven / Gradle / Ivy

package org.deeplearning4j.text.tokenization.tokenizer.preprocessor;

import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;

import java.text.Normalizer;
import java.util.List;
import java.util.Map;

/**
 * A preprocessor for cleaning/normaling text. Does the following:
 * 1. Optionally converts all characters to lower case
 * 2. Optionally strips accents off characters
 * 3. Strips all control characters
 * 4. Replaces whitespace characters with space ' ' (this includes newline and tab)
 * 5. Appends spaces before/after Chinese characters
 */
public class BertWordPiecePreProcessor implements TokenPreProcess {

    public static final char REPLACEMENT_CHAR = 0xfffd;

    protected final boolean lowerCase;
    protected final boolean stripAccents;
    protected final IntSet charSet;

    public BertWordPiecePreProcessor(){
        this(false, false, null);
    }

    /**
     *
     * @param lowerCase If true: tokenization should convert all characters to lower case
     * @param stripAccents  If true: strip accents off characters. Usually same as lower case. Should be true when using "uncased" official BERT TensorFlow models
     */
    public BertWordPiecePreProcessor(boolean lowerCase, boolean stripAccents, Map vocab){
        this.lowerCase = lowerCase;
        this.stripAccents = stripAccents;
        if(vocab != null) {
            charSet = new IntOpenHashSet();
            for (String s : vocab.keySet()) {
                int cpNum = 0;
                int n = s.codePointCount(0, s.length());
                int charOffset = 0;
                while (cpNum++ < n) {
                    int cp = s.codePointAt(charOffset);
                    charOffset += Character.charCount(cp);
                    charSet.add(cp);
                }
            }
        } else {
            charSet = null;
        }
    }

    @Override
    public String preProcess(String token) {
        if(stripAccents) {
            token = Normalizer.normalize(token, Normalizer.Form.NFD);
        }

        int n = token.codePointCount(0, token.length());
        StringBuilder sb = new StringBuilder();
        int charOffset = 0;
        int cps = 0;
        while(cps++ < n){
            int cp = token.codePointAt(charOffset);
            charOffset += Character.charCount(cp);

            //Remove control characters and accents
            if(cp == 0 || cp == REPLACEMENT_CHAR || isControlCharacter(cp) || (stripAccents && Character.getType(cp) == Character.NON_SPACING_MARK))
                continue;

            //Convert to lower case if necessary
            if(lowerCase){
                cp = Character.toLowerCase(cp);
            }

            //Replace whitespace chars with space
            if(isWhiteSpace(cp)) {
                sb.append(' ');
                continue;
            }

            if(charSet != null && !charSet.contains(cp)){
                //Skip unknown character (out-of-vocab - though this should rarely happen)
                continue;
            }

            //Handle Chinese and other characters
            if(isChineseCharacter(cp)){
                sb.append(' ');
                sb.appendCodePoint(cp);
                sb.append(' ');
                continue;
            }

            //All other characters - keep
            sb.appendCodePoint(cp);
        }

        return sb.toString();
    }

    public static boolean isControlCharacter(int cp){
        //Treat newline/tab as whitespace
        if(cp == '\t' || cp == '\n' || cp == '\r')
            return false;
        int type = Character.getType(cp);
        return type == Character.CONTROL || type == Character.FORMAT;
    }

    public static boolean isWhiteSpace(int cp){
        //Treat newline/tab as whitespace
        if(cp == '\t' || cp == '\n' || cp == '\r')
            return true;
        int type = Character.getType(cp);
        return type == Character.SPACE_SEPARATOR;
    }

    public static boolean isChineseCharacter(int cp) {
        //Remove any CJK Unicode code block characters
        // https://en.wikipedia.org/wiki/List_of_CJK_Unified_Ideographs,_part_1_of_4
        return (cp >= 0x4E00 && cp <= 0x9FFF) ||
                (cp >= 0x3400 && cp <= 0x4DBF) ||
                (cp >= 0x20000 && cp <= 0x2A6DF) ||
                (cp >= 0x2A700 && cp <= 0x2B73F) ||
                (cp >= 0x2B740 && cp <= 0x2B81F) ||
                (cp >= 0x2B820 && cp <= 0x2CEAF) ||
                (cp >= 0xF900 && cp <= 0xFAFF) ||
                (cp >= 0x2F800 && cp <= 0x2FA1F);
    }


    /**
     * Reconstruct the String from tokens
     * @param tokens
     * @return
     */
    public static String reconstructFromTokens(List tokens){
        StringBuilder sb = new StringBuilder();
        boolean first = true;
        for(String s : tokens){
            if(s.startsWith("##")){
                sb.append(s.substring(2));
            } else {
                if(!first && !".".equals(s))
                    sb.append(" ");
                sb.append(s);
                first = false;
//            }
            }
        }
        return sb.toString();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy