All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.ner.dl.NerDLApproach.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.ner.dl

import com.johnsnowlabs.client.CloudResources
import com.johnsnowlabs.client.util.CloudHelper
import com.johnsnowlabs.ml.crf.TextSentenceLabels
import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, NAMED_ENTITY, TOKEN, WORD_EMBEDDINGS}
import com.johnsnowlabs.nlp.annotators.common.{NerTagged, WordpieceEmbeddingsSentence}
import com.johnsnowlabs.nlp.annotators.ner.{ModelMetrics, NerApproach, Verbose}
import com.johnsnowlabs.nlp.annotators.param.EvaluationDLParams
import com.johnsnowlabs.nlp.util.io.{OutputHelper, ResourceHelper}
import com.johnsnowlabs.nlp.{AnnotatorApproach, AnnotatorType, ParamsAndFeaturesWritable}
import com.johnsnowlabs.storage.HasStorageRef
import org.apache.commons.io.IOUtils
import org.apache.commons.lang3.SystemUtils
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.tensorflow.Graph
import org.tensorflow.proto.framework.GraphDef

import java.io.File
import scala.collection.mutable
import scala.util.Random

/** This Named Entity recognition annotator allows to train generic NER model based on Neural
  * Networks.
  *
  * The architecture of the neural network is a Char CNNs - BiLSTM - CRF that achieves
  * state-of-the-art in most datasets.
  *
  * For instantiated/pretrained models, see [[NerDLModel]].
  *
  * The training data should be a labeled Spark Dataset, in the format of
  * [[com.johnsnowlabs.nlp.training.CoNLL CoNLL]] 2003 IOB with `Annotation` type columns. The
  * data should have columns of type `DOCUMENT, TOKEN, WORD_EMBEDDINGS` and an additional label
  * column of annotator type `NAMED_ENTITY`. Excluding the label, this can be done with for
  * example
  *   - a [[com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector SentenceDetector]],
  *   - a [[com.johnsnowlabs.nlp.annotators.Tokenizer Tokenizer]] and
  *   - a [[com.johnsnowlabs.nlp.embeddings.WordEmbeddingsModel WordEmbeddingsModel]] (any
  *     embeddings can be chosen, e.g.
  *     [[com.johnsnowlabs.nlp.embeddings.BertEmbeddings BertEmbeddings]] for BERT based
  *     embeddings).
  *
  * Setting a test dataset to monitor model metrics can be done with `.setTestDataset`. The method
  * expects a path to a parquet file containing a dataframe that has the same required columns as
  * the training dataframe. The pre-processing steps for the training dataframe should also be
  * applied to the test dataframe. The following example will show how to create the test dataset
  * with a CoNLL dataset:
  *
  * {{{
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("document")
  *
  * val embeddings = WordEmbeddingsModel
  *   .pretrained()
  *   .setInputCols("document", "token")
  *   .setOutputCol("embeddings")
  *
  * val preProcessingPipeline = new Pipeline().setStages(Array(documentAssembler, embeddings))
  *
  * val conll = CoNLL()
  * val Array(train, test) = conll
  *   .readDataset(spark, "src/test/resources/conll2003/eng.train")
  *   .randomSplit(Array(0.8, 0.2))
  *
  * preProcessingPipeline
  *   .fit(test)
  *   .transform(test)
  *   .write
  *   .mode("overwrite")
  *   .parquet("test_data")
  *
  * val nerTagger = new NerDLApproach()
  *   .setInputCols("document", "token", "embeddings")
  *   .setLabelColumn("label")
  *   .setOutputCol("ner")
  *   .setTestDataset("test_data")
  * }}}
  *
  * For extended examples of usage, see the
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/dl-ner Examples]]
  * and the
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLSpec.scala NerDLSpec]].
  *
  * ==Example==
  * {{{
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
  * import com.johnsnowlabs.nlp.annotators.Tokenizer
  * import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
  * import com.johnsnowlabs.nlp.embeddings.BertEmbeddings
  * import com.johnsnowlabs.nlp.annotators.ner.dl.NerDLApproach
  * import com.johnsnowlabs.nlp.training.CoNLL
  * import org.apache.spark.ml.Pipeline
  *
  * // This CoNLL dataset already includes a sentence, token and label
  * // column with their respective annotator types. If a custom dataset is used,
  * // these need to be defined with for example:
  *
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("document")
  *
  * val sentence = new SentenceDetector()
  *   .setInputCols("document")
  *   .setOutputCol("sentence")
  *
  * val tokenizer = new Tokenizer()
  *   .setInputCols("sentence")
  *   .setOutputCol("token")
  *
  * // Then the training can start
  * val embeddings = BertEmbeddings.pretrained()
  *   .setInputCols("sentence", "token")
  *   .setOutputCol("embeddings")
  *
  * val nerTagger = new NerDLApproach()
  *   .setInputCols("sentence", "token", "embeddings")
  *   .setLabelColumn("label")
  *   .setOutputCol("ner")
  *   .setMaxEpochs(1)
  *   .setRandomSeed(0)
  *   .setVerbose(0)
  *
  * val pipeline = new Pipeline().setStages(Array(
  *   embeddings,
  *   nerTagger
  * ))
  *
  * // We use the sentences, tokens and labels from the CoNLL dataset
  * val conll = CoNLL()
  * val trainingData = conll.readDataset(spark, "src/test/resources/conll2003/eng.train")
  *
  * val pipelineModel = pipeline.fit(trainingData)
  * }}}
  *
  * @see
  *   [[com.johnsnowlabs.nlp.annotators.ner.crf.NerCrfApproach NerCrfApproach]] for a generic CRF
  *   approach
  * @see
  *   [[com.johnsnowlabs.nlp.annotators.ner.NerConverter NerConverter]] to further process the
  *   results
  * @param uid
  *   required uid for storing annotator to disk
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class NerDLApproach(override val uid: String)
    extends AnnotatorApproach[NerDLModel]
    with NerApproach[NerDLApproach]
    with Logging
    with ParamsAndFeaturesWritable
    with EvaluationDLParams {

  def this() = this(Identifiable.randomUID("NerDL"))

  override def getLogName: String = "NerDL"

  /** Trains Tensorflow based Char-CNN-BLSTM model */
  override val description = "Trains Tensorflow based Char-CNN-BLSTM model"

  /** Input annotator types: DOCUMENT, TOKEN, WORD_EMBEDDINGS
    *
    * @group anno
    */
  override val inputAnnotatorTypes: Array[String] = Array(DOCUMENT, TOKEN, WORD_EMBEDDINGS)

  /** Output annotator types: NAMED_ENTITY
    *
    * @group anno
    */
  override val outputAnnotatorType: String = NAMED_ENTITY

  /** Learning Rate (Default: `1e-3f`)
    *
    * @group param
    */
  val lr = new FloatParam(this, "lr", "Learning Rate")

  /** Learning rate decay coefficient (Default: `0.005f`). Real Learning Rate calculates to `lr /
    * (1 + po * epoch)`
    *
    * @group param
    */
  val po = new FloatParam(
    this,
    "po",
    "Learning rate decay coefficient. Real Learning Rage = lr / (1 + po * epoch)")

  /** Batch size (Default: `8`)
    *
    * @group param
    */
  val batchSize = new IntParam(this, "batchSize", "Batch size")

  /** Dropout coefficient (Default: `0.5f`)
    *
    * @group param
    */
  val dropout = new FloatParam(this, "dropout", "Dropout coefficient")

  /** Folder path that contain external graph files
    *
    * @group param
    */
  val graphFolder =
    new Param[String](this, "graphFolder", "Folder path that contain external graph files")

  /** ConfigProto from tensorflow, serialized into byte array. Get with
    * config_proto.SerializeToString()
    *
    * @group param
    */
  val configProtoBytes = new IntArrayParam(
    this,
    "configProtoBytes",
    "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()")

  /** Whether to use contrib LSTM Cells (Default: `true`). Not compatible with Windows. Might
    * slightly improve accuracy. This param is deprecated and only exists for backward
    * compatibility
    *
    * @group param
    */
  val useContrib =
    new BooleanParam(this, "useContrib", "deprecated param - the value won't have any effect")

  /** Whether to include confidence scores in annotation metadata (Default: `false`)
    *
    * @group param
    */
  val includeConfidence = new BooleanParam(
    this,
    "includeConfidence",
    "Whether to include confidence scores in annotation metadata")

  /** whether to include all confidence scores in annotation metadata or just score of the
    * predicted tag
    *
    * @group param
    */
  val includeAllConfidenceScores = new BooleanParam(
    this,
    "includeAllConfidenceScores",
    "whether to include all confidence scores in annotation metadata")

  /** Whether to optimize for large datasets or not (Default: `false`). Enabling this option can
    * slow down training.
    *
    * @group param
    */
  val enableMemoryOptimizer = new BooleanParam(
    this,
    "enableMemoryOptimizer",
    "Whether to optimize for large datasets or not. Enabling this option can slow down training.")

  /** Whether to restore and use the model that has achieved the best performance at the end of
    * the training. The metric that is being monitored is F1 for testDataset and if it's not set
    * it will be validationSplit, and if it's not set finally looks for loss.
    *
    * @group param
    */
  val useBestModel = new BooleanParam(
    this,
    "useBestModel",
    "Whether to restore and use the model that has achieved the best performance at the end of the training.")

  /** Whether to check F1 Micro-average or F1 Macro-average as a final metric for the best model
    * This will fall back to loss if there is no validation or test dataset
    *
    * @group param
    */
  val bestModelMetric = new Param[String](
    this,
    "bestModelMetric",
    "Whether to check F1 Micro-average or F1 Macro-average as a final metric for the best model.")

  /** Learning Rate
    *
    * @group getParam
    */
  def getLr: Float = $(this.lr)

  /** Learning rate decay coefficient. Real Learning Rage = lr / (1 + po * epoch)
    *
    * @group getParam
    */
  def getPo: Float = $(this.po)

  /** Batch size
    *
    * @group getParam
    */
  def getBatchSize: Int = $(this.batchSize)

  /** Dropout coefficient
    *
    * @group getParam
    */
  def getDropout: Float = $(this.dropout)

  /** ConfigProto from tensorflow, serialized into byte array. Get with
    * config_proto.SerializeToString()
    *
    * @group getParam
    */
  def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte))

  /** Whether to use contrib LSTM Cells. Not compatible with Windows. Might slightly improve
    * accuracy.
    *
    * @group getParam
    */
  def getUseContrib: Boolean = $(this.useContrib)

  /** Memory Optimizer
    *
    * @group getParam
    */
  def getEnableMemoryOptimizer: Boolean = $(this.enableMemoryOptimizer)

  /** useBestModel
    *
    * @group getParam
    */
  def getUseBestModel: Boolean = $(this.useBestModel)

  /** @group getParam */
  def getBestModelMetric: String = $(bestModelMetric)

  /** Learning Rate
    *
    * @group setParam
    */
  def setLr(lr: Float): NerDLApproach.this.type = set(this.lr, lr)

  /** Learning rate decay coefficient. Real Learning Rage = lr / (1 + po * epoch)
    *
    * @group setParam
    */
  def setPo(po: Float): NerDLApproach.this.type = set(this.po, po)

  /** Batch size
    *
    * @group setParam
    */
  def setBatchSize(batch: Int): NerDLApproach.this.type = set(this.batchSize, batch)

  /** Dropout coefficient
    *
    * @group setParam
    */
  def setDropout(dropout: Float): NerDLApproach.this.type = set(this.dropout, dropout)

  /** Folder path that contain external graph files
    *
    * @group setParam
    */
  def setGraphFolder(path: String): NerDLApproach.this.type = set(this.graphFolder, path)

  /** ConfigProto from tensorflow, serialized into byte array. Get with
    * config_proto.SerializeToString()
    *
    * @group setParam
    */
  def setConfigProtoBytes(bytes: Array[Int]): NerDLApproach.this.type =
    set(this.configProtoBytes, bytes)

  /** Whether to use contrib LSTM Cells. Not compatible with Windows. Might slightly improve
    * accuracy.
    *
    * @group setParam
    */
  def setUseContrib(value: Boolean): NerDLApproach.this.type =
    if (value && SystemUtils.IS_OS_WINDOWS)
      throw new UnsupportedOperationException("Cannot set contrib in Windows")
    else set(useContrib, value)

  /** Whether to optimize for large datasets or not. Enabling this option can slow down training.
    *
    * @group setParam
    */
  def setEnableMemoryOptimizer(value: Boolean): NerDLApproach.this.type =
    set(this.enableMemoryOptimizer, value)

  /** Whether to include confidence scores in annotation metadata
    *
    * @group setParam
    */
  def setIncludeConfidence(value: Boolean): NerDLApproach.this.type =
    set(this.includeConfidence, value)

  /** whether to include confidence scores for all tags rather than just for the predicted one
    *
    * @group setParam
    */
  def setIncludeAllConfidenceScores(value: Boolean): this.type =
    set(this.includeAllConfidenceScores, value)

  /** @group setParam */
  def setUseBestModel(value: Boolean): NerDLApproach.this.type = set(this.useBestModel, value)

  /** @group setParam */
  def setBestModelMetric(value: String): NerDLApproach.this.type = {
    require(
      ModelMetrics.values.contains(value),
      s"Invalid metric: $value. Allowed metrics are: ${ModelMetrics.values.mkString(", ")}")

    set(this.bestModelMetric, value)

  }

  setDefault(
    minEpochs -> 0,
    maxEpochs -> 70,
    lr -> 1e-3f,
    po -> 0.005f,
    batchSize -> 8,
    dropout -> 0.5f,
    useContrib -> true,
    includeConfidence -> false,
    includeAllConfidenceScores -> false,
    enableMemoryOptimizer -> false,
    useBestModel -> false,
    bestModelMetric -> ModelMetrics.loss)

  override val verboseLevel: Verbose.Level = Verbose($(verbose))

  def calculateEmbeddingsDim(sentences: Seq[WordpieceEmbeddingsSentence]): Int = {
    sentences
      .find(s => s.tokens.nonEmpty)
      .map(s => s.tokens.head.embeddings.length)
      .getOrElse(1)
  }

  override def beforeTraining(spark: SparkSession): Unit = {
    LoadsContrib.loadContribToCluster(spark)
    LoadsContrib.loadContribToTensorflow()
  }

  override def train(
      dataset: Dataset[_],
      recursivePipeline: Option[PipelineModel]): NerDLModel = {

    require(
      $(validationSplit) <= 1f | $(validationSplit) >= 0f,
      "The validationSplit must be between 0f and 1f")

    val train = dataset.toDF()

    val test = if (!isDefined(testDataset)) {
      train.limit(0) // keep the schema only
    } else {
      ResourceHelper.readSparkDataFrame($(testDataset))
    }

    val embeddingsRef =
      HasStorageRef.getStorageRefFromInput(dataset, $(inputCols), AnnotatorType.WORD_EMBEDDINGS)

    val Array(validSplit, trainSplit) =
      train.randomSplit(Array($(validationSplit), 1.0f - $(validationSplit)))

    val trainIteratorFunc = NerDLApproach.getIteratorFunc(
      trainSplit,
      inputColumns = getInputCols,
      labelColumn = $(labelColumn),
      batchSize = $(batchSize),
      enableMemoryOptimizer = $(enableMemoryOptimizer))

    val validIteratorFunc = NerDLApproach.getIteratorFunc(
      validSplit,
      inputColumns = getInputCols,
      labelColumn = $(labelColumn),
      batchSize = $(batchSize),
      enableMemoryOptimizer = $(enableMemoryOptimizer))

    val testIteratorFunc = NerDLApproach.getIteratorFunc(
      test,
      inputColumns = getInputCols,
      labelColumn = $(labelColumn),
      batchSize = $(batchSize),
      enableMemoryOptimizer = $(enableMemoryOptimizer))

    val (labels, chars, embeddingsDim, dsLen) =
      NerDLApproach.getDataSetParams(trainIteratorFunc())

    val settings = DatasetEncoderParams(
      labels.toList,
      chars.toList,
      Array.fill(embeddingsDim)(0f).toList,
      embeddingsDim)
    val encoder = new NerDatasetEncoder(settings)

    val graphFile = NerDLApproach.searchForSuitableGraph(
      labels.size,
      embeddingsDim,
      chars.size + 1,
      get(graphFolder))

    val graph = new Graph()
    val graphStream = ResourceHelper.getResourceStream(graphFile)
    val graphBytesDef = IOUtils.toByteArray(graphStream)
    graph.importGraphDef(GraphDef.parseFrom(graphBytesDef))

    val tfWrapper = new TensorflowWrapper(
      Variables(Array.empty[Array[Byte]], Array.empty[Byte]),
      graph.toGraphDef.toByteArray)

    val (ner, trainedTf) =
      try {
        val model = new TensorflowNer(tfWrapper, encoder, Verbose($(verbose)))
        if (isDefined(randomSeed)) {
          Random.setSeed($(randomSeed))
        }

        // start the iterator here once again
        val trainedTf = model.train(
          trainIteratorFunc(),
          dsLen,
          validIteratorFunc(),
          (dsLen * $(validationSplit)).toLong,
          $(lr),
          $(po),
          $(dropout),
          $(batchSize),
          $(useBestModel),
          $(bestModelMetric),
          graphFileName = graphFile,
          test = testIteratorFunc(),
          startEpoch = 0,
          endEpoch = $(maxEpochs),
          configProtoBytes = getConfigProtoBytes,
          validationSplit = $(validationSplit),
          evaluationLogExtended = $(evaluationLogExtended),
          enableOutputLogs = $(enableOutputLogs),
          outputLogsPath = $(outputLogsPath),
          uuid = this.uid)
        (model, trainedTf)
      } catch {
        case e: Exception =>
          graph.close()
          throw e
      }

    val newWrapper =
      new TensorflowWrapper(
        TensorflowWrapper.extractVariablesSavedModel(trainedTf),
        tfWrapper.graph)

    val model = new NerDLModel()
      .setDatasetParams(ner.encoder.params)
      .setModelIfNotSet(dataset.sparkSession, newWrapper)
      .setIncludeConfidence($(includeConfidence))
      .setIncludeAllConfidenceScores($(includeAllConfidenceScores))
      .setStorageRef(embeddingsRef)

    if (get(configProtoBytes).isDefined)
      model.setConfigProtoBytes($(configProtoBytes))

    model

  }
}

trait WithGraphResolver {

  def searchForSuitableGraph(
      tags: Int,
      embeddingsNDims: Int,
      nChars: Int,
      localGraphPath: Option[String] = None): String = {

    val files: Seq[String] = getFiles(localGraphPath)

    // 1. Filter Graphs by embeddings
    val embeddingsFiltered = files.map { filePath =>
      val file = new File(filePath)
      val name = file.getName
      val graphPrefix = "blstm_"

      if (name.startsWith(graphPrefix)) {
        val clean = name.replace(graphPrefix, "").replace(".pb", "")
        val graphParams = clean.split("_").take(4).map(s => s.toInt)
        val Array(fileTags, fileEmbeddingsNDims, _, fileNChars) = graphParams

        if (embeddingsNDims == fileEmbeddingsNDims)
          Some((fileTags, fileEmbeddingsNDims, fileNChars))
        else
          None
      } else {
        None
      }
    }

    require(
      embeddingsFiltered.exists(_.nonEmpty),
      s"Graph dimensions should be $embeddingsNDims: Could not find a suitable tensorflow graph for embeddings dim: $embeddingsNDims tags: $tags nChars: $nChars. " +
        s"Check https://sparknlp.org/docs/en/graph for instructions to generate the required graph.")

    // 2. Filter by labels and nChars
    val tagsFiltered = embeddingsFiltered.map {
      case Some((fileTags, fileEmbeddingsNDims, fileNChars)) =>
        if (tags > fileTags)
          None
        else
          Some((fileTags, fileEmbeddingsNDims, fileNChars))
      case _ => None
    }

    require(
      tagsFiltered.exists(_.nonEmpty),
      s"Graph tags size should be $tags: Could not find a suitable tensorflow graph for embeddings dim: $embeddingsNDims tags: $tags nChars: $nChars. " +
        s"Check https://sparknlp.org/docs/en/graph for instructions to generate the required graph.")

    // 3. Filter by labels and nChars
    val charsFiltered = tagsFiltered.map {
      case Some((fileTags, fileEmbeddingsNDims, fileNChars)) =>
        if (nChars > fileNChars)
          None
        else
          Some((fileTags, fileEmbeddingsNDims, fileNChars))
      case _ => None
    }

    require(
      charsFiltered.exists(_.nonEmpty),
      s"Graph chars size should be $nChars: Could not find a suitable tensorflow graph for embeddings dim: $embeddingsNDims tags: $tags nChars: $nChars. " +
        s"Check https://sparknlp.org/docs/en/graph for instructions to generate the required graph")

    for (i <- files.indices) {
      if (charsFiltered(i).nonEmpty)
        return files(i)
    }

    throw new IllegalStateException("Code shouldn't pass here")
  }

  private def getFiles(localGraphPath: Option[String]): Seq[String] = {
    var files: Seq[String] = List()

    if (localGraphPath.isDefined && CloudHelper.isCloudPath(localGraphPath.get)) {
      val tmpDirectory = CloudResources.downloadBucketToLocalTmp(localGraphPath.get).getPath
      files = ResourceHelper.listLocalFiles(tmpDirectory).map(_.getAbsolutePath)
    } else {

      if (localGraphPath.isDefined && OutputHelper
          .getFileSystem(localGraphPath.get)
          .getScheme == "dbfs") {
        files =
          ResourceHelper.listLocalFiles(localGraphPath.get).map(file => file.getAbsolutePath)
      } else {
        files = localGraphPath
          .map(path =>
            ResourceHelper
              .listLocalFiles(ResourceHelper.copyToLocal(path))
              .map(_.getAbsolutePath))
          .getOrElse(ResourceHelper.listResourceDirectory("/ner-dl"))
      }

    }
    files
  }

}

/** This is the companion object of [[NerDLApproach]]. Please refer to that class for the
  * documentation.
  */
object NerDLApproach extends DefaultParamsReadable[NerDLApproach] with WithGraphResolver {

  def getIteratorFunc(
      dataset: Dataset[Row],
      inputColumns: Array[String],
      labelColumn: String,
      batchSize: Int,
      enableMemoryOptimizer: Boolean)
      : () => Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = {

    if (enableMemoryOptimizer) { () =>
      NerTagged.iterateOnDataframe(dataset, inputColumns, labelColumn, batchSize)

    } else {
      val inMemory = dataset
        .select(labelColumn, inputColumns.toSeq: _*)
        .collect()

      () => NerTagged.iterateOnArray(inMemory, inputColumns, batchSize)
    }
  }

  def getDataSetParams(dsIt: Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]])
      : (mutable.Set[String], mutable.Set[Char], Int, Long) = {

    val labels = scala.collection.mutable.Set[String]()
    val chars = scala.collection.mutable.Set[Char]()
    var embeddingsDim = 1
    var dsLen = 0L

    // try to be frugal with memory and with number of passes thru the iterator
    for (batch <- dsIt) {
      dsLen += batch.length
      for (datapoint <- batch) {

        for (label <- datapoint._1.labels)
          labels += label

        for (token <- datapoint._2.tokens; char <- token.token.toCharArray)
          chars += char

        if (datapoint._2.tokens.nonEmpty)
          embeddingsDim = datapoint._2.tokens.head.embeddings.length
      }
    }

    (labels, chars, embeddingsDim, dsLen)
  }

  def getGraphParams(
      dataset: Dataset[_],
      inputColumns: java.util.ArrayList[java.lang.String],
      labelColumn: String): (Int, Int, Int) = {

    val trainIteratorFunc =
      getIteratorFunc(dataset.toDF(), inputColumns.toArray.map(_.toString), labelColumn, 0, false)

    val (labels, chars, embeddingsDim, _) = getDataSetParams(trainIteratorFunc())

    (labels.size, embeddingsDim, chars.size + 1)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy