
templates.inference.inference.api.rest.py.vm Maven / Gradle / Ivy
"""
Base implementation of a REST endpoint for an inference analytic.
GENERATED CODE - ***DO*** MODIFY
NOTE: Developers may customize the REST API defined in this module to support authentication and
authorization, but should add inference analytic implementation logic to inference_impl.py.
Generated from: ${templateName}
"""
from fastapi import FastAPI
from ...validation.inference_message_definition import RequestBody, ResponseBody
from ...generated.inference.rest.inference_payload_definition import (
InferenceRequest as RestInferenceRequest,
BatchInferenceRequest as RestBatchInferenceRequest,
InferenceResponse as RestInferenceResponse,
BatchInferenceResponse as RestBatchInferenceResponse,
RecordRowIdAndInferenceResultPair as RestRecordIdAndInferencePair,
)
from ...inference_impl import *
from aissembleauth.json_web_token_util import JsonWebTokenUtil
from fastapi.security import OAuth2PasswordBearer
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token')
jwt_util = JsonWebTokenUtil()
# Later iterations will integrate authentication
# def authenticate_request(token: str = Depends(oauth2_scheme)):
# """
# Validates the token in the request.
# """
# try:
# jwt_util.validate_token(token)
# auth_config = AuthConfig()
# pdp_client = PDPClient(auth_config.pdp_host_url())
# decision = pdp_client.authorize(token, None, 'data-access')
# if "PERMIT" != decision:
# raise AissembleSecurityException('User is not authorized')
# except AissembleSecurityException as e:
# raise HTTPException(
# status_code=status.HTTP_401_UNAUTHORIZED,
# detail=str(e),
# headers={'WWW-Authenticate': 'Bearer'}
# )
@app.get("/healthcheck")
def healthcheck():
"""
REST endpoint to check that the inference service is running.
"""
return 'Inference service for InferencePipeline is running'
@app.post("/infer", response_model=ResponseBody)
def infer(request: RequestBody):
"""
REST endpoint for inference requests.
"""
return execute_inference(request)
@app.post("/analyze", response_model=RestInferenceResponse)
def analyze(request: RestInferenceRequest):
"""
REST endpoint for performing inferencing against a single provided input data record.
:param request:
:return:
"""
infer_impl_response = execute_inference(RequestBody(data=[request]))
return RestInferenceResponse.parse_obj(infer_impl_response.inferences[0].dict())
@app.post("/analyze-batch", response_model=RestBatchInferenceResponse)
def analyze_batch(request: RestBatchInferenceRequest):
"""
REST endpoint for performing inferencing against a batch of multiple provided input data records.
:param request:
:return:
"""
infer_impl_response = execute_inference(RequestBody(data=request.data))
record_row_id_and_inference_pairs = []
for index, inference_result in enumerate(infer_impl_response.inferences):
record_row_id = getattr(request.data[index], request.row_id_key)
record_row_id_and_inference_pairs.append(
RestRecordIdAndInferencePair(row_id=record_row_id, result=inference_result))
return RestBatchInferenceResponse(results=record_row_id_and_inference_pairs)
© 2015 - 2025 Weber Informatics LLC | Privacy Policy