All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.classification.BooleanPerceptronClassifier Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.classification;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.concurrent.ConcurrentSkipListMap;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermVectors;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.FSTCompiler;
import org.apache.lucene.util.fst.PositiveIntOutputs;
import org.apache.lucene.util.fst.Util;

/**
 * A perceptron (see http://en.wikipedia.org/wiki/Perceptron) based Boolean
 *  {@link org.apache.lucene.classification.Classifier}. The weights are calculated using
 * {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field and a per document
 * basis and then a corresponding {@link org.apache.lucene.util.fst.FST} is used for class
 * assignment.
 *
 * @lucene.experimental
 */
public class BooleanPerceptronClassifier implements Classifier {

  private final Double bias;
  private final Terms textTerms;
  private final Analyzer analyzer;
  private final String textFieldName;
  private FST fst;

  /**
   * Creates a {@link BooleanPerceptronClassifier}
   *
   * @param indexReader the reader on the index to be used for classification
   * @param analyzer an {@link Analyzer} used to analyze unseen text
   * @param query a {@link Query} to eventually filter the docs used for training the classifier, or
   *     {@code null} if all the indexed docs should be used
   * @param batchSize the size of the batch of docs to use for updating the perceptron weights
   * @param bias the bias used for class separation
   * @param classFieldName the name of the field used as the output for the classifier
   * @param textFieldName the name of the field used as input for the classifier
   * @throws IOException if the building of the underlying {@link FST} fails and / or {@link
   *     TermsEnum} for the text field cannot be found
   */
  public BooleanPerceptronClassifier(
      IndexReader indexReader,
      Analyzer analyzer,
      Query query,
      Integer batchSize,
      Double bias,
      String classFieldName,
      String textFieldName)
      throws IOException {
    this.textTerms = MultiTerms.getTerms(indexReader, textFieldName);

    if (textTerms == null) {
      throw new IOException("term vectors need to be available for field " + textFieldName);
    }

    this.analyzer = analyzer;
    this.textFieldName = textFieldName;

    if (bias == null || bias == 0d) {
      // automatic assign the bias to be the average total term freq
      double t =
          (double) indexReader.getSumTotalTermFreq(textFieldName)
              / (double) indexReader.getDocCount(textFieldName);
      if (t != -1) {
        this.bias = t;
      } else {
        throw new IOException(
            "bias cannot be assigned since term vectors for field "
                + textFieldName
                + " do not exist");
      }
    } else {
      this.bias = bias;
    }

    // TODO : remove this map as soon as we have a writable FST
    SortedMap weights = new ConcurrentSkipListMap<>();

    TermsEnum termsEnum = textTerms.iterator();
    BytesRef textTerm;
    while ((textTerm = termsEnum.next()) != null) {
      weights.put(textTerm.utf8ToString(), (double) termsEnum.totalTermFreq());
    }
    updateFST(weights);

    IndexSearcher indexSearcher = new IndexSearcher(indexReader);

    int batchCount = 0;

    BooleanQuery.Builder q = new BooleanQuery.Builder();
    q.add(
        new BooleanClause(
            new WildcardQuery(new Term(classFieldName, "*")), BooleanClause.Occur.MUST));
    if (query != null) {
      q.add(new BooleanClause(query, BooleanClause.Occur.MUST));
    }
    TermVectors termVectors = indexReader.termVectors();
    StoredFields storedFields = indexReader.storedFields();
    // run the search and use stored field values
    for (ScoreDoc scoreDoc : indexSearcher.search(q.build(), Integer.MAX_VALUE).scoreDocs) {
      Document doc = storedFields.document(scoreDoc.doc);

      IndexableField textField = doc.getField(textFieldName);

      // get the expected result
      IndexableField classField = doc.getField(classFieldName);

      if (textField != null && classField != null) {
        // assign class to the doc
        ClassificationResult classificationResult = assignClass(textField.stringValue());
        Boolean assignedClass = classificationResult.assignedClass();

        Boolean correctClass = Boolean.valueOf(classField.stringValue());
        double modifier = Math.signum(correctClass.compareTo(assignedClass));
        if (modifier != 0D) {
          updateWeights(
              termVectors,
              scoreDoc.doc,
              assignedClass,
              weights,
              modifier,
              batchCount % batchSize == 0);
        }
        batchCount++;
      }
    }
    weights.clear(); // free memory while waiting for GC
  }

  private void updateWeights(
      TermVectors termVectors,
      int docId,
      Boolean assignedClass,
      SortedMap weights,
      double modifier,
      boolean updateFST)
      throws IOException {
    TermsEnum cte = textTerms.iterator();

    // get the doc term vectors
    Terms terms = termVectors.get(docId, textFieldName);

    if (terms == null) {
      throw new IOException("term vectors must be stored for field " + textFieldName);
    }

    TermsEnum termsEnum = terms.iterator();

    BytesRef term;

    while ((term = termsEnum.next()) != null) {
      cte.seekExact(term);
      if (assignedClass != null) {
        long termFreqLocal = termsEnum.totalTermFreq();
        // update weights
        Long previousValue = Util.get(fst, term);
        String termString = term.utf8ToString();
        weights.put(
            termString,
            previousValue == null ? 0 : Math.max(0, previousValue + modifier * termFreqLocal));
      }
    }
    if (updateFST) {
      updateFST(weights);
    }
  }

  private void updateFST(SortedMap weights) throws IOException {
    PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
    FSTCompiler fstCompiler =
        new FSTCompiler.Builder<>(FST.INPUT_TYPE.BYTE1, outputs).build();
    BytesRefBuilder scratchBytes = new BytesRefBuilder();
    IntsRefBuilder scratchInts = new IntsRefBuilder();
    for (Map.Entry entry : weights.entrySet()) {
      scratchBytes.copyChars(entry.getKey());
      fstCompiler.add(
          Util.toIntsRef(scratchBytes.get(), scratchInts), entry.getValue().longValue());
    }
    fst = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
  }

  @Override
  public ClassificationResult assignClass(String text) throws IOException {
    long output = 0L;
    try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
      CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
      tokenStream.reset();
      while (tokenStream.incrementToken()) {
        String s = charTermAttribute.toString();
        Long d = Util.get(fst, new BytesRef(s));
        if (d != null) {
          output += d;
        }
      }
      tokenStream.end();
    }

    double score = 1 - Math.exp(-1 * Math.abs(bias - (double) output) / bias);
    return new ClassificationResult<>(output >= bias, score);
  }

  @Override
  public List> getClasses(String text) {
    return null;
  }

  @Override
  public List> getClasses(String text, int max) {
    return null;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy