org.apache.lucene.classification.BooleanPerceptronClassifier Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of lucene-classification Show documentation
Show all versions of lucene-classification Show documentation
Apache Lucene (module: classification)
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.classification;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.concurrent.ConcurrentSkipListMap;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermVectors;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.FSTCompiler;
import org.apache.lucene.util.fst.PositiveIntOutputs;
import org.apache.lucene.util.fst.Util;
/**
* A perceptron (see http://en.wikipedia.org/wiki/Perceptron
) based Boolean
*
{@link org.apache.lucene.classification.Classifier}. The weights are calculated using
* {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field and a per document
* basis and then a corresponding {@link org.apache.lucene.util.fst.FST} is used for class
* assignment.
*
* @lucene.experimental
*/
public class BooleanPerceptronClassifier implements Classifier {
private final Double bias;
private final Terms textTerms;
private final Analyzer analyzer;
private final String textFieldName;
private FST fst;
/**
* Creates a {@link BooleanPerceptronClassifier}
*
* @param indexReader the reader on the index to be used for classification
* @param analyzer an {@link Analyzer} used to analyze unseen text
* @param query a {@link Query} to eventually filter the docs used for training the classifier, or
* {@code null} if all the indexed docs should be used
* @param batchSize the size of the batch of docs to use for updating the perceptron weights
* @param bias the bias used for class separation
* @param classFieldName the name of the field used as the output for the classifier
* @param textFieldName the name of the field used as input for the classifier
* @throws IOException if the building of the underlying {@link FST} fails and / or {@link
* TermsEnum} for the text field cannot be found
*/
public BooleanPerceptronClassifier(
IndexReader indexReader,
Analyzer analyzer,
Query query,
Integer batchSize,
Double bias,
String classFieldName,
String textFieldName)
throws IOException {
this.textTerms = MultiTerms.getTerms(indexReader, textFieldName);
if (textTerms == null) {
throw new IOException("term vectors need to be available for field " + textFieldName);
}
this.analyzer = analyzer;
this.textFieldName = textFieldName;
if (bias == null || bias == 0d) {
// automatic assign the bias to be the average total term freq
double t =
(double) indexReader.getSumTotalTermFreq(textFieldName)
/ (double) indexReader.getDocCount(textFieldName);
if (t != -1) {
this.bias = t;
} else {
throw new IOException(
"bias cannot be assigned since term vectors for field "
+ textFieldName
+ " do not exist");
}
} else {
this.bias = bias;
}
// TODO : remove this map as soon as we have a writable FST
SortedMap weights = new ConcurrentSkipListMap<>();
TermsEnum termsEnum = textTerms.iterator();
BytesRef textTerm;
while ((textTerm = termsEnum.next()) != null) {
weights.put(textTerm.utf8ToString(), (double) termsEnum.totalTermFreq());
}
updateFST(weights);
IndexSearcher indexSearcher = new IndexSearcher(indexReader);
int batchCount = 0;
BooleanQuery.Builder q = new BooleanQuery.Builder();
q.add(
new BooleanClause(
new WildcardQuery(new Term(classFieldName, "*")), BooleanClause.Occur.MUST));
if (query != null) {
q.add(new BooleanClause(query, BooleanClause.Occur.MUST));
}
TermVectors termVectors = indexReader.termVectors();
StoredFields storedFields = indexReader.storedFields();
// run the search and use stored field values
for (ScoreDoc scoreDoc : indexSearcher.search(q.build(), Integer.MAX_VALUE).scoreDocs) {
Document doc = storedFields.document(scoreDoc.doc);
IndexableField textField = doc.getField(textFieldName);
// get the expected result
IndexableField classField = doc.getField(classFieldName);
if (textField != null && classField != null) {
// assign class to the doc
ClassificationResult classificationResult = assignClass(textField.stringValue());
Boolean assignedClass = classificationResult.assignedClass();
Boolean correctClass = Boolean.valueOf(classField.stringValue());
double modifier = Math.signum(correctClass.compareTo(assignedClass));
if (modifier != 0D) {
updateWeights(
termVectors,
scoreDoc.doc,
assignedClass,
weights,
modifier,
batchCount % batchSize == 0);
}
batchCount++;
}
}
weights.clear(); // free memory while waiting for GC
}
private void updateWeights(
TermVectors termVectors,
int docId,
Boolean assignedClass,
SortedMap weights,
double modifier,
boolean updateFST)
throws IOException {
TermsEnum cte = textTerms.iterator();
// get the doc term vectors
Terms terms = termVectors.get(docId, textFieldName);
if (terms == null) {
throw new IOException("term vectors must be stored for field " + textFieldName);
}
TermsEnum termsEnum = terms.iterator();
BytesRef term;
while ((term = termsEnum.next()) != null) {
cte.seekExact(term);
if (assignedClass != null) {
long termFreqLocal = termsEnum.totalTermFreq();
// update weights
Long previousValue = Util.get(fst, term);
String termString = term.utf8ToString();
weights.put(
termString,
previousValue == null ? 0 : Math.max(0, previousValue + modifier * termFreqLocal));
}
}
if (updateFST) {
updateFST(weights);
}
}
private void updateFST(SortedMap weights) throws IOException {
PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
FSTCompiler fstCompiler =
new FSTCompiler.Builder<>(FST.INPUT_TYPE.BYTE1, outputs).build();
BytesRefBuilder scratchBytes = new BytesRefBuilder();
IntsRefBuilder scratchInts = new IntsRefBuilder();
for (Map.Entry entry : weights.entrySet()) {
scratchBytes.copyChars(entry.getKey());
fstCompiler.add(
Util.toIntsRef(scratchBytes.get(), scratchInts), entry.getValue().longValue());
}
fst = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
}
@Override
public ClassificationResult assignClass(String text) throws IOException {
long output = 0L;
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
tokenStream.reset();
while (tokenStream.incrementToken()) {
String s = charTermAttribute.toString();
Long d = Util.get(fst, new BytesRef(s));
if (d != null) {
output += d;
}
}
tokenStream.end();
}
double score = 1 - Math.exp(-1 * Math.abs(bias - (double) output) / bias);
return new ClassificationResult<>(output >= bias, score);
}
@Override
public List> getClasses(String text) {
return null;
}
@Override
public List> getClasses(String text, int max) {
return null;
}
}