
org.elasticsearch.search.vectors.KnnVectorQueryBuilder Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of elasticsearch Show documentation
Show all versions of elasticsearch Show documentation
Elasticsearch - Open Source, Distributed, RESTful Search Engine
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.search.vectors;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NestedObjectMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.search.NestedHelper;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
/**
* A query that performs kNN search using Lucene's {@link org.apache.lucene.search.KnnFloatVectorQuery} or
* {@link org.apache.lucene.search.KnnByteVectorQuery}.
*
*/
public class KnnVectorQueryBuilder extends AbstractQueryBuilder {
public static final String NAME = "knn";
private static final int NUM_CANDS_LIMIT = 10_000;
private static final float NUM_CANDS_MULTIPLICATIVE_FACTOR = 1.5f;
public static final ParseField FIELD_FIELD = new ParseField("field");
public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates");
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
public static final ParseField VECTOR_SIMILARITY_FIELD = new ParseField("similarity");
public static final ParseField FILTER_FIELD = new ParseField("filter");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("knn", args -> {
List vector = (List) args[1];
final float[] vectorArray;
if (vector != null) {
vectorArray = new float[vector.size()];
for (int i = 0; i < vector.size(); i++) {
vectorArray[i] = vector.get(i);
}
} else {
vectorArray = null;
}
return new KnnVectorQueryBuilder((String) args[0], vectorArray, (Integer) args[2], (Float) args[3]);
});
static {
PARSER.declareString(constructorArg(), FIELD_FIELD);
PARSER.declareFloatArray(constructorArg(), QUERY_VECTOR_FIELD);
PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD);
PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY_FIELD);
PARSER.declareFieldArray(
KnnVectorQueryBuilder::addFilterQueries,
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
FILTER_FIELD,
ObjectParser.ValueType.OBJECT_ARRAY
);
declareStandardFields(PARSER);
}
public static KnnVectorQueryBuilder fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final String fieldName;
private final float[] queryVector;
private Integer numCands;
private final List filterQueries = new ArrayList<>();
private final Float vectorSimilarity;
public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer numCands, Float vectorSimilarity) {
if (numCands != null && numCands > NUM_CANDS_LIMIT) {
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
}
if (queryVector == null) {
throw new IllegalArgumentException("[" + QUERY_VECTOR_FIELD.getPreferredName() + "] must be provided");
}
this.fieldName = fieldName;
this.queryVector = queryVector;
this.numCands = numCands;
this.vectorSimilarity = vectorSimilarity;
}
public KnnVectorQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_NUMCANDS_AS_OPTIONAL_PARAM)) {
this.numCands = in.readOptionalVInt();
} else {
this.numCands = in.readVInt();
}
if (in.getTransportVersion().before(TransportVersions.V_8_7_0) || in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
this.queryVector = in.readFloatArray();
} else {
in.readBoolean();
this.queryVector = in.readFloatArray();
in.readBoolean(); // used for byteQueryVector, which was always null
}
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_2_0)) {
this.filterQueries.addAll(readQueries(in));
}
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) {
this.vectorSimilarity = in.readOptionalFloat();
} else {
this.vectorSimilarity = null;
}
}
public String getFieldName() {
return fieldName;
}
@Nullable
public float[] queryVector() {
return queryVector;
}
@Nullable
public Float getVectorSimilarity() {
return vectorSimilarity;
}
public Integer numCands() {
return numCands;
}
public List filterQueries() {
return filterQueries;
}
public KnnVectorQueryBuilder addFilterQuery(QueryBuilder filterQuery) {
Objects.requireNonNull(filterQuery);
this.filterQueries.add(filterQuery);
return this;
}
public KnnVectorQueryBuilder addFilterQueries(List filterQueries) {
Objects.requireNonNull(filterQueries);
this.filterQueries.addAll(filterQueries);
return this;
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_NUMCANDS_AS_OPTIONAL_PARAM)) {
out.writeOptionalVInt(numCands);
} else {
if (numCands == null) {
throw new IllegalArgumentException(
"["
+ NUM_CANDS_FIELD.getPreferredName()
+ "] field was mandatory in previous releases "
+ "and is required to be non-null by some nodes. "
+ "Please make sure to provide the parameter as part of the request."
);
} else {
out.writeVInt(numCands);
}
}
if (out.getTransportVersion().before(TransportVersions.V_8_7_0)
|| out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
out.writeFloatArray(queryVector);
} else {
out.writeBoolean(true);
out.writeFloatArray(queryVector);
out.writeBoolean(false); // used for byteQueryVector, which was always null
}
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_2_0)) {
writeQueries(out, filterQueries);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) {
out.writeOptionalFloat(vectorSimilarity);
}
}
@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.field(FIELD_FIELD.getPreferredName(), fieldName);
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
if (numCands != null) {
builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands);
}
if (vectorSimilarity != null) {
builder.field(VECTOR_SIMILARITY_FIELD.getPreferredName(), vectorSimilarity);
}
if (filterQueries.isEmpty() == false) {
builder.startArray(FILTER_FIELD.getPreferredName());
for (QueryBuilder filterQuery : filterQueries) {
filterQuery.toXContent(builder, params);
}
builder.endArray();
}
boostAndQueryNameToXContent(builder);
builder.endObject();
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
if (ctx.convertToInnerHitsRewriteContext() != null) {
return new ExactKnnQueryBuilder(queryVector, fieldName).boost(boost).queryName(queryName);
}
boolean changed = false;
List rewrittenQueries = new ArrayList<>(filterQueries.size());
for (QueryBuilder query : filterQueries) {
QueryBuilder rewrittenQuery = query.rewrite(ctx);
if (rewrittenQuery instanceof MatchNoneQueryBuilder) {
return rewrittenQuery;
}
if (rewrittenQuery != query) {
changed = true;
}
rewrittenQueries.add(rewrittenQuery);
}
if (changed) {
return new KnnVectorQueryBuilder(fieldName, queryVector, numCands, vectorSimilarity).boost(boost)
.queryName(queryName)
.addFilterQueries(rewrittenQueries);
}
return this;
}
@Override
protected Query doToQuery(SearchExecutionContext context) throws IOException {
MappedFieldType fieldType = context.getFieldType(fieldName);
int requestSize = context.requestSize() == null || context.requestSize() < 0 ? DEFAULT_SIZE : context.requestSize();
int adjustedNumCands = numCands == null
? Math.round(Math.min(NUM_CANDS_MULTIPLICATIVE_FACTOR * requestSize, NUM_CANDS_LIMIT))
: numCands;
if (fieldType == null) {
throw new IllegalArgumentException("field [" + fieldName + "] does not exist in the mapping");
}
if (fieldType instanceof DenseVectorFieldType == false) {
throw new IllegalArgumentException(
"[" + NAME + "] queries are only supported on [" + DenseVectorFieldMapper.CONTENT_TYPE + "] fields"
);
}
BooleanQuery.Builder builder = new BooleanQuery.Builder();
for (QueryBuilder query : this.filterQueries) {
builder.add(query.toQuery(context), BooleanClause.Occur.FILTER);
}
if (context.getAliasFilter() != null) {
builder.add(context.getAliasFilter().toQuery(context), BooleanClause.Occur.FILTER);
}
BooleanQuery booleanQuery = builder.build();
Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType;
String parentPath = context.nestedLookup().getNestedParent(fieldName);
if (parentPath != null) {
final BitSetProducer parentBitSet;
final Query parentFilter;
NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper();
if (originalObjectMapper != null) {
try {
// we are in a nested context, to get the parent filter we need to go up one level
context.nestedScope().previousLevel();
NestedObjectMapper objectMapper = context.nestedScope().getObjectMapper();
parentFilter = objectMapper == null
? Queries.newNonNestedFilter(context.indexVersionCreated())
: objectMapper.nestedTypeFilter();
} finally {
context.nestedScope().nextLevel(originalObjectMapper);
}
} else {
// we are NOT in a nested context, coming from the top level knn search
parentFilter = Queries.newNonNestedFilter(context.indexVersionCreated());
}
parentBitSet = context.bitsetFilter(parentFilter);
if (filterQuery != null) {
NestedHelper nestedHelper = new NestedHelper(context.nestedLookup(), context::isFieldMapped);
// We treat the provided filter as a filter over PARENT documents, so if it might match nested documents
// we need to adjust it.
if (nestedHelper.mightMatchNestedDocs(filterQuery)) {
// Ensure that the query only returns parent documents matching `filterQuery`
filterQuery = Queries.filtered(filterQuery, parentFilter);
}
// Now join the filterQuery & parentFilter to provide the matching blocks of children
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
}
return vectorFieldType.createKnnQuery(queryVector, adjustedNumCands, filterQuery, vectorSimilarity, parentBitSet);
}
return vectorFieldType.createKnnQuery(queryVector, adjustedNumCands, filterQuery, vectorSimilarity, null);
}
@Override
protected int doHashCode() {
return Objects.hash(fieldName, Arrays.hashCode(queryVector), numCands, filterQueries, vectorSimilarity);
}
@Override
protected boolean doEquals(KnnVectorQueryBuilder other) {
return Objects.equals(fieldName, other.fieldName)
&& Arrays.equals(queryVector, other.queryVector)
&& Objects.equals(numCands, other.numCands)
&& Objects.equals(filterQueries, other.filterQueries)
&& Objects.equals(vectorSimilarity, other.vectorSimilarity);
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_0_0;
}
}