All Downloads are FREE. Search and download functionalities are using the official Maven repository.

templates.machine-learning-mlflow.inference.api.py.vm Maven / Gradle / Ivy

"""
Implementation of the inference API.

GENERATED STUB CODE - PLEASE ***DO*** MODIFY

Originally generated from: ${templateName} 
"""

import logging
import mlflow
from fastapi import FastAPI, BackgroundTasks, Depends, HTTPException, status
from typing import List
from validation.request import RequestBody
from validation.response import ResponseBody, Inference
from config.inference_config import InferenceConfig
import os
import sys
from aiopsauth.json_web_token_util import JsonWebTokenUtil, AiopsSecurityException
from fastapi.security import OAuth2PasswordBearer
try:
    from modzy.modzy import ModzyRequest, ModzyResponse, ModzyRunner
except ImportError:
    # Modzy is not available
    logging.exception('Modzy is unavailable.  You may need to remove the supporting endpoints i.e. status, run and shutdown')

app = FastAPI()
config = InferenceConfig()
model = None
oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token')
jwt_util = JsonWebTokenUtil()

@app.on_event("startup")
def load_model():
    try:
        # Load the model from the model directory
        global model
        model = mlflow.sklearn.load_model(config.model_directory())
    except Exception:
        logging.exception('Failed to load model')


def authenticate_request(token: str = Depends(oauth2_scheme)):
    """
    Validates the token in the request.
    """
    try:
        jwt_util.validate_token(token)
    except AiopsSecurityException as e:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=str(e),
            headers={'WWW-Authenticate': 'Bearer'}
        )


@app.get("/healthcheck")
def healthcheck():
    """
    REST endpoint to check that the inference service is running.
    """
    return 'Inference service for ${pipeline.name} is running'


@app.post("/infer", response_model=ResponseBody)
def infer(request: RequestBody, authenticate = Depends(authenticate_request)):
    """
    REST endpoint for inference requests.
    """
    if not model:
        load_model()

    if model:
        # Prep the data from the inference request
        prepped_data = request.prep_data()

        # Use the model to predict using the prepped data
        predictions = model.predict(prepped_data)

        # TODO: Map predictions to desired inference format
        inferences: List[Inference] = []
    else:
        raise HTTPException(
            status_code=status.HTTP_412_PRECONDITION_FAILED,
            detail='Model not available.  Unable to process inference. Please verify the model is trained and available.'
        )

    return ResponseBody(inferences=inferences)

# Needed for Modzy
@app.get("/status")
def get_status():
    if not model:
        load_model()

    modzy_runner = ModzyRunner(model)

    return modzy_runner.get_status()

# Needed for Modzy
@app.post("/run")
def post_run(request: ModzyRequest):
    if not model:
        load_model()

    modzy_runner = ModzyRunner(model)

    return modzy_runner.run(request)

# Needed for Modzy
@app.get("/shutdown")
def read_shutdown(background_tasks: BackgroundTasks):
    background_tasks.add_task(exit)
    return ModzyResponse(message='Shutting down', status='OK', statusCode=200)

def exit():
    sys.exit(os.EX_OK)




© 2015 - 2025 Weber Informatics LLC | Privacy Policy