org.apache.solr.search.TextLogisticRegressionQParserPlugin Maven / Gradle / Ivy
Show all versions of solr-core Show documentation
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.solr.client.solrj.io.ClassificationEvaluation;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.request.SolrQueryRequest;
/**
* Returns an AnalyticsQuery implementation that performs one Gradient Descent iteration of a result
* set to train a logistic regression model
*
* The TextLogitStream provides the parallel iterative framework for this class.
*/
public class TextLogisticRegressionQParserPlugin extends QParserPlugin {
public static final String NAME = "tlogit";
@Override
public QParser createParser(
String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
return new TextLogisticRegressionQParser(qstr, localParams, params, req);
}
private static class TextLogisticRegressionQParser extends QParser {
TextLogisticRegressionQParser(
String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
super(qstr, localParams, params, req);
}
@Override
public Query parse() {
String fs = params.get("feature");
String[] terms = params.get("terms").split(",");
String ws = params.get("weights");
String dfsStr = params.get("idfs");
int iteration = params.getInt("iteration", 0);
String outcome = params.get("outcome");
int positiveLabel = params.getInt("positiveLabel", 1);
double threshold = params.getDouble("threshold", 0.5);
double alpha = params.getDouble("alpha", 0.01);
double[] idfs = new double[terms.length];
String[] idfsArr = dfsStr.split(",");
for (int i = 0; i < idfsArr.length; i++) {
idfs[i] = Double.parseDouble(idfsArr[i]);
}
double[] weights = new double[terms.length + 1];
if (ws != null) {
String[] wa = ws.split(",");
for (int i = 0; i < wa.length; i++) {
weights[i] = Double.parseDouble(wa[i]);
}
} else {
for (int i = 0; i < weights.length; i++) {
weights[i] = 1.0d;
}
}
TrainingParams input =
new TrainingParams(
fs, terms, idfs, outcome, weights, iteration, alpha, positiveLabel, threshold);
return new TextLogisticRegressionQuery(input);
}
}
private static class TextLogisticRegressionQuery extends AnalyticsQuery {
private TrainingParams trainingParams;
public TextLogisticRegressionQuery(TrainingParams trainingParams) {
this.trainingParams = trainingParams;
}
@Override
public DelegatingCollector getAnalyticsCollector(
ResponseBuilder rbsp, IndexSearcher indexSearcher) {
return new TextLogisticRegressionCollector(rbsp, indexSearcher, trainingParams);
}
}
private static class TextLogisticRegressionCollector extends DelegatingCollector {
private TrainingParams trainingParams;
private LeafReader leafReader;
private double[] workingDeltas;
private ClassificationEvaluation classificationEvaluation;
private double[] weights;
private ResponseBuilder rbsp;
private NumericDocValues leafOutcomeValue;
private double totalError;
private SparseFixedBitSet positiveDocsSet;
private SparseFixedBitSet docsSet;
private IndexSearcher searcher;
TextLogisticRegressionCollector(
ResponseBuilder rbsp, IndexSearcher searcher, TrainingParams trainingParams) {
this.trainingParams = trainingParams;
this.workingDeltas = new double[trainingParams.weights.length];
this.weights = Arrays.copyOf(trainingParams.weights, trainingParams.weights.length);
this.rbsp = rbsp;
this.classificationEvaluation = new ClassificationEvaluation();
this.searcher = searcher;
positiveDocsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs());
docsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs());
}
@Override
public void doSetNextReader(LeafReaderContext context) throws IOException {
super.doSetNextReader(context);
leafReader = context.reader();
leafOutcomeValue = leafReader.getNumericDocValues(trainingParams.outcome);
}
@Override
public void collect(int doc) throws IOException {
int outcome;
if (leafOutcomeValue.advanceExact(doc)) {
outcome = (int) leafOutcomeValue.longValue();
} else {
outcome = 0;
}
outcome = trainingParams.positiveLabel == outcome ? 1 : 0;
if (outcome == 1) {
positiveDocsSet.set(context.docBase + doc);
}
docsSet.set(context.docBase + doc);
}
@Override
@SuppressWarnings({"unchecked"})
public void complete() throws IOException {
Map docVectors = new HashMap<>();
Terms terms =
((SolrIndexSearcher) searcher).getSlowAtomicReader().terms(trainingParams.feature);
TermsEnum termsEnum = terms == null ? TermsEnum.EMPTY : terms.iterator();
PostingsEnum postingsEnum = null;
int termIndex = 0;
for (String termStr : trainingParams.terms) {
BytesRef term = new BytesRef(termStr);
if (termsEnum.seekExact(term)) {
postingsEnum = termsEnum.postings(postingsEnum);
while (postingsEnum.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
int docId = postingsEnum.docID();
if (docsSet.get(docId)) {
double[] vector = docVectors.get(docId);
if (vector == null) {
vector = new double[trainingParams.terms.length + 1];
vector[0] = 1.0;
docVectors.put(docId, vector);
}
vector[termIndex + 1] =
trainingParams.idfs[termIndex] * (1.0 + Math.log(postingsEnum.freq()));
}
}
}
termIndex++;
}
for (Map.Entry entry : docVectors.entrySet()) {
double[] vector = entry.getValue();
int outcome = 0;
if (positiveDocsSet.get(entry.getKey())) {
outcome = 1;
}
double sig = sigmoid(sum(multiply(vector, weights)));
double error = sig - outcome;
double lastSig = sigmoid(sum(multiply(vector, trainingParams.weights)));
totalError += Math.abs(lastSig - outcome);
classificationEvaluation.count(outcome, lastSig >= trainingParams.threshold ? 1 : 0);
workingDeltas = multiply(error * trainingParams.alpha, vector);
for (int i = 0; i < workingDeltas.length; i++) {
weights[i] -= workingDeltas[i];
}
}
NamedList