All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.openai.GPT4Tokenizer.kt Maven / Gradle / Ivy

There is a newer version: 1.0.33
Show newest version
package com.simiacryptus.openai

import com.simiacryptus.openai.GPT4CodecData.bpeRegex
import com.simiacryptus.util.JsonUtil
import java.nio.charset.Charset
import kotlin.reflect.javaType
import kotlin.reflect.typeOf
@OptIn(ExperimentalStdlibApi::class)
class GPT4Tokenizer(isCodex:Boolean = false) {

    class TextEncoder {
        fun encode(text: String): ByteArray {
            return text.map { c -> c.toInt() }.map { i -> i.toByte() }.toByteArray()
        }

    }

    class TextDecoder {
        fun decode(bytes: ByteArray): String {
            return bytes.map { b -> b.toInt() }.joinToString("") { i -> chr(i) }
        }
    }

    companion object {

        val codecJson = GPT4Tokenizer::class.java.getResourceAsStream("/gpt4.json")?.readAllBytes()?.toString(Charsets.UTF_8) ?: ""
        fun  List.indexOf2(element: E, minIndex: Int): Int {
            for (i in minIndex until this.size) {
                if (this[i] == element) {
                    return i
                }
            }
            return -1
        }

        val range = { x: Int, y: Int ->
            val res = (Array(y, { i -> i + x })).toList()
            res
        }

        val ord = { x: String ->
            x[0].toInt()
        }

        val chr = { n: Int ->
            n.toChar().toString()
        }

    }

    private val vocab: String
    private val nMergedSpaces: Int
    private val nVocab: Int

    private val encodings: HashMap
    private val decodings: HashMap

    private var byteEncoder: HashMap
    private val byteDecoder: HashMap

    private val bpeRanks: HashMap, Int>
    private val cache: HashMap
    private val encodeCache = HashMap>()

    private val textEncoder: TextEncoder
    private val textDecoder: TextDecoder

    init {
        this.encodings = JsonUtil.fromJson(codecJson, typeOf>().javaType)
        this.vocab = GPT4CodecData.bpeVocab
        this.nMergedSpaces = if (isCodex) 24 else 0
        this.nVocab = 50257 + this.nMergedSpaces
        this.decodings = HashMap()
        this.bpeRanks = HashMap, Int>()
        this.byteEncoder = HashMap()
        this.byteDecoder = HashMap()
        this.cache = HashMap()

        this.textEncoder = TextEncoder()
        this.textDecoder = TextDecoder()

        this.initialize()
    }

    private fun initialize() {
        if (this.vocab.length < 100) {
            throw Exception("Tokenizer vocab file did not load correctly")
        }

        val vocabLines = this.vocab.split("\n")
        val bpeMerges: List> = vocabLines
            .subList(1, vocabLines.size - 1)
            .map { line -> line.split(Regex("(\\s+)")).filter { part -> part.trim().length > 0 } }
            .map { list -> Pair(list[0], list[1]) }

        // add merged spaces for codex tokenizer
        if (this.nMergedSpaces > 0) {
            for (i in 1..this.nMergedSpaces) {
                for (j in 1..this.nMergedSpaces) {
                    if (i + j <= this.nMergedSpaces) {
                        bpeMerges.plus(Pair("\u0120".repeat(i), "\u0120".repeat(j)))
                    }
                }
            }

            for (i in 0..this.nMergedSpaces) {
                this.encodings["\u0120".repeat(i + 2)] = this.nVocab - this.nMergedSpaces + i
            }
        }

        for (key in this.encodings.keys) {
            this.decodings[encodings[key]!!] = key
        }

        this.byteEncoder = this.bytesToUnicode()

        this.byteEncoder.forEach { (key, value) ->
            this.byteDecoder[value] = key
        }

        this.zip(this.bpeRanks, bpeMerges, range(0, bpeMerges.size))
    }

    fun  zip(map: HashMap, first: List, second: List): HashMap {
        val length = first.size
        for (idx in 0 until length) {
            map[first[idx]] = second[idx]
        }
        return map
    }

    fun bytesToUnicode(): HashMap {
        val bs = (range(ord("!"), ord("~") + 1) +
                range(ord("\\xa1"), ord("\\xac") + 1) +
                range(ord("\\xae"), ord("\\xff") + 1)).toMutableList()

        val cs: MutableList = bs.toMutableList()
        var n = 0

        for (b in 0 until Math.pow(2.0, 8.0).toInt()) {
            if (!bs.contains(b)) {
                bs.add(b)
                cs.add(Math.pow(2.0, 8.0).toInt() + n)
                n = n + 1
            }
        }

        val csStr = cs.map { it -> chr(it) }

        val result = HashMap()
        zip(result, bs, csStr)
        return result
    }

    fun getPairs(word: List): Set> {
        val pairs = mutableSetOf>()
        var prevChar = word[0]

        for (i in 1 until word.size) {
            val char = word[i]
            pairs.add(Pair(prevChar, char))
            prevChar = char
        }

        return pairs
    }

    fun bpe(token: String): String {
        if (this.cache.containsKey(token)) {
            return this.cache[token]!!
        }

        var word = token.toCharArray().map { it.toString() }

        var pairs = this.getPairs(word)

        if (pairs.isEmpty()) {
            return token
        }

        while (true) {
            var minRank = Integer.MAX_VALUE
            var bigram: Pair = Pair("", "")
            for (pair in pairs) {
                val rank = this.bpeRanks[pair]
                val realRank = rank ?: Integer.MAX_VALUE
                if (realRank < minRank) {
                    bigram = pair
                    minRank = realRank
                }
            }

            if (!this.bpeRanks.containsKey(bigram)) {
                break
            }

            val first = bigram.first
            val second = bigram.second
            val newWord: MutableList = mutableListOf()
            var i = 0

            while (i < word.size) {
                val j = word.indexOf2(first, i)
                if (j == -1) {
                    newWord.addAll(word.subList(i, word.size))
                    break
                }
                newWord.addAll(word.subList(i, j))
                i = j

                if (word[i] == first && i < word.size - 1 && word[i + 1] == second) {
                    newWord.add(first + second)
                    i += 2
                } else {
                    newWord.add(word[i])
                    i += 1
                }
            }

            word = newWord
            if (word.size == 1) {
                break
            } else {
                pairs = this.getPairs(newWord.toMutableList())
            }
        }

        val finalWord = word.joinToString(separator = " ")
        this.cache[token] = finalWord

        return finalWord
    }

    fun encode(text: String): MutableList {
        val bpeTokens: MutableList = mutableListOf()
        val matches = bpeRegex.toRegex().findAll(text).flatMap { it.groupValues }.toList().toTypedArray()

        for (token in matches) {
            var newTokens = this.encodeCache[token]
            if (newTokens == null) {
                val joinToString = token.toCharArray()
                    .map { this.byteEncoder[it.toInt()] }
                    .joinToString(separator = "")
                val tokens = this.bpe(joinToString)
                newTokens = tokens
                    .split(" ")
                    .map { this.encodings[it]!! }
                this.encodeCache[token] = newTokens
            }
            for (i in newTokens.indices) {
                bpeTokens.add(newTokens[i])
            }
        }

        return bpeTokens
    }

    fun encodeUtf8(text: String): ByteArray {
        return this.textEncoder.encode(text)
    }

    fun decodeUtf8(bytes: ByteArray): String {
        return this.textDecoder.decode(bytes)
    }

    fun decode(tokens: List): String {
        val text = tokens.map { x -> this.decodings[x] }.joinToString(separator = "")
        return String(text.toCharArray().map { this.byteDecoder[it.toString()]?.toByte() ?: 0 }.toTypedArray().toByteArray(), Charset.forName("UTF-8"))
    }

    fun estimateTokenCount(input: String): Int {
        var count: Int = 0
        val matches = bpeRegex.toRegex().findAll(input).flatMap { it.groupValues }.toList().toTypedArray()
        for (token in matches) {
            var newToken = token.toCharArray()
                .map { this.byteEncoder[it.toInt()] }
                .joinToString(separator = "")
            val newTokens = this.bpe(newToken).split(" ")
            count += newTokens.size
        }
        return count
    }

    fun chunkText(text: String, maxTokensPerChunk: Int): MutableList> {
        val encoded = this.encode(text)
        val chunks: MutableList> = mutableListOf()
        for (i in encoded.indices step maxTokensPerChunk) {
            val chunk = encoded.subList(i, Math.min(i + maxTokensPerChunk, encoded.size))
            chunks.add(
                mapOf(
                    "text" to this.decode(chunk),
                    "bpe" to chunk
                )
            )
            // do whatever
        }
        return chunks
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy