All Downloads are FREE. Search and download functionalities are using the official Maven repository.

python.mlsql_model.py Maven / Gradle / Ivy

The newest version!
import tensorflow as tf
import os
import shutil
import pickle
import mlsql


def save_model(path, session, input_tensor, output_tensor, overwrite=False):
    if overwrite and os.path.exists(os.path.join(path, "saved_model.pb")):
        shutil.rmtree(path)
    signature = tf.saved_model.signature_def_utils.build_signature_def(
        inputs={'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
        outputs={'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
    )
    builder = tf.saved_model.builder.SavedModelBuilder(path)
    builder.add_meta_graph_and_variables(session,
                                         [tf.saved_model.tag_constants.SERVING],
                                         signature_def_map={
                                             tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
    builder.save()


# def keras_save_model(path, model, overwrite=False):
#     newpath = os.path.join(path, "saved_model.h5")
#     if overwrite and os.path.exists():
#         shutil.rmtree(path)
#     model.save(newpath)


def sk_save_model(model):
    isp = mlsql.params()["internalSystemParam"]
    tempModelLocalPath = isp["tempModelLocalPath"] if "tempModelLocalPath" in isp else "/tmp/"
    dir_name = tempModelLocalPath
    if os.path.exists(dir_name):
        shutil.rmtree(dir_name)
    os.makedirs(dir_name)
    with open(os.path.join(dir_name, "model.pickle"), "wb") as f:
        pickle.dump(model, f, protocol=2)




© 2015 - 2024 Weber Informatics LLC | Privacy Policy