All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.reader.HTMLReader.scala Maven / Gradle / Ivy

There is a newer version: 6.0.3
Show newest version
/*
 * Copyright 2017-2024 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.johnsnowlabs.reader

import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.nlp.util.io.ResourceHelper.{isValidURL, validFile}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, lit, udf}
import org.jsoup.Jsoup
import org.jsoup.nodes.{Document, Element, Node, TextNode}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

class HTMLReader(titleFontSize: Int = 16, storeContent: Boolean = false) extends Serializable {

  private val spark = ResourceHelper.spark
  import spark.implicits._

  def read(inputSource: String): DataFrame = {

    ResourceHelper match {
      case _ if validFile(inputSource) && !inputSource.startsWith("http") =>
        val htmlDf = spark.sparkContext
          .wholeTextFiles(inputSource)
          .toDF("path", "content")
          .withColumn("html", parseHtmlUDF(col("content")))
        if (storeContent) htmlDf.select("path", "content", "html")
        else htmlDf.select("path", "html")
      case _ if isValidURL(inputSource) =>
        val htmlDf = spark
          .createDataset(Seq(inputSource))
          .toDF("url")
          .withColumn("html", parseURLUDF(col("url")))
        if (storeContent) htmlDf.select("url", "content", "html")
        else htmlDf.select("url", "html")
      case _ =>
        throw new IllegalArgumentException(s"Invalid inputSource: $inputSource")
    }
  }

  def read(inputURLs: Array[String]): DataFrame = {
    val spark = ResourceHelper.spark
    import spark.implicits._

    val validURLs = inputURLs.filter(url => ResourceHelper.isValidURL(url)).toSeq
    spark
      .createDataset(validURLs)
      .toDF("url")
      .withColumn("html", parseURLUDF(col("url")))
  }

  private val parseHtmlUDF = udf((html: String) => {
    val document = Jsoup.parse(html)
    startTraversalFromBody(document)
  })

  private val parseURLUDF = udf((url: String) => {
    val document = Jsoup.connect(url).get()
    startTraversalFromBody(document)
  })

  private def startTraversalFromBody(document: Document): Array[HTMLElement] = {
    val body = document.body()
    extractElements(body)
  }

  private case class NodeMetadata(tagName: Option[String], hidden: Boolean, var visited: Boolean)

  private def extractElements(root: Node): Array[HTMLElement] = {
    val elements = ArrayBuffer[HTMLElement]()
    val trackingNodes = mutable.Map[Node, NodeMetadata]()
    var pageNumber = 1

    def isNodeHidden(node: Node): Boolean = {
      node match {
        case elem: Element =>
          val style = elem.attr("style").toLowerCase
          val isHiddenByStyle =
            style.contains("display:none") || style.contains("visibility:hidden")
          val isHiddenByAttribute = elem.hasAttr("hidden") || elem.attr("aria-hidden") == "true"
          isHiddenByStyle || isHiddenByAttribute
        case _ => false
      }
    }

    def collectTextFromNodes(nodes: List[Node]): String = {
      val textBuffer = ArrayBuffer[String]()

      def traverseAndCollect(node: Node): Unit = {
        val isHiddenNode = trackingNodes
          .getOrElseUpdate(
            node,
            NodeMetadata(tagName = getTagName(node), hidden = isNodeHidden(node), visited = true))
          .hidden
        if (!isHiddenNode) {
          node match {
            case textNode: TextNode =>
              trackingNodes(textNode).visited = true
              val text = textNode.text().trim
              if (text.nonEmpty) textBuffer += text

            case elem: Element =>
              trackingNodes(elem).visited = true
              val text = elem.ownText().trim
              if (text.nonEmpty) textBuffer += text
              // Recursively collect text from all child nodes
              elem.childNodes().asScala.foreach(traverseAndCollect)

            case _ => // Ignore other node types
          }
        }
      }

      // Start traversal for each node in the list
      nodes.foreach(traverseAndCollect)
      textBuffer.mkString(" ").replaceAll("\\s+", " ").trim
    }

    def traverse(node: Node, tagName: Option[String]): Unit = {
      trackingNodes.getOrElseUpdate(
        node,
        NodeMetadata(tagName = tagName, hidden = isNodeHidden(node), visited = false))

      node.childNodes().forEach { childNode =>
        trackingNodes.getOrElseUpdate(
          childNode,
          NodeMetadata(tagName = tagName, hidden = isNodeHidden(childNode), visited = false))
      }

      if (trackingNodes(node).hidden) {
        return
      }

      node match {
        case element: Element =>
          val visitedNode = trackingNodes(element).visited
          val pageMetadata: mutable.Map[String, String] =
            mutable.Map("pageNumber" -> pageNumber.toString)
          element.tagName() match {
            case "a" =>
              val href = element.attr("href").trim
              val linkText = element.text().trim
              if (href.nonEmpty && linkText.nonEmpty && !visitedNode) {
                trackingNodes(element).visited = true
                elements += HTMLElement(
                  ElementType.LINK,
                  content = s"[$linkText]($href)",
                  metadata = pageMetadata)
              }
            case "table" =>
              val tableText = extractNestedTableContent(element).trim
              if (tableText.nonEmpty && !visitedNode) {
                trackingNodes(element).visited = true
                elements += HTMLElement(
                  ElementType.TABLE,
                  content = tableText,
                  metadata = pageMetadata)
              }
            case "p" =>
              if (!visitedNode) {
                classifyParagraphElement(element) match {
                  case ElementType.NARRATIVE_TEXT =>
                    trackingNodes(element).visited = true
                    val childNodes = element.childNodes().asScala.toList
                    val aggregatedText = collectTextFromNodes(childNodes)
                    if (aggregatedText.nonEmpty) {
                      elements += HTMLElement(
                        ElementType.NARRATIVE_TEXT,
                        content = aggregatedText,
                        metadata = pageMetadata)
                    }
                  case ElementType.TITLE =>
                    trackingNodes(element).visited = true
                    val titleText = element.text().trim
                    if (titleText.nonEmpty) {
                      elements += HTMLElement(
                        ElementType.TITLE,
                        content = titleText,
                        metadata = pageMetadata)
                    }
                  case ElementType.UNCATEGORIZED_TEXT =>
                    trackingNodes(element).visited = true
                    val titleText = element.text().trim
                    if (titleText.nonEmpty) {
                      elements += HTMLElement(
                        ElementType.UNCATEGORIZED_TEXT,
                        content = titleText,
                        metadata = pageMetadata)
                    }
                }
              }
            case _ if isTitleElement(element) && !visitedNode =>
              trackingNodes(element).visited = true
              val titleText = element.text().trim
              if (titleText.nonEmpty) {
                elements += HTMLElement(
                  ElementType.TITLE,
                  content = titleText,
                  metadata = pageMetadata)
              }
            case "hr" =>
              if (element.attr("style").toLowerCase.contains("page-break")) {
                pageNumber = pageNumber + 1
              }
            case _ =>
              element.childNodes().asScala.foreach { childNode =>
                val tagName = getTagName(childNode)
                traverse(childNode, tagName)
              }
          }
        case _ => // Ignore other node types
      }
    }

    // Start traversal from the root node
    val tagName = getTagName(root)
    traverse(root, tagName)
    elements.toArray
  }

  private def getTagName(node: Node): Option[String] = {
    node match {
      case element: Element => Some(element.tagName())
      case _ => None
    }
  }

  private def classifyParagraphElement(element: Element): String = {
    if (isTitleElement(element)) {
      ElementType.TITLE
    } else if (isTextElement(element)) {
      ElementType.NARRATIVE_TEXT
    } else {
      ElementType.UNCATEGORIZED_TEXT
    }
  }

  private def isTextElement(elem: Element): Boolean = {
    !isFormattedAsTitle(elem) &&
    (elem.attr("style").toLowerCase.contains("text") || elem.tagName().toLowerCase == "p")
  }

  private def isTitleElement(elem: Element): Boolean = {
    val tag = elem.tagName().toLowerCase

    // Recognize titles from common title-related tags or formatted 

elements tag match { case "title" | "h1" | "h2" | "h3" | "header" => true case "p" => isFormattedAsTitle(elem) // Check if

behaves like a title case _ => elem.attr("role").toLowerCase == "heading" // ARIA role="heading" } } private def isFormattedAsTitle(elem: Element): Boolean = { // Check for bold text, large font size, or centered alignment val style = elem.attr("style").toLowerCase val isBold = style.contains("font-weight:bold") val isLargeFont = style.contains("font-size") && extractFontSize(style) >= titleFontSize val isCentered = style.contains("text-align:center") isBold || isLargeFont || (isCentered && isBold) || (isCentered && isLargeFont) } private def extractFontSize(style: String): Int = { val sizePattern = """font-size:(\d+)pt""".r sizePattern.findFirstMatchIn(style) match { case Some(m) => m.group(1).toInt case None => 0 } } private def extractNestedTableContent(elem: Element): String = { val textBuffer = ArrayBuffer[String]() val processedElements = mutable.Set[Node]() // Set to track processed elements // Recursive function to collect text from the element and its children def collectText(node: Node): Unit = { node match { case childElem: Element => if (!processedElements.contains(childElem)) { processedElements += childElem val directText = childElem.ownText().trim if (directText.nonEmpty) textBuffer += directText childElem.childNodes().asScala.foreach(collectText) } case _ => // Ignore other node types } } // Start the recursive text collection collectText(elem) textBuffer.mkString(" ") } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy