com.yahoo.schema.derived.RankProfileList Maven / Gradle / Ivy
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema.derived;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
import com.yahoo.config.model.api.ModelContext;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.schema.RankingExpressionBody;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.schema.LargeRankingExpressions;
import com.yahoo.schema.OnnxModel;
import com.yahoo.schema.RankProfileRegistry;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.schema.RankProfile;
import com.yahoo.schema.Schema;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
/**
* The derived rank profiles of a schema
*
* @author bratseth
*/
public class RankProfileList extends Derived {
private final Map rankProfiles;
private final FileDistributedConstants constants;
private final LargeRankingExpressions largeRankingExpressions;
private final FileDistributedOnnxModels onnxModels;
public static final RankProfileList empty = new RankProfileList();
private RankProfileList() {
constants = new FileDistributedConstants(null, List.of());
largeRankingExpressions = new LargeRankingExpressions(null);
onnxModels = new FileDistributedOnnxModels(null, List.of());
rankProfiles = Map.of();
}
/**
* Creates a rank profile list
*
* @param schema the schema this is a rank profile from
* @param attributeFields the attribute fields to create a ranking for
*/
public RankProfileList(Schema schema,
LargeRankingExpressions largeRankingExpressions,
AttributeFields attributeFields,
DeployState deployState) {
setName(schema == null ? "default" : schema.getName());
this.largeRankingExpressions = largeRankingExpressions;
this.rankProfiles = deriveRankProfiles(schema, attributeFields, deployState);
this.constants = deriveFileDistributedConstants(schema, rankProfiles.values(), deployState);
this.onnxModels = deriveFileDistributedOnnxModels(schema, rankProfiles.values(), deployState);
}
private boolean areDependenciesReady(RankProfile rank, RankProfileRegistry registry, Set processedProfiles) {
return rank.inheritedNames().isEmpty() ||
processedProfiles.containsAll(rank.inheritedNames()) ||
(rank.schema() != null && rank.inheritedNames().stream().allMatch(name -> registry.resolve(rank.schema().getDocument(), name) != null));
}
private Map deriveRankProfiles(Schema schema,
AttributeFields attributeFields,
DeployState deployState) {
Map rawRankProfiles = new LinkedHashMap<>();
if (schema != null) { // profiles belonging to a schema have a default profile
RawRankProfile rawRank = new RawRankProfile(deployState.rankProfileRegistry().get(schema, "default"),
largeRankingExpressions,
deployState.getQueryProfiles().getRegistry(),
deployState.getImportedModels(),
attributeFields,
deployState.getProperties());
rawRankProfiles.put(rawRank.getName(), rawRank);
}
Map remaining = new LinkedHashMap<>();
deployState.rankProfileRegistry().rankProfilesOf(schema).forEach(rank -> remaining.put(rank.name(), rank));
remaining.remove("default");
while (!remaining.isEmpty()) {
List ready = new ArrayList<>();
remaining.forEach((name, profile) -> {
if (areDependenciesReady(profile, deployState.rankProfileRegistry(), rawRankProfiles.keySet()))
ready.add(profile);
});
rawRankProfiles.putAll(processRankProfiles(ready,
deployState.getQueryProfiles().getRegistry(),
deployState.getImportedModels(),
attributeFields,
deployState.getProperties(),
deployState.getExecutor()));
ready.forEach(rank -> remaining.remove(rank.name()));
}
return rawRankProfiles;
}
private Map processRankProfiles(List profiles,
QueryProfileRegistry queryProfiles,
ImportedMlModels importedModels,
AttributeFields attributeFields,
ModelContext.Properties deployProperties,
ExecutorService executor) {
Map> futureRawRankProfiles = new LinkedHashMap<>();
for (RankProfile profile : profiles) {
futureRawRankProfiles.put(profile.name(), executor.submit(() -> new RawRankProfile(profile, largeRankingExpressions, queryProfiles, importedModels,
attributeFields, deployProperties)));
}
try {
Map rawRankProfiles = new LinkedHashMap<>();
for (Future rawFuture : futureRawRankProfiles.values()) {
RawRankProfile rawRank = rawFuture.get();
rawRankProfiles.put(rawRank.getName(), rawRank);
}
return rawRankProfiles;
} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
catch (ExecutionException e) {
if (e.getCause() instanceof IllegalArgumentException iArg) throw iArg;
if (e.getCause() instanceof IllegalStateException iState) throw iState;
throw new IllegalStateException(e);
}
}
private static FileDistributedConstants deriveFileDistributedConstants(Schema schema,
Collection rankProfiles,
DeployState deployState) {
Map allFileConstants = new HashMap<>();
addFileConstants(schema != null ? schema.constants().values() : List.of(),
allFileConstants,
schema != null ? schema.toString() : "[global]");
for (var profile : rankProfiles)
addFileConstants(profile.constants(), allFileConstants, profile.toString());
return new FileDistributedConstants(deployState.getFileRegistry(), allFileConstants.values());
}
private static void addFileConstants(Collection source,
Map destination,
String sourceName) {
for (var constant : source) {
if (constant.valuePath().isEmpty()) continue;
var existing = destination.get(constant.name());
if ( existing != null && ! constant.equals(existing)) {
throw new IllegalArgumentException("Duplicate constants: " + sourceName + " have " + constant +
", but we already have " + existing +
": Value reference constants must be unique across all rank profiles/models");
}
destination.put(constant.name(), constant);
}
}
private static FileDistributedOnnxModels deriveFileDistributedOnnxModels(Schema schema,
Collection rankProfiles,
DeployState deployState) {
Map allModels = new LinkedHashMap<>();
addOnnxModels(schema != null ? schema.onnxModels().values() : List.of(),
allModels,
schema != null ? schema.toString() : "[global]");
for (var profile : rankProfiles)
addOnnxModels(profile.onnxModels(), allModels, profile.toString());
return new FileDistributedOnnxModels(deployState.getFileRegistry(), allModels.values());
}
private static void addOnnxModels(Collection source,
Map destination,
String sourceName) {
for (var model : source) {
var existing = destination.get(model.getName());
if ( existing != null && ! model.equals(existing)) {
throw new IllegalArgumentException("Duplicate onnx model: " + sourceName + " have " + model +
", but we already have " + existing +
": Onnx models must be unique across all rank profiles/models");
}
destination.put(model.getName(), model);
}
}
public Map getRankProfiles() { return rankProfiles; }
public FileDistributedConstants constants() { return constants; }
public FileDistributedOnnxModels getOnnxModels() { return onnxModels; }
@Override public String getDerivedName() { return "rank-profiles"; }
public void export(String toDirectory) throws IOException {
export(toDirectory, new RankProfilesConfig.Builder().rankprofile(getRankProfilesConfig()).build());
}
public List getRankProfilesConfig() {
return rankProfiles.values().stream().map(RawRankProfile::getConfig).toList();
}
private static RankingExpressionsConfig.Expression.Builder toConfig(RankingExpressionBody expr) {
return new RankingExpressionsConfig.Expression.Builder()
.name(expr.getName())
.fileref(expr.getFileReference());
}
public List getExpressionsConfig() {
return largeRankingExpressions.expressions().stream().map(RankProfileList::toConfig).toList();
}
public List getConstantsConfig() {
return constants.getConfig();
}
public List getOnnxConfig() {
return onnxModels.getConfig();
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy