com.johnsnowlabs.ml.tensorflow.TensorflowLD.scala Maven / Gradle / Ivy
/*
* Copyright 2017-2022 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.ml.tensorflow
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}
import scala.collection.JavaConverters._
import scala.collection.immutable.ListMap
import scala.collection.mutable
/** Language Identification and Detection by using CNNs and RNNs architectures in TensowrFlow
*
* The models are trained on large datasets such as Wikipedia and Tatoeba The output is a
* language code in Wiki Code style: https://en.wikipedia.org/wiki/List_of_Wikipedias
*
* @param tensorflow
* LanguageDetectorDL Model wrapper with TensorFlow Wrapper
* @param configProtoBytes
* Configuration for TensorFlow session
* @param orderedLanguages
* ordered ListMap of language codes detectable by this trained model
* @param orderedAlphabets
* ordered ListMap of alphabets to be used to encode the inputs
*/
class TensorflowLD(
val tensorflow: TensorflowWrapper,
configProtoBytes: Option[Array[Byte]] = None,
orderedLanguages: ListMap[String, Int],
orderedAlphabets: ListMap[String, Int])
extends Serializable {
private val inputKey = "inputs:0"
private val outputKey = "output/Softmax:0"
// LD models from 2.7.0 must be 150 sequences
private val maxSentenceLength = 150
def cleanText(docs: List[String]): List[String] = {
val rmChars = "!\"#$%&()*+,-./:;<=>?@[\\\\]^_`\\{|\\}~\\t\\n"
docs.map(_.replaceAll(rmChars, "").toLowerCase())
}
def encode(docs: Seq[Sentence]): Array[Array[Float]] = {
val charsArr = orderedAlphabets.keys.toArray
docs.map { x =>
val chars = cleanText(x.content.map(_.toString).toList).take(maxSentenceLength)
val tokens = mutable.ArrayBuffer[Float]()
chars.foreach { char =>
val charID = charsArr.indexOf(char).toFloat
if (charID >= 0) {
tokens.append(charID + 1.0f)
}
}
val diff = maxSentenceLength - tokens.length
tokens.toArray ++ Array.fill(diff)(0.0f)
}.toArray
}
def tag(inputs: Array[Array[Float]], inputSize: Int, outputSize: Int): Array[Array[Float]] = {
val tensors = new TensorResources()
val tokenBuffers = tensors.createFloatBuffer(inputs.length * inputSize)
val shape = Array(inputs.length.toLong, inputSize)
inputs.zipWithIndex.foreach { case (sentence, idx) =>
val offset = idx * maxSentenceLength
tokenBuffers.offset(offset).write(sentence)
}
val runner = tensorflow.getTFSession(configProtoBytes = configProtoBytes).runner
val tokenTensors = tensors.createFloatBufferTensor(shape, tokenBuffers)
runner
.feed(inputKey, tokenTensors)
.fetch(outputKey)
val outs = runner.run().asScala
val predictions = TensorResources.extractFloats(outs.head).grouped(outputSize).toArray
tensors.clearSession(outs)
tensors.clearTensors()
predictions
}
def predict(
documents: Seq[Sentence],
threshold: Float = 0.01f,
thresholdLabel: String = "unk",
coalesceSentences: Boolean = false): Array[Annotation] = {
val sentences = encode(documents)
val outputDimension = orderedLanguages.toArray.length
val scores = tag(sentences, maxSentenceLength, outputDimension)
val langLabels = orderedLanguages.map(x => x._1.mkString).toArray
val outputs = scores.map(x => x.zip(langLabels))
if (coalesceSentences) {
val avgScores =
outputs.flatMap(x => x.toList).groupBy(_._2).mapValues(_.map(_._1).sum / outputs.length)
val maxResult = avgScores.maxBy(_._2)
val finalLabel = if (maxResult._2 >= threshold) maxResult._1 else thresholdLabel
Array(
Annotation(
annotatorType = AnnotatorType.LANGUAGE,
begin = documents.head.start,
end = documents.last.end,
result = finalLabel,
metadata = Map("sentence" -> documents.head.index.toString) ++ avgScores.flatMap(x =>
Map(x._1 -> x._2.toString))))
} else {
outputs.zip(documents).map { case (score, sentence) =>
val maxResult = score.maxBy(_._1)
val finalLabel = if (maxResult._1 >= threshold) maxResult._2 else thresholdLabel
Annotation(
annotatorType = AnnotatorType.LANGUAGE,
begin = sentence.start,
end = sentence.end,
result = finalLabel,
metadata = Map("sentence" -> sentence.index.toString) ++ score.flatMap(x =>
Map(x._2 -> x._1.toString)))
}
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy