All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apdplat.qa.system.QuestionAnsweringSystemImpl Maven / Gradle / Ivy

/**
 * 
 * APDPlat - Application Product Development Platform
 * Copyright (c) 2013, 杨尚川, [email protected]
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 * 
 */

package org.apdplat.qa.system;

import java.util.ArrayList;
import java.util.List;

import org.apdplat.qa.datasource.DataSource;
import org.apdplat.qa.datasource.FileDataSource;
import org.apdplat.qa.files.FilesConfig;
import org.apdplat.qa.model.CandidateAnswer;
import org.apdplat.qa.model.CandidateAnswerCollection;
import org.apdplat.qa.model.Evidence;
import org.apdplat.qa.model.Question;
import org.apdplat.qa.model.QuestionType;
import org.apdplat.qa.questiontypeanalysis.QuestionClassifier;
import org.apdplat.qa.score.answer.CandidateAnswerScore;
import org.apdplat.qa.score.answer.CombinationCandidateAnswerScore;
import org.apdplat.qa.score.answer.HotCandidateAnswerScore;
import org.apdplat.qa.score.answer.MoreTextualAlignmentCandidateAnswerScore;
import org.apdplat.qa.score.answer.RewindTextualAlignmentCandidateAnswerScore;
import org.apdplat.qa.score.answer.TermDistanceCandidateAnswerScore;
import org.apdplat.qa.score.answer.TermDistanceMiniCandidateAnswerScore;
import org.apdplat.qa.score.answer.TermFrequencyCandidateAnswerScore;
import org.apdplat.qa.score.answer.TextualAlignmentCandidateAnswerScore;
import org.apdplat.qa.score.evidence.BigramEvidenceScore;
import org.apdplat.qa.score.evidence.CombinationEvidenceScore;
import org.apdplat.qa.score.evidence.EvidenceScore;
import org.apdplat.qa.score.evidence.SkipBigramEvidenceScore;
import org.apdplat.qa.score.evidence.TermMatchEvidenceScore;
import org.apdplat.qa.select.CandidateAnswerSelect;
import org.apdplat.qa.select.CommonCandidateAnswerSelect;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * 使用此问答系统实现要指定4个组件: 1、问答系统使用的数据源(不可以同时使用多个数据源) 2、候选答案提取器(不可以同时使用多个提取器)
 * 3、证据评分组件(可以同时使用多个组件) 4、候选答案评分组件(可以同时使用多个组件)
 *
 * @author 杨尚川
 */
public class QuestionAnsweringSystemImpl implements QuestionAnsweringSystem {

    private static final Logger LOG = LoggerFactory.getLogger(QuestionAnsweringSystemImpl.class);

    private int questionIndex = 1;
    private double mrr;

    private final List perfectQuestions = new ArrayList<>();
    private final List notPerfectQuestions = new ArrayList<>();
    private final List wrongQuestions = new ArrayList<>();
    private final List unknownTypeQuestions = new ArrayList<>();

    private QuestionClassifier questionClassifier;
    private DataSource dataSource;
    private CandidateAnswerSelect candidateAnswerSelect;
    private EvidenceScore evidenceScore;
    private CandidateAnswerScore candidateAnswerScore;

    @Override
    public void setQuestionClassifier(QuestionClassifier questionClassifier) {
        this.questionClassifier = questionClassifier;
    }

    @Override
    public QuestionClassifier getQuestionClassifier() {
        return this.questionClassifier;
    }

    @Override
    public void setDataSource(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    @Override
    public DataSource getDataSource() {
        return this.dataSource;
    }

    @Override
    public void setCandidateAnswerSelect(CandidateAnswerSelect candidateAnswerSelect) {
        this.candidateAnswerSelect = candidateAnswerSelect;
    }

    @Override
    public CandidateAnswerSelect getCandidateAnswerSelect() {
        return this.candidateAnswerSelect;
    }

    @Override
    public void setEvidenceScore(EvidenceScore evidenceScore) {
        this.evidenceScore = evidenceScore;
    }

    @Override
    public EvidenceScore getEvidenceScore() {
        return this.evidenceScore;

    }

    @Override
    public void setCandidateAnswerScore(CandidateAnswerScore candidateAnswerScore) {
        this.candidateAnswerScore = candidateAnswerScore;
    }

    @Override
    public CandidateAnswerScore getCandidateAnswerScore() {
        return this.candidateAnswerScore;
    }

    @Override
    public int getQuestionCount() {
        return getPerfectCount() + getNotPerfectCount() + getWrongCount();
    }

    @Override
    public int getPerfectCount() {
        return perfectQuestions.size();
    }

    @Override
    public int getNotPerfectCount() {
        return notPerfectQuestions.size();
    }

    @Override
    public int getWrongCount() {
        return wrongQuestions.size();
    }

    @Override
    public int getUnknownTypeCount() {
        return unknownTypeQuestions.size();
    }

    @Override
    public Question answerQuestion(String questionStr) {
        Question question = dataSource.getQuestion(questionStr);
        if (question != null) {
            return answerQuestion(question);
        }
        return null;
    }

    @Override
    public List answerQuestions() {
        return dataSource.getAndAnswerQuestions(this);
    }

    @Override
    public Question answerQuestion(Question question) {
        if (question != null) {
            List questions = new ArrayList<>();
            questions.add(question);

            return answerQuestions(questions).get(0);
        }
        return null;
    }

    @Override
    public List answerQuestions(List questions) {
        for (Question question : questions) {
            question = questionClassifier.classify(question);
            LOG.info("开始处理Question " + (questionIndex++) + ":" + question.getQuestion() + " 【问题类型:" + question.getQuestionType() + "】");
            if (question.getQuestionType() == QuestionType.NULL) {
                unknownTypeQuestions.add(question);
                //未知类型按回答错误处理
                wrongQuestions.add(question);
                LOG.error("未知的问题类型,拒绝回答!!!");
                continue;
            }
            int i = 1;
            for (Evidence evidence : question.getEvidences()) {
                LOG.debug("开始处理Evidence " + (i++));
                //对证据进行评分
                //证据的分值存储在evidence对象里面
                evidenceScore.score(question, evidence);

                LOG.debug("Evidence Detail");
                LOG.debug("Title:" + evidence.getTitle());
                LOG.debug("Snippet:" + evidence.getSnippet());
                LOG.debug("Score:" + evidence.getScore());
                LOG.debug("Terms:" + evidence.getTerms());
                //提取候选答案
                //候选答案存储在evidence对象里面
                candidateAnswerSelect.select(question, evidence);
                //从evidence对象里面获得候选答案
                CandidateAnswerCollection candidateAnswerCollection = evidence.getCandidateAnswerCollection();

                if (!candidateAnswerCollection.isEmpty()) {
                    LOG.debug("Evidence候选答案(未评分):");
                    candidateAnswerCollection.showAll();
                    LOG.debug("");
                    //对候选答案进行打分
                    candidateAnswerScore.score(question, evidence, candidateAnswerCollection);
                    LOG.debug("Evidence候选答案(已评分):");
                    candidateAnswerCollection.showAll();
                    LOG.debug("");
                } else {
                    LOG.debug("Evidence无候选答案");
                }

                LOG.debug("");
            }
            LOG.info("************************************");
            LOG.info("************************************");
            LOG.info("Question " + question.getQuestion());
            LOG.info("Question 候选答案:");
            for (CandidateAnswer candidateAnswer : question.getAllCandidateAnswer()) {
                LOG.info(candidateAnswer.getAnswer() + "  " + candidateAnswer.getScore());
            }
            int rank = question.getExpectAnswerRank();
            LOG.info("ExpectAnswerRank: " + rank);
            LOG.info("");
            //完美答案
            if (rank == 1) {
                perfectQuestions.add(question);
            }
            //不完美答案
            if (rank > 1) {
                notPerfectQuestions.add(question);
            }
            //错误答案
            if (rank == -1) {
                wrongQuestions.add(question);
            }
            //计算mrr
            if (rank > 0) {
                mrr += (double) 1 / rank;
            }
            LOG.info("mrr: " + mrr);
            LOG.info("perfectCount: " + getPerfectCount());
            LOG.info("notPerfectCount: " + getNotPerfectCount());
            LOG.info("wrongCount: " + getWrongCount());
            LOG.info("unknownTypeCount: " + getUnknownTypeCount());
            LOG.info("questionCount: " + getQuestionCount());
        }
        LOG.info("");

        LOG.info("MRR:" + getMRR() * 100 + "%");
        LOG.info("回答完美率:" + (double) getPerfectCount() / getQuestionCount() * 100 + "%");
        LOG.info("回答不完美率:" + (double) getNotPerfectCount() / getQuestionCount() * 100 + "%");
        LOG.info("回答错误率:" + (double) getWrongCount() / getQuestionCount() * 100 + "%");
        LOG.info("未知类型率:" + (double) getUnknownTypeCount() / getQuestionCount() * 100 + "%");

        LOG.info("");

        return questions;
    }

    @Override
    public double getMRR() {
        return (double) mrr / getQuestionCount();
    }

    @Override
    public void showPerfectQuestions() {
        LOG.info("回答完美的问题:");
        int i = 1;
        for (Question question : perfectQuestions) {
            LOG.info((i++) + "、" + question.getQuestion() + " : " + question.getExpectAnswerRank());
        }
    }

    @Override
    public void showNotPerfectQuestions() {
        LOG.info("回答不完美的问题:");
        int i = 1;
        for (Question question : notPerfectQuestions) {
            LOG.info((i++) + "、" + question.getQuestion() + " : " + question.getExpectAnswerRank());
        }
    }

    @Override
    public void showWrongQuestions() {
        LOG.info("回答错误的问题:");
        int i = 1;
        for (Question question : wrongQuestions) {
            LOG.info((i++) + "、" + question.getQuestion());
        }
    }

    @Override
    public void showUnknownTypeQuestions() {
        LOG.info("未知类型的问题:");
        int i = 1;
        for (Question question : unknownTypeQuestions) {
            LOG.info((i++) + "、" + question.getQuestion());
        }
    }

    @Override
    public List getPerfectQuestions() {
        return perfectQuestions;
    }

    @Override
    public List getNotPerfectQuestions() {
        return notPerfectQuestions;
    }

    @Override
    public List getWrongQuestions() {
        return wrongQuestions;
    }

    @Override
    public List getUnknownTypeQuestions() {
        return unknownTypeQuestions;
    }

    /**
     * @param args
     */
    public static void main(String[] args) {
        //1、默认评分组件权重
        ScoreWeight scoreWeight = new ScoreWeight();

        //2、问答系统数据源(人名文件数据源)
        DataSource dataSource = new FileDataSource(FilesConfig.personNameMaterial);

        //3、候选答案提取器(不可以同时使用多个提取器)
        CandidateAnswerSelect candidateAnswerSelect = new CommonCandidateAnswerSelect();

        //4、证据评分组件(可以同时使用多个组件)
        //***********************
        //4.1、TermMatch评分组件
        EvidenceScore termMatchEvidenceScore = new TermMatchEvidenceScore();
        termMatchEvidenceScore.setScoreWeight(scoreWeight);
        //4.2、二元模型评分组件
        EvidenceScore bigramEvidenceScore = new BigramEvidenceScore();
        bigramEvidenceScore.setScoreWeight(scoreWeight);
        //4.3、跳跃二元模型评分组件
        EvidenceScore skipBigramEvidenceScore = new SkipBigramEvidenceScore();
        skipBigramEvidenceScore.setScoreWeight(scoreWeight);
        //4.4、组合证据评分组件
        CombinationEvidenceScore combinationEvidenceScore = new CombinationEvidenceScore();
        combinationEvidenceScore.addEvidenceScore(termMatchEvidenceScore);
        combinationEvidenceScore.addEvidenceScore(bigramEvidenceScore);
        combinationEvidenceScore.addEvidenceScore(skipBigramEvidenceScore);

        //5、候选答案评分组件(可以同时使用多个组件)
        //***********************
        //5.1、词频评分组件
        CandidateAnswerScore termFrequencyCandidateAnswerScore = new TermFrequencyCandidateAnswerScore();
        termFrequencyCandidateAnswerScore.setScoreWeight(scoreWeight);
        //5.2、词距评分组件
        CandidateAnswerScore termDistanceCandidateAnswerScore = new TermDistanceCandidateAnswerScore();
        termDistanceCandidateAnswerScore.setScoreWeight(scoreWeight);
        //5.3、词距评分组件(只取候选词和问题词的最短距离)
        CandidateAnswerScore termDistanceMiniCandidateAnswerScore = new TermDistanceMiniCandidateAnswerScore();
        termDistanceMiniCandidateAnswerScore.setScoreWeight(scoreWeight);
        //5.4、文本对齐评分组件
        CandidateAnswerScore textualAlignmentCandidateAnswerScore = new TextualAlignmentCandidateAnswerScore();
        textualAlignmentCandidateAnswerScore.setScoreWeight(scoreWeight);
        //5.5、文本对齐评分组件
        CandidateAnswerScore moreTextualAlignmentCandidateAnswerScore = new MoreTextualAlignmentCandidateAnswerScore();
        moreTextualAlignmentCandidateAnswerScore.setScoreWeight(scoreWeight);
        //5.6、回带文本对齐评分组件
        CandidateAnswerScore rewindTextualAlignmentCandidateAnswerScore = new RewindTextualAlignmentCandidateAnswerScore();
        rewindTextualAlignmentCandidateAnswerScore.setScoreWeight(scoreWeight);
        //5.7、热词评分组件
        CandidateAnswerScore hotCandidateAnswerScore = new HotCandidateAnswerScore();
        hotCandidateAnswerScore.setScoreWeight(scoreWeight);
        //5.8、组合候选答案评分组件
        CombinationCandidateAnswerScore combinationCandidateAnswerScore = new CombinationCandidateAnswerScore();
        combinationCandidateAnswerScore.addCandidateAnswerScore(termFrequencyCandidateAnswerScore);
        combinationCandidateAnswerScore.addCandidateAnswerScore(termDistanceCandidateAnswerScore);
        combinationCandidateAnswerScore.addCandidateAnswerScore(termDistanceMiniCandidateAnswerScore);
        combinationCandidateAnswerScore.addCandidateAnswerScore(textualAlignmentCandidateAnswerScore);
        combinationCandidateAnswerScore.addCandidateAnswerScore(moreTextualAlignmentCandidateAnswerScore);
        //combinationCandidateAnswerScore.addCandidateAnswerScore(rewindTextualAlignmentCandidateAnswerScore);
        combinationCandidateAnswerScore.addCandidateAnswerScore(hotCandidateAnswerScore);

        QuestionAnsweringSystem questionAnsweringSystem = new QuestionAnsweringSystemImpl();
        questionAnsweringSystem.setDataSource(dataSource);
        questionAnsweringSystem.setCandidateAnswerSelect(candidateAnswerSelect);
        questionAnsweringSystem.setEvidenceScore(combinationEvidenceScore);
        questionAnsweringSystem.setCandidateAnswerScore(combinationCandidateAnswerScore);
        questionAnsweringSystem.answerQuestions();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy