All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.crf.DatasetReader.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.crf

import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream

import java.io.FileInputStream
import scala.collection.TraversableOnce
import scala.collection.mutable.ArrayBuffer
import scala.io.Source

case class TextSentenceLabels(labels: Seq[String])
case class TextSentenceAttrs(words: Seq[WordAttrs])
case class WordAttrs(strAttrs: Seq[(String, String)], numAttrs: Array[Float] = Array.empty)

object DatasetReader {

  private def getSource(file: String): Source = {
    if (file.endsWith(".gz")) {
      val fis = new FileInputStream(file)
      val zis = new GzipCompressorInputStream(fis)
      Source.fromInputStream(zis)
    } else {
      Source.fromFile(file)
    }
  }

  private def readWithLabels(
      file: String,
      skipLines: Int = 0): TraversableOnce[(TextSentenceLabels, TextSentenceAttrs)] = {
    val lines = getSource(file)
      .getLines()
      .drop(skipLines)

    var labels = new ArrayBuffer[String]()
    var tokens = new ArrayBuffer[WordAttrs]()

    def addToResultIfExists(): Option[(TextSentenceLabels, TextSentenceAttrs)] = {
      if (tokens.nonEmpty) {
        val result = (TextSentenceLabels(labels), TextSentenceAttrs(tokens))

        labels = new ArrayBuffer[String]()
        tokens = new ArrayBuffer[WordAttrs]()
        Some(result)
      } else {
        None
      }
    }

    lines.flatMap { line =>
      val words = line.split("\t")
      if (words.length <= 1) {
        addToResultIfExists()
      } else {
        val attrValues = words
          .drop(1)
          .map(feature => {
            val attrValue = feature.split("=")
            val attr = attrValue(0)
            val value = if (attrValue.size == 1) "" else attrValue(1)

            (attr, value)
          })

        tokens.append(WordAttrs(attrValues))
        labels.append(words.head)
        None
      }
    }
  }

  def encodeDataset(
      source: TraversableOnce[(TextSentenceLabels, TextSentenceAttrs)]): CrfDataset = {
    val metadata = new DatasetEncoder()

    val instances = source.map { case (textLabels, textSentence) =>
      var prevLabel = metadata.startLabel
      val (labels, features) = textLabels.labels
        .zip(textSentence.words)
        .map { case (label, word) =>
          val attrs = word.strAttrs.map(a => a._1 + "=" + a._2)
          val (labelId, features) =
            metadata.getFeatures(prevLabel, label, attrs, word.numAttrs)
          prevLabel = label

          (labelId, features)
        }
        .unzip

      (InstanceLabels(labels), Instance(features))
    }.toArray

    CrfDataset(instances, metadata.getMetadata)
  }

  private def encodeLabels(
      labels: TextSentenceLabels,
      metadata: DatasetMetadata): InstanceLabels = {
    val labelIds = labels.labels.map(text => metadata.label2Id.getOrElse(text, -1))
    InstanceLabels(labelIds)
  }

  def encodeSentence(sentence: TextSentenceAttrs, metadata: DatasetMetadata): Instance = {
    val items = sentence.words.map { word =>
      val strAttrs = word.strAttrs
        .flatMap { case (name, value) =>
          val key = name + "=" + value
          metadata.attr2Id.get(key)
        }
        .map((_, 1f))

      val numAttrs = word.numAttrs.zipWithIndex.flatMap { case (value, idx) =>
        val key = "num" + idx
        val attr = metadata.attr2Id.get(key)
        attr.map(attrName => (attrName, value))
      }

      val id2value = strAttrs ++ numAttrs

      val attrValues = id2value.sortBy(id => id._1).distinct.toArray
      new SparseArray(attrValues)
    }

    Instance(items)
  }

  def readAndEncode(file: String, skipLines: Int): CrfDataset = {
    val textDataset = readWithLabels(file, skipLines)

    encodeDataset(textDataset)
  }

  def readAndEncode(
      file: String,
      skipLines: Int,
      metadata: DatasetMetadata): TraversableOnce[(InstanceLabels, Instance)] = {
    val textDataset = readWithLabels(file, skipLines)

    textDataset.map { case (sourceLabels, sourceInstance) =>
      val labels = encodeLabels(sourceLabels, metadata)
      val instance = encodeSentence(sourceInstance, metadata)
      (labels, instance)
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy