All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.schema.derived.FileDistributedOnnxModels Maven / Gradle / Ivy

There is a newer version: 8.441.21
Show newest version
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema.derived;

import com.yahoo.config.application.api.FileRegistry;
import com.yahoo.schema.OnnxModel;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

/**
 * ONNX models distributed as files.
 *
 * @author bratseth
 */
public class FileDistributedOnnxModels {

    private static final Logger log = Logger.getLogger(FileDistributedOnnxModels.class.getName());

    private final Map models;

    public FileDistributedOnnxModels(FileRegistry fileRegistry, Collection models) {
        Map distributableModels = new LinkedHashMap<>();
        for (var model : models) {
            model.validate();
            model.register(fileRegistry);
            distributableModels.put(model.getName(), model);
        }
        this.models = Collections.unmodifiableMap(distributableModels);
    }

    private FileDistributedOnnxModels(Collection models) {
        Map distributableModels = models.stream()
                .collect(LinkedHashMap::new, (m, v) -> m.put(v.getName(), v.clone()), LinkedHashMap::putAll);
        this.models = Collections.unmodifiableMap(distributableModels);
    }

    public FileDistributedOnnxModels clone() {
        return new FileDistributedOnnxModels(models.values());
    }

    public Map asMap() { return models; }

    private static OnnxModelsConfig.Model.Builder toConfig(OnnxModel model) {
        OnnxModelsConfig.Model.Builder builder = new OnnxModelsConfig.Model.Builder();
        builder.dry_run_on_setup(true);
        builder.name(model.getName());
        builder.fileref(model.getFileReference());
        model.getInputMap().forEach((name, source) -> builder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
        model.getOutputMap().forEach((name, as) -> builder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
        if (model.getStatelessExecutionMode().isPresent())
            builder.stateless_execution_mode(model.getStatelessExecutionMode().get());
        if (model.getStatelessInterOpThreads().isPresent())
            builder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
        if (model.getStatelessIntraOpThreads().isPresent())
            builder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
        if (model.getGpuDevice().isPresent()) {
            builder.gpu_device(model.getGpuDevice().get().deviceNumber());
            builder.gpu_device_required(model.getGpuDevice().get().required());
        }
        return builder;
    }

    public List getConfig() {
        List cfgList = new ArrayList<>();
        for (OnnxModel model : models.values()) {
            if ("".equals(model.getFileReference()))
                log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way
            else {
                cfgList.add(toConfig(model));
            }
        }
        return cfgList;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy