All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.solr.update.processor.ClassificationUpdateProcessorFactory Maven / Gradle / Ivy

There is a newer version: 9.7.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.solr.update.processor;

import static org.apache.solr.update.processor.ClassificationUpdateProcessorFactory.Algorithm.KNN;

import java.util.Locale;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Query;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.search.LuceneQParser;
import org.apache.solr.search.SyntaxError;

/**
 * This class implements an UpdateProcessorFactory for the Classification Update Processor. It takes
 * in input a series of parameter that will be necessary to instantiate and use the Classifier
 *
 * @since 6.1.0
 */
public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessorFactory {

  // Update Processor Config params
  private static final String INPUT_FIELDS_PARAM = "inputFields";
  private static final String TRAINING_CLASS_FIELD_PARAM = "classField";
  private static final String PREDICTED_CLASS_FIELD_PARAM = "predictedClassField";
  private static final String MAX_CLASSES_TO_ASSIGN_PARAM = "predictedClass.maxCount";
  private static final String ALGORITHM_PARAM = "algorithm";
  private static final String KNN_MIN_TF_PARAM = "knn.minTf";
  private static final String KNN_MIN_DF_PARAM = "knn.minDf";
  private static final String KNN_K_PARAM = "knn.k";
  private static final String KNN_FILTER_QUERY = "knn.filterQuery";

  public enum Algorithm {
    KNN,
    BAYES
  }

  // Update Processor Defaults
  private static final int DEFAULT_MAX_CLASSES_TO_ASSIGN = 1;
  private static final int DEFAULT_MIN_TF = 1;
  private static final int DEFAULT_MIN_DF = 1;
  private static final int DEFAULT_K = 10;
  private static final Algorithm DEFAULT_ALGORITHM = KNN;

  private SolrParams params;
  private ClassificationUpdateProcessorParams classificationParams;

  @Override
  public void init(final NamedList args) {
    if (args != null) {
      params = args.toSolrParams();
      classificationParams = new ClassificationUpdateProcessorParams();

      String fieldNames =
          params.get(INPUT_FIELDS_PARAM); // must be a comma separated list of fields
      checkNotNull(INPUT_FIELDS_PARAM, fieldNames);
      classificationParams.setInputFieldNames(fieldNames.split("\\,"));

      String trainingClassField = (params.get(TRAINING_CLASS_FIELD_PARAM));
      checkNotNull(TRAINING_CLASS_FIELD_PARAM, trainingClassField);
      classificationParams.setTrainingClassField(trainingClassField);

      String predictedClassField = (params.get(PREDICTED_CLASS_FIELD_PARAM));
      if (predictedClassField == null || predictedClassField.isEmpty()) {
        predictedClassField = trainingClassField;
      }
      classificationParams.setPredictedClassField(predictedClassField);

      classificationParams.setMaxPredictedClasses(
          getIntParam(params, MAX_CLASSES_TO_ASSIGN_PARAM, DEFAULT_MAX_CLASSES_TO_ASSIGN));

      String algorithmString = params.get(ALGORITHM_PARAM);
      Algorithm classificationAlgorithm;
      try {
        if (algorithmString == null
            || Algorithm.valueOf(algorithmString.toUpperCase(Locale.ROOT)) == null) {
          classificationAlgorithm = DEFAULT_ALGORITHM;
        } else {
          classificationAlgorithm = Algorithm.valueOf(algorithmString.toUpperCase(Locale.ROOT));
        }
      } catch (IllegalArgumentException e) {
        throw new SolrException(
            SolrException.ErrorCode.SERVER_ERROR,
            "Classification UpdateProcessor Algorithm: '" + algorithmString + "' not supported");
      }
      classificationParams.setAlgorithm(classificationAlgorithm);

      classificationParams.setMinTf(getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF));
      classificationParams.setMinDf(getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF));
      classificationParams.setK(getIntParam(params, KNN_K_PARAM, DEFAULT_K));
    }
  }

  /*
   * Returns an Int parsed param or a default if the param is null
   *
   * @param params       Solr params in input
   * @param name         the param name
   * @param defaultValue the param default
   * @return the Int value for the param
   */
  private int getIntParam(SolrParams params, String name, int defaultValue) {
    String paramString = params.get(name);
    int paramInt;
    if (paramString != null && !paramString.isEmpty()) {
      paramInt = Integer.parseInt(paramString);
    } else {
      paramInt = defaultValue;
    }
    return paramInt;
  }

  private void checkNotNull(String paramName, Object param) {
    if (param == null) {
      throw new SolrException(
          SolrException.ErrorCode.SERVER_ERROR,
          "Classification UpdateProcessor '" + paramName + "' can not be null");
    }
  }

  @Override
  public UpdateRequestProcessor getInstance(
      SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
    String trainingFilterQueryString = (params.get(KNN_FILTER_QUERY));
    try {
      if (trainingFilterQueryString != null && !trainingFilterQueryString.isEmpty()) {
        Query trainingFilterQuery = this.parseFilterQuery(trainingFilterQueryString, params, req);
        classificationParams.setTrainingFilterQuery(trainingFilterQuery);
      }
    } catch (SyntaxError | RuntimeException syntaxError) {
      throw new SolrException(
          SolrException.ErrorCode.SERVER_ERROR,
          "Classification UpdateProcessor Training Filter Query: '"
              + trainingFilterQueryString
              + "' is not supported",
          syntaxError);
    }

    IndexSchema schema = req.getSchema();
    IndexReader indexReader = req.getSearcher().getIndexReader();

    return new ClassificationUpdateProcessor(classificationParams, next, indexReader, schema);
  }

  private Query parseFilterQuery(
      String trainingFilterQueryString, SolrParams params, SolrQueryRequest req)
      throws SyntaxError {
    LuceneQParser parser = new LuceneQParser(trainingFilterQueryString, null, params, req);
    return parser.parse();
  }

  public ClassificationUpdateProcessorParams getClassificationParams() {
    return classificationParams;
  }

  public void setClassificationParams(ClassificationUpdateProcessorParams classificationParams) {
    this.classificationParams = classificationParams;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy