All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.solr.update.processor.ClassificationUpdateProcessor Maven / Gradle / Ivy

There is a newer version: 9.6.1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.solr.update.processor;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.document.DocumentClassifier;
import org.apache.lucene.classification.document.KNearestNeighborDocumentClassifier;
import org.apache.lucene.classification.document.SimpleNaiveBayesDocumentClassifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.util.BytesRef;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.update.AddUpdateCommand;
import org.apache.solr.update.processor.ClassificationUpdateProcessorFactory.Algorithm;

/**
 * This Class is a Request Update Processor to classify the document in input and add a field
 * containing the class to the Document.
 * It uses the Lucene Document Classification module, see {@link DocumentClassifier}.
 */
class ClassificationUpdateProcessor
    extends UpdateRequestProcessor {

  private final String trainingClassField;
  private final String predictedClassField;
  private final int maxOutputClasses;
  private DocumentClassifier classifier;

  /**
   * Sole constructor
   *
   * @param classificationParams classification advanced params
   * @param next            next update processor in the chain
   * @param indexReader     index reader
   * @param schema          schema
   */
  public ClassificationUpdateProcessor(ClassificationUpdateProcessorParams classificationParams, UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) {
    super(next);
    this.trainingClassField = classificationParams.getTrainingClassField();
    this.predictedClassField = classificationParams.getPredictedClassField();
    this.maxOutputClasses = classificationParams.getMaxPredictedClasses();
    String[] inputFieldNamesWithBoost = classificationParams.getInputFieldNames();
    Algorithm classificationAlgorithm = classificationParams.getAlgorithm();

    Map field2analyzer = new HashMap<>();
    String[] inputFieldNames = this.removeBoost(inputFieldNamesWithBoost);
    for (String fieldName : inputFieldNames) {
      SchemaField fieldFromSolrSchema = schema.getField(fieldName);
      Analyzer indexAnalyzer = fieldFromSolrSchema.getType().getQueryAnalyzer();
      field2analyzer.put(fieldName, indexAnalyzer);
    }
    switch (classificationAlgorithm) {
      case KNN:
        classifier = new KNearestNeighborDocumentClassifier(indexReader, null, classificationParams.getTrainingFilterQuery(), classificationParams.getK(), classificationParams.getMinDf(), classificationParams.getMinTf(), trainingClassField, field2analyzer, inputFieldNamesWithBoost);
        break;
      case BAYES:
        classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, trainingClassField, field2analyzer, inputFieldNamesWithBoost);
        break;
    }
  }

  private String[] removeBoost(String[] inputFieldNamesWithBoost) {
    String[] inputFieldNames = new String[inputFieldNamesWithBoost.length];
    for (int i = 0; i < inputFieldNamesWithBoost.length; i++) {
      String singleFieldNameWithBoost = inputFieldNamesWithBoost[i];
      String[] fieldName2boost = singleFieldNameWithBoost.split("\\^");
      inputFieldNames[i] = fieldName2boost[0];
    }
    return inputFieldNames;
  }

  /**
   * @param cmd the update command in input containing the Document to classify
   * @throws IOException If there is a low-level I/O error
   */
  @Override
  public void processAdd(AddUpdateCommand cmd)
      throws IOException {
    SolrInputDocument doc = cmd.getSolrInputDocument();
    Document luceneDocument = cmd.getLuceneDocument();
    String assignedClass;
    Object documentClass = doc.getFieldValue(trainingClassField);
    if (documentClass == null) {
      List> assignedClassifications = classifier.getClasses(luceneDocument, maxOutputClasses);
      if (assignedClassifications != null) {
        for (ClassificationResult singleClassification : assignedClassifications) {
          assignedClass = singleClassification.getAssignedClass().utf8ToString();
          doc.addField(predictedClassField, assignedClass);
        }
      }
    }
    super.processAdd(cmd);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy