
templates.general-mlflow.training.driver.py.vm Maven / Gradle / Ivy
"""
Driver to run this pipeline.
GENERATED STUB CODE - PLEASE ***DO*** MODIFY
Originally generated from: ${templateName}
"""
import os
import sys
#foreach($step in $steps)
from ${packageFolderName}.impl.${step.getLowercaseSnakeCaseName()} import ${step.capitalizedName}
#end
from ${packageFolderName}.impl.${pipeline.getSnakeCaseName()} import ${pipeline.capitalizedName}
#if ($autoTrain)
from ${packageFolderName}.config.pipeline_config import PipelineConfig
from kafka import KafkaConsumer
#end
#if ($pipeline.trainingStep.isModelLineageEnabled())
from ${packageFolderName}.generated.pipeline.pipeline_base import PipelineBase
#end
if __name__ == "__main__":
if os.getenv("MODE") == "no-op":
print("Training job successfully registered.")
sys.exit()
#if ($pipeline.trainingStep.isModelLineageEnabled())
PipelineBase().record_pipeline_lineage_start_event()
#end
#foreach($step in $steps)
${step.capitalizedName}().execute_step_impl()
#end
#if ($autoTrain)
kafka_server = PipelineConfig().kafka_server()
training_alert_consumer = KafkaConsumer('${pipeline.trainingStep.inbound.channelName}', group_id='Train', bootstrap_servers=[kafka_server],
auto_offset_reset='earliest', api_version=(2, 0, 2), max_poll_records=100,
max_poll_interval_ms=500000, enable_auto_commit=False)
pipeline = ${pipeline.capitalizedName}()
print('Waiting for training alert from Kafka (%s)...' % kafka_server)
for alert in training_alert_consumer:
pipeline.acknowledge_training_alert(alert)
pipeline.run()
#else
${pipeline.capitalizedName}().run()
#end
#if ($pipeline.trainingStep.isModelLineageEnabled())
PipelineBase().record_pipeline_lineage_complete_event()
#end
© 2015 - 2025 Weber Informatics LLC | Privacy Policy