All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.serialization.Feature.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.serialization

import com.github.liblevenshtein.serialization.PlainTextSerializer
import com.johnsnowlabs.nlp.HasFeatures
import com.johnsnowlabs.nlp.annotators.spell.context.parser.VocabParser
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.util.{ConfigHelper, ConfigLoader}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}

import scala.reflect.ClassTag

abstract class Feature[Serializable1, Serializable2, TComplete: ClassTag](
    model: HasFeatures,
    val name: String)
    extends Serializable {
  model.features.append(this)

  private val spark: SparkSession = ResourceHelper.spark

  val serializationMode: String =
    ConfigLoader.getConfigStringValue(ConfigHelper.serializationMode)
  val useBroadcast: Boolean = ConfigLoader.getConfigBooleanValue(ConfigHelper.useBroadcast)
  final protected var broadcastValue: Option[Broadcast[TComplete]] = None

  final protected var rawValue: Option[TComplete] = None
  final protected var fallbackRawValue: Option[TComplete] = None

  final protected var fallbackLazyValue: Option[() => TComplete] = None
  final protected var isProtected: Boolean = false

  final def serialize(
      spark: SparkSession,
      path: String,
      field: String,
      value: TComplete): Unit = {
    serializationMode match {
      case "dataset" => serializeDataset(spark, path, field, value)
      case "object" => serializeObject(spark, path, field, value)
      case _ =>
        throw new IllegalArgumentException(
          "Illegal performance.serialization setting. Can be 'dataset' or 'object'")
    }
  }

  final def serializeInfer(spark: SparkSession, path: String, field: String, value: Any): Unit =
    serialize(spark, path, field, value.asInstanceOf[TComplete])

  final def deserialize(spark: SparkSession, path: String, field: String): Option[_] = {
    if (broadcastValue.isDefined || rawValue.isDefined)
      throw new Exception(
        s"Trying de deserialize an already set value for ${this.name}. This should not happen.")
    serializationMode match {
      case "dataset" => deserializeDataset(spark, path, field)
      case "object" => deserializeObject(spark, path, field)
      case _ =>
        throw new IllegalArgumentException(
          "Illegal performance.serialization setting. Can be 'dataset' or 'object'")
    }
  }

  protected def serializeDataset(
      spark: SparkSession,
      path: String,
      field: String,
      value: TComplete): Unit

  protected def deserializeDataset(spark: SparkSession, path: String, field: String): Option[_]

  protected def serializeObject(
      spark: SparkSession,
      path: String,
      field: String,
      value: TComplete): Unit

  protected def deserializeObject(spark: SparkSession, path: String, field: String): Option[_]

  final protected def getFieldPath(path: String, field: String): Path =
    Path.mergePaths(new Path(path), new Path("/fields/" + field))

  private def callAndSetFallback: Option[TComplete] = {
    fallbackRawValue = fallbackLazyValue.map(_())
    fallbackRawValue
  }

  final def get: Option[TComplete] = {
    broadcastValue.map(_.value).orElse(rawValue)
  }

  final def orDefault: Option[TComplete] = {
    broadcastValue
      .map(_.value)
      .orElse(rawValue)
      .orElse(fallbackRawValue)
      .orElse(callAndSetFallback)
  }

  final def getOrDefault: TComplete = {
    orDefault
      .getOrElse(throw new Exception(s"feature $name is not set"))
  }

  final def setValue(value: Option[Any]): HasFeatures = {
    if (isProtected && isSet) {
      val warnString =
        s"Warning: The parameter ${this.name} is protected and can only be set once." +
          " For a pretrained model, this was done during the initialization process." +
          " If you are trying to train your own model, please check the documentation."
      println(warnString)
    } else {
      if (useBroadcast) {
        if (isSet) broadcastValue.get.destroy()
        broadcastValue =
          value.map(v => spark.sparkContext.broadcast[TComplete](v.asInstanceOf[TComplete]))
      } else {
        rawValue = value.map(_.asInstanceOf[TComplete])
      }
    }
    model
  }

  def setFallback(v: Option[() => TComplete]): HasFeatures = {
    fallbackLazyValue = v
    model
  }

  final def isSet: Boolean = {
    broadcastValue.isDefined || rawValue.isDefined
  }

  /** Sets this feature to be protected and only settable once.
    *
    * @return
    *   This Feature
    */
  final def setProtected(): this.type = {
    isProtected = true
    this
  }

}

class StructFeature[TValue: ClassTag](model: HasFeatures, override val name: String)
    extends Feature[TValue, TValue, TValue](model, name) {

  implicit val encoder: Encoder[TValue] = Encoders.kryo[TValue]

  override def serializeObject(
      spark: SparkSession,
      path: String,
      field: String,
      value: TValue): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.sparkContext.parallelize(Seq(value)).saveAsObjectFile(dataPath.toString)
  }

  override def deserializeObject(
      spark: SparkSession,
      path: String,
      field: String): Option[TValue] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.sparkContext.objectFile[TValue](dataPath.toString).first)
    } else {
      None
    }
  }

  override def serializeDataset(
      spark: SparkSession,
      path: String,
      field: String,
      value: TValue): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.createDataset(Seq(value)).write.mode("overwrite").parquet(dataPath.toString)
  }

  override def deserializeDataset(
      spark: SparkSession,
      path: String,
      field: String): Option[TValue] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.read.parquet(dataPath.toString).as[TValue].first)
    } else {
      None
    }
  }

}

class MapFeature[TKey: ClassTag, TValue: ClassTag](model: HasFeatures, override val name: String)
    extends Feature[TKey, TValue, Map[TKey, TValue]](model, name) {

  implicit val encoder: Encoder[(TKey, TValue)] = Encoders.kryo[(TKey, TValue)]

  override def serializeObject(
      spark: SparkSession,
      path: String,
      field: String,
      value: Map[TKey, TValue]): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.sparkContext.parallelize(value.toSeq).saveAsObjectFile(dataPath.toString)
  }

  override def deserializeObject(
      spark: SparkSession,
      path: String,
      field: String): Option[Map[TKey, TValue]] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.sparkContext.objectFile[(TKey, TValue)](dataPath.toString).collect.toMap)
    } else {
      None
    }
  }

  override def serializeDataset(
      spark: SparkSession,
      path: String,
      field: String,
      value: Map[TKey, TValue]): Unit = {
    import spark.implicits._
    val dataPath = getFieldPath(path, field)
    value.toSeq.toDS.write.mode("overwrite").parquet(dataPath.toString)
  }

  override def deserializeDataset(
      spark: SparkSession,
      path: String,
      field: String): Option[Map[TKey, TValue]] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.read.parquet(dataPath.toString).as[(TKey, TValue)].collect.toMap)
    } else {
      None
    }
  }

}

class ArrayFeature[TValue: ClassTag](model: HasFeatures, override val name: String)
    extends Feature[TValue, TValue, Array[TValue]](model, name) {

  implicit val encoder: Encoder[TValue] = Encoders.kryo[TValue]

  override def serializeObject(
      spark: SparkSession,
      path: String,
      field: String,
      value: Array[TValue]): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.sparkContext.parallelize(value).saveAsObjectFile(dataPath.toString)
  }

  override def deserializeObject(
      spark: SparkSession,
      path: String,
      field: String): Option[Array[TValue]] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.sparkContext.objectFile[TValue](dataPath.toString).collect())
    } else {
      None
    }
  }

  override def serializeDataset(
      spark: SparkSession,
      path: String,
      field: String,
      value: Array[TValue]): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.createDataset(value).write.mode("overwrite").parquet(dataPath.toString)
  }

  override def deserializeDataset(
      spark: SparkSession,
      path: String,
      field: String): Option[Array[TValue]] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.read.parquet(dataPath.toString).as[TValue].collect)
    } else {
      None
    }
  }

}

class SetFeature[TValue: ClassTag](model: HasFeatures, override val name: String)
    extends Feature[TValue, TValue, Set[TValue]](model, name) {

  implicit val encoder: Encoder[TValue] = Encoders.kryo[TValue]

  override def serializeObject(
      spark: SparkSession,
      path: String,
      field: String,
      value: Set[TValue]): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.sparkContext.parallelize(value.toSeq).saveAsObjectFile(dataPath.toString)
  }

  override def deserializeObject(
      spark: SparkSession,
      path: String,
      field: String): Option[Set[TValue]] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.sparkContext.objectFile[TValue](dataPath.toString).collect().toSet)
    } else {
      None
    }
  }

  override def serializeDataset(
      spark: SparkSession,
      path: String,
      field: String,
      value: Set[TValue]): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.createDataset(value.toSeq).write.mode("overwrite").parquet(dataPath.toString)
  }

  override def deserializeDataset(
      spark: SparkSession,
      path: String,
      field: String): Option[Set[TValue]] = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      Some(spark.read.parquet(dataPath.toString).as[TValue].collect.toSet)
    } else {
      None
    }
  }

}

class TransducerFeature(model: HasFeatures, override val name: String)
    extends Feature[VocabParser, VocabParser, VocabParser](model, name) {

  override def serializeObject(
      spark: SparkSession,
      path: String,
      field: String,
      trans: VocabParser): Unit = {
    val dataPath = getFieldPath(path, field)
    spark.sparkContext.parallelize(Seq(trans), 1).saveAsObjectFile(dataPath.toString)

  }

  override def deserializeObject(
      spark: SparkSession,
      path: String,
      field: String): Option[VocabParser] = {
    val serializer = new PlainTextSerializer
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      val sc = spark.sparkContext.objectFile[VocabParser](dataPath.toString).collect().head
      Some(sc)
    } else {
      None
    }
  }

  override def serializeDataset(
      spark: SparkSession,
      path: String,
      field: String,
      trans: VocabParser): Unit = {
    implicit val encoder: Encoder[VocabParser] = Encoders.kryo[VocabParser]
    val serializer = new PlainTextSerializer
    val dataPath = getFieldPath(path, field)
    spark.createDataset(Seq(trans)).write.mode("overwrite").parquet(dataPath.toString)
  }

  override def deserializeDataset(
      spark: SparkSession,
      path: String,
      field: String): Option[VocabParser] = {
    implicit val encoder: Encoder[VocabParser] = Encoders.kryo[VocabParser]
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dataPath = getFieldPath(path, field)
    if (fs.exists(dataPath)) {
      val sc = spark.read.parquet(dataPath.toString).as[VocabParser].collect.head
      Some(sc)
    } else {
      None
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy