All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tri.ai.text.chunks.process.SmartTextChunker.kt Maven / Gradle / Ivy

/*-
 * #%L
 * tri.promptfx:promptkt
 * %%
 * Copyright (C) 2023 - 2024 Johns Hopkins University Applied Physics Laboratory
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */
package tri.ai.text.chunks.process

import tri.ai.text.chunks.TextChunkInDoc
import tri.ai.text.chunks.TextChunk
import tri.ai.text.chunks.TextChunkRaw
import tri.util.fine
import java.text.BreakIterator

/**
 * Smart implementation of [TextChunker], attempting to automatically split up text into reasonable sections based on
 * natural breaks.
 */
class SmartTextChunker(
    /** Maximum chunk size in characters. */
    val maxChunkSize: Int = 1000
) : TextChunker {

    override fun chunk(doc: TextChunkRaw) =
        doc.chunkBySections(combineShortSections = true)

    //region CODE FROM TextDocumentSectioner.kt

    /** Basic chunking by splitting on whitespace. */
    fun chunkTextBySectionsSimple(text: String): List> {
        val sections = mutableListOf>()
        val words = text.split(Regex("\\s+"))
        val currentSection = StringBuilder()
        var currentIndex = 0
        var currentRangeStart = 0

        for (word in words) {
            if (currentSection.length + word.length + 1 <= maxChunkSize) {
                if (currentSection.isNotEmpty()) {
                    currentSection.append(' ')
                }
                currentSection.append(word)
                currentIndex += word.length + 1
            } else {
                sections.add(Pair(currentRangeStart until currentIndex, currentSection.toString()))

                currentSection.clear()
                currentSection.append(word)
                currentRangeStart = currentIndex + 1
                currentIndex += word.length + 1
            }
        }

        if (currentSection.isNotEmpty()) {
            sections.add(Pair(currentRangeStart until currentIndex, currentSection.toString()))
        }
        return sections
    }

    //endregion

    /** Chunk into sections by section breaks. Optionally combine shorter sections. */
    fun TextChunkRaw.chunkBySections(combineShortSections: Boolean): List {
        // return chunk if it's short enough
        if (combineShortSections && text.length <= maxChunkSize)
            return listOf(TextChunkInDoc(text.indices))

        // break into sections and optionally concatenate short sections
        val all = TextChunkInDoc(text.indices)
        val sections = all.splitOnSections(this)
            .recombine(if (combineShortSections) maxChunkSize else 0)

        // split up any sections that are too long
        val result = mutableListOf()
        sections.forEach { section ->
            if (section.text(this).length <= maxChunkSize) {
                result += section
            } else {
                result += section.chunkByParagraphs(this)
            }
        }

        // log chunks
        result.take(10).forEach { chunk ->
            fine("  ${(chunk as TextChunkInDoc).range} ${chunk.text(this).firstFiftyChars()}")
        }

        return result
    }

    private fun String.firstFiftyChars() = replace("[\r\n]+".toRegex(), " ").let {
        if (it.length <= 50) it.trim() else (it.substring(0, 50).trim()+"...")
    }

    fun TextChunkInDoc.chunkByParagraphs(doc: TextChunkRaw): List {
        // return chunk if it's short enough
        if (text(doc).length <= maxChunkSize)
            return listOf(this)

        // break into paragraphs and concatenate short paragraphs
        val paragraphs = splitOnParagraphs(doc)
            .recombine(maxChunkSize)

        // split up any paragraphs that are too long
        val result = mutableListOf()
        paragraphs.forEach { section ->
            if (section.text(doc).length <= maxChunkSize) {
                result += section
            } else {
                result += section.splitOnSentences(doc)
                    .recombine(maxChunkSize)
            }
        }
        return result
    }

    // TODO - make this find things like likely section headings
    private fun TextChunkInDoc.splitOnSections(doc: TextChunkRaw) =
        chunkByDividers(doc, listOf("\n\n\n", "\r\n\r\n\r\n", "\r\r\r", "\n\n", "\r\n\r\n", "\r\r"))

    private fun TextChunkInDoc.splitOnParagraphs(doc: TextChunkRaw) =
        chunkByDividers(doc, listOf("\n", "\r\n", "\r"))

    private fun TextChunkInDoc.splitOnSentences(doc: TextChunkRaw): List {
        val sentences = mutableListOf()
        val iterator = BreakIterator.getSentenceInstance()
        iterator.setText(text(doc))

        var start = iterator.first()
        var end = iterator.next()

        while (end != BreakIterator.DONE) {
            val sentence = doc.text.substring(start, end).trim()
            if (sentence.isNotEmpty()) {
                sentences.add(TextChunkInDoc(range.first + start, range.first + end - 1))
            }
            start = end
            end = iterator.next()
        }

        return sentences
    }

    private fun TextChunkInDoc.chunkByDividers(doc: TextChunkRaw, dividers: List): List {
        val pattern = dividers.joinToString(separator = "|") { Regex.escape(it) }
        val regex = Regex(pattern)
        val chunks = mutableListOf()
        var currentIndex = 0

        val text = text(doc)
        regex.findAll(text).forEach { matchResult ->
            val chunkEnd = matchResult.range.first
            if (chunkEnd > currentIndex) {
                chunks.add(
                    TextChunkInDoc(range.first + currentIndex, range.first + chunkEnd)
                )
            }
            currentIndex = matchResult.range.last + 1
        }

        if (currentIndex < doc.text.length) {
            chunks.add(
                TextChunkInDoc(range.first + currentIndex, range.first + text.length - 1)
            )
        }

        return chunks.filter { it.text(doc).isNotBlank() }
    }

    //endregion

    companion object {
        /** Recombines shorter chunks into longer chunks. Assumes list of contiguous chunks. */
        private fun List.recombine(maxChunkSize: Int) =
            chunkWhile { it.totalSize() <= maxChunkSize }
                .map { it.concatenate() }

        /** Total number of characters in chunks. */
        private fun List.totalSize() = filterIsInstance().let {
            val min = it.minOfOrNull { it.range.first } ?: 0
            val max = it.maxOfOrNull { it.range.last } ?: 0
            max - min + 1
        }

        /** Concatenates all chunks into one. */
        private fun List.concatenate() = filterIsInstance().let {
            require(it.isNotEmpty())
            TextChunkInDoc(it.minOfOrNull { it.range.first } ?: 0, it.maxOfOrNull { it.range.last } ?: 0)
        }

        /**
         * Chunks a list into sections, each of which are either size 1, or are the largest sublist for which the given predicate is true.
         * Example: [1, 2, 3, 4, 5, 6, 7] with op sum <= 6 -> [1, 2, 3] | [4] | [5] | [6] | [7]
         */
        fun  List.chunkWhile(op: (List) -> Boolean): List> {
            val result = mutableListOf>()
            var current = mutableListOf()
            for (item in this) {
                current.add(item)
                if (!op(current)) {
                    current = if (current.size == 1) {
                        result.add(current)
                        mutableListOf()
                    } else {
                        result.add(current.dropLast(1))
                        current.takeLast(1).toMutableList()
                    }
                }
            }
            if (current.isNotEmpty()) {
                result.add(current)
            }
            return result
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy