All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.solr.search.TextLogisticRegressionQParserPlugin Maven / Gradle / Ivy

There is a newer version: 9.7.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.solr.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.solr.client.solrj.io.ClassificationEvaluation;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.request.SolrQueryRequest;

/**
 * Returns an AnalyticsQuery implementation that performs one Gradient Descent iteration of a result
 * set to train a logistic regression model
 *
 * 

The TextLogitStream provides the parallel iterative framework for this class. */ public class TextLogisticRegressionQParserPlugin extends QParserPlugin { public static final String NAME = "tlogit"; @Override public QParser createParser( String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) { return new TextLogisticRegressionQParser(qstr, localParams, params, req); } private static class TextLogisticRegressionQParser extends QParser { TextLogisticRegressionQParser( String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) { super(qstr, localParams, params, req); } @Override public Query parse() { String fs = params.get("feature"); String[] terms = params.get("terms").split(","); String ws = params.get("weights"); String dfsStr = params.get("idfs"); int iteration = params.getInt("iteration", 0); String outcome = params.get("outcome"); int positiveLabel = params.getInt("positiveLabel", 1); double threshold = params.getDouble("threshold", 0.5); double alpha = params.getDouble("alpha", 0.01); double[] idfs = new double[terms.length]; String[] idfsArr = dfsStr.split(","); for (int i = 0; i < idfsArr.length; i++) { idfs[i] = Double.parseDouble(idfsArr[i]); } double[] weights = new double[terms.length + 1]; if (ws != null) { String[] wa = ws.split(","); for (int i = 0; i < wa.length; i++) { weights[i] = Double.parseDouble(wa[i]); } } else { for (int i = 0; i < weights.length; i++) { weights[i] = 1.0d; } } TrainingParams input = new TrainingParams( fs, terms, idfs, outcome, weights, iteration, alpha, positiveLabel, threshold); return new TextLogisticRegressionQuery(input); } } private static class TextLogisticRegressionQuery extends AnalyticsQuery { private TrainingParams trainingParams; public TextLogisticRegressionQuery(TrainingParams trainingParams) { this.trainingParams = trainingParams; } @Override public DelegatingCollector getAnalyticsCollector( ResponseBuilder rbsp, IndexSearcher indexSearcher) { return new TextLogisticRegressionCollector(rbsp, indexSearcher, trainingParams); } } private static class TextLogisticRegressionCollector extends DelegatingCollector { private TrainingParams trainingParams; private LeafReader leafReader; private double[] workingDeltas; private ClassificationEvaluation classificationEvaluation; private double[] weights; private ResponseBuilder rbsp; private NumericDocValues leafOutcomeValue; private double totalError; private SparseFixedBitSet positiveDocsSet; private SparseFixedBitSet docsSet; private IndexSearcher searcher; TextLogisticRegressionCollector( ResponseBuilder rbsp, IndexSearcher searcher, TrainingParams trainingParams) { this.trainingParams = trainingParams; this.workingDeltas = new double[trainingParams.weights.length]; this.weights = Arrays.copyOf(trainingParams.weights, trainingParams.weights.length); this.rbsp = rbsp; this.classificationEvaluation = new ClassificationEvaluation(); this.searcher = searcher; positiveDocsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs()); docsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs()); } @Override public void doSetNextReader(LeafReaderContext context) throws IOException { super.doSetNextReader(context); leafReader = context.reader(); leafOutcomeValue = leafReader.getNumericDocValues(trainingParams.outcome); } @Override public void collect(int doc) throws IOException { int outcome; if (leafOutcomeValue.advanceExact(doc)) { outcome = (int) leafOutcomeValue.longValue(); } else { outcome = 0; } outcome = trainingParams.positiveLabel == outcome ? 1 : 0; if (outcome == 1) { positiveDocsSet.set(context.docBase + doc); } docsSet.set(context.docBase + doc); } @Override @SuppressWarnings({"unchecked"}) public void complete() throws IOException { Map docVectors = new HashMap<>(); Terms terms = ((SolrIndexSearcher) searcher).getSlowAtomicReader().terms(trainingParams.feature); TermsEnum termsEnum = terms == null ? TermsEnum.EMPTY : terms.iterator(); PostingsEnum postingsEnum = null; int termIndex = 0; for (String termStr : trainingParams.terms) { BytesRef term = new BytesRef(termStr); if (termsEnum.seekExact(term)) { postingsEnum = termsEnum.postings(postingsEnum); while (postingsEnum.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { int docId = postingsEnum.docID(); if (docsSet.get(docId)) { double[] vector = docVectors.get(docId); if (vector == null) { vector = new double[trainingParams.terms.length + 1]; vector[0] = 1.0; docVectors.put(docId, vector); } vector[termIndex + 1] = trainingParams.idfs[termIndex] * (1.0 + Math.log(postingsEnum.freq())); } } } termIndex++; } for (Map.Entry entry : docVectors.entrySet()) { double[] vector = entry.getValue(); int outcome = 0; if (positiveDocsSet.get(entry.getKey())) { outcome = 1; } double sig = sigmoid(sum(multiply(vector, weights))); double error = sig - outcome; double lastSig = sigmoid(sum(multiply(vector, trainingParams.weights))); totalError += Math.abs(lastSig - outcome); classificationEvaluation.count(outcome, lastSig >= trainingParams.threshold ? 1 : 0); workingDeltas = multiply(error * trainingParams.alpha, vector); for (int i = 0; i < workingDeltas.length; i++) { weights[i] -= workingDeltas[i]; } } NamedList analytics = new NamedList<>(); rbsp.rsp.add("logit", analytics); List outWeights = new ArrayList<>(); for (Double d : weights) { outWeights.add(d); } analytics.add("weights", outWeights); analytics.add("error", totalError); analytics.add("evaluation", classificationEvaluation.toMap()); analytics.add("feature", trainingParams.feature); analytics.add("positiveLabel", trainingParams.positiveLabel); if (this.delegate instanceof DelegatingCollector) { ((DelegatingCollector) this.delegate).complete(); } } private double sigmoid(double in) { double d = 1.0 / (1 + Math.exp(-in)); return d; } private double[] multiply(double[] vals, double[] weights) { for (int i = 0; i < vals.length; ++i) { workingDeltas[i] = vals[i] * weights[i]; } return workingDeltas; } private double[] multiply(double d, double[] vals) { for (int i = 0; i < vals.length; ++i) { workingDeltas[i] = vals[i] * d; } return workingDeltas; } private double sum(double[] vals) { double d = 0.0d; for (double val : vals) { d += val; } return d; } } private static class TrainingParams { public final String feature; public final String[] terms; public final double[] idfs; public final String outcome; public final double[] weights; public final int interation; public final int positiveLabel; public final double threshold; public final double alpha; public TrainingParams( String feature, String[] terms, double[] idfs, String outcome, double[] weights, int interation, double alpha, int positiveLabel, double threshold) { this.feature = feature; this.terms = terms; this.idfs = idfs; this.outcome = outcome; this.weights = weights; this.alpha = alpha; this.interation = interation; this.positiveLabel = positiveLabel; this.threshold = threshold; } } }