All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.vespa.model.VespaModelFactory Maven / Gradle / Ivy

// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model;

import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import ai.vespa.rankingexpression.importer.vespa.VespaImporter;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import com.yahoo.component.Version;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.ConfigModelRegistry;
import com.yahoo.config.model.MapConfigModelRegistry;
import com.yahoo.config.model.NullConfigModelRegistry;
import com.yahoo.config.model.api.ConfigChangeAction;
import com.yahoo.config.model.api.ConfigModelPlugin;
import com.yahoo.config.model.api.Model;
import com.yahoo.config.model.api.ModelContext;
import com.yahoo.config.model.api.ModelCreateResult;
import com.yahoo.config.model.api.ModelFactory;
import com.yahoo.config.model.api.ValidationParameters;
import com.yahoo.config.model.application.provider.ApplicationPackageXmlFilesValidator;
import com.yahoo.config.model.builder.xml.ConfigModelBuilder;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.provision.QuotaExceededException;
import com.yahoo.config.provision.TransientException;
import com.yahoo.config.provision.Zone;
import com.yahoo.vespa.config.VespaVersion;
import com.yahoo.vespa.model.application.validation.Validation;
import com.yahoo.vespa.model.application.validation.Validator;
import org.xml.sax.SAXException;

import java.io.IOException;
import java.time.Clock;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * Factory for creating {@link VespaModel} instances.
 *
 * @author Ulf Lilleengen
 */
public class VespaModelFactory implements ModelFactory {

    private static final Logger log = Logger.getLogger(VespaModelFactory.class.getName());
    private final ConfigModelRegistry configModelRegistry;
    private final Collection modelImporters;
    private final Zone zone;
    private final Clock clock;
    private final Version version;
    private final List additionalValidators;

    /** Creates a factory for Vespa models for this version of the source */
    @Inject
    public VespaModelFactory(ComponentRegistry pluginRegistry,
                             ComponentRegistry additionalValidators,
                             Zone zone) {
        this.version = new Version(VespaVersion.major, VespaVersion.minor, VespaVersion.micro);
        List> modelBuilders = new ArrayList<>();
        for (ConfigModelPlugin plugin : pluginRegistry.allComponents()) {
            if (plugin instanceof ConfigModelBuilder p) {
                modelBuilders.add(p);
            }
        }
        this.configModelRegistry = new MapConfigModelRegistry(modelBuilders);
        this.modelImporters = List.of(
                new VespaImporter(),
                new OnnxImporter(),
                new TensorFlowImporter(),
                new XGBoostImporter(),
                new LightGBMImporter());
        this.zone = zone;
        this.additionalValidators = List.copyOf(additionalValidators.allComponents());

        this.clock = Clock.systemUTC();
    }

    // For testing only
    protected VespaModelFactory(ConfigModelRegistry configModelRegistry) {
        this(new Version(VespaVersion.major, VespaVersion.minor, VespaVersion.micro), configModelRegistry,
                Clock.systemUTC(), Zone.defaultZone());
    }

    private VespaModelFactory(Version version, ConfigModelRegistry configModelRegistry, Clock clock, Zone zone) {
        this.version = version;
        if (configModelRegistry == null) {
            this.configModelRegistry = new NullConfigModelRegistry();
            log.info("Will not load config models from plugins, as no registry is available");
        } else {
            this.configModelRegistry = configModelRegistry;
        }
        this.modelImporters = List.of();
        this.additionalValidators = List.of();
        this.zone = zone;
        this.clock = clock;
    }

    public static VespaModelFactory createTestFactory() {
        return createTestFactory(new NullConfigModelRegistry(), Clock.systemUTC());
    }
    public static VespaModelFactory createTestFactory(ConfigModelRegistry configModelRegistry, Clock clock) {
        return createTestFactory(new Version(VespaVersion.major, VespaVersion.minor, VespaVersion.micro), configModelRegistry,
                clock, Zone.defaultZone());
    }

    public static VespaModelFactory createTestFactory(Version version, ConfigModelRegistry configModelRegistry, Clock clock, Zone zone) {
        return new VespaModelFactory(version, configModelRegistry, clock, zone);
    }

    /** Returns the version this model is build for */
    @Override
    public Version version() { return version; }

    @Override
    public Model createModel(ModelContext modelContext) {
        return buildModel(createDeployState(modelContext, new ValidationParameters(ValidationParameters.IgnoreValidationErrors.TRUE)));
    }

    private void logReindexingReasons(List changeActions,
                                      VespaModel nextModel,
                                      Optional currentActiveModel)
    {
        if (currentActiveModel.isEmpty()) {
            return;
        }
        for (ConfigChangeAction action : changeActions) {
            if (action.getType().equals(ConfigChangeAction.Type.REINDEX)) {
                VespaModel currentModel = (VespaModel) currentActiveModel.get();
                var currentMeta = currentModel.applicationPackage().getMetaData();
                var nextMeta = nextModel.applicationPackage().getMetaData();
                log.log(Level.INFO, String.format("Model [%s/%s] -> [%s/%s] triggers reindexing: %s",
                                                  currentModel.version().toString(), currentMeta.toString(),
                                                  nextModel.version().toString(), nextMeta.toString(),
                                                  action));
            }
        }
    }

    @Override
    public ModelCreateResult createAndValidateModel(ModelContext modelContext, ValidationParameters validationParameters) {
        validateXml(modelContext, validationParameters.ignoreValidationErrors());
        DeployState deployState = createDeployState(modelContext, validationParameters);
        VespaModel model = buildModel(deployState);
        List changeActions = validateModel(model, deployState, validationParameters);
        logReindexingReasons(changeActions, model, deployState.getPreviousModel());
        return new ModelCreateResult(model, changeActions);
    }
    
    private void validateXml(ModelContext modelContext, boolean ignoreValidationErrors) {
        if (modelContext.appDir().isPresent()) {
            ApplicationPackageXmlFilesValidator validator =
                    ApplicationPackageXmlFilesValidator.create(modelContext.appDir().get(),
                                                               modelContext.modelVespaVersion());
            try {
                validator.checkApplication();
                validator.checkIncludedDirs(modelContext.applicationPackage());
            } catch (IllegalArgumentException e) {
                rethrowUnlessIgnoreErrors(e, ignoreValidationErrors);
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } else {
            validateXML(modelContext.applicationPackage(), ignoreValidationErrors);
        }
    }

    private VespaModel buildModel(DeployState deployState) {
        try {
            return new VespaModel(configModelRegistry, deployState);
        } catch (IOException | SAXException e) {
            throw new IllegalArgumentException(e);
        }
    }

    private DeployState createDeployState(ModelContext modelContext, ValidationParameters validationParameters) {
        DeployState.Builder builder = new DeployState.Builder()
            .applicationPackage(modelContext.applicationPackage())
            .deployLogger(modelContext.deployLogger())
            .configDefinitionRepo(modelContext.configDefinitionRepo())
            .fileRegistry(modelContext.getFileRegistry())
            .executor(modelContext.getExecutor())
            .properties(modelContext.properties())
            .vespaVersion(version())
            .modelHostProvisioner(modelContext.getHostProvisioner())
            .provisioned(modelContext.provisioned())
            .endpoints(modelContext.properties().endpoints())
            .modelImporters(modelImporters)
            .zone(zone)
            .now(clock.instant())
            .wantedNodeVespaVersion(modelContext.wantedNodeVespaVersion())
            .wantedDockerImageRepo(modelContext.wantedDockerImageRepo())
            .onnxModelCost(modelContext.onnxModelCost());
        modelContext.previousModel().ifPresent(builder::previousModel);
        modelContext.reindexing().ifPresent(builder::reindexing);
        return builder.build(validationParameters);
    }

    private void validateXML(ApplicationPackage applicationPackage, boolean ignoreValidationErrors) {
        try {
            applicationPackage.validateXML();
        } catch (IllegalArgumentException e) {
            rethrowUnlessIgnoreErrors(e, ignoreValidationErrors);
        } catch (Exception e) {
             throw new RuntimeException(e);
        }
    }

    private List validateModel(VespaModel model, DeployState deployState, ValidationParameters validationParameters) {
        try {
            return new Validation(additionalValidators).validate(model, validationParameters, deployState);
        } catch (IllegalArgumentException | TransientException | QuotaExceededException e) {
            rethrowUnlessIgnoreErrors(e, validationParameters.ignoreValidationErrors());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        return new ArrayList<>();
    }

    private static void rethrowUnlessIgnoreErrors(RuntimeException e, boolean ignoreValidationErrors) {
        if (!ignoreValidationErrors) {
            throw e;
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy