
com.johnsnowlabs.nlp.annotators.ner.dl.NerDLApproach.scala Maven / Gradle / Ivy
package com.johnsnowlabs.nlp.annotators.ner.dl
import java.io.File
import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.nlp.{AnnotatorApproach, ParamsAndFeaturesWritable}
import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, NAMED_ENTITY, TOKEN, WORD_EMBEDDINGS}
import com.johnsnowlabs.nlp.annotators.common.{NerTagged, WordpieceEmbeddingsSentence}
import com.johnsnowlabs.nlp.annotators.ner.{NerApproach, Verbose}
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import org.apache.commons.io.IOUtils
import org.apache.commons.lang.SystemUtils
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.sql.{Dataset, SparkSession}
import org.tensorflow.Graph
import scala.util.Random
class NerDLApproach(override val uid: String)
extends AnnotatorApproach[NerDLModel]
with NerApproach[NerDLApproach]
with Logging
with ParamsAndFeaturesWritable {
def this() = this(Identifiable.randomUID("NerDL"))
override def getLogName: String = "NerDL"
override val description = "Trains Tensorflow based Char-CNN-BLSTM model"
override val inputAnnotatorTypes = Array(DOCUMENT, TOKEN, WORD_EMBEDDINGS)
override val outputAnnotatorType = NAMED_ENTITY
val lr = new FloatParam(this, "lr", "Learning Rate")
val po = new FloatParam(this, "po", "Learning rate decay coefficient. Real Learning Rage = lr / (1 + po * epoch)")
val batchSize = new IntParam(this, "batchSize", "Batch size")
val dropout = new FloatParam(this, "dropout", "Dropout coefficient")
val graphFolder = new Param[String](this, "graphFolder", "Folder path that contain external graph files")
val configProtoBytes = new IntArrayParam(this, "configProtoBytes", "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()")
val useContrib = new BooleanParam(this, "useContrib", "whether to use contrib LSTM Cells. Not compatible with Windows. Might slightly improve accuracy.")
def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte))
def getUseContrib(): Boolean = $(this.useContrib)
def setLr(lr: Float) = set(this.lr, lr)
def setPo(po: Float) = set(this.po, po)
def setBatchSize(batch: Int) = set(this.batchSize, batch)
def setDropout(dropout: Float) = set(this.dropout, dropout)
def setGraphFolder(path: String) = set(this.graphFolder, path)
def setConfigProtoBytes(bytes: Array[Int]) = set(this.configProtoBytes, bytes)
def setUseContrib(value: Boolean) = if (value && SystemUtils.IS_OS_WINDOWS) throw new UnsupportedOperationException("Cannot set contrib in Windows") else set(useContrib, value)
setDefault(
minEpochs -> 0,
maxEpochs -> 70,
lr -> 1e-3f,
po -> 0.005f,
batchSize -> 8,
dropout -> 0.5f,
verbose -> Verbose.Silent.id,
useContrib -> {if (SystemUtils.IS_OS_WINDOWS) false else true}
)
override val verboseLevel = Verbose($(verbose))
def calculateEmbeddingsDim(sentences: Seq[WordpieceEmbeddingsSentence]): Int = {
sentences.find(s => s.tokens.nonEmpty)
.map(s => s.tokens.head.embeddings.length)
.getOrElse(1)
}
override def beforeTraining(spark: SparkSession): Unit = {
LoadsContrib.loadContribToCluster(spark)
LoadsContrib.loadContribToTensorflow()
}
override def train(dataset: Dataset[_], recursivePipeline: Option[PipelineModel]): NerDLModel = {
val train = dataset.toDF()
val trainDataset = NerTagged.collectTrainingInstances(train, getInputCols, $(labelColumn))
val trainSentences = trainDataset.map(r => r._2)
val labels = trainDataset.flatMap(r => r._1.labels).distinct ++ Set("X")
val chars = trainDataset.flatMap(r => r._2.tokens.flatMap(token => token.wordpiece.toCharArray)).distinct
val embeddingsDim = calculateEmbeddingsDim(trainSentences)
val settings = DatasetEncoderParams(labels.toList, chars.toList,
Array.fill(embeddingsDim)(0f).toList, embeddingsDim)
val encoder = new NerDatasetEncoder(
settings
)
val graphFile = NerDLApproach.searchForSuitableGraph(labels.length, embeddingsDim, chars.length, get(graphFolder), getUseContrib())
val graph = new Graph()
val graphStream = ResourceHelper.getResourceStream(graphFile)
val graphBytesDef = IOUtils.toByteArray(graphStream)
graph.importGraphDef(graphBytesDef)
val tf = new TensorflowWrapper(Variables(Array.empty[Byte], Array.empty[Byte]), graph.toGraphDef)
val ner = try {
val model = new TensorflowNer(tf, encoder, $(batchSize), Verbose($(verbose)))
if (isDefined(randomSeed)) {
Random.setSeed($(randomSeed))
}
model.train(trainDataset, $(lr), $(po), $(batchSize), $(dropout), 0, $(maxEpochs), configProtoBytes=getConfigProtoBytes)
model
}
catch {
case e: Exception =>
graph.close()
throw e
}
val newWrapper = new TensorflowWrapper(TensorflowWrapper.extractVariables(tf.getSession(configProtoBytes=getConfigProtoBytes)), tf.graph)
val model = new NerDLModel()
.setDatasetParams(ner.encoder.params)
.setBatchSize($(batchSize))
.setModelIfNotSet(dataset.sparkSession, newWrapper)
if (get(configProtoBytes).isDefined)
model.setConfigProtoBytes($(configProtoBytes))
model
}
}
trait WithGraphResolver {
def searchForSuitableGraph(tags: Int, embeddingsNDims: Int, nChars: Int, localGraphPath: Option[String] = None, loadContrib: Boolean = false): String = {
val files = localGraphPath.map(path => ResourceHelper.listLocalFiles(ResourceHelper.copyToLocal(path)).map(_.getAbsolutePath))
.getOrElse(ResourceHelper.listResourceDirectory("/ner-dl"))
// 1. Filter Graphs by embeddings
val embeddingsFiltered = files.map { filePath =>
val file = new File(filePath)
val name = file.getName
val graphPrefix = if (loadContrib) "blstm_" else "blstm-noncontrib_"
if (name.startsWith(graphPrefix)) {
val clean = name.replace(graphPrefix, "").replace(".pb", "")
val graphParams = clean.split("_").take(4).map(s => s.toInt)
val Array(fileTags, fileEmbeddingsNDims, _, fileNChars) = graphParams
if (embeddingsNDims == fileEmbeddingsNDims)
Some((fileTags, fileEmbeddingsNDims, fileNChars))
else
None
}
else {
None
}
}
require(embeddingsFiltered.exists(_.nonEmpty), s"Could not find a suitable tensorflow graph for embeddings dim: $embeddingsNDims tags: $tags nChars: $nChars. " +
s"Generate graph by python code in python/tensorflow/ner/create_models before usage and use setGraphFolder Param to point to output.")
// 2. Filter by labels and nChars
val tagsFiltered = embeddingsFiltered.map {
case Some((fileTags, fileEmbeddingsNDims, fileNChars)) =>
if (tags > fileTags)
None
else
Some((fileTags, fileEmbeddingsNDims, fileNChars))
case _ => None
}
require(tagsFiltered.exists(_.nonEmpty), s"Not found tensorflow graph suitable for number of dim: $embeddingsNDims tags: $tags nChars: $nChars. " +
s"Generate graph by python code in python/tensorflow/ner/create_models before usage and use setGraphFolder Param to point to output.")
// 3. Filter by labels and nChars
val charsFiltered = tagsFiltered.map {
case Some((fileTags, fileEmbeddingsNDims, fileNChars)) =>
if (nChars > fileNChars)
None
else
Some((fileTags, fileEmbeddingsNDims, fileNChars))
case _ => None
}
require(charsFiltered.exists(_.nonEmpty), s"Not found tensorflow graph suitable for number of dim: $embeddingsNDims tags: $tags nChars: $nChars. " +
s"Generate graph by python code before usage.")
for (i <- files.indices) {
if (charsFiltered(i).nonEmpty)
return files(i)
}
throw new IllegalStateException("Code shouldn't pass here")
}
}
object NerDLApproach extends DefaultParamsReadable[NerDLApproach] with WithGraphResolver
© 2015 - 2025 Weber Informatics LLC | Privacy Policy