All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.nlp.perceptron.EvaluateUtils.kt Maven / Gradle / Ivy

/*
 * Copyright 2018 mayabot.com authors. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.mayabot.nlp.perceptron

import com.mayabot.nlp.common.utils.CharNormUtils


/**
 * 简单评估。只考虑标签相等性
 */
fun simpleEvaluate(model: PerceptronModel, samples: List): EvaluateResult {

    var total = 0
    var right = 0

    samples.forEach { sample ->
        total += sample.label.size
        val gold = sample.label
        val pred = model.decode(sample.featureSequence)
        for (i in sample.label.indices) {
            if (gold[i] == pred[i]) {
                right++
            }
        }
    }

    return EvaluateResult(total, total, right)
}

fun segmentEvaluateFunction(
        textSegment:(String)->List,
        split:String = "﹍",
        verbose:Boolean = false):EvaluateFunction{


    return EvaluateFunction{evaluateSample->
        // int goldTotal = 0, predTotal = 0, correct = 0;
        var goldTotal = 0
        var predTotal = 0
        var correct = 0


        //val splitter = Splitter.on(split).omitEmptyStrings()

        if (verbose) System.out.print("Evaluating 0%")

        val t1 = System.currentTimeMillis()

        var count = 0
        for (line in evaluateSample) {

            val wordArray = CharNormUtils.convert(line).split(split).filter { it.isNotBlank() }
            goldTotal += wordArray.size

            val text = wordArray.joinToString(separator = "")
            val predArray = textSegment(text)
            predTotal += predArray.size

            correct += wordCorrect(wordArray, predArray)

            count++

            if (count % 2000 == 0) {
                if(verbose) System.out.print("\rEvaluating ${"%.2f".format(count * 100.0 / evaluateSample.size)}%")
            }

        }

        fun prf(goldTotal: Int, predTotal: Int, correct: Int): DoubleArray {
            val precision = correct * 100.0 / predTotal
            val recall = correct * 100.0 / goldTotal
            val performance = DoubleArray(3)
            performance[0] = precision
            performance[1] = recall
            performance[2] = 2.0 * precision * recall / (precision + recall)
            return performance
        }

        val result = prf(goldTotal, predTotal, correct)

        if(verbose) System.out.print("\r")

        val t2 = System.currentTimeMillis()

        if(verbose) System.out.println("正确率(P) %.2f , 召回率(R) %.2f , F1 %.2f".format(result[0], result[1], result[2]))
        if(verbose) println("Evaluate use time ${t2 - t1} ms")

        EvaluateResult(result[0].toFloat(),result[1].toFloat())
    }
}

/**
 * 计算分词正确的词数
 */
fun wordCorrect(gold: List, pred: List): Int {
    var goldIndex = 0
    var predIndex = 0
    var goldLen = 0
    var predLen = 0

    var correct = 0
    while (goldIndex < gold.size && predIndex < pred.size) {
        if (goldLen == predLen) {
            if (gold[goldIndex] == pred[predIndex]) {
                correct++
                goldLen += gold[goldIndex].length
                predLen += gold[goldIndex].length
                goldIndex++
                predIndex++
            } else {
                goldLen += gold[goldIndex].length
                predLen += pred[predIndex].length
                goldIndex++
                predIndex++
            }
        } else if (goldLen < predLen) {
            goldLen += gold[goldIndex].length
            goldIndex++
        } else {
            predLen += pred[predIndex].length
            predIndex++
        }
    }
    return correct
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy