All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.crf.DatasetMetadata.scala Maven / Gradle / Ivy

package com.johnsnowlabs.ml.crf

import com.johnsnowlabs.nlp.annotators.param.{SerializedAnnotatorComponent, WritableAnnotatorComponent}

case class AttrStat(frequency: Int, sum: Float)


class DatasetMetadata
(
  // labelId -> Label Name
  val labels: Array[String],

  // AttrId -> Attr Metadata
  val attrs: Array[Attr],

  // FeatureId -> AttrFeature
  val attrFeatures: Array[AttrFeature],

  // All possible transitions according to train dataset
  val transitions: Array[Transition],

  // FeatureId -> (attr, label) statistics
  val featuresStat: Array[AttrStat]

) extends WritableAnnotatorComponent {

  require(attrFeatures.length + transitions.length == featuresStat.length,
    s"Number of features ${featuresStat.length} should be equal to number of attr features ${attrFeatures.length}" +
      s" plus number of transition features ${transitions.length}")

  for (i <- attrs.indices)
    require(attrs(i).id == i, s"Attribute ${attrs(i)} stored at index $i that does not equal to id")

  for (i <- attrFeatures.indices)
    require(attrFeatures(i).id == i, s"Feature ${attrFeatures(i)} stored at index $i that does not equal to id")

  // Label Name -> Label Id
  lazy val label2Id: Map[String, Int] = labels.zipWithIndex.toMap

  // Attr Name -> Attr Id
  lazy val attr2Id: Map[String, Int] = attrs.map(attr => (attr.name, attr.id)).toMap

  // (Attr Id, Label Id) -> Feature Id
  lazy val attrFeatures2Id: Map[(Int, Int), Int] = attrFeatures.map(f => ((f.attrId, f.label), f.id)).toMap

  // Transition -> Feature Id
  lazy val transFeature2Id: Map[Transition, Int] = {
    transitions
      .zipWithIndex
      .map(p => (p._1, p._2 + attrFeatures.length))
      .toMap
  }

  // Attr Id -> List of AttrFeatures
  lazy val attr2Features: IndexedSeq[Array[AttrFeature]] = {
    val attr2FeaturesMap = attrFeatures
      .groupBy(f => f.attrId)
      .toSeq
      .sortBy(p => p._1)
      .toMap

    (0 until attrs.length).map{ aId =>
      attr2FeaturesMap.getOrElse(aId, Array.empty)
    }
  }

  override def serialize: SerializedAnnotatorComponent[_ <: WritableAnnotatorComponent] = {
    SerializedDatasetMetadata(
      labels.toList,
      attrs.toList,
      attrFeatures.toList,
      transitions.toList,
      featuresStat.toList
    )
  }

  /**
    * Leaves only features that in featureIds list
    * @param featureIds - feature ids to leave
   */
  def filterFeatures(featureIds: Seq[Int]): DatasetMetadata = {
    val (attrFeaturesIds, transFeaturesIds) = featureIds.partition(id => id < attrFeatures.length)
    val filteredAttrFeatures = attrFeaturesIds
      .map(id => attrFeatures(id))
      .zipWithIndex
      .map{case(oldAttr, idx) => AttrFeature(idx, oldAttr.attrId, oldAttr.label)}
      .toArray

    val filteredTransFeatures = transFeaturesIds.map(id => transitions(id - attrFeatures.length)).toArray
    val filteredStat = featureIds.map(id => featuresStat(id)).toArray

    new DatasetMetadata(labels, attrs, filteredAttrFeatures, filteredTransFeatures, filteredStat)
  }
}

case class SerializedDatasetMetadata
(
  labels: Seq[String],
  attrs: Seq[Attr],
  attrFeatures: Seq[AttrFeature],
  transitions: Seq[Transition],
  featuresStat: Seq[AttrStat]
)
  extends SerializedAnnotatorComponent[DatasetMetadata]
{
  override def deserialize: DatasetMetadata = {
    new DatasetMetadata(
      labels.toArray,
      attrs.toArray,
      attrFeatures.toArray,
      transitions.toArray,
      featuresStat.toArray
    )
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy