All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.ensemble.StackedEnsembleMojoReader Maven / Gradle / Ivy

There is a newer version: 3.46.0.5
Show newest version
package hex.genmodel.algos.ensemble;

import hex.genmodel.MojoModel;
import hex.genmodel.MultiModelMojoReader;

public class StackedEnsembleMojoReader extends MultiModelMojoReader {

    @Override
    public String getModelName() {
        return "StackedEnsemble";
    }

    @Override
    protected void readParentModelData() {
        int baseModelNum = readkv("base_models_num", 0);
        _model._baseModelNum = baseModelNum;
        _model._metaLearner = getModel((String) readkv("metalearner"));

        final String metaLearnerTransform = readkv("metalearner_transform", "NONE");
        if (!metaLearnerTransform.equals("NONE") && !metaLearnerTransform.equals("Logit"))
            throw new UnsupportedOperationException("Metalearner Transform \"" + metaLearnerTransform + "\" is not supported!");
        _model._useLogitMetaLearnerTransform = metaLearnerTransform.equals("Logit");

        _model._baseModels = new StackedEnsembleMojoModel.StackedEnsembleMojoSubModel[baseModelNum];
        final String[] columnNames = readkv("[columns]");
        for (int i = 0; i < baseModelNum; i++) {
            String modelKey = readkv("base_model" + i);
            if (modelKey == null)
                continue;
            final MojoModel model = getModel(modelKey);
            _model._baseModels[i] = new StackedEnsembleMojoModel.StackedEnsembleMojoSubModel(model,
                    createMapping(model, columnNames, modelKey));
        }
    }

    /**
     * Creates an array of integers with mapping of referential column name space into model-specific column name space.
     *
     * @param model     Model to create column mapping for
     * @param reference Column mapping serving as a reference
     * @param modelName Name of the model for various error reports
     * @return An array of integers with representing the mapping.
     */
    private static int[] createMapping(final MojoModel model, final String[] reference, final String modelName) {
        String[] features = model.features();
        int[] mapping = new int[features.length];
        for (int i = 0; i < mapping.length; i++) {
            String feature = features[i];
            mapping[i] = findColumnIndex(reference, feature);
            if (mapping[i] < 0) {
                throw new IllegalStateException(String.format("Model '%s' does not have input column '%s'",
                        modelName, feature));
            }
        }
        return mapping;
    }

    private static int findColumnIndex(String[] arr, String searchedColname) {
        for (int i = 0; i < arr.length; i++) {
            if (arr[i].equals(searchedColname)) return i;
        }
        return -1;
    }

    @Override
    protected StackedEnsembleMojoModel makeModel(String[] columns, String[][] domains, String responseColumn) {
        return new StackedEnsembleMojoModel(columns, domains, responseColumn);
    }

    @Override public String mojoVersion() {
        return "1.01";
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy