All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.crf.DatasetReader.scala Maven / Gradle / Ivy

package com.johnsnowlabs.ml.crf

import java.io.FileInputStream

import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream

import scala.collection.TraversableOnce
import scala.collection.mutable.ArrayBuffer
import scala.io.Source


case class TextSentenceLabels(labels: Seq[String])
case class TextSentenceAttrs(words: Seq[WordAttrs])
case class WordAttrs(strAttrs: Seq[(String, String)], numAttrs: Array[Float] = Array.empty)


object DatasetReader {

  private def getSource(file: String): Source = {
    if (file.endsWith(".gz")) {
      val fis = new FileInputStream(file)
      val zis = new GzipCompressorInputStream(fis)
      Source.fromInputStream(zis)
    } else {
      Source.fromFile(file)
    }
  }

  private def readWithLabels(file: String, skipLines: Int = 0): TraversableOnce[(TextSentenceLabels, TextSentenceAttrs)] = {
    val lines = getSource(file)
      .getLines()
      .drop(skipLines)

    var labels = new ArrayBuffer[String]()
    var tokens = new ArrayBuffer[WordAttrs]()

    def addToResultIfExists(): Option[(TextSentenceLabels, TextSentenceAttrs)] = {
      if (tokens.nonEmpty) {
        val result = (TextSentenceLabels(labels), TextSentenceAttrs(tokens))

        labels = new ArrayBuffer[String]()
        tokens = new ArrayBuffer[WordAttrs]()
        Some(result)
      }
      else {
        None
      }
    }

    lines.flatMap{line =>
      val words = line.split("\t")
      if (words.length <= 1) {
        addToResultIfExists()
      } else {
        val attrValues = words
          .drop(1)
          .map(feature => {
            val attrValue = feature.split("=")
            val attr = attrValue(0)
            val value = if (attrValue.size == 1) "" else attrValue(1)

            (attr, value)
          })

        tokens.append(WordAttrs(attrValues))
        labels.append(words.head)
        None
      }
    }
  }

  def encodeDataset(source: TraversableOnce[(TextSentenceLabels, TextSentenceAttrs)]): CrfDataset = {
    val metadata = new DatasetEncoder()

    val instances = source.map{case (textLabels, textSentence) =>
      var prevLabel = metadata.startLabel
      val (labels, features) = textLabels.labels.zip(textSentence.words)
        .map{case (label, word) =>
          val attrs = word.strAttrs.map(a => a._1 + "=" + a._2)
          val (labelId, features) = metadata.getFeatures(prevLabel, label, attrs, word.numAttrs)
          prevLabel = label

          (labelId, features)
        }.unzip

      (InstanceLabels(labels), Instance(features))
    }.toArray

    CrfDataset(instances, metadata.getMetadata)
  }

  private def encodeLabels(labels: TextSentenceLabels, metadata: DatasetMetadata): InstanceLabels = {
    val labelIds = labels.labels.map(text => metadata.label2Id.getOrElse(text, -1))
    InstanceLabels(labelIds)
  }

  def encodeSentence(sentence: TextSentenceAttrs, metadata: DatasetMetadata): Instance = {
    val items = sentence.words.map{word =>
      val strAttrs = word.strAttrs.flatMap { case (name, value) =>
        val key = name + "=" + value
        metadata.attr2Id.get(key)
      }.map((_, 1f))

      val numAttrs = word.numAttrs.zipWithIndex.flatMap {case(value, idx) =>
        val key = "num" + idx
        val attr = metadata.attr2Id.get(key)
        attr.map(attrName => (attrName, value))
      }

      val id2value = strAttrs ++ numAttrs

      val attrValues = id2value.sortBy(id => id._1).distinct.toArray
      new SparseArray(attrValues)
    }

    Instance(items)
  }

  def readAndEncode(file: String, skipLines: Int): CrfDataset = {
    val textDataset = readWithLabels(file, skipLines)

    encodeDataset(textDataset)
  }

  def readAndEncode(file: String, skipLines: Int, metadata: DatasetMetadata): TraversableOnce[(InstanceLabels, Instance)] = {
    val textDataset = readWithLabels(file, skipLines)

    textDataset.map{case (sourceLabels, sourceInstance) =>
      val labels = encodeLabels(sourceLabels, metadata)
      val instance = encodeSentence(sourceInstance, metadata)
      (labels, instance)
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy