All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hankcs.hanlp.classification.classifiers.NaiveBayesClassifier Maven / Gradle / Ivy

There is a newer version: portable-1.8.5
Show newest version
package com.hankcs.hanlp.classification.classifiers;

import com.hankcs.hanlp.classification.utilities.MathUtility;
import com.hankcs.hanlp.collection.trie.bintrie.BinTrie;
import com.hankcs.hanlp.classification.corpus.*;
import com.hankcs.hanlp.classification.features.ChiSquareFeatureExtractor;
import com.hankcs.hanlp.classification.features.BaseFeatureData;
import com.hankcs.hanlp.classification.models.AbstractModel;
import com.hankcs.hanlp.classification.models.NaiveBayesModel;

import static com.hankcs.hanlp.classification.utilities.Predefine.logger;

import java.util.*;

/**
 * 实现一个基于多项式贝叶斯模型的文本分类器
 */
public class NaiveBayesClassifier extends AbstractClassifier
{

    private NaiveBayesModel model;

    /**
     * 由训练结果构造一个贝叶斯分类器,通常用来加载磁盘中的分类器
     *
     * @param naiveBayesModel
     */
    public NaiveBayesClassifier(NaiveBayesModel naiveBayesModel)
    {
        this.model = naiveBayesModel;
    }

    /**
     * 构造一个空白的贝叶斯分类器,通常准备用来进行训练
     */
    public NaiveBayesClassifier()
    {
        this(null);
    }

    /**
     * 获取训练结果
     *
     * @return
     */
    public NaiveBayesModel getNaiveBayesModel()
    {
        return model;
    }

    public void train(IDataSet dataSet)
    {
        logger.out("原始数据集大小:%d\n", dataSet.size());
        //选择最佳特征
        BaseFeatureData featureData = selectFeatures(dataSet);

        //初始化分类器所用的数据
        model = new NaiveBayesModel();
        model.n = featureData.n; //样本数量
        model.d = featureData.featureCategoryJointCount.length; //特征数量

        model.c = featureData.categoryCounts.length; //类目数量
        model.logPriors = new TreeMap();

        int sumCategory;
        for (int category = 0; category < featureData.categoryCounts.length; category++)
        {
            sumCategory = featureData.categoryCounts[category];
            model.logPriors.put(category, Math.log((double) sumCategory / model.n));
        }

        //拉普拉斯平滑处理(又称加一平滑),这时需要估计每个类目下的实例
        Map featureOccurrencesInCategory = new TreeMap();

        Double featureOccSum;
        for (Integer category : model.logPriors.keySet())
        {
            featureOccSum = 0.0;
            for (int feature = 0; feature < featureData.featureCategoryJointCount.length; feature++)
            {

                featureOccSum += featureData.featureCategoryJointCount[feature][category];
            }
            featureOccurrencesInCategory.put(category, featureOccSum);
        }

        //对数似然估计
        int count;
        int[] featureCategoryCounts;
        double logLikelihood;
        for (Integer category : model.logPriors.keySet())
        {
            for (int feature = 0; feature < featureData.featureCategoryJointCount.length; feature++)
            {

                featureCategoryCounts = featureData.featureCategoryJointCount[feature];

                count = featureCategoryCounts[category];

                logLikelihood = Math.log((count + 1.0) / (featureOccurrencesInCategory.get(category) + model.d));
                if (!model.logLikelihoods.containsKey(feature))
                {
                    model.logLikelihoods.put(feature, new TreeMap());
                }
                model.logLikelihoods.get(feature).put(category, logLikelihood);
            }
        }
        logger.out("贝叶斯统计结束\n");
        model.catalog = dataSet.getCatalog().toArray();
        model.tokenizer = dataSet.getTokenizer();
        model.wordIdTrie = featureData.wordIdTrie;
    }

    public AbstractModel getModel()
    {
        return model;
    }

    public Map predict(String text) throws IllegalArgumentException, IllegalStateException
    {
        if (model == null)
        {
            throw new IllegalStateException("未训练模型!无法执行预测!");
        }
        if (text == null)
        {
            throw new IllegalArgumentException("参数 text == null");
        }

        //分词,创建文档
        Document doc = new Document(model.wordIdTrie, model.tokenizer.segment(text));

        return predict(doc);
    }

    @Override
    public double[] categorize(Document document) throws IllegalArgumentException, IllegalStateException
    {
        Integer category;
        Integer feature;
        Integer occurrences;
        Double logprob;

        double[] predictionScores = new double[model.catalog.length];
        for (Map.Entry entry1 : model.logPriors.entrySet())
        {
            category = entry1.getKey();
            logprob = entry1.getValue(); //用类目的对数似然初始化概率

            //对文档中的每个特征
            for (Map.Entry entry2 : document.tfMap.entrySet())
            {
                feature = entry2.getKey();

                if (!model.logLikelihoods.containsKey(feature))
                {
                    continue; //如果在模型中找不到就跳过了
                }

                occurrences = entry2.getValue()[0]; //获取其在文档中的频次

                logprob += occurrences * model.logLikelihoods.get(feature).get(category); //将对数似然乘上频次
            }
            predictionScores[category] = logprob;
        }

        if (configProbabilityEnabled) MathUtility.normalizeExp(predictionScores);
        return predictionScores;
    }

    /**
     * 统计特征并且执行特征选择,返回一个FeatureStats对象,用于计算模型中的概率
     *
     * @param dataSet
     * @return
     */
    protected BaseFeatureData selectFeatures(IDataSet dataSet)
    {
        ChiSquareFeatureExtractor chiSquareFeatureExtractor = new ChiSquareFeatureExtractor();

        logger.start("使用卡方检测选择特征中...");
        //FeatureStats对象包含文档中所有特征及其统计信息
        BaseFeatureData featureData = chiSquareFeatureExtractor.extractBasicFeatureData(dataSet); //执行统计

        //我们传入这些统计信息到特征选择算法中,得到特征与其分值
        Map selectedFeatures = chiSquareFeatureExtractor.chi_square(featureData);

        //从统计数据中删掉无用的特征并重建特征映射表
        int[][] featureCategoryJointCount = new int[selectedFeatures.size()][];
        featureData.wordIdTrie = new BinTrie();
        String[] wordIdArray = dataSet.getLexicon().getWordIdArray();
        int p = -1;
        for (Integer feature : selectedFeatures.keySet())
        {
            featureCategoryJointCount[++p] = featureData.featureCategoryJointCount[feature];
            featureData.wordIdTrie.put(wordIdArray[feature], p);
        }
        logger.finish(",选中特征数:%d / %d = %.2f%%\n", featureCategoryJointCount.length,
                      featureData.featureCategoryJointCount.length,
                      featureCategoryJointCount.length / (double)featureData.featureCategoryJointCount.length * 100.);
        featureData.featureCategoryJointCount = featureCategoryJointCount;

        return featureData;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy