All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.automl.modeling.CompletionStepsProvider Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package ai.h2o.automl.modeling;

import ai.h2o.automl.*;
import ai.h2o.automl.ModelingStep.DynamicStep;
import hex.Model;
import hex.grid.Grid;
import hex.grid.HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria;
import hex.leaderboard.Leaderboard;
import water.Job;
import water.Key;

import java.util.*;
import java.util.stream.Collectors;

public class CompletionStepsProvider implements ModelingStepsProvider {

    public static class CompletionSteps extends ModelingSteps {

        static final String NAME = "completion";
        
        static class ResumingGridStep extends ModelingStep.GridStep {
            
            private transient GridStep _step;
            
            public ResumingGridStep(GridStep step, int priorityGroup, int weight, AutoML aml) {
                super(NAME, step.getAlgo(), step.getProvider()+"_"+step.getId(), priorityGroup, weight, aml);
                _work = makeWork();
                _step = step;
            }

            @Override
            public boolean canRun() {
                return _step != null && _weight > 0;
            }

            @Override
            public Model.Parameters prepareModelParameters() {
                return _step.prepareModelParameters();
            }

            @Override
            public Map prepareSearchParameters() {
                return _step.prepareSearchParameters();
            }

            @Override
            protected void setSearchCriteria(RandomDiscreteValueSearchCriteria searchCriteria, Model.Parameters baseParms) {
                super.setSearchCriteria(searchCriteria, baseParms);
                searchCriteria.set_stopping_rounds(0);
            }

            @Override
            @SuppressWarnings("unchecked")
            protected Job startJob() {
                Key[] resumedGrid = aml().session().getResumableKeys(_step.getProvider(), _step.getId());
                if (resumedGrid.length == 0) return null;
                return hyperparameterSearch(resumedGrid[0], prepareModelParameters(), prepareSearchParameters());
            }
        }
        
        static class ResumeBestNGridsStep extends DynamicStep {
            
            private final int _nGrids;
            
            public ResumeBestNGridsStep(String id, int nGrids, AutoML autoML) {
                super(NAME, id, autoML);
                _nGrids = nGrids;
            }
            
            private List sortModelingStepByPerf() {
                Map> scoresBySource = new HashMap<>();
                Model[] models = getTrainedModels();
                double[] metrics = aml().leaderboard().getSortMetricValues();
                if (metrics == null) return Collections.emptyList();
                for (int i = 0; i < models.length; i++) {
                    ModelingStep source = aml().session().getModelingStep(models[i]._key);
                    if (!scoresBySource.containsKey(source)) {
                        scoresBySource.put(source, new ArrayList<>());
                    }
                    scoresBySource.get(source).add(metrics[i]);
                }
                Comparator> metricsComparator = Map.Entry.comparingByValue();
                if (!Leaderboard.isLossFunction(aml().leaderboard().getSortMetric())) metricsComparator = metricsComparator.reversed();
                return scoresBySource.entrySet().stream()
                        .collect(Collectors.toMap(
                                Map.Entry::getKey,
                                e -> e.getValue().stream().mapToDouble(Double::doubleValue).average().orElse(-1)
                        ))
                        .entrySet().stream().sorted(metricsComparator)
                        .filter(e -> e.getValue() >= 0)
                        .map(Map.Entry::getKey)
                        .collect(Collectors.toList());
            }

            @Override
            protected Collection prepareModelingSteps() {
                List bestStep = sortModelingStepByPerf();
                return bestStep.stream()
                        .filter(ModelingStep::isResumable)
                        .filter(GridStep.class::isInstance)
//                        .map(s -> aml().getModelingStep(s.getProvider(), s.getId()+"_resume"))
//                        .filter(Objects::nonNull)
                        .limit(_nGrids)
                        .map(s -> new ResumingGridStep((GridStep)s, _priorityGroup, _weight/_nGrids, aml()))
                        .collect(Collectors.toList());
            }
        }
        
        private final ModelingStep[] optionals = new ModelingStep[] {
                new ResumeBestNGridsStep("resume_best_grids", 2, aml())
        };
        
        public CompletionSteps(AutoML autoML) {
            super(autoML);
        }

        @Override
        public String getProvider() {
            return NAME;
        }

        @Override
        protected ModelingStep[] getOptionals() {
            return optionals;
        }
    }
    
    @Override
    public String getName() {
        return CompletionSteps.NAME;
    }

    @Override
    public CompletionSteps newInstance(AutoML aml) {
        return new CompletionSteps(aml);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy