All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.classification.SimpleNaiveBayesClassifier Maven / Gradle / Ivy

There is a newer version: 10.0.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.classification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;

/**
 * A simplistic Lucene based NaiveBayes classifier, see http://en.wikipedia.org/wiki/Naive_Bayes_classifier
 *
 * @lucene.experimental
 */
public class SimpleNaiveBayesClassifier implements Classifier {

  /**
   * {@link org.apache.lucene.index.IndexReader} used to access the {@link org.apache.lucene.classification.Classifier}'s
   * index
   */
  protected final IndexReader indexReader;

  /**
   * names of the fields to be used as input text
   */
  protected final String[] textFieldNames;

  /**
   * name of the field to be used as a class / category output
   */
  protected final String classFieldName;

  /**
   * {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing unseen input text
   */
  protected final Analyzer analyzer;

  /**
   * {@link org.apache.lucene.search.IndexSearcher} to run searches on the index for retrieving frequencies
   */
  protected final IndexSearcher indexSearcher;

  /**
   * {@link org.apache.lucene.search.Query} used to eventually filter the document set to be used to classify
   */
  protected final Query query;

  /**
   * Creates a new NaiveBayes classifier.
   *
   * @param indexReader     the reader on the index to be used for classification
   * @param analyzer       an {@link Analyzer} used to analyze unseen text
   * @param query          a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
   *                       if all the indexed docs should be used
   * @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
   *                       as the returned class will be a token indexed for this field
   * @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field
   */
  public SimpleNaiveBayesClassifier(IndexReader indexReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
    this.indexReader = indexReader;
    this.indexSearcher = new IndexSearcher(this.indexReader);
    this.textFieldNames = textFieldNames;
    this.classFieldName = classFieldName;
    this.analyzer = analyzer;
    this.query = query;
  }

  @Override
  public ClassificationResult assignClass(String inputDocument) throws IOException {
    List> assignedClasses = assignClassNormalizedList(inputDocument);
    ClassificationResult assignedClass = null;
    double maxscore = -Double.MAX_VALUE;
    for (ClassificationResult c : assignedClasses) {
      if (c.getScore() > maxscore) {
        assignedClass = c;
        maxscore = c.getScore();
      }
    }
    return assignedClass;
  }

  @Override
  public List> getClasses(String text) throws IOException {
    List> assignedClasses = assignClassNormalizedList(text);
    Collections.sort(assignedClasses);
    return assignedClasses;
  }

  @Override
  public List> getClasses(String text, int max) throws IOException {
    List> assignedClasses = assignClassNormalizedList(text);
    Collections.sort(assignedClasses);
    return assignedClasses.subList(0, max);
  }

  /**
   * Calculate probabilities for all classes for a given input text
   * @param inputDocument the input text as a {@code String}
   * @return a {@code List} of {@code ClassificationResult}, one for each existing class
   * @throws IOException if assigning probabilities fails
   */
  protected List> assignClassNormalizedList(String inputDocument) throws IOException {
    List> assignedClasses = new ArrayList<>();

    Terms classes = MultiTerms.getTerms(indexReader, classFieldName);
    if (classes != null) {
      TermsEnum classesEnum = classes.iterator();
      BytesRef next;
      String[] tokenizedText = tokenize(inputDocument);
      int docsWithClassSize = countDocsWithClass();
      while ((next = classesEnum.next()) != null) {
        if (next.length > 0) {
          Term term = new Term(this.classFieldName, next);
          double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
          assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
        }
      }
    }
    // normalization; the values transforms to a 0-1 range
    return normClassificationResults(assignedClasses);
  }

  /**
   * count the number of documents in the index having at least a value for the 'class' field
   *
   * @return the no. of documents having a value for the 'class' field
   * @throws IOException if accessing to term vectors or search fails
   */
  protected int countDocsWithClass() throws IOException {
    Terms terms = MultiTerms.getTerms(this.indexReader, this.classFieldName);
    int docCount;
    if (terms == null || terms.getDocCount() == -1) { // in case codec doesn't support getDocCount
      TotalHitCountCollector classQueryCountCollector = new TotalHitCountCollector();
      BooleanQuery.Builder q = new BooleanQuery.Builder();
      q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, String.valueOf(WildcardQuery.WILDCARD_STRING))), BooleanClause.Occur.MUST));
      if (query != null) {
        q.add(query, BooleanClause.Occur.MUST);
      }
      indexSearcher.search(q.build(),
          classQueryCountCollector);
      docCount = classQueryCountCollector.getTotalHits();
    } else {
      docCount = terms.getDocCount();
    }
    return docCount;
  }

  /**
   * tokenize a String on this classifier's text fields and analyzer
   *
   * @param text the String representing an input text (to be classified)
   * @return a String array of the resulting tokens
   * @throws IOException if tokenization fails
   */
  protected String[] tokenize(String text) throws IOException {
    Collection result = new LinkedList<>();
    for (String textFieldName : textFieldNames) {
      try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
        CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
        tokenStream.reset();
        while (tokenStream.incrementToken()) {
          result.add(charTermAttribute.toString());
        }
        tokenStream.end();
      }
    }
    return result.toArray(new String[result.size()]);
  }

  private double calculateLogLikelihood(String[] tokenizedText, Term term, int docsWithClass) throws IOException {
    // for each word
    double result = 0d;
    for (String word : tokenizedText) {
      // search with text:word AND class:c
      int hits = getWordFreqForClass(word, term);

      // num : count the no of times the word appears in documents of class c (+1)
      double num = hits + 1; // +1 is added because of add 1 smoothing

      // den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
      double den = getTextTermFreqForClass(term) + docsWithClass;

      // P(w|c) = num/den
      double wordProbability = num / den;
      result += Math.log(wordProbability);
    }

    // log(P(d|c)) = log(P(w1|c))+...+log(P(wn|c))
    return result;
  }

  /**
   * Returns the average number of unique terms times the number of docs belonging to the input class
   * @param term the term representing the class
   * @return the average number of unique terms
   * @throws IOException if a low level I/O problem happens
   */
  private double getTextTermFreqForClass(Term term) throws IOException {
    double avgNumberOfUniqueTerms = 0;
    for (String textFieldName : textFieldNames) {
      Terms terms = MultiTerms.getTerms(indexReader, textFieldName);
      long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
      avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
    }
    int docsWithC = indexReader.docFreq(term);
    return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
  }

  /**
   * Returns the number of documents of the input class ( from the whole index or from a subset)
   * that contains the word ( in a specific field or in all the fields if no one selected)
   * @param word the token produced by the analyzer
   * @param term the term representing the class
   * @return the number of documents of the input class
   * @throws IOException if a low level I/O problem happens
   */
  private int getWordFreqForClass(String word, Term term) throws IOException {
    BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
    BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
    for (String textFieldName : textFieldNames) {
      subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
    }
    booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST));
    booleanQuery.add(new BooleanClause(new TermQuery(term), BooleanClause.Occur.MUST));
    if (query != null) {
      booleanQuery.add(query, BooleanClause.Occur.MUST);
    }
    TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
    indexSearcher.search(booleanQuery.build(), totalHitCountCollector);
    return totalHitCountCollector.getTotalHits();
  }

  private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {
    return Math.log((double) docCount(term)) - Math.log(docsWithClassSize);
  }

  private int docCount(Term term) throws IOException {
    return indexReader.docFreq(term);
  }

  /**
   * Normalize the classification results based on the max score available
   * @param assignedClasses the list of assigned classes
   * @return the normalized results
   */
  protected ArrayList> normClassificationResults(List> assignedClasses) {
    // normalization; the values transforms to a 0-1 range
    ArrayList> returnList = new ArrayList<>();
    if (!assignedClasses.isEmpty()) {
      Collections.sort(assignedClasses);
      // this is a negative number closest to 0 = a
      double smax = assignedClasses.get(0).getScore();

      double sumLog = 0;
      // log(sum(exp(x_n-a)))
      for (ClassificationResult cr : assignedClasses) {
        // getScore-smax <=0 (both negative, smax is the smallest abs()
        sumLog += Math.exp(cr.getScore() - smax);
      }
      // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
      double loga = smax;
      loga += Math.log(sumLog);

      // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
      for (ClassificationResult cr : assignedClasses) {
        double scoreDiff = cr.getScore() - loga;
        returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(scoreDiff)));
      }
    }
    return returnList;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy