com.yahoo.vespa.model.VespaModelFactory Maven / Gradle / Ivy
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model;
import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import ai.vespa.rankingexpression.importer.vespa.VespaImporter;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import com.yahoo.component.Version;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.ConfigModelRegistry;
import com.yahoo.config.model.MapConfigModelRegistry;
import com.yahoo.config.model.NullConfigModelRegistry;
import com.yahoo.config.model.api.ConfigChangeAction;
import com.yahoo.config.model.api.ConfigModelPlugin;
import com.yahoo.config.model.api.Model;
import com.yahoo.config.model.api.ModelContext;
import com.yahoo.config.model.api.ModelCreateResult;
import com.yahoo.config.model.api.ModelFactory;
import com.yahoo.config.model.api.ValidationParameters;
import com.yahoo.config.model.application.provider.ApplicationPackageXmlFilesValidator;
import com.yahoo.config.model.builder.xml.ConfigModelBuilder;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.provision.QuotaExceededException;
import com.yahoo.config.provision.TransientException;
import com.yahoo.config.provision.Zone;
import com.yahoo.vespa.config.VespaVersion;
import com.yahoo.vespa.model.application.validation.Validation;
import com.yahoo.vespa.model.application.validation.Validator;
import org.xml.sax.SAXException;
import java.io.IOException;
import java.time.Clock;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* Factory for creating {@link VespaModel} instances.
*
* @author Ulf Lilleengen
*/
public class VespaModelFactory implements ModelFactory {
private static final Logger log = Logger.getLogger(VespaModelFactory.class.getName());
private final ConfigModelRegistry configModelRegistry;
private final Collection modelImporters;
private final Zone zone;
private final Clock clock;
private final Version version;
private final List additionalValidators;
/** Creates a factory for Vespa models for this version of the source */
@Inject
public VespaModelFactory(ComponentRegistry pluginRegistry,
ComponentRegistry additionalValidators,
Zone zone) {
this.version = new Version(VespaVersion.major, VespaVersion.minor, VespaVersion.micro);
List> modelBuilders = new ArrayList<>();
for (ConfigModelPlugin plugin : pluginRegistry.allComponents()) {
if (plugin instanceof ConfigModelBuilder p) {
modelBuilders.add(p);
}
}
this.configModelRegistry = new MapConfigModelRegistry(modelBuilders);
this.modelImporters = List.of(
new VespaImporter(),
new OnnxImporter(),
new TensorFlowImporter(),
new XGBoostImporter(),
new LightGBMImporter());
this.zone = zone;
this.additionalValidators = List.copyOf(additionalValidators.allComponents());
this.clock = Clock.systemUTC();
}
// For testing only
protected VespaModelFactory(ConfigModelRegistry configModelRegistry) {
this(new Version(VespaVersion.major, VespaVersion.minor, VespaVersion.micro), configModelRegistry,
Clock.systemUTC(), Zone.defaultZone());
}
private VespaModelFactory(Version version, ConfigModelRegistry configModelRegistry, Clock clock, Zone zone) {
this.version = version;
if (configModelRegistry == null) {
this.configModelRegistry = new NullConfigModelRegistry();
log.info("Will not load config models from plugins, as no registry is available");
} else {
this.configModelRegistry = configModelRegistry;
}
this.modelImporters = List.of();
this.additionalValidators = List.of();
this.zone = zone;
this.clock = clock;
}
public static VespaModelFactory createTestFactory() {
return createTestFactory(new NullConfigModelRegistry(), Clock.systemUTC());
}
public static VespaModelFactory createTestFactory(ConfigModelRegistry configModelRegistry, Clock clock) {
return createTestFactory(new Version(VespaVersion.major, VespaVersion.minor, VespaVersion.micro), configModelRegistry,
clock, Zone.defaultZone());
}
public static VespaModelFactory createTestFactory(Version version, ConfigModelRegistry configModelRegistry, Clock clock, Zone zone) {
return new VespaModelFactory(version, configModelRegistry, clock, zone);
}
/** Returns the version this model is build for */
@Override
public Version version() { return version; }
@Override
public Model createModel(ModelContext modelContext) {
return buildModel(createDeployState(modelContext, new ValidationParameters(ValidationParameters.IgnoreValidationErrors.TRUE)));
}
private void logReindexingReasons(List changeActions,
VespaModel nextModel,
Optional currentActiveModel)
{
if (currentActiveModel.isEmpty()) {
return;
}
for (ConfigChangeAction action : changeActions) {
if (action.getType().equals(ConfigChangeAction.Type.REINDEX)) {
VespaModel currentModel = (VespaModel) currentActiveModel.get();
var currentMeta = currentModel.applicationPackage().getMetaData();
var nextMeta = nextModel.applicationPackage().getMetaData();
log.log(Level.INFO, String.format("Model [%s/%s] -> [%s/%s] triggers reindexing: %s",
currentModel.version().toString(), currentMeta.toString(),
nextModel.version().toString(), nextMeta.toString(),
action));
}
}
}
@Override
public ModelCreateResult createAndValidateModel(ModelContext modelContext, ValidationParameters validationParameters) {
validateXml(modelContext, validationParameters.ignoreValidationErrors());
DeployState deployState = createDeployState(modelContext, validationParameters);
VespaModel model = buildModel(deployState);
List changeActions = validateModel(model, deployState, validationParameters);
logReindexingReasons(changeActions, model, deployState.getPreviousModel());
return new ModelCreateResult(model, changeActions);
}
private void validateXml(ModelContext modelContext, boolean ignoreValidationErrors) {
if (modelContext.appDir().isPresent()) {
ApplicationPackageXmlFilesValidator validator =
ApplicationPackageXmlFilesValidator.create(modelContext.appDir().get(),
modelContext.modelVespaVersion());
try {
validator.checkApplication();
validator.checkIncludedDirs(modelContext.applicationPackage());
} catch (IllegalArgumentException e) {
rethrowUnlessIgnoreErrors(e, ignoreValidationErrors);
} catch (Exception e) {
throw new RuntimeException(e);
}
} else {
validateXML(modelContext.applicationPackage(), ignoreValidationErrors);
}
}
private VespaModel buildModel(DeployState deployState) {
try {
return new VespaModel(configModelRegistry, deployState);
} catch (IOException | SAXException e) {
throw new IllegalArgumentException(e);
}
}
private DeployState createDeployState(ModelContext modelContext, ValidationParameters validationParameters) {
DeployState.Builder builder = new DeployState.Builder()
.applicationPackage(modelContext.applicationPackage())
.deployLogger(modelContext.deployLogger())
.configDefinitionRepo(modelContext.configDefinitionRepo())
.fileRegistry(modelContext.getFileRegistry())
.executor(modelContext.getExecutor())
.properties(modelContext.properties())
.vespaVersion(version())
.modelHostProvisioner(modelContext.getHostProvisioner())
.provisioned(modelContext.provisioned())
.endpoints(modelContext.properties().endpoints())
.modelImporters(modelImporters)
.zone(zone)
.now(clock.instant())
.wantedNodeVespaVersion(modelContext.wantedNodeVespaVersion())
.wantedDockerImageRepo(modelContext.wantedDockerImageRepo())
.onnxModelCost(modelContext.onnxModelCost());
modelContext.previousModel().ifPresent(builder::previousModel);
modelContext.reindexing().ifPresent(builder::reindexing);
return builder.build(validationParameters);
}
private void validateXML(ApplicationPackage applicationPackage, boolean ignoreValidationErrors) {
try {
applicationPackage.validateXML();
} catch (IllegalArgumentException e) {
rethrowUnlessIgnoreErrors(e, ignoreValidationErrors);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private List validateModel(VespaModel model, DeployState deployState, ValidationParameters validationParameters) {
try {
return new Validation(additionalValidators).validate(model, validationParameters, deployState);
} catch (IllegalArgumentException | TransientException | QuotaExceededException e) {
rethrowUnlessIgnoreErrors(e, validationParameters.ignoreValidationErrors());
} catch (Exception e) {
throw new RuntimeException(e);
}
return new ArrayList<>();
}
private static void rethrowUnlessIgnoreErrors(RuntimeException e, boolean ignoreValidationErrors) {
if (!ignoreValidationErrors) {
throw e;
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy