templates.general-mlflow.training.base.py.vm Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of foundation-mda Show documentation
Show all versions of foundation-mda Show documentation
Model driven architecture artifacts for aiSSEMBLE
The newest version!
"""
Base implementation of this pipeline.
GENERATED CODE - DO NOT MODIFY (add your customizations in ${pipeline.capitalizedName}).
Generated from: ${templateName}
"""
from abc import ABC, abstractmethod
from typing import List, NamedTuple
from pandas import DataFrame
from ..config.pipeline_config import PipelineConfig
from datetime import datetime
import mlflow
import json
from aissemble_security.pdp_client import PDPClient
from aissembleauth.auth_config import AuthConfig
from pathlib import Path
from krausening.logging import LogManager
#if ($pipeline.trainingStep.postActions)
#foreach ($postAction in $pipeline.trainingStep.postActions)
from ..post_action.${postAction.snakeCaseName} import ${postAction.capitalizedName}
#end
#end
from .pipeline.pipeline_base import PipelineBase
#if ($pipeline.trainingStep.isModelLineageEnabled())
from uuid import uuid4
from aissemble_data_lineage import Run, Job, Emitter, RunEvent, InputDataset, from_open_lineage_facet, LineageUtil, LineageEventData
from aissemble_model_lineage import MLflowRunFacet, LineageBuilder
from aissemble_core_config import MessagingConfig
import os
import attr
from openlineage.client.facet import ErrorMessageRunFacet, NominalTimeRunFacet, ParentRunFacet
from platform import python_version
#end
class DatasetSplits(NamedTuple):
"""
Class to store the training and test splits of a dataset.
The splits are of type any, to allow for custom implementation
for handling any number of datasets per split.
"""
train: any
test: any
class ${pipeline.capitalizedName}Base(ABC):
"""
Base implementation of the pipeline.
"""
logger = LogManager.get_instance().get_logger('${pipeline.capitalizedName}')
def __init__(self, experiment_name):
"""
Default initializations for the pipeline.
"""
# set default mlflow configurations
self.config = PipelineConfig()
mlflow.set_tracking_uri(self.config.mlflow_tracking_uri())
mlflow.set_experiment(experiment_name)
#if ($pipeline.trainingStep.isModelLineageEnabled())
self.messaging_config = MessagingConfig()
self.emitter = Emitter()
self.lineage_builder = LineageBuilder()
self.lineage_util = LineageUtil()
#end
#if ($autoTrain)
@abstractmethod
def acknowledge_training_alert(self, alert: any) -> None:
"""
Method to acknowledge a training alert for auto-training purposes.
"""
pass
#end
@abstractmethod
def load_dataset(self) -> DataFrame:
"""
Method to load a dataset for training.
Returns a dataset of type DataFrame.
"""
pass
@abstractmethod
def prep_dataset(self, dataset: DataFrame) -> DataFrame:
"""
Method to perform last-mile data preparation on the loaded dataset.
Returns the prepped dataset.
"""
pass
@abstractmethod
def select_features(self, dataset: DataFrame) -> List[str]:
"""
Method to perform feature selection on the prepped dataset.
Returns a list of the features (columns) selected from the dataset.
"""
pass
@abstractmethod
def split_dataset(self, dataset: DataFrame) -> DatasetSplits:
"""
Method to create the train and test splits on the dataset with selected features.
Returns the splits within a DatasetSplits object.
"""
pass
@abstractmethod
def train_model(self, train_dataset: any) -> any:
"""
Method to train a model with the training dataset split(s).
Returns the model that has been trained.
"""
pass
@abstractmethod
def evaluate_model(self, model: any, test_dataset: any) -> float:
"""
Method to evaluate the trained model with the test dataset split(s).
Returns the score of the model evaluation.
"""
pass
@abstractmethod
def save_model(self, model: any) -> None:
"""
Method to save the model to a location.
"""
pass
@abstractmethod
def deploy_model(self, score: float, model: any) -> None:
"""
Method to deploy the model if needed.
"""
pass
def run(self):
"""
Runs the pipeline.
"""
self.logger.info('Running %s...' % type(self).__name__)
#if ($pipeline.trainingStep.isModelLineageEnabled())
run_id = uuid4()
parent_run_facet = PipelineBase().get_pipeline_run_as_parent_run_facet()
job_name= self.get_job_name()
#end
try:
with mlflow.start_run() as run:
self.training_run_id = run.info.run_uuid
#if ($pipeline.trainingStep.isModelLineageEnabled())
# pylint: disable-next=assignment-from-none
event_data = self.create_base_lineage_event_data()
default_namespace = self.get_default_namespace()
#end
start = datetime.utcnow()
#if ($pipeline.trainingStep.isModelLineageEnabled())
self.record_lineage(self.create_lineage_start_event(run_id=run_id, job_name=job_name, default_namespace=default_namespace, parent_run_facet=parent_run_facet, event_data=event_data, start_time=start))
#end
loaded_dataset = self.load_dataset()
prepped_dataset = self.prep_dataset(loaded_dataset)
features = self.select_features(prepped_dataset)
splits = self.split_dataset(prepped_dataset[features])
model = self.train_model(splits.train)
score = self.evaluate_model(model, splits.test)
self.save_model(model)
self.deploy_model(score, model)
#if ($pipeline.trainingStep.postActions)
self.apply_post_actions(self.training_run_id, model)
#end
end = datetime.utcnow()
self.log_information(start, end, loaded_dataset, features)
self.logger.info('Complete')
#if ($pipeline.trainingStep.isModelLineageEnabled())
self.record_lineage(self.create_lineage_complete_event(run_id=run_id, job_name=job_name, default_namespace=default_namespace, parent_run_facet=parent_run_facet, event_data=event_data, start_time=start, end_time=end))
#end
except Exception as error:
#if ($pipeline.trainingStep.isModelLineageEnabled())
self.record_lineage(self.create_lineage_fail_event(run_id=run_id, job_name=job_name, event_data=event_data, default_namespace=default_namespace, parent_run_facet=parent_run_facet, start_time=start, end_time=datetime.now(), error=error))
PipelineBase().record_pipeline_lineage_fail_event()
#end
raise Exception(error)
#if ($pipeline.trainingStep.postActions)
def apply_post_actions(self, training_run_id: str, model: any) -> None:
"""
Applies the post actions specified for the training.
"""
#foreach ($postAction in $pipeline.trainingStep.postActions)
postAction${postAction.capitalizedName} = ${postAction.capitalizedName}(training_run_id, model)
postAction${postAction.capitalizedName}.apply()
#end
#end
#if ($pipeline.trainingStep.isModelLineageEnabled())
def create_base_lineage_event_data(self) -> LineageEventData:
"""
Create a base lineage event data that will included in all the step events
Returns
LineageEventData
"""
job_facets = {
"documentation": from_open_lineage_facet(self.lineage_builder.get_documentation_job_facet()),
"ownership": from_open_lineage_facet(self.lineage_builder.get_ownership_job_facet()),
"sourceCodeLocation": from_open_lineage_facet(self.lineage_builder.get_source_code_directory_job_facet())
}
run_facets = {
"hardwareDetails": from_open_lineage_facet(self.lineage_builder.get_hardware_details_run_facet()),
"hyperparameters": from_open_lineage_facet(self.lineage_builder.get_hyperparameter_run_facet()),
"mlflowRunId": from_open_lineage_facet(MLflowRunFacet(self.training_run_id)),
"performanceMetrics": from_open_lineage_facet(self.lineage_builder.get_performance_metric_run_facet())
}
dataset_facets = {
"dataSource": from_open_lineage_facet(self.lineage_builder.get_data_source_dataset_facet()),
"dataQualityAssertions": from_open_lineage_facet(self.lineage_builder.get_data_quality_assertions_facet()),
"ownership": from_open_lineage_facet(self.lineage_builder.get_ownership_dataset_facet()),
"schema": from_open_lineage_facet(self.lineage_builder.get_schema_dataset_facet()),
"storage": from_open_lineage_facet(self.lineage_builder.get_storage_dataset_facet())
}
input_dataset = InputDataset("${pipeline.capitalizedName}Input", dataset_facets)
return LineageEventData(job_facets=job_facets, run_facets=run_facets, event_inputs=[input_dataset])
def create_lineage_start_event(self, run_id: str = None, job_name: str = "", default_namespace:str = None, parent_run_facet: ParentRunFacet = None, event_data: LineageEventData = None, **kwargs) -> RunEvent:
"""
Creates the Start RunEvent with given uuid, parent run facet, job name, lineage data event or any input parameters
To customize the event, override the customize_lineage_start_event(...) function to include the job facets, run facets
or the inputs/outputs dataset.
The customize_run_event(..) is deprecated customize point.
Returns:
RunEvent created from the input arguments
"""
event = self.lineage_util.create_start_run_event(
run_id=run_id,
parent_run_facet=parent_run_facet,
job_name=job_name,
default_namespace=default_namespace,
event_data=event_data)
event = self.customize_lineage_start_event(event, **kwargs)
return self.customize_run_event(event)
def customize_lineage_start_event(self, event: RunEvent = None, **kwargs) -> RunEvent:
"""
Customize the start event with the given input
Returns
lineage event
"""
if "start_time" in kwargs:
run_facets = {
"nominalTime": from_open_lineage_facet(NominalTimeRunFacet(kwargs["start_time"].isoformat(timespec="milliseconds") + "Z"))
}
event.run.facets.update(run_facets)
return event
def create_lineage_complete_event(self, run_id: str = None, job_name: str = "", default_namespace:str = None, parent_run_facet: ParentRunFacet = None, event_data: LineageEventData = None, **kwargs) -> RunEvent:
"""
Creates the Complete RunEvent with given uuid, parent run facet, job name, lineage data event or any input parameters
To customize the event, override the customize_lineage_complete_event(...) function to include the job facets, run facets
or the inputs/outputs dataset.
The customize_run_event(...) is deprecated customize point.
Returns:
RunEvent created from the input arguments
"""
event = self.lineage_util.create_complete_run_event(
run_id=run_id,
parent_run_facet=parent_run_facet,
job_name=job_name,
default_namespace=default_namespace,
event_data=event_data)
event = self.customize_lineage_complete_event(event, **kwargs)
return self.customize_run_event(event)
def customize_lineage_complete_event(self, event: RunEvent = None, **kwargs) -> RunEvent:
"""
Customize the complete event with the given input
Returns
lineage event
"""
if "start_time" in kwargs and "end_time" in kwargs:
event.run.facets.update(self.record_run_end(kwargs["start_time"], kwargs["end_time"]))
return event
def create_lineage_fail_event(self, run_id: str = None, job_name: str = "", default_namespace:str = None, parent_run_facet: ParentRunFacet = None, event_data: LineageEventData = None, **kwargs) -> RunEvent:
"""
Creates the Fail RunEvent with given uuid, parent run facet, job name, lineage data event or any input parameters
To customize the event, override the customize_lineage_fail_event(...) function to include the job facets, run facets
or the inputs/outputs dataset.
The customize_run_event() is deprecated customize point.
Returns:
RunEvent created from the input arguments
"""
event = self.lineage_util.create_fail_run_event(
run_id=run_id,
parent_run_facet=parent_run_facet,
job_name=job_name,
default_namespace=default_namespace,
event_data=event_data)
event = self.customize_lineage_fail_event(event, **kwargs)
return self.customize_run_event(event)
def customize_lineage_fail_event(self, event: RunEvent = None, **kwargs) -> RunEvent:
"""
Customize the fail event with the given input
Returns
lineage event
"""
if "start_time" in kwargs and "end_time" in kwargs and "error" in kwargs:
event.run.facets.update(self.record_run_end(kwargs["start_time"], kwargs["end_time"], kwargs["error"]))
return event
def customize_run_event(self, event: RunEvent) -> RunEvent:
"""
Override this method to modify the created RunEvent. Provides an opportunity for adding customizations,
such as Input or Output Datasets.
Returns:
RunEvent with customizations added.
"""
return event
def record_run_end(self, start_time: datetime, end_time: datetime, error: Exception = None) -> None:
"""
Records the end of the training run by updating the OpenLineage Run. The end of the run can be due to successful
completion of the logic or by an error.
:param start_time: The start time of the training execution.
:param end_time: The end time of the training execution.
:param error: The `Exception` that caused the run to fail, if applicable. `None` if the run was successful.
"""
run_end = { "nominalTime": from_open_lineage_facet(NominalTimeRunFacet(start_time.isoformat(timespec="milliseconds") + "Z", end_time.isoformat(timespec="milliseconds") + "Z"))}
if error:
run_end.update({"errorMessage": from_open_lineage_facet(ErrorMessageRunFacet(str(error), "Python"+python_version()))})
return run_end
def record_lineage(self, event: RunEvent):
"""
Records metadata for this step in an OpenLineage format.
"""
self.lineage_util.record_lineage(self.emitter, event)
def get_job_name(self) -> str:
"""
The default job name is the training step name; override this function to change the default job name
"""
return "${pipeline.capitalizedName}.${pipeline.trainingStep.name}"
def get_default_namespace(self) -> str:
"""
The default namespace is the Pipeline name. Override this function to change the default namespace.
"""
return "${pipeline.capitalizedName}"
#end
def set_dataset_origin(self, origin: str) -> None:
"""
Sets the origin of the dataset for a training run.
"""
if not origin:
self.logger.warning('No value given for dataset origin!')
self.dataset_origin = origin
def set_model_information(self, model_type: str, model_architecture: str) -> None:
"""
Sets the model information for a training run.
"""
if not model_type:
self.logger.warning('No value given for model type!')
if not model_architecture:
self.logger.warning('No value given for model architecture!')
self.model_type = model_type
self.model_architecture = model_architecture
def log_information(self, start: datetime, end: datetime, loaded_dataset: DataFrame, selected_features: List[str]) -> None:
"""
Log information into MLflow tags.
"""
try:
mlflow.set_tags(
{
"architecture": self.model_architecture,
"dataset_origin": self.dataset_origin,
"dataset_size": len(loaded_dataset),
"end_time": end,
"original_features": list(loaded_dataset),
"selected_features": selected_features,
"start_time": start,
"type": self.model_type,
}
)
except Exception as error:
raise Exception(error)
def authorize(self, token: str, action: str):
"""
Calls the Policy Decision Point server to authorize a jwt
"""
auth_config = AuthConfig()
pdp_client = PDPClient(auth_config.pdp_host_url())
decision = pdp_client.authorize(token, "", action)
return decision
© 2015 - 2025 Weber Informatics LLC | Privacy Policy