All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.boozallen.aiops.mda.metamodel.element.training.OnnxModelConversionPostAction Maven / Gradle / Ivy

There is a newer version: 1.12.1
Show newest version
package com.boozallen.aiops.mda.metamodel.element.training;

/*-
 * #%L
 * AIOps Foundation::AIOps MDA
 * %%
 * Copyright (C) 2021 Booz Allen
 * %%
 * This software package is licensed under the Booz Allen Public License. All Rights Reserved.
 * #L%
 */

import java.util.HashMap;
import java.util.Map;

import org.technologybrewery.fermenter.mda.generator.GenerationException;

import com.boozallen.aiops.mda.metamodel.element.PostAction;

/**
 * Decorates {@link PostAction} with model-conversion functionality specific to
 * ONNX model conversion.
 */
public class OnnxModelConversionPostAction extends ModelConversionPostAction {

    // supported onnx conversions
    public static final String SKLEARN = "sklearn";
    public static final String KERAS = "keras";

    // onnx conversion parameters
    private static final String NAME_PARAM = "name";
    private static final String INITIAL_TYPES_PARAM = "initial_types";
    private static final String DOC_STRING_PARAM = "doc_string";
    private static final String TARGET_OPSET_PARAM = "target_opset";
    private static final String CUSTOM_CONVERSION_FUNCTIONS_PARAM = "custom_conversion_functions";
    private static final String CUSTOM_SHAPE_CALCULATORS_PARAM = "custom_shape_calculators";
    private static final String CHANNEL_FIRST_INPUTS_PARAM = "channel_first_inputs";
    private static final String DEFAULT_BATCH_SIZE_PARAM = "default_batch_size";

    // common default values
    private static final String NONE = "None";

    private static final Map defaultParameterValues;
    static {
        defaultParameterValues = new HashMap<>();
        defaultParameterValues.put(NAME_PARAM, NONE);
        defaultParameterValues.put(INITIAL_TYPES_PARAM, NONE);
        defaultParameterValues.put(DOC_STRING_PARAM, "''");
        defaultParameterValues.put(TARGET_OPSET_PARAM, NONE);
        defaultParameterValues.put(CUSTOM_CONVERSION_FUNCTIONS_PARAM, NONE);
        defaultParameterValues.put(CUSTOM_SHAPE_CALCULATORS_PARAM, NONE);
        defaultParameterValues.put(CHANNEL_FIRST_INPUTS_PARAM, NONE);
        defaultParameterValues.put(DEFAULT_BATCH_SIZE_PARAM, "1");
    }

    /**
     * {@inheritDoc}
     */
    public OnnxModelConversionPostAction(PostAction postActionToDecorate) {
        super(postActionToDecorate);
    }

    /**
     * Returns the onnx conversion parameters for converting the source model,
     * and whether each parameter is required or not.
     * 
     * @return onnx conversion parameters
     */
    public Map getOnnxConversionParameters() {
        Map parameters;

        String modelSource = getModelSource();
        if (SKLEARN.equals(modelSource)) {
            parameters = getSklearnParameters();
        } else if (KERAS.equals(modelSource)) {
            parameters = getKerasParameters();
        } else {
            throw new GenerationException(
                    "ONNX conversion for modelSource '" + modelSource + "' is currently not supported!");
        }

        return parameters;
    }

    /**
     * Returns the default value for the given parameter.
     * 
     * @param parameter
     * @return default parameter value
     */
    public String getDefaultParameterValue(String parameter) {
        return defaultParameterValues.get(parameter);
    }

    /**
     * Returns the onnx conversion method.
     * 
     * @return onnx conversion method
     */
    public String getOnnxConversionMethod() {
        return "convert_" + getModelSource();
    }

    private Map getSklearnParameters() {
        Map sklearnParameters = new HashMap<>();

        sklearnParameters.put(INITIAL_TYPES_PARAM, true);
        sklearnParameters.put(NAME_PARAM, false);
        sklearnParameters.put(DOC_STRING_PARAM, false);
        sklearnParameters.put(TARGET_OPSET_PARAM, false);
        sklearnParameters.put(CUSTOM_CONVERSION_FUNCTIONS_PARAM, false);
        sklearnParameters.put(CUSTOM_SHAPE_CALCULATORS_PARAM, false);

        return sklearnParameters;
    }

    private Map getKerasParameters() {
        Map kerasParameters = new HashMap<>();

        kerasParameters.put(INITIAL_TYPES_PARAM, false);
        kerasParameters.put(NAME_PARAM, false);
        kerasParameters.put(DOC_STRING_PARAM, false);
        kerasParameters.put(TARGET_OPSET_PARAM, false);
        kerasParameters.put(CUSTOM_CONVERSION_FUNCTIONS_PARAM, false);
        kerasParameters.put(CUSTOM_SHAPE_CALCULATORS_PARAM, false);
        kerasParameters.put(CHANNEL_FIRST_INPUTS_PARAM, false);
        kerasParameters.put(DEFAULT_BATCH_SIZE_PARAM, false);

        return kerasParameters;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy