org.elasticsearch.index.similarity.SimilarityService Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of elasticsearch Show documentation
Show all versions of elasticsearch Show documentation
Elasticsearch subproject :server
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.similarity;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.BooleanSimilarity;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.apache.lucene.search.similarity.LegacyBM25Similarity;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.IndexModule;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.script.ScriptService;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.Supplier;
public final class SimilarityService {
private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(SimilarityService.class);
public static final String DEFAULT_SIMILARITY = "BM25";
private static final String CLASSIC_SIMILARITY = "classic";
private static final Map>> DEFAULTS;
public static final Map> BUILT_IN;
static {
Map>> defaults = new HashMap<>();
defaults.put(CLASSIC_SIMILARITY, version -> {
if (version.onOrAfter(Version.V_7_0_0)) {
return () -> {
throw new IllegalArgumentException(
"The [classic] similarity may not be used anymore. Please use the [BM25] "
+ "similarity or build a custom [scripted] similarity instead."
);
};
} else {
final ClassicSimilarity similarity = SimilarityProviders.createClassicSimilarity(Settings.EMPTY, version);
return () -> {
deprecationLogger.critical(
DeprecationCategory.QUERIES,
"classic_similarity",
"The [classic] similarity is now deprecated in favour of BM25, which is generally "
+ "accepted as a better alternative. Use the [BM25] similarity or build a custom [scripted] similarity "
+ "instead."
);
return similarity;
};
}
});
defaults.put("BM25", version -> {
final LegacyBM25Similarity similarity = SimilarityProviders.createBM25Similarity(Settings.EMPTY, version);
return () -> similarity;
});
defaults.put("boolean", version -> {
final Similarity similarity = new BooleanSimilarity();
return () -> similarity;
});
Map> builtIn = new HashMap<>();
builtIn.put(CLASSIC_SIMILARITY, (settings, version, script) -> {
if (version.onOrAfter(Version.V_7_0_0)) {
throw new IllegalArgumentException(
"The [classic] similarity may not be used anymore. Please use the [BM25] "
+ "similarity or build a custom [scripted] similarity instead."
);
} else {
deprecationLogger.critical(
DeprecationCategory.QUERIES,
"classic_similarity",
"The [classic] similarity is now deprecated in favour of BM25, which is generally "
+ "accepted as a better alternative. Use the [BM25] similarity or build a custom [scripted] similarity "
+ "instead."
);
return SimilarityProviders.createClassicSimilarity(settings, version);
}
});
builtIn.put("BM25", (settings, version, scriptService) -> SimilarityProviders.createBM25Similarity(settings, version));
builtIn.put("boolean", (settings, version, scriptService) -> SimilarityProviders.createBooleanSimilarity(settings, version));
builtIn.put("DFR", (settings, version, scriptService) -> SimilarityProviders.createDfrSimilarity(settings, version));
builtIn.put("IB", (settings, version, scriptService) -> SimilarityProviders.createIBSimilarity(settings, version));
builtIn.put(
"LMDirichlet",
(settings, version, scriptService) -> SimilarityProviders.createLMDirichletSimilarity(settings, version)
);
builtIn.put(
"LMJelinekMercer",
(settings, version, scriptService) -> SimilarityProviders.createLMJelinekMercerSimilarity(settings, version)
);
builtIn.put("DFI", (settings, version, scriptService) -> SimilarityProviders.createDfiSimilarity(settings, version));
builtIn.put("scripted", new ScriptedSimilarityProvider());
DEFAULTS = Collections.unmodifiableMap(defaults);
BUILT_IN = Collections.unmodifiableMap(builtIn);
}
private final Similarity defaultSimilarity;
private final Map> similarities;
public SimilarityService(
IndexSettings indexSettings,
ScriptService scriptService,
Map> similarities
) {
Map> providers = new HashMap<>(similarities.size());
Map similaritySettings = indexSettings.getSettings().getGroups(IndexModule.SIMILARITY_SETTINGS_PREFIX);
for (Map.Entry entry : similaritySettings.entrySet()) {
String name = entry.getKey();
if (BUILT_IN.containsKey(name)) {
throw new IllegalArgumentException("Cannot redefine built-in Similarity [" + name + "]");
}
Settings providerSettings = entry.getValue();
String typeName = providerSettings.get("type");
if (typeName == null) {
throw new IllegalArgumentException("Similarity [" + name + "] must have an associated type");
} else if ((similarities.containsKey(typeName) || BUILT_IN.containsKey(typeName)) == false) {
throw new IllegalArgumentException("Unknown Similarity type [" + typeName + "] for [" + name + "]");
}
TriFunction defaultFactory = BUILT_IN.get(typeName);
TriFunction factory = similarities.getOrDefault(typeName, defaultFactory);
Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
validateSimilarity(indexSettings.getIndexVersionCreated(), similarity);
if (BUILT_IN.containsKey(typeName) == false || "scripted".equals(typeName)) {
// We don't trust custom similarities
similarity = new NonNegativeScoresSimilarity(similarity);
}
final Similarity similarityF = similarity; // like similarity but final
providers.put(name, () -> similarityF);
}
for (Map.Entry>> entry : DEFAULTS.entrySet()) {
providers.put(entry.getKey(), entry.getValue().apply(indexSettings.getIndexVersionCreated()));
}
this.similarities = providers;
defaultSimilarity = (providers.get("default") != null)
? providers.get("default").get()
: providers.get(SimilarityService.DEFAULT_SIMILARITY).get();
if (providers.get("base") != null) {
deprecationLogger.warn(
DeprecationCategory.QUERIES,
"base_similarity_ignored",
"The [base] similarity is ignored since query normalization and coords have been removed"
);
}
}
/**
* The similarity to use in searches, which takes into account per-field configuration.
*/
public Similarity similarity(@Nullable Function fieldTypeLookup) {
return (fieldTypeLookup != null) ? new PerFieldSimilarity(defaultSimilarity, fieldTypeLookup) : defaultSimilarity;
}
public SimilarityProvider getSimilarity(String name) {
Supplier sim = similarities.get(name);
if (sim == null) {
return null;
}
return new SimilarityProvider(name, sim.get());
}
/**
* The default similarity configured in the index settings.
*/
public Similarity getDefaultSimilarity() {
return defaultSimilarity;
}
static class PerFieldSimilarity extends PerFieldSimilarityWrapper {
private final Similarity defaultSimilarity;
private final Function fieldTypeLookup;
PerFieldSimilarity(Similarity defaultSimilarity, Function fieldTypeLookup) {
super();
this.defaultSimilarity = defaultSimilarity;
this.fieldTypeLookup = Objects.requireNonNull(fieldTypeLookup, "fieldTypeLookup cannot be null");
}
@Override
public Similarity get(String name) {
MappedFieldType fieldType = fieldTypeLookup.apply(name);
return (fieldType != null && fieldType.getTextSearchInfo().getSimilarity() != null)
? fieldType.getTextSearchInfo().getSimilarity().get()
: defaultSimilarity;
}
}
static void validateSimilarity(Version indexCreatedVersion, Similarity similarity) {
validateScoresArePositive(indexCreatedVersion, similarity);
validateScoresDoNotDecreaseWithFreq(indexCreatedVersion, similarity);
validateScoresDoNotIncreaseWithNorm(indexCreatedVersion, similarity);
}
private static void validateScoresArePositive(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(
indexCreatedVersion.luceneVersion.major,
"some_field",
IndexOptions.DOCS_AND_FREQS,
20,
20,
0,
50,
10,
3
); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(freq, norm);
if (score < 0) {
fail(
indexCreatedVersion,
"Similarities should not return negative scores:\n" + scorer.explain(Explanation.match(freq, "term freq"), norm)
);
break;
}
}
}
private static void validateScoresDoNotDecreaseWithFreq(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(
indexCreatedVersion.luceneVersion.major,
"some_field",
IndexOptions.DOCS_AND_FREQS,
20,
20,
0,
50,
10,
3
); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
float previousScore = 0;
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(freq, norm);
if (score < previousScore) {
fail(
indexCreatedVersion,
"Similarity scores should not decrease when term frequency increases:\n"
+ scorer.explain(Explanation.match(freq - 1, "term freq"), norm)
+ "\n"
+ scorer.explain(Explanation.match(freq, "term freq"), norm)
);
break;
}
previousScore = score;
}
}
private static void validateScoresDoNotIncreaseWithNorm(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
long previousNorm = 0;
float previousScore = Float.MAX_VALUE;
for (int length = 1; length <= 10; ++length) {
FieldInvertState state = new FieldInvertState(
indexCreatedVersion.luceneVersion.major,
"some_field",
IndexOptions.DOCS_AND_FREQS,
length,
length,
0,
50,
10,
3
); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
if (Long.compareUnsigned(previousNorm, norm) > 0) {
// esoteric similarity, skip this check
break;
}
float score = scorer.score(1, norm);
if (score > previousScore) {
fail(
indexCreatedVersion,
"Similarity scores should not increase when norm increases:\n"
+ scorer.explain(Explanation.match(1, "term freq"), norm - 1)
+ "\n"
+ scorer.explain(Explanation.match(1, "term freq"), norm)
);
break;
}
previousScore = score;
previousNorm = norm;
}
}
private static void fail(Version indexCreatedVersion, String message) {
if (indexCreatedVersion.onOrAfter(Version.V_7_0_0)) {
throw new IllegalArgumentException(message);
} else if (indexCreatedVersion.onOrAfter(Version.V_6_5_0)) {
deprecationLogger.critical(DeprecationCategory.QUERIES, "similarity_failure", message);
}
}
}