All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.automl.ModelSelectionStrategies Maven / Gradle / Ivy

The newest version!
package ai.h2o.automl;

import hex.Model;
import hex.leaderboard.Leaderboard;
import org.apache.log4j.Logger;
import water.Key;
import water.util.ArrayUtils;

import java.util.Arrays;
import java.util.function.Predicate;
import java.util.function.Supplier;

public final class ModelSelectionStrategies {
    
    private static final Logger LOG = Logger.getLogger(ModelSelectionStrategies.class);

    public static abstract class LeaderboardBasedSelectionStrategy implements ModelSelectionStrategy {

        final Supplier _leaderboardSupplier;

        public LeaderboardBasedSelectionStrategy(Supplier leaderboardSupplier) {
            _leaderboardSupplier = leaderboardSupplier;
        }

        LeaderboardHolder makeSelectionLeaderboard() {
            return _leaderboardSupplier.get();
        }
    }

    public static class KeepBestN extends LeaderboardBasedSelectionStrategy{

        private final int _N;

        public KeepBestN(int N, Supplier leaderboardSupplier) {
            super(leaderboardSupplier);
            _N = N;
        }

        @Override
        @SuppressWarnings("unchecked")
        public Selection select(Key[] originalModels, Key[] newModels) {
            LeaderboardHolder lbHolder = makeSelectionLeaderboard();
            Leaderboard tmpLeaderboard = lbHolder.get();
            tmpLeaderboard.addModels((Key[]) originalModels);
            tmpLeaderboard.addModels((Key[]) newModels);
            if (LOG.isDebugEnabled()) LOG.debug(tmpLeaderboard.toLogString());
            Key[] sortedKeys = tmpLeaderboard.getModelKeys();
            Key[] bestN = ArrayUtils.subarray(sortedKeys, 0, Math.min(sortedKeys.length, _N));
            Key[] toAdd = Arrays.stream(bestN).filter(k -> !ArrayUtils.contains(originalModels, k)).toArray(Key[]::new);
            Key[] toRemove = Arrays.stream(originalModels).filter(k -> !ArrayUtils.contains(bestN, k)).toArray(Key[]::new);
            Selection selection = new Selection<>(toAdd, toRemove);
            lbHolder.cleanup();
            return selection;
        }
    }

    public static class KeepBestConstantSize extends LeaderboardBasedSelectionStrategy {

        public KeepBestConstantSize(Supplier leaderboardSupplier) {
            super(leaderboardSupplier);
        }

        @Override
        public Selection select(Key[] originalModels, Key[] newModels) {
            return new KeepBestN(originalModels.length, _leaderboardSupplier).select(originalModels, newModels);
        }
    }

    public static class KeepBestNFromSubgroup extends LeaderboardBasedSelectionStrategy {

        private final Predicate> _criterion;
        private final int _N;

        public KeepBestNFromSubgroup(int N, Predicate> criterion, Supplier leaderboardSupplier) {
            super(leaderboardSupplier);
            _criterion = criterion;
            _N = N;
        }

        @Override
        public Selection select(Key[] originalModels, Key[] newModels) {
            Key[] originalModelsSubgroup = Arrays.stream(originalModels).filter(_criterion).toArray(Key[]::new);
            Key[] newModelsSubGroup = Arrays.stream(newModels).filter(_criterion).toArray(Key[]::new);
            return new KeepBestN(_N, _leaderboardSupplier).select(originalModelsSubgroup, newModelsSubGroup);
        }
    }

    public interface LeaderboardHolder {
        Leaderboard get();
        default void cleanup() {};
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy