All Downloads are FREE. Search and download functionalities are using the official Maven repository.

djl_python.test_model.py Maven / Gradle / Ivy

There is a newer version: 0.28.0
Show newest version
#!/usr/bin/env python3
#
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import logging
import os
import sys

from .arg_parser import ArgParser
from .inputs import Input
from .outputs import Output
from .np_util import to_nd_list
from .service_loader import load_model_service


def create_request(input_files, parameters):
    request = Input()
    request.properties["device_id"] = "-1"

    if parameters:
        for parameter in parameters:
            pair = parameter.split("=", 2)
            if len(pair) != 2:
                raise ValueError(f"Invalid model parameter: {parameter}")
            request.properties[pair[0]] = pair[1]

    data_file = None
    for file in input_files:
        pair = file.split("=", 2)
        if len(pair) == 1:
            key = None
            val = pair[0]
        else:
            key = pair[0]
            val = pair[1]

        if data_file is None or key == "data":
            data_file = val

        if not os.path.exists(val):
            request.content.add(key=key, value=val.encode("utf-8"))
            if val.startswith("{") and val.endswith("}"):
                request.properties["content-type"] = "application/json"
            else:
                request.properties["content-type"] = "text/plain"
        else:
            with open(val, "rb") as f:
                request.content.add(key=key, value=f.read(-1))

    if data_file.endswith(".json"):
        request.properties["content-type"] = "application/json"
    elif data_file.endswith(".txt"):
        request.properties["content-type"] = "text/plain"
    elif data_file.endswith(".gif"):
        request.properties["content-type"] = "images/gif"
    elif data_file.endswith(".png"):
        request.properties["content-type"] = "images/png"
    elif data_file.endswith(".jpeg") or data_file.endswith(".jpg"):
        request.properties["content-type"] = "images/jpeg"
    elif data_file.endswith(".ndlist"):
        request.properties["content-type"] = "tensor/ndlist"
    elif data_file.endswith(".npz"):
        request.properties["content-type"] = "tensor/npz"

    return request


def create_text_request(text: str, key: str = None) -> Input:
    request = Input()
    request.properties["device_id"] = "-1"
    request.properties["content-type"] = "text/plain"
    request.content.add(key=key, value=text.encode("utf-8"))
    return request


def create_numpy_request(list, key: str = None) -> Input:
    request = Input()
    request.properties["device_id"] = "-1"
    request.properties["content-type"] = "tensor/ndlist"
    request.content.add(key=key, value=to_nd_list(list))
    return request


def create_npz_request(list, key: str = None) -> Input:
    import io
    import numpy as np
    request = Input()
    request.properties["device_id"] = "-1"
    request.properties["content-type"] = "tensor/npz"
    memory_file = io.BytesIO()
    np.savez(memory_file, *list)
    memory_file.seek(0)
    request.content.add(key=key, value=memory_file.read(-1))
    return request


def _extract_output(outputs: Output) -> Input:
    inputs = Input()
    inputs.properties = outputs.properties
    inputs.content = outputs.content
    return inputs


def extract_output_as_bytes(outputs: Output, key=None):
    return _extract_output(outputs).get_as_bytes(key)


def extract_output_as_numpy(outputs: Output, key=None):
    return _extract_output(outputs).get_as_numpy(key)


def extract_output_as_npz(outputs: Output, key=None):
    return _extract_output(outputs).get_as_npz(key)


def extract_output_as_string(outputs: Output, key=None):
    return _extract_output(outputs).get_as_string(key)


def run():
    logging.basicConfig(stream=sys.stdout,
                        format="%(message)s",
                        level=logging.INFO)
    args = ArgParser.test_model_args().parse_args()

    inputs = create_request(args.input, args.parameters)
    inputs.function_name = args.handler

    os.chdir(args.model_dir)
    model_dir = os.getcwd()
    sys.path.append(model_dir)

    entry_point = args.entry_point
    service = load_model_service(model_dir, entry_point, "-1")

    function_name = inputs.get_function_name()
    outputs = service.invoke_handler(function_name, inputs)
    print("output: " + str(outputs))


if __name__ == "__main__":
    run()




© 2015 - 2024 Weber Informatics LLC | Privacy Policy