com.yahoo.vespa.model.container.ml.ModelsEvaluatorTester Maven / Gradle / Ivy
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.container.ml;
import ai.vespa.models.evaluation.ModelsEvaluator;
import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import ai.vespa.rankingexpression.importer.vespa.VespaImporter;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import com.yahoo.config.FileReference;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.application.api.FileRegistry;
import com.yahoo.config.model.application.provider.MockFileRegistry;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
import com.yahoo.io.IOUtils;
import com.yahoo.schema.derived.RankProfileList;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import com.yahoo.vespa.model.VespaModel;
import net.jpountz.lz4.LZ4FrameOutputStream;
import org.xml.sax.SAXException;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* A ModelsEvaluator object is usually injected automatically in a component if
* requested. This class is for creating a ModelsEvaluator so that the component
* can be properly unit tested. Pass a directory containing model files, such
* as the application's "models" directory, and it will return a ModelsEvaluator
* for the imported models.
*
* For use in testing only.
*
* @author lesters
*/
public class ModelsEvaluatorTester {
private static final List importers = List.of(new TensorFlowImporter(),
new OnnxImporter(),
new LightGBMImporter(),
new XGBoostImporter(),
new VespaImporter());
private static final String modelEvaluationServices = "" +
" " +
" " +
" " +
" ";
/**
* Create a ModelsEvaluator from the models found in the modelsPath. Does
* not need to be in a application package.
*
* @param modelsPath Path to a directory containing models to import
* @return a ModelsEvaluator containing the imported models
*/
public static ModelsEvaluator create(String modelsPath) {
File temporaryApplicationDir = null;
try {
temporaryApplicationDir = createTemporaryApplicationDir(modelsPath);
MockFileRegistry fileRegistry = new MockFileBlobRegistry(temporaryApplicationDir);
RankProfileList rankProfileList = createRankProfileList(temporaryApplicationDir, fileRegistry);
RankProfilesConfig rankProfilesConfig = getRankProfilesConfig(rankProfileList);
RankingConstantsConfig rankingConstantsConfig = getRankingConstantConfig(rankProfileList);
RankingExpressionsConfig rankingExpressionsConfig = getRankingExpressionsConfig(rankProfileList);
OnnxModelsConfig onnxModelsConfig = getOnnxModelsConfig(rankProfileList);
FileAcquirer files = createFileAcquirer(fileRegistry, temporaryApplicationDir);
return new ModelsEvaluator(rankProfilesConfig, rankingConstantsConfig, rankingExpressionsConfig, onnxModelsConfig, files);
} catch (IOException | SAXException e) {
throw new IllegalArgumentException(e);
} finally {
if (temporaryApplicationDir != null) {
IOUtils.recursiveDeleteDir(temporaryApplicationDir);
}
}
}
private static File createTemporaryApplicationDir(String modelsPath) throws IOException {
String tmpDir = Files.exists(Path.of("target")) ? "target" : "";
File temporaryApplicationDir = Files.createTempDirectory(Path.of(tmpDir), "tmp_").toFile();
File modelsDir = relativePath(temporaryApplicationDir, ApplicationPackage.MODELS_DIR.toString());
IOUtils.copyDirectory(new File(modelsPath), modelsDir);
return temporaryApplicationDir;
}
private static RankProfileList createRankProfileList(File appDir, FileRegistry registry) throws IOException, SAXException {
ApplicationPackage app = new MockApplicationPackage.Builder()
.withEmptyHosts()
.withServices(modelEvaluationServices)
.withRoot(appDir).build();
DeployState deployState = new DeployState.Builder()
.applicationPackage(app)
.fileRegistry(registry)
.modelImporters(importers).build();
VespaModel vespaModel = new VespaModel(deployState);
return vespaModel.rankProfileList();
}
private static RankProfilesConfig getRankProfilesConfig(RankProfileList rankProfileList) {
return new RankProfilesConfig.Builder()
.rankprofile(rankProfileList.getRankProfilesConfig())
.build();
}
private static RankingConstantsConfig getRankingConstantConfig(RankProfileList rankProfileList) {
return new RankingConstantsConfig.Builder()
.constant(rankProfileList.getConstantsConfig())
.build();
}
private static RankingExpressionsConfig getRankingExpressionsConfig(RankProfileList rankProfileList) {
return new RankingExpressionsConfig.Builder()
.expression(rankProfileList.getExpressionsConfig())
.build();
}
private static OnnxModelsConfig getOnnxModelsConfig(RankProfileList rankProfileList) {
return new OnnxModelsConfig.Builder()
.model(rankProfileList.getOnnxConfig())
.build();
}
private static FileAcquirer createFileAcquirer(MockFileRegistry fileRegistry, File appDir) {
Map fileMap = new HashMap<>();
for (FileRegistry.Entry entry : fileRegistry.export()) {
fileMap.put(entry.reference.value(), relativePath(appDir, entry.reference.value()));
}
return MockFileAcquirer.returnFiles(fileMap);
}
private static File relativePath(File root, String subpath) {
return new File(root.getAbsolutePath() + File.separator + subpath);
}
private static class MockFileBlobRegistry extends MockFileRegistry {
private final File appDir;
MockFileBlobRegistry(File appdir) {
this.appDir = appdir;
}
@Override
public FileReference addBlob(String name, ByteBuffer blob) {
writeBlob(blob, name);
return addFile(name);
}
private void writeBlob(ByteBuffer blob, String relativePath) {
try (FileOutputStream fos = new FileOutputStream(new File(appDir, relativePath))) {
if (relativePath.endsWith(".lz4")) {
LZ4FrameOutputStream lz4 = new LZ4FrameOutputStream(fos);
lz4.write(blob.array(), blob.arrayOffset(), blob.remaining());
lz4.close();
} else {
fos.write(blob.array(), blob.arrayOffset(), blob.remaining());
}
} catch (IOException e) {
throw new IllegalArgumentException("Failed writing temp file", e);
}
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy