All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.classifier.dl.SentimentDLApproach.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.classifier.dl

import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.nlp.AnnotatorType.{CATEGORY, SENTENCE_EMBEDDINGS}
import com.johnsnowlabs.nlp.annotators.ner.Verbose
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.nlp.{AnnotatorApproach, AnnotatorType, ParamsAndFeaturesWritable}
import com.johnsnowlabs.storage.HasStorageRef
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, SparkSession}

import scala.util.Random

/** Trains a SentimentDL, an annotator for multi-class sentiment analysis.
  *
  * In natural language processing, sentiment analysis is the task of classifying the affective
  * state or subjective view of a text. A common example is if either a product review or tweet
  * can be interpreted positively or negatively.
  *
  * For the instantiated/pretrained models, see [[SentimentDLModel]].
  *
  * '''Notes''':
  *   - This annotator accepts a label column of a single item in either type of String, Int,
  *     Float, or Double. So positive sentiment can be expressed as either `"positive"` or `0`,
  *     negative sentiment as `"negative"` or `1`.
  *   - [[com.johnsnowlabs.nlp.embeddings.UniversalSentenceEncoder UniversalSentenceEncoder]],
  *     [[com.johnsnowlabs.nlp.embeddings.BertSentenceEmbeddings BertSentenceEmbeddings]], or
  *     [[com.johnsnowlabs.nlp.embeddings.SentenceEmbeddings SentenceEmbeddings]] can be used for
  *     the `inputCol`.
  *
  * Setting a test dataset to monitor model metrics can be done with `.setTestDataset`. The method
  * expects a path to a parquet file containing a dataframe that has the same required columns as
  * the training dataframe. The pre-processing steps for the training dataframe should also be
  * applied to the test dataframe. The following example will show how to create the test dataset:
  *
  * {{{
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("document")
  *
  * val embeddings = UniversalSentenceEncoder.pretrained()
  *   .setInputCols("document")
  *   .setOutputCol("sentence_embeddings")
  *
  * val preProcessingPipeline = new Pipeline().setStages(Array(documentAssembler, embeddings))
  *
  * val Array(train, test) = data.randomSplit(Array(0.8, 0.2))
  * preProcessingPipeline
  *   .fit(test)
  *   .transform(test)
  *   .write
  *   .mode("overwrite")
  *   .parquet("test_data")
  *
  * val classifier = new SentimentDLApproach()
  *   .setInputCols("sentence_embeddings")
  *   .setOutputCol("sentiment")
  *   .setLabelColumn("label")
  *   .setTestDataset("test_data")
  * }}}
  *
  * For extended examples of usage, see the
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/classification/SentimentDL_train_multiclass_sentiment_classifier.ipynb Examples]]
  * and the
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/SentimentDLTestSpec.scala SentimentDLTestSpec]].
  *
  * ==Example==
  * In this example, `sentiment.csv` is in the form
  * {{{
  * text,label
  * This movie is the best movie I have watched ever! In my opinion this movie can win an award.,0
  * This was a terrible movie! The acting was bad really bad!,1
  * }}}
  * The model can then be trained with
  * {{{
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
  * import com.johnsnowlabs.nlp.annotator.UniversalSentenceEncoder
  * import com.johnsnowlabs.nlp.annotators.classifier.dl.{SentimentDLApproach, SentimentDLModel}
  * import org.apache.spark.ml.Pipeline
  *
  * val smallCorpus = spark.read.option("header", "true").csv("src/test/resources/classifier/sentiment.csv")
  *
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("document")
  *
  * val useEmbeddings = UniversalSentenceEncoder.pretrained()
  *   .setInputCols("document")
  *   .setOutputCol("sentence_embeddings")
  *
  * val docClassifier = new SentimentDLApproach()
  *   .setInputCols("sentence_embeddings")
  *   .setOutputCol("sentiment")
  *   .setLabelColumn("label")
  *   .setBatchSize(32)
  *   .setMaxEpochs(1)
  *   .setLr(5e-3f)
  *   .setDropout(0.5f)
  *
  * val pipeline = new Pipeline()
  *   .setStages(
  *     Array(
  *       documentAssembler,
  *       useEmbeddings,
  *       docClassifier
  *     )
  *   )
  *
  * val pipelineModel = pipeline.fit(smallCorpus)
  * }}}
  *
  * @see
  *   [[ClassifierDLApproach]] for general single-class classification
  * @see
  *   [[MultiClassifierDLApproach]] for general multi-class classification
  * @param uid
  *   required uid for storing annotator to disk
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class SentimentDLApproach(override val uid: String)
    extends AnnotatorApproach[SentimentDLModel]
    with ParamsAndFeaturesWritable
    with ClassifierEncoder {

  def this() = this(Identifiable.randomUID("SentimentDL"))

  override val description = "Trains TensorFlow model for Sentiment Classification"

  /** Input Annotator Types: SENTENCE_EMBEDDINGS
    *
    * @group anno
    */
  override val inputAnnotatorTypes: Array[AnnotatorType] = Array(SENTENCE_EMBEDDINGS)

  /** Output Annotator Types: CATEGORY
    *
    * @group anno
    */
  override val outputAnnotatorType: String = CATEGORY

  /** Dropout coefficient (Default: `0.5f`)
    *
    * @group param
    */
  val dropout = new FloatParam(this, "dropout", "Dropout coefficient")

  /** The minimum threshold for the final result otherwise it will be either neutral or the value
    * set in thresholdLabel (Default: `0.6f`)
    *
    * @group param
    */
  val threshold = new FloatParam(
    this,
    "threshold",
    "The minimum threshold for the final result otherwise it will be either neutral or the value set in thresholdLabel.")

  /** In case the score is less than threshold, what should be the label (Default: `"neutral"`)
    *
    * @group param
    */
  val thresholdLabel = new Param[String](
    this,
    "thresholdLabel",
    "In case the score is less than threshold, what should be the label. Default is neutral.")

  /** @group setParam */
  def setDropout(dropout: Float): SentimentDLApproach.this.type = set(this.dropout, dropout)

  /** @group setParam */
  def setThreshold(threshold: Float): SentimentDLApproach.this.type =
    set(this.threshold, threshold)

  /** @group setParam */
  def setThresholdLabel(label: String): SentimentDLApproach.this.type =
    set(this.thresholdLabel, label)

  /** @group getParam */
  def getDropout: Float = $(this.dropout)

  /** @group getParam */
  def getThreshold: Float = $(this.threshold)

  /** @group getParam */
  def getThresholdLabel: String = $(this.thresholdLabel)

  setDefault(
    maxEpochs -> 10,
    lr -> 5e-3f,
    dropout -> 0.5f,
    batchSize -> 64,
    threshold -> 0.6f,
    thresholdLabel -> "neutral")

  override def beforeTraining(spark: SparkSession): Unit = {}

  override def train(
      dataset: Dataset[_],
      recursivePipeline: Option[PipelineModel]): SentimentDLModel = {

    val labelColType = dataset.schema($(labelColumn)).dataType
    require(
      labelColType == StringType | labelColType == IntegerType | labelColType == DoubleType | labelColType == FloatType | labelColType == LongType,
      s"The label column $labelColumn type is $labelColType and it's not compatible. " +
        s"Compatible types are StringType, IntegerType, DoubleType, LongType, or FloatType. ")

    val (trainDataset, trainLabels) = buildDatasetWithLabels(dataset, getInputCols(0))
    val settings = ClassifierDatasetEncoderParams(tags = trainLabels)
    val encoder = new ClassifierDatasetEncoder(settings)
    val trainInputs = extractInputs(encoder, trainDataset)

    var testEncoder: Option[ClassifierDatasetEncoder] = None
    val testInputs =
      if (!isDefined(testDataset)) None
      else {
        val testDataFrame = ResourceHelper.readSparkDataFrame($(testDataset))
        val (test, testLabels) = buildDatasetWithLabels(testDataFrame, getInputCols(0))
        val settings = ClassifierDatasetEncoderParams(tags = testLabels)
        testEncoder = Some(new ClassifierDatasetEncoder(settings))
        Option(extractInputs(testEncoder.get, test))
      }

    val tfWrapper: TensorflowWrapper = loadSavedModel()

    val classifier =
      try {
        val model = new TensorflowSentiment(tensorflow = tfWrapper, encoder, Verbose($(verbose)))
        if (isDefined(randomSeed)) {
          Random.setSeed($(randomSeed))
        }

        model.train(
          trainInputs,
          testInputs,
          lr = $(lr),
          batchSize = $(batchSize),
          dropout = $(dropout),
          endEpoch = $(maxEpochs),
          configProtoBytes = getConfigProtoBytes,
          validationSplit = $(validationSplit),
          evaluationLogExtended = $(evaluationLogExtended),
          enableOutputLogs = $(enableOutputLogs),
          outputLogsPath = $(outputLogsPath),
          uuid = this.uid)
        model
      } catch {
        case e: Exception =>
          throw e
      }

    val newWrapper = new TensorflowWrapper(
      TensorflowWrapper.extractVariablesSavedModel(
        tfWrapper.getTFSession(configProtoBytes = getConfigProtoBytes)),
      tfWrapper.graph)

    val embeddingsRef = HasStorageRef.getStorageRefFromInput(
      dataset,
      $(inputCols),
      AnnotatorType.SENTENCE_EMBEDDINGS)

    val model = new SentimentDLModel()
      .setDatasetParams(classifier.encoder.params)
      .setModelIfNotSet(dataset.sparkSession, newWrapper)
      .setStorageRef(embeddingsRef)
      .setThreshold($(threshold))
      .setThresholdLabel($(thresholdLabel))

    if (get(configProtoBytes).isDefined)
      model.setConfigProtoBytes($(configProtoBytes))

    model
  }

  def loadSavedModel(): TensorflowWrapper = {

    val wrapper =
      TensorflowWrapper.readZippedSavedModel(
        "/sentiment-dl",
        tags = Array("serve"),
        initAllTables = true)
    wrapper.variables = Variables(Array.empty[Array[Byte]], Array.empty[Byte])
    wrapper
  }
}

/** This is the companion object of [[SentimentApproach]]. Please refer to that class for the
  * documentation.
  */
object SentimentApproach extends DefaultParamsReadable[SentimentDLApproach]




© 2015 - 2024 Weber Informatics LLC | Privacy Policy