All Downloads are FREE. Search and download functionalities are using the official Maven repository.

templates.inference.inference.api.rest.py.vm Maven / Gradle / Ivy

"""
Base implementation of a REST endpoint for an inference analytic.

GENERATED CODE - ***DO*** MODIFY

NOTE: Developers may customize the REST API defined in this module to support authentication and
authorization, but should add inference analytic implementation logic to inference_impl.py.

Generated from: ${templateName}
"""

from fastapi import FastAPI
from ...validation.inference_message_definition import RequestBody, ResponseBody
from ...generated.inference.rest.inference_payload_definition import (
    InferenceRequest as RestInferenceRequest,
    BatchInferenceRequest as RestBatchInferenceRequest,
    InferenceResponse as RestInferenceResponse,
    BatchInferenceResponse as RestBatchInferenceResponse,
    RecordRowIdAndInferenceResultPair as RestRecordIdAndInferencePair,
)
from ...inference_impl import *
from aissembleauth.json_web_token_util import JsonWebTokenUtil
from fastapi.security import OAuth2PasswordBearer


app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token')
jwt_util = JsonWebTokenUtil()

# Later iterations will integrate authentication
# def authenticate_request(token: str = Depends(oauth2_scheme)):
#     """
#     Validates the token in the request.
#     """
#     try:
#         jwt_util.validate_token(token)
#         auth_config = AuthConfig()
#         pdp_client = PDPClient(auth_config.pdp_host_url())
#         decision = pdp_client.authorize(token, None, 'data-access')
#         if "PERMIT" != decision:
#             raise AissembleSecurityException('User is not authorized')
#     except AissembleSecurityException as e:
#         raise HTTPException(
#             status_code=status.HTTP_401_UNAUTHORIZED,
#             detail=str(e),
#             headers={'WWW-Authenticate': 'Bearer'}
#         )

@app.get("/healthcheck")
def healthcheck():
    """
    REST endpoint to check that the inference service is running.
    """
    return 'Inference service for InferencePipeline is running'


@app.post("/infer", response_model=ResponseBody)
def infer(request: RequestBody):
    """
    REST endpoint for inference requests.
    """
    return execute_inference(request)


@app.post("/analyze", response_model=RestInferenceResponse)
def analyze(request: RestInferenceRequest):
    """
    REST endpoint for performing inferencing against a single provided input data record.
    :param request:
    :return:
    """
    infer_impl_response = execute_inference(RequestBody(data=[request]))
    return RestInferenceResponse.parse_obj(infer_impl_response.inferences[0].dict())


@app.post("/analyze-batch", response_model=RestBatchInferenceResponse)
def analyze_batch(request: RestBatchInferenceRequest):
    """
    REST endpoint for performing inferencing against a batch of multiple provided input data records.
    :param request:
    :return:
    """
    infer_impl_response = execute_inference(RequestBody(data=request.data))
    record_row_id_and_inference_pairs = []
    for index, inference_result in enumerate(infer_impl_response.inferences):
        record_row_id = getattr(request.data[index], request.row_id_key)
        record_row_id_and_inference_pairs.append(
            RestRecordIdAndInferencePair(row_id=record_row_id, result=inference_result))
    return RestBatchInferenceResponse(results=record_row_id_and_inference_pairs)






© 2015 - 2025 Weber Informatics LLC | Privacy Policy