com.yahoo.vespa.model.ml.OnnxModelProbe Maven / Gradle / Ivy
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.ml;
import com.yahoo.json.Jackson;
import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonNode;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.api.OnnxMemoryStats;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.tensor.TensorType;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Map;
/**
* Defers to 'vespa-analyze-onnx-model' to determine the output type given
* a set of inputs. For situations with symbolic dimension sizes that can't
* easily be determined.
*
* @author lesters
*/
public class OnnxModelProbe {
private static final String binary = "vespa-analyze-onnx-model";
static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map inputTypes) {
TensorType outputType = TensorType.empty;
String contextKey = createContextKey(outputName, inputTypes);
try {
// Check if output type has already been probed
outputType = readProbedOutputType(app, modelPath, contextKey);
// Otherwise, run vespa-analyze-onnx-model if the model is available
if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) {
String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes);
var jsonOutput = callVespaAnalyzeOnnxModel(jsonInput);
outputType = outputTypeFromJson(jsonOutput, outputName);
writeMemoryStats(app, modelPath, OnnxMemoryStats.fromJson(jsonOutput));
if ( ! outputType.equals(TensorType.empty)) {
writeProbedOutputType(app, modelPath, contextKey, outputType);
}
}
} catch (IllegalArgumentException | IOException | InterruptedException ignored) { }
return outputType;
}
private static void writeMemoryStats(ApplicationPackage app, Path modelPath, OnnxMemoryStats memoryStats) throws IOException {
String path = app.getFileReference(OnnxMemoryStats.memoryStatsFilePath(modelPath)).getAbsolutePath();
IOUtils.writeFile(path, memoryStats.toJson().toPrettyString(), false);
}
private static String createContextKey(String onnxName, Map inputTypes) {
StringBuilder key = new StringBuilder().append(onnxName).append(":");
inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey())
.forEachOrdered(e -> key.append(e.getKey()).append(":").append(e.getValue()).append(","));
return key.substring(0, key.length()-1);
}
private static Path probedOutputTypesPath(Path path) {
String fileName = OnnxModelInfo.asValidIdentifier(path.getRelative()) + ".probed_output_types";
return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName);
}
static void writeProbedOutputType(ApplicationPackage app, Path modelPath, String output,
Map inputTypes, TensorType type) throws IOException {
writeProbedOutputType(app, modelPath, createContextKey(output, inputTypes), type);
}
private static void writeProbedOutputType(ApplicationPackage app, Path modelPath,
String contextKey, TensorType type) throws IOException {
String path = app.getFileReference(probedOutputTypesPath(modelPath)).getAbsolutePath();
IOUtils.writeFile(path, contextKey + "\t" + type + "\n", true);
}
private static TensorType readProbedOutputType(ApplicationPackage app, Path modelPath,
String contextKey) throws IOException {
ApplicationFile file = app.getFile(probedOutputTypesPath(modelPath));
if ( ! file.exists()) {
return TensorType.empty;
}
try (BufferedReader reader = new BufferedReader(file.createReader())) {
String line;
while (null != (line = reader.readLine())) {
String[] parts = line.split("\t");
String key = parts[0];
if (key.equals(contextKey)) {
return TensorType.fromSpec(parts[1]);
}
}
}
return TensorType.empty;
}
private static TensorType outputTypeFromJson(JsonNode root, String outputName) throws IOException {
if ( ! root.isObject() || ! root.has("outputs")) {
return TensorType.empty;
}
JsonNode outputs = root.get("outputs");
if ( ! outputs.has(outputName)) {
return TensorType.empty;
}
return TensorType.fromSpec(outputs.get(outputName).asText());
}
private static String createJsonInput(String modelPath, Map inputTypes) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8);
g.writeStartObject();
g.writeStringField("model", modelPath);
g.writeObjectFieldStart("inputs");
for (Map.Entry input : inputTypes.entrySet()) {
g.writeStringField(input.getKey(), input.getValue().toString());
}
g.writeEndObject();
g.writeEndObject();
g.close();
return out.toString();
}
private static JsonNode callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException {
StringBuilder output = new StringBuilder();
ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types");
processBuilder.redirectError(ProcessBuilder.Redirect.DISCARD);
Process process = processBuilder.start();
// Write json array to process stdin
OutputStream os = process.getOutputStream();
os.write(jsonInput.getBytes(StandardCharsets.UTF_8));
os.close();
// Read output from stdout
InputStream inputStream = process.getInputStream();
while (true) {
int b = inputStream.read();
if (b == -1) break;
output.append((char)b);
}
int returnCode = process.waitFor();
if (returnCode != 0) {
throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". " +
"Output: '" + output + "'");
}
return Jackson.mapper().readTree(output.toString());
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy