All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.collections.StorageSearchTrie.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.collections

import com.johnsnowlabs.nlp.Annotation
import com.johnsnowlabs.nlp.annotators.TokenizerModel
import com.johnsnowlabs.nlp.annotators.btm._
import com.johnsnowlabs.storage.{Database, StorageWriter}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

/** Immutable Collection that used for fast substring search Implementation of Aho-Corasick
  * algorithm https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm
  */
class StorageSearchTrie(
    vocabReader: TMVocabReader,
    edgesReader: TMEdgesReader,
    nodesReader: TMNodesReader) {

  /** Searchs phrases in the text
    * @param text
    *   test to search in
    * @return
    *   Iterator with pairs of (begin, end)
    */
  def search(text: Seq[String]): Seq[(Int, Int)] = {
    var nodeId = 0
    val result = new ArrayBuffer[(Int, Int)]()

    def addResultIfNeed(nodeId: Int, index: Int): Unit = {
      var currentId = nodeId

      while (currentId >= 0) {
        val node = nodesReader.lookup(currentId)
        if (node.isLeaf)
          result.append((index - node.length + 1, index))

        currentId = node.lastLeaf
      }
    }

    for ((word, index) <- text.zipWithIndex) {
      val wordId = vocabReader.lookup(word).getOrElse(vocabReader.emptyValue)
      if (wordId < 0) {
        nodeId = 0
      } else {
        var found = false

        while (nodeId > 0 && !found) {
          val newId = edgesReader.lookup((nodeId, wordId)).getOrElse(edgesReader.emptyValue)
          if (newId < 0) {
            nodeId = nodesReader.lookup(nodeId).pi
          } else {
            nodeId = newId
            addResultIfNeed(nodeId, index)
            found = true
          }
        }

        if (!found) {
          nodeId = edgesReader.lookup((nodeId, wordId)).getOrElse(0)
          addResultIfNeed(nodeId, index)
        }
      }
    }

    result
  }
}

object StorageSearchTrie {
  def load(
      inputFileLines: Iterator[String],
      writers: Map[Database.Name, StorageWriter[_]],
      withTokenizer: Option[TokenizerModel]): Unit = {

    // Have only root at the beginning
    val vocabrw = writers(Database.TMVOCAB).asInstanceOf[TMVocabReadWriter]
    var vocabSize = 0

    val edgesrw = writers(Database.TMEDGES).asInstanceOf[TMEdgesReadWriter]

    val nodesrw = writers(Database.TMNODES).asInstanceOf[TMNodesWriter]

    val parents = mutable.ArrayBuffer(0)
    val parentWord = mutable.ArrayBuffer(0)

    val isLeaf = mutable.ArrayBuffer(false)
    val length = mutable.ArrayBuffer(0)

    def vocabUpdate(w: String): Int = {
      val r = vocabrw
        .lookup(w)
        .getOrElse({
          vocabrw.add(w, vocabSize)
          vocabSize
        })
      vocabSize += 1
      r
    }

    def addNode(parentNodeId: Int, wordId: Int): Int = {
      parents.append(parentNodeId)
      parentWord.append(wordId)
      length.append(length(parentNodeId) + 1)
      isLeaf.append(false)

      parents.length - 1
    }

    // Add every phrase as root from root in the tree
    for (line <- inputFileLines) {
      val phrase = withTokenizer match {
        case Some(tokenizerModel) =>
          val annotation = Seq(Annotation(line))
          tokenizerModel.annotate(annotation).map(_.result).toArray
        case _ => line.split(" ")
      }

      var nodeId = 0

      for (word <- phrase) {
        val wordId = vocabUpdate(word)
        nodeId = edgesrw
          .lookup((nodeId, wordId))
          .getOrElse({
            val r = addNode(nodeId, wordId)
            edgesrw.add((nodeId, wordId), r)
            r
          })
      }

      if (nodeId > 0)
        isLeaf(nodeId) = true
    }

    // Calculate pi function
    val piCalculated = Array.fill[Boolean](parents.size)(false)
    val pi = Array.fill[Int](parents.size)(0)

    def calcPi(v: Int): Int = {
      if (piCalculated(v))
        return pi(v)

      if (v == 0) {
        piCalculated(v) = true
        pi(v) = 0
        return 0
      }

      val wordId = parentWord(v)
      var candidate = parents(v)

      while (candidate > 0) {
        candidate = calcPi(candidate)
        val answer = edgesrw.lookup((candidate, wordId)).getOrElse(0)
        if (answer > 0) {
          pi(v) = answer
          candidate = 0
        }
      }

      piCalculated(v) = true
      pi(v)
    }

    val lastLeaf = Array.fill[Int](parents.size)(-1)
    val lastLeafCalculated = Array.fill[Boolean](parents.size)(false)

    def calcLastLeaf(v: Int): Int = {
      if (lastLeafCalculated(v))
        return lastLeaf(v)

      if (v == 0) {
        lastLeafCalculated(v) = true
        lastLeaf(v) = -1
        return -1
      }

      val piNode = pi(v)
      if (isLeaf(piNode))
        lastLeaf(v) = piNode
      else
        lastLeaf(v) = calcLastLeaf(piNode)

      lastLeafCalculated(v) = true
      lastLeaf(v)
    }

    for (i <- parents.indices) {
      calcPi(i)
      calcLastLeaf(i)
    }

    pi.zip(isLeaf)
      .zip(length)
      .zip(lastLeaf)
      .zipWithIndex
      .foreach { case ((((a, b), c), d), i) => nodesrw.add(i, TrieNode(a, b, c, d)) }

    vocabrw.close()
    edgesrw.close()
    nodesrw.close()

  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy