All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.serialization.Feature.scala Maven / Gradle / Ivy

package com.johnsnowlabs.nlp.serialization

import com.github.liblevenshtein.proto.LibLevenshteinProtos.DawgNode
import com.github.liblevenshtein.serialization.PlainTextSerializer
import com.github.liblevenshtein.transducer.{Candidate, ITransducer, Transducer}
import com.johnsnowlabs.nlp.HasFeatures
import com.johnsnowlabs.nlp.annotators.spell.context.parser.SpecialClassParser
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.util.ConfigLoader
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}

import scala.reflect.ClassTag

abstract class Feature[Serializable1, Serializable2, TComplete: ClassTag](model: HasFeatures, val name: String) extends Serializable {
  model.features.append(this)

  private val config = ConfigLoader.retrieve
  private val spark = ResourceHelper.spark

  val serializationMode: String = config.getString("sparknlp.settings.annotatorSerializationFormat")
  val useBroadcast: Boolean = config.getBoolean("sparknlp.settings.useBroadcastForFeatures")

  final protected var broadcastValue: Option[Broadcast[TComplete]] = None

  final protected var rawValue: Option[TComplete] = None
  final protected var fallbackRawValue: Option[TComplete] = None

  final protected var fallbackLazyValue: Option[() => TComplete] = None

  final def serialize(spark: SparkSession, path: String, field: String, value: TComplete): Unit = {
    serializationMode match {
      case "dataset" => serializeDataset(spark, path, field, value)
      case "object" => serializeObject(spark, path, field, value)
      case _ => throw new IllegalArgumentException("Illegal performance.serialization setting. Can be 'dataset' or 'object'")
    }
  }

  final def serializeInfer(spark: SparkSession, path: String, field: String, value: Any): Unit =
    serialize(spark, path, field, value.asInstanceOf[TComplete])

  final def deserialize(spark: SparkSession, path: String, field: String): Option[_] = {
    if (broadcastValue.isDefined || rawValue.isDefined)
      throw new Exception(s"Trying de deserialize an already set value for ${this.name}. This should not happen.")
    serializationMode match {
      case "dataset" => deserializeDataset(spark, path, field)
      case "object" => deserializeObject(spark, path, field)
      case _ => throw new IllegalArgumentException("Illegal performance.serialization setting. Can be 'dataset' or 'object'")
    }
  }

  protected def serializeDataset(spark: SparkSession, path: String, field: String, value: TComplete): Unit

  protected def deserializeDataset(spark: SparkSession, path: String, field: String): Option[_]

  protected def serializeObject(spark: SparkSession, path: String, field: String, value: TComplete): Unit

  protected def deserializeObject(spark: SparkSession, path: String, field: String): Option[_]

  final protected def getFieldPath(path: String, field: String): Path =
    Path.mergePaths(new Path(path), new Path("/fields/" + field))

  private def callAndSetFallback: Option[TComplete] = {
    fallbackRawValue = fallbackLazyValue.map(_())
    fallbackRawValue
  }

  final def get: Option[TComplete] = {
    broadcastValue.map(_.value).orElse(rawValue)
  }

  final def orDefault: Option[TComplete] = {
    broadcastValue.map(_.value)
      .orElse(rawValue)
      .orElse(fallbackRawValue)
      .orElse(callAndSetFallback)
  }

  final def getOrDefault: TComplete = {
    orDefault
      .getOrElse(throw new Exception(s"feature $name is not set"))
  }

  final def setValue(value: Option[Any]): HasFeatures = {
    if (useBroadcast) {
      if (isSet) broadcastValue.get.destroy()
      broadcastValue = value.map(v => spark.sparkContext.broadcast[TComplete](v.asInstanceOf[TComplete]))
    } else {
      rawValue = value.map(_.asInstanceOf[TComplete])
    }
    model
  }

  def setFallback(v: Option[() => TComplete]): HasFeatures = {
    fallbackLazyValue = v
    model
  }

  final def isSet: Boolean = {
    broadcastValue.isDefined || rawValue.isDefined
  }

}

class StructFeature[TValue: ClassTag](model: HasFeatures, override val name: String)
  extends Feature[TValue, TValue, TValue](model, name) {

  implicit val encoder: Encoder[TValue] = Encoders.kryo[TValue]

  override def serializeObject(spark: SparkSession, path: String, field: String, value: TValue): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.sparkContext.parallelize(Seq(value)).saveAsObjectFile(dataPath.toString)
  }

  override def deserializeObject(spark: SparkSession, path: String, field: String): Option[TValue] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.sparkContext.objectFile[TValue](dataPath.toString).first)
    } else {
      None
    }
  }

  override def serializeDataset(spark: SparkSession, path: String, field: String, value: TValue): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.createDataset(Seq(value)).write.mode("overwrite").parquet(dataPath.toString)
  }

  override def deserializeDataset(spark: SparkSession, path: String, field: String): Option[TValue] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.read.parquet(dataPath.toString).as[TValue].first)
    } else {
      None
    }
  }

}

class MapFeature[TKey: ClassTag, TValue: ClassTag](model: HasFeatures, override val name: String)
  extends Feature[TKey, TValue, Map[TKey, TValue]](model, name) {

  implicit val encoder: Encoder[(TKey, TValue)] = Encoders.kryo[(TKey, TValue)]

  override def serializeObject(spark: SparkSession, path: String, field: String, value: Map[TKey, TValue]): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.sparkContext.parallelize(value.toSeq).saveAsObjectFile(dataPath.toString)
  }



  override def deserializeObject(spark: SparkSession, path: String, field: String): Option[Map[TKey, TValue]] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.sparkContext.objectFile[(TKey, TValue)](dataPath.toString).collect.toMap)
    } else {
      None
    }
  }

  override def serializeDataset(spark: SparkSession, path: String, field: String, value: Map[TKey, TValue]): Unit = {
    import spark.implicits._
    val dataPath = getFieldPath(path, field)
    value.toSeq.toDS.write.mode("overwrite").parquet(dataPath.toString)
  }



  override def deserializeDataset(spark: SparkSession, path: String, field: String): Option[Map[TKey, TValue]] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.read.parquet(dataPath.toString).as[(TKey, TValue)].collect.toMap)
    } else {
      None
    }
  }

}

class ArrayFeature[TValue: ClassTag](model: HasFeatures, override val name: String)
  extends Feature[TValue, TValue, Array[TValue]](model, name) {

  implicit val encoder: Encoder[TValue] = Encoders.kryo[TValue]

  override def serializeObject(spark: SparkSession, path: String, field: String, value: Array[TValue]): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.sparkContext.parallelize(value).saveAsObjectFile(dataPath.toString)
  }

  override def deserializeObject(spark: SparkSession, path: String, field: String): Option[Array[TValue]] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.sparkContext.objectFile[TValue](dataPath.toString).collect())
    } else {
      None
    }
  }

  override def serializeDataset(spark: SparkSession, path: String, field: String, value: Array[TValue]): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.createDataset(value).write.mode("overwrite").parquet(dataPath.toString)
  }

  override def deserializeDataset(spark: SparkSession, path: String, field: String): Option[Array[TValue]] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.read.parquet(dataPath.toString).as[TValue].collect)
    } else {
      None
    }
  }

}

class SetFeature[TValue: ClassTag](model: HasFeatures, override val name: String)
  extends Feature[TValue, TValue, Set[TValue]](model, name) {

  implicit val encoder: Encoder[TValue] = Encoders.kryo[TValue]

  override def serializeObject(spark: SparkSession, path: String, field: String, value: Set[TValue]): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.sparkContext.parallelize(value.toSeq).saveAsObjectFile(dataPath.toString)
  }

  override def deserializeObject(spark: SparkSession, path: String, field: String): Option[Set[TValue]] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.sparkContext.objectFile[TValue](dataPath.toString).collect().toSet)
    } else {
      None
    }
  }

  override def serializeDataset(spark: SparkSession, path: String, field: String, value: Set[TValue]): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.createDataset(value.toSeq).write.mode("overwrite").parquet(dataPath.toString)
  }

  override def deserializeDataset(spark: SparkSession, path: String, field: String): Option[Set[TValue]] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.read.parquet(dataPath.toString).as[TValue].collect.toSet)
    } else {
      None
    }
  }

}

class TransducerFeature(model: HasFeatures, override val name: String)
  extends Feature[ITransducer[Candidate], ITransducer[Candidate], ITransducer[Candidate]](model, name) {

  override def serializeObject(spark: SparkSession, path: String, field: String, trans: ITransducer[Candidate]): Unit = {
    val serializer = new PlainTextSerializer
    val dataPath = getFieldPath(path, field)
    val bytes = serializer.serialize(trans)
    spark.sparkContext.parallelize(bytes.toSeq, 1).saveAsObjectFile(dataPath.toString)

  }

  override def deserializeObject(spark: SparkSession, path: String, field: String): Option[ITransducer[Candidate]] = {
    val serializer = new PlainTextSerializer
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      val bytes = spark.sparkContext.objectFile[Byte](dataPath.toString).collect()
      val deserialized = serializer.deserialize(classOf[Transducer[DawgNode, Candidate]], bytes)
      Some(deserialized)
    } else {
      None
    }
  }


  override def serializeDataset(spark: SparkSession, path: String, field: String, trans: ITransducer[Candidate]): Unit = {
    val serializer = new PlainTextSerializer
    import spark.implicits._
    val dataPath = getFieldPath(path, field)
    val bytes = serializer.serialize(trans)
    spark.createDataset(bytes.toSeq).write.mode("overwrite").parquet(dataPath.toString)
  }

  override def deserializeDataset(spark: SparkSession, path: String, field: String): Option[ITransducer[Candidate]] = {
    val serializer = new PlainTextSerializer
    implicit val encoder: Encoder[Byte] = Encoders.kryo[Byte]
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      val bytes = spark.read.parquet(dataPath.toString).as[Byte].collect
      val deserialized = serializer.deserialize(classOf[Transducer[DawgNode, Candidate]], bytes)
      Some(deserialized)
    } else {
      None
    }
  }
}


class TransducerSeqFeature(model: HasFeatures, override val name: String)
  extends Feature[Seq[SpecialClassParser], Seq[SpecialClassParser], Seq[SpecialClassParser]](model, name) {

  implicit val encoder: Encoder[SpecialClassParser] = Encoders.kryo[SpecialClassParser]

  override def serializeObject(spark: SparkSession, path: String, field: String, specialClasses: Seq[SpecialClassParser]): Unit = {
    import spark.implicits._
    val dataPath = getFieldPath(path, field)
    val serializer = new PlainTextSerializer

    specialClasses.foreach { case specialClass =>

      // hadoop won't see files starting with '_'
      val label = specialClass.label.replaceAll("_", "-")

      val transducer = specialClass.transducer
      specialClass.setTransducer(null)
      // the object per se
      spark.sparkContext.parallelize(Seq(specialClass)).
        saveAsObjectFile(s"${dataPath.toString}/${label}")


      // we handle the transducer separately
      val transBytes = serializer.serialize(transducer)
      spark.sparkContext.parallelize(transBytes.toSeq, 1).
        saveAsObjectFile(s"${dataPath.toString}/${label}transducer")

    }
  }

  override def deserializeObject(spark: SparkSession, path: String, field: String): Option[Seq[SpecialClassParser]] = {
    import scala.collection.JavaConversions._
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    val serializer = new PlainTextSerializer

    if (fs.exists(dataPath)) {
      val elements = fs.listStatus(dataPath)
      var result = Seq[SpecialClassParser]()
      elements.foreach{ element =>
          val path = element.getPath()
          if(path.getName.contains("transducer")) {
            // take care of transducer
            val bytes = spark.sparkContext.objectFile[Byte](path.toString).collect()
            val trans = serializer.deserialize(classOf[Transducer[DawgNode, Candidate]], bytes)
            // the object
            val sc = spark.sparkContext.objectFile[SpecialClassParser](path.toString.dropRight(10)).collect().head
            sc.setTransducer(trans)
            result = result :+ sc
          }
      }

      Some(result)
    } else {
      None
    }
  }

  override def serializeDataset(spark: SparkSession, path: String, field: String, specialClasses: Seq[SpecialClassParser]): Unit = {
    implicit val encoder: Encoder[SpecialClassParser] = Encoders.kryo[SpecialClassParser]

    import spark.implicits._
    val dataPath = getFieldPath(path, field)
    specialClasses.foreach { case specialClass =>
      val serializer = new PlainTextSerializer

      // hadoop won't see files starting with '_'
      val label = specialClass.label.replaceAll("_", "-")

      val transducer = specialClass.transducer
      specialClass.setTransducer(null)
      // the object per se
      spark.createDataset(Seq(specialClass)).
      write.mode("overwrite").
        parquet(s"${dataPath.toString}/${label}")

      // we handle the transducer separately
      val transBytes = serializer.serialize(transducer)
      spark.createDataset(transBytes.toSeq).
        write.mode("overwrite").
        parquet(s"${dataPath.toString}/${label}transducer")

    }
  }

  override def deserializeDataset(spark: SparkSession, path: String, field: String): Option[Seq[SpecialClassParser]] = {
    import spark.implicits._
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    val serializer = new PlainTextSerializer

    if (fs.exists(dataPath)) {
      val elements = fs.listFiles(dataPath, false)
      var result = Seq[SpecialClassParser]()
      while(elements.hasNext) {
        val next = elements.next
        val path = next.getPath.toString
        if(path.contains("transducer")) {
          // take care of transducer
          val bytes = spark.read.parquet(path).as[Byte].collect
          val trans = serializer.deserialize(classOf[Transducer[DawgNode, Candidate]], bytes)

          // the object
          val sc = spark.read.parquet(path.dropRight(10)).as[SpecialClassParser].collect.head
          sc.setTransducer(trans)
          result = result :+ sc
        }
      }
      Some(result)
    } else {
      None
    }
  }


}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy