All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.boozallen.aiops.mda.generator.MlTrainingPyProjectGenerator Maven / Gradle / Ivy

package com.boozallen.aiops.mda.generator;

/*-
 * #%L
 * AIOps Foundation::AIOps MDA
 * %%
 * Copyright (C) 2021 Booz Allen
 * %%
 * This software package is licensed under the Booz Allen Public License. All Rights Reserved.
 * #L%
 */

import com.boozallen.aiops.mda.generator.common.VelocityProperty;
import com.boozallen.aiops.mda.generator.util.PipelineUtils;
import com.boozallen.aiops.mda.metamodel.element.Pipeline;
import com.boozallen.aiops.mda.metamodel.element.PostAction;
import com.boozallen.aiops.mda.metamodel.element.Step;
import com.boozallen.aiops.mda.metamodel.element.python.MachineLearningPipeline;
import com.boozallen.aiops.mda.metamodel.element.training.OnnxModelConversionPostAction;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.velocity.VelocityContext;
import org.technologybrewery.fermenter.mda.generator.GenerationContext;

import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

/**
 * A {@link TargetedPipelinePyProjectGenerator} that generates a {@code pyproject.toml} specifically for the training
 * step modules of machine learning pipelines.
 */
public class MlTrainingPyProjectGenerator extends TargetedPipelinePyProjectGenerator {
    /*--~-~-~~
     * Usages:
     * | Target                     | Template                          | Generated File  |
     * |----------------------------|-----------------------------------|-----------------|
     * | trainingPipelinePyProject  | general-mlflow/pyproject.toml.vm  | pyproject.toml  |
     */


    // additional requirements needed if there are onnx model conversion post
    // actions added to the training step
    private static final String ONNX_ML_TOOLS_DEPENDENCY = "onnxmltools = \"^1.11.1\"";
    private static final String ONNX_KERAS_DEPENDENCY = "tf2onnx = \"^1.12.1\"";

    /**
     * Populates the given {@link VelocityContext} with any necessary machine learning training pipeline specific
     * attributes, generates the project's {@code pyproject.toml}, and provides a notification to users if manual
     * modification to the {@code pyproject.toml} is needed.
     *
     * @param generationContext Fermenter generation context.
     * @param velocityContext   pre-populated context that contains commonly used attributes.
     * @param pipeline          targeted pipeline for which this generator is being applied.
     */
    @Override
    protected void doGenerateFile(GenerationContext generationContext, VelocityContext velocityContext, Pipeline pipeline) {
        MachineLearningPipeline mlPipeline = new MachineLearningPipeline(pipeline);
        Step trainingStep = mlPipeline.getTrainingStep();

        Set postActionDependencies = null;
        velocityContext.put(VelocityProperty.PIPELINE, mlPipeline);
        List postActions = trainingStep.getPostActions();
        if (CollectionUtils.isNotEmpty(postActions)) {
            postActionDependencies = getPostActionDependencies(postActions);
            velocityContext.put(VelocityProperty.POST_ACTION_REQUIREMENTS, postActionDependencies);
        }

        generateFile(generationContext, velocityContext);

        if (CollectionUtils.isNotEmpty(postActionDependencies)) {
            manualActionNotificationService.addNoticeToAddPythonDependencies(generationContext, postActionDependencies, "post action support");
        }
    }

    /**
     * Retrieves any additional Python package dependencies that are required to support the provided
     * {@link PostAction}s.
     *
     * @param postActions post actions for the targeted pipeline for which to retrieve any additional
     *                    needed dependencies.
     * @return additional Python package dependencies needed to support the provided {@link PostAction}s,
     * formatted as Poetry {@code pyproject.toml} dependency specifications.
     */
    protected Set getPostActionDependencies(List postActions) {
        Set postActionRequirements = new LinkedHashSet<>();
        for (PostAction postAction : postActions) {
            if (PipelineUtils.forOnnxModelConversion(postAction)) {
                postActionRequirements.add(ONNX_ML_TOOLS_DEPENDENCY);

                String modelSource = postAction.getModelSource();
                if (OnnxModelConversionPostAction.KERAS.equals(modelSource)) {
                    postActionRequirements.add(ONNX_KERAS_DEPENDENCY);
                }
            }
        }

        return postActionRequirements;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy