ai.h2o.automl.ModelSelectionStrategies Maven / Gradle / Ivy
package ai.h2o.automl;
import ai.h2o.automl.leaderboard.Leaderboard;
import hex.Model;
import org.apache.log4j.Logger;
import water.Key;
import water.util.ArrayUtils;
import java.util.Arrays;
import java.util.function.Predicate;
import java.util.function.Supplier;
public final class ModelSelectionStrategies {
private static final Logger LOG = Logger.getLogger(ModelSelectionStrategies.class);
public static abstract class LeaderboardBasedSelectionStrategy implements ModelSelectionStrategy {
final Supplier _leaderboardSupplier;
public LeaderboardBasedSelectionStrategy(Supplier leaderboardSupplier) {
_leaderboardSupplier = leaderboardSupplier;
}
LeaderboardHolder makeSelectionLeaderboard() {
return _leaderboardSupplier.get();
}
}
public static class KeepBestN extends LeaderboardBasedSelectionStrategy{
private final int _N;
public KeepBestN(int N, Supplier leaderboardSupplier) {
super(leaderboardSupplier);
_N = N;
}
@Override
@SuppressWarnings("unchecked")
public Selection select(Key[] originalModels, Key[] newModels) {
LeaderboardHolder lbHolder = makeSelectionLeaderboard();
Leaderboard tmpLeaderboard = lbHolder.get();
tmpLeaderboard.addModels((Key[]) originalModels);
tmpLeaderboard.addModels((Key[]) newModels);
if (LOG.isDebugEnabled()) LOG.debug(tmpLeaderboard.toLogString());
Key[] sortedKeys = tmpLeaderboard.getModelKeys();
Key[] bestN = ArrayUtils.subarray(sortedKeys, 0, Math.min(sortedKeys.length, _N));
Key[] toAdd = Arrays.stream(bestN).filter(k -> !ArrayUtils.contains(originalModels, k)).toArray(Key[]::new);
Key[] toRemove = Arrays.stream(originalModels).filter(k -> !ArrayUtils.contains(bestN, k)).toArray(Key[]::new);
Selection selection = new Selection<>(toAdd, toRemove);
lbHolder.cleanup();
return selection;
}
}
public static class KeepBestConstantSize extends LeaderboardBasedSelectionStrategy {
public KeepBestConstantSize(Supplier leaderboardSupplier) {
super(leaderboardSupplier);
}
@Override
public Selection select(Key[] originalModels, Key[] newModels) {
return new KeepBestN(originalModels.length, _leaderboardSupplier).select(originalModels, newModels);
}
}
public static class KeepBestNFromSubgroup extends LeaderboardBasedSelectionStrategy {
private final Predicate> _criterion;
private final int _N;
public KeepBestNFromSubgroup(int N, Predicate> criterion, Supplier leaderboardSupplier) {
super(leaderboardSupplier);
_criterion = criterion;
_N = N;
}
@Override
public Selection select(Key[] originalModels, Key[] newModels) {
Key[] originalModelsSubgroup = Arrays.stream(originalModels).filter(_criterion).toArray(Key[]::new);
Key[] newModelsSubGroup = Arrays.stream(newModels).filter(_criterion).toArray(Key[]::new);
return new KeepBestN(_N, _leaderboardSupplier).select(originalModelsSubgroup, newModelsSubGroup);
}
}
public interface LeaderboardHolder {
Leaderboard get();
default void cleanup() {};
}
}