
nlp-model-stanford.23.3.4.source-code.StanfordEntityClassifier.kt Maven / Gradle / Ivy
/*
* This file is part of the tock-corenlp distribution.
* (https://github.com/theopenconversationkit/tock-corenlp)
* Copyright (c) 2017 VSCT.
*
* tock-corenlp is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, version 3.
*
* tock-corenlp is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package ai.tock.nlp.stanford
import ai.tock.nlp.core.Entity
import ai.tock.nlp.core.EntityRecognition
import ai.tock.nlp.core.EntityValue
import ai.tock.nlp.core.IntOpenRange
import ai.tock.nlp.model.EntityCallContext
import ai.tock.nlp.model.EntityCallContextForEntity
import ai.tock.nlp.model.EntityCallContextForIntent
import ai.tock.nlp.model.EntityCallContextForSubEntities
import ai.tock.nlp.model.service.engine.EntityModelHolder
import ai.tock.nlp.model.service.engine.NlpEntityClassifier
import ai.tock.nlp.stanford.StanfordModelBuilder.ADJACENT_ENTITY_MARKER
import ai.tock.nlp.stanford.StanfordModelBuilder.TAB
import edu.stanford.nlp.ie.crf.CRFClassifier
import edu.stanford.nlp.ling.CoreAnnotations
import edu.stanford.nlp.ling.CoreLabel
import mu.KotlinLogging
internal class StanfordEntityClassifier(model: EntityModelHolder) : NlpEntityClassifier(model) {
companion object {
private val logger = KotlinLogging.logger {}
private val adjacentMarkerRegep = ADJACENT_ENTITY_MARKER.toRegex()
}
private data class Token(
override val start: Int,
override val end: Int,
val text: String,
val type: String
) : IntOpenRange
override fun classifyEntities(
context: EntityCallContext,
text: String,
tokens: Array
): List {
return when (context) {
is EntityCallContextForIntent -> classifyEntities(context, text, tokens)
is EntityCallContextForEntity -> error("EntityCallContextForEntity is not supported")
is EntityCallContextForSubEntities -> classifyEntities(context, text, tokens)
}
}
private fun classifyEntities(
context: EntityCallContextForSubEntities,
text: String,
tokens: Array
): List {
return classifyEntities(text, tokens) { context.entityType.findSubEntity(it) }
}
private fun classifyEntities(
context: EntityCallContextForIntent,
text: String,
tokens: Array
): List {
return classifyEntities(text, tokens) { context.intent.getEntity(it) }
}
private fun classifyEntities(
text: String,
tokens: Array,
entityFinder: (String) -> Entity?
): List {
return try {
with(model) {
@Suppress("UNCHECKED_CAST")
val classifier = nativeModel as CRFClassifier
val evaluationData = getEvaluationData(tokens)
val documents = classifier.makeObjectBankFromString(evaluationData, classifier.defaultReaderAndWriter())
val document = documents.flatten()
val classifiedLabels = classifier.classify(document)
val confidence = getConfidence(classifier, classifiedLabels)
val coreTokens = mutableListOf()
var previousToken: Token? = null
document.forEachIndexed { index, word ->
var t = text
var start = 0
for (i in 0 until index) {
val nextTokenIndex = document[i].word().length + t.indexOf(document[i].word())
start += nextTokenIndex
t = t.substring(nextTokenIndex)
}
start += t.indexOf(document[index].word())
val end = start + document[index].word().length
val entityRole = word.get(CoreAnnotations.AnswerAnnotation::class.java)
if (entityRole != "O") {
if (previousToken?.type != entityRole) {
previousToken = Token(start, end, word.word(), entityRole)
} else {
coreTokens.removeAt(coreTokens.lastIndex)
val w = text.substring((previousToken as Token).start, end)
previousToken = Token((previousToken as Token).start, end, w, entityRole)
}
val tok = previousToken!!
coreTokens.add(Token(tok.start, tok.end, tok.text, tok.type))
} else {
previousToken = null
}
}
coreTokens.mapNotNull {
val entity = entityFinder.invoke(it.type.replaceFirst(adjacentMarkerRegep, ""))
if (entity == null) {
logger.warn { "unknown entity role ${it.type}" }
null
} else {
EntityRecognition(EntityValue(it.start, it.end, entity), confidence)
}
}
}
} catch (e: Exception) {
logger.error("error with $text and ${tokens.contentToString()}", e)
emptyList()
}
}
private fun getEvaluationData(tokens: Array): String =
tokens.joinToString(separator = "") { "$it${TAB}O\n" }
private fun getConfidence(classifier: CRFClassifier, classifiedLabels: List): Double {
try {
//TODO confidence by entity
var counter = 0
var probSum = 0.0
val cliqueTree = classifier.getCliqueTree(classifiedLabels)
for (i in 0 until cliqueTree.length()) {
val wi = classifiedLabels[i]
val index = classifier.classIndex.indexOf(wi.get(CoreAnnotations.AnswerAnnotation::class.java))
probSum = cliqueTree.prob(i, index)
counter++
}
val prob = 1 - (probSum / counter)
return Math.round(prob * 1000) / 1000.0
} catch (e: Exception) {
logger.error(e) { "Exception during confidence calculation - skipped" }
return 0.1
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy