All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.vespa.model.ml.OnnxModelProbe Maven / Gradle / Ivy

// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.ml;

import com.yahoo.json.Jackson;
import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonNode;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.api.OnnxMemoryStats;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.tensor.TensorType;

import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Map;

/**
 * Defers to 'vespa-analyze-onnx-model' to determine the output type given
 * a set of inputs. For situations with symbolic dimension sizes that can't
 * easily be determined.
 *
 * @author lesters
 */
public class OnnxModelProbe {

    private static final String binary = "vespa-analyze-onnx-model";

    static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map inputTypes) {
        TensorType outputType = TensorType.empty;
        String contextKey = createContextKey(outputName, inputTypes);

        try {
            // Check if output type has already been probed
            outputType = readProbedOutputType(app, modelPath, contextKey);

            // Otherwise, run vespa-analyze-onnx-model if the model is available
            if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) {
                String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes);
                var jsonOutput = callVespaAnalyzeOnnxModel(jsonInput);
                outputType = outputTypeFromJson(jsonOutput, outputName);
                writeMemoryStats(app, modelPath, OnnxMemoryStats.fromJson(jsonOutput));
                if ( ! outputType.equals(TensorType.empty)) {
                    writeProbedOutputType(app, modelPath, contextKey, outputType);
                }
            }

        } catch (IllegalArgumentException | IOException | InterruptedException ignored) { }

        return outputType;
    }

    private static void writeMemoryStats(ApplicationPackage app, Path modelPath, OnnxMemoryStats memoryStats) throws IOException {
        String path = app.getFileReference(OnnxMemoryStats.memoryStatsFilePath(modelPath)).getAbsolutePath();
        IOUtils.writeFile(path, memoryStats.toJson().toPrettyString(), false);
    }

    private static String createContextKey(String onnxName, Map inputTypes) {
        StringBuilder key = new StringBuilder().append(onnxName).append(":");
        inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey())
                .forEachOrdered(e -> key.append(e.getKey()).append(":").append(e.getValue()).append(","));
        return key.substring(0, key.length()-1);
    }

    private static Path probedOutputTypesPath(Path path) {
        String fileName = OnnxModelInfo.asValidIdentifier(path.getRelative()) + ".probed_output_types";
        return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName);
    }

    static void writeProbedOutputType(ApplicationPackage app, Path modelPath, String output,
                                      Map inputTypes, TensorType type) throws IOException {
        writeProbedOutputType(app, modelPath, createContextKey(output, inputTypes), type);
    }

    private static void writeProbedOutputType(ApplicationPackage app, Path modelPath,
                                              String contextKey, TensorType type) throws IOException {
        String path = app.getFileReference(probedOutputTypesPath(modelPath)).getAbsolutePath();
        IOUtils.writeFile(path, contextKey + "\t" + type + "\n", true);
    }

    private static TensorType readProbedOutputType(ApplicationPackage app, Path modelPath,
                                                   String contextKey) throws IOException {
        ApplicationFile file = app.getFile(probedOutputTypesPath(modelPath));
        if ( ! file.exists()) {
            return TensorType.empty;
        }
        try (BufferedReader reader = new BufferedReader(file.createReader())) {
            String line;
            while (null != (line = reader.readLine())) {
                String[] parts = line.split("\t");
                String key = parts[0];
                if (key.equals(contextKey)) {
                    return TensorType.fromSpec(parts[1]);
                }
            }
        }
        return TensorType.empty;
    }

    private static TensorType outputTypeFromJson(JsonNode root, String outputName) throws IOException {
        if ( ! root.isObject() || ! root.has("outputs")) {
            return TensorType.empty;
        }
        JsonNode outputs = root.get("outputs");
        if ( ! outputs.has(outputName)) {
            return TensorType.empty;
        }
        return TensorType.fromSpec(outputs.get(outputName).asText());
    }

    private static String createJsonInput(String modelPath, Map inputTypes) throws IOException {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8);
        g.writeStartObject();
        g.writeStringField("model", modelPath);
        g.writeObjectFieldStart("inputs");
        for (Map.Entry input : inputTypes.entrySet()) {
            g.writeStringField(input.getKey(), input.getValue().toString());
        }
        g.writeEndObject();
        g.writeEndObject();
        g.close();
        return out.toString();
    }

    private static JsonNode callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException {
        StringBuilder output = new StringBuilder();

        ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types");
        processBuilder.redirectError(ProcessBuilder.Redirect.DISCARD);
        Process process = processBuilder.start();

        // Write json array to process stdin
        OutputStream os = process.getOutputStream();
        os.write(jsonInput.getBytes(StandardCharsets.UTF_8));
        os.close();

        // Read output from stdout
        InputStream inputStream = process.getInputStream();
        while (true) {
            int b = inputStream.read();
            if (b == -1) break;
            output.append((char)b);
        }

        int returnCode = process.waitFor();
        if (returnCode != 0) {
            throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". " +
                                               "Output: '" + output + "'");
        }
        return Jackson.mapper().readTree(output.toString());
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy