All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.mlflow.LoaderModule Maven / Gradle / Ivy

There is a newer version: 2.16.2
Show newest version
package org.mlflow;

import java.io.IOException;
import java.util.Optional;
import org.mlflow.models.Model;
import org.mlflow.sagemaker.Predictor;
import org.mlflow.sagemaker.PredictorLoadingException;

/**
 * A generic loader for encapsulating flavor-specific model deserialization logic. By extending
 * {@link LoaderModule}, models of a specific flavor can be loaded as generic {@link Predictor}
 * objects. This allows tools, such as model containers, to use the models for inference
 */
public abstract class LoaderModule {
  /**
   * Loads an MLflow model as a generic predictor that can be used for inference
   *
   * Throws {@link PredictorLoadingException} for any failure encountered while attempting to load
   *     the model
   */
  public Predictor load(Model model) {
    Optional flavor = model.getFlavor(getFlavorName(), getFlavorClass());
    if (!flavor.isPresent()) {
      throw new PredictorLoadingException(
          String.format(
              "Attempted to load the %s flavor of the model,"
                  + " but the model does not have this flavor.",
              getFlavorName()));
    }
    Optional rootPath = model.getRootPath();
    if (!rootPath.isPresent()) {
      throw new PredictorLoadingException(
          "An internal error occurred while loading the model:"
              + " the model's root path could not be found. Please ensure that this"
              + " model was created using `Model.fromRootPath()` or `Model.fromConfigPath()`");
    }
    return createPredictor(rootPath.get(), flavor.get());
  }

  /**
   * Loads an MLflow model as a generic predictor that can be used for inference
   * Throws {@link PredictorLoadingException} for any failure encountered while attempting to load
   *     the model
   *
   * @param modelRootPath The path to the root directory of the MLflow model
   */
  public Predictor load(String modelRootPath) throws PredictorLoadingException {
    try {
      Optional model = Optional.of(Model.fromRootPath(modelRootPath));
      return load(model.get());
    } catch (IOException e) {
      throw new PredictorLoadingException(
          "Failed to load the model configuration at the specified path. Please ensure that"
              + " this is the path to the root directory of a valid MLflow model", e);
    }
  }

  /**
   * Creates a {@link Predictor} from an MLflow model using the specified flavor configuration
   *
   * 

Implementations of this method are expected to throw a {@link PredictorLoadingException} * when errors are encountered while loading the model * * @param modelRootPath The path to the root directory of the MLflow model * @param flavor The flavor configuration to use when creating the {@link Predictor}. This * configuration provides additional metadata that may be necessary for {@link Predictor} * creation. */ protected abstract Predictor createPredictor(String modelRootPath, T flavor) throws PredictorLoadingException; /** * @return The {@link org.mlflow.Flavor} class associated with this loader module. This is * required during the {@link #load(Model)} procedure */ protected abstract Class getFlavorClass(); /** * @return The name of the flavor associated with this loader module. module. This is required * during the {@link #load(Model)} procedure */ protected abstract String getFlavorName(); }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy