All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tri.ai.text.chunks.StandardTextChunker.kt Maven / Gradle / Ivy

/*-
 * #%L
 * tri.promptfx:promptkt
 * %%
 * Copyright (C) 2023 - 2024 Johns Hopkins University Applied Physics Laboratory
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */
package tri.ai.text.chunks

import java.text.BreakIterator

/** Standard implementation of [TextChunker]. */
class StandardTextChunker(
    val maxChunkSize: Int = 1000
) : TextChunker {

    override fun chunk(doc: TextDocument) =
        TextSection(doc).chunkBySections(combineShortSections = true)

    //region CODE FROM TextDocumentSectioner.kt

    /** Basic chunking by splitting on whitespace. */
    fun chunkTextBySectionsSimple(text: String): List> {
        val sections = mutableListOf>()
        val words = text.split(Regex("\\s+"))
        val currentSection = StringBuilder()
        var currentIndex = 0
        var currentRangeStart = 0

        for (word in words) {
            if (currentSection.length + word.length + 1 <= maxChunkSize) {
                if (currentSection.isNotEmpty()) {
                    currentSection.append(' ')
                }
                currentSection.append(word)
                currentIndex += word.length + 1
            } else {
                sections.add(Pair(currentRangeStart until currentIndex, currentSection.toString()))

                currentSection.clear()
                currentSection.append(word)
                currentRangeStart = currentIndex + 1
                currentIndex += word.length + 1
            }
        }

        if (currentSection.isNotEmpty()) {
            sections.add(Pair(currentRangeStart until currentIndex, currentSection.toString()))
        }
        return sections
    }

    /** Chunk into sections by section breaks. Optionally combine shorter sections. */
    fun TextSection.chunkBySections(combineShortSections: Boolean): List {
        // return chunk if it's short enough
        if (combineShortSections && text.length <= maxChunkSize)
            return listOf(this)

        // break into sections and optionally concatenate short sections
        val sections = splitOnSections()
            .recombine(if (combineShortSections) maxChunkSize else 0)

        // split up any sections that are too long
        val result = mutableListOf()
        sections.forEach { section ->
            if (section.text.length <= maxChunkSize) {
                result += section
            } else {
                result += section.chunkByParagraphs()
            }
        }
        return result
    }

    fun TextSection.chunkByParagraphs(): List {
        // return chunk if it's short enough
        if (text.length <= maxChunkSize)
            return listOf(this)

        // break into paragraphs and concatenate short paragraphs
        val paragraphs = splitOnParagraphs()
            .recombine(maxChunkSize)

        // split up any paragraphs that are too long
        val result = mutableListOf()
        paragraphs.forEach { section ->
            if (section.text.length <= maxChunkSize) {
                result += section
            } else {
                result += section.splitOnSentences()
                    .recombine(maxChunkSize)
            }
        }
        return result
    }

    // TODO - make this find things like likely section headings
    private fun TextSection.splitOnSections() =
        chunkByDividers(listOf("\n\n\n", "\r\n\r\n\r\n", "\r\r\r", "\n\n", "\r\n\r\n", "\r\r"))

    private fun TextSection.splitOnParagraphs() =
        chunkByDividers(listOf("\n", "\r\n", "\r"))

    private fun TextSection.splitOnSentences(): List {
        val sentences = mutableListOf()
        val iterator = BreakIterator.getSentenceInstance()
        iterator.setText(text)

        var start = iterator.first()
        var end = iterator.next()

        while (end != BreakIterator.DONE) {
            val sentence = text.substring(start, end).trim()
            if (sentence.isNotEmpty()) {
                sentences.add(TextSection(doc, (range.first + start) until (range.first + end)))
            }
            start = end
            end = iterator.next()
        }

        return sentences
    }

    private fun TextSection.chunkByDividers(dividers: List): List {
        val pattern = dividers.joinToString(separator = "|") { Regex.escape(it) }
        val regex = Regex(pattern)
        val chunks = mutableListOf()
        var currentIndex = 0

        regex.findAll(text).forEach { matchResult ->
            val chunkEnd = matchResult.range.first
            if (chunkEnd > currentIndex) {
                chunks.add(
                    TextSection(doc, (range.first + currentIndex) .. (range.first + chunkEnd))
                )
            }
            currentIndex = matchResult.range.last + 1
        }

        if (currentIndex < text.length) {
            chunks.add(
                TextSection(doc, (range.first + currentIndex) until (range.first + text.length))
            )
        }

        return chunks.filter { it.text.isNotBlank() }
    }

    /** Recombines shorter chunks into longer chunks. Assumes list of contiguous chunks. */
    private fun List.recombine(maxChunkSize: Int) =
        chunkWhile { it.totalSize() <= maxChunkSize }
            .map { it.concatenate() }

    /** Total number of characters in chunks. Assumes list of contiguous chunks. */
    private fun List.totalSize() =
        last().range.last - first().range.first + 1

    /** Concatenates all chunks into one. Assumes list of contiguous chunks. */
    private fun List.concatenate() =
        TextSection(first().doc, first().range.first .. last().range.last)

    //endregion

}

/**
 * Chunks a list into sections, each of which are either size 1, or are the largest sublist for which the given predicate is true.
 * Example: [1, 2, 3, 4, 5, 6, 7] with op sum <= 6 -> [1, 2, 3] | [4] | [5] | [6] | [7]
 */
fun  List.chunkWhile(op: (List) -> Boolean): List> {
    val result = mutableListOf>()
    var current = mutableListOf()
    for (item in this) {
        current.add(item)
        if (!op(current)) {
            current = if (current.size == 1) {
                result.add(current)
                mutableListOf()
            } else {
                result.add(current.dropLast(1))
                current.takeLast(1).toMutableList()
            }
        }
    }
    if (current.isNotEmpty()) {
        result.add(current)
    }
    return result
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy