ai.h2o.automl.leaderboard.Leaderboard Maven / Gradle / Ivy
package ai.h2o.automl.leaderboard;
import ai.h2o.automl.events.EventLog;
import ai.h2o.automl.events.EventLogEntry.Stage;
import ai.h2o.automl.utils.DKVUtils;
import hex.*;
import water.*;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.util.*;
import java.util.*;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Utility to track all the models built for a given dataset type.
*
* Note that if a new Leaderboard is made for the same project_name it'll
* keep using the old model list, which allows us to run AutoML multiple
* times and keep adding to the leaderboard.
*
* The models are returned sorted by either an appropriate default metric
* for the model category (auc, mean per class error, or mean residual deviance),
* or by a metric that's set via #setMetricAndDirection.
*
* TODO: make this robust against removal of models from the DKV.
*/
public class Leaderboard extends Lockable implements ModelContainer{
/**
* @param project_name
* @return a Leaderboard id for the project name
*/
public static String idForProject(String project_name) { return "Leaderboard_" + project_name; }
/**
* @param metric
* @return true iff the metric is a loss function
*/
public static boolean isLossFunction(String metric) {
return metric != null && !Arrays.asList("auc", "aucpr").contains(metric.toLowerCase());
}
/**
* Retrieves a leaderboard from DKV or creates a fresh one and add it to DKV.
*
* Note that if the leaderboard is reused to add new models, we have to use the same leaderboard frame.
*
* IMPORTANT!
* if the leaderboard is created without leaderboardFrame, the models will be sorted according to their default metrics
* (in order of availability: cross-validation metrics, validation metrics, training metrics).
* Therefore, if some models were trained with/without cross-validation, or with different training or validation frames,
* then we can't guarantee the fairness of the leaderboard ranking.
*
* @param projectName
* @param eventLog
* @param leaderboardFrame
* @param sortMetric
* @return an existing leaderboard if there's already one in DKV for this project, or a new leaderboard added to DKV.
*/
public static Leaderboard getOrMake(String projectName, EventLog eventLog, Frame leaderboardFrame, String sortMetric) {
Leaderboard leaderboard = DKV.getGet(Key.make(idForProject(projectName)));
if (null != leaderboard) {
if (leaderboardFrame != null
&& (!leaderboardFrame._key.equals(leaderboard._leaderboard_frame_key)
|| leaderboardFrame.checksum() != leaderboard._leaderboard_frame_checksum)) {
throw new H2OIllegalArgumentException("Cannot use leaderboard "+projectName+" with a new leaderboard frame"
+" (existing leaderboard frame: "+leaderboard._leaderboard_frame_key+").");
} else {
eventLog.warn(Stage.Workflow, "New models will be added to existing leaderboard "+projectName
+" (leaderboard frame="+leaderboard._leaderboard_frame_key+") with already "+leaderboard.getModelKeys().length+" models.");
}
if (sortMetric != null && !sortMetric.equals(leaderboard._sort_metric)) {
leaderboard._sort_metric = sortMetric.toLowerCase();
if (leaderboard.getLeader() != null) leaderboard.setDefaultMetrics(leaderboard.getLeader()); //reinitialize
}
} else {
leaderboard = new Leaderboard(projectName, eventLog, leaderboardFrame, sortMetric);
}
DKV.put(leaderboard);
return leaderboard;
}
/**
* Identifier for models that should be grouped together in the leaderboard
* (e.g., "airlines" and "iris").
*/
private final String _project_name;
/**
* List of models for this leaderboard, sorted by metric so that the best is first,
* according to the standard metric for the given model type.
*
* Updated inside addModels().
*/
private Key[] _model_keys = new Key[0];
/**
* Leaderboard/test set ModelMetrics objects for the models.
*
* Updated inside addModels().
*/
private final IcedHashMap, ModelMetrics> _leaderboard_model_metrics = new IcedHashMap<>();
/**
* Map providing for a given metric name, the list of metric values in the same order as the models
*/
private IcedHashMap _metric_values = new IcedHashMap<>();
private LeaderboardExtensionsProvider _extensionsProvider;
/**
* Map listing the leaderboard extensions per model
*/
private LeaderboardCell[] _extensions_cells = new LeaderboardCell[0];
/**
* Metric used to sort this leaderboard.
*/
private String _sort_metric;
/**
* Metrics reported in leaderboard
* Regression metrics: mean_residual_deviance, rmse, mse, mae, rmsle
* Binomial metrics: auc, logloss, aucpr, mean_per_class_error, rmse, mse
* Multinomial metrics: logloss, mean_per_class_error, rmse, mse
*/
private String[] _metrics;
/**
* The eventLog attached to same instance as this Leaderboard object.
*/
private final Key _eventlog_key;
/**
* Frame for which we return the metrics, by default.
*/
private final Key _leaderboard_frame_key;
/**
* Checksum for the Frame for which we return the metrics, by default.
*/
private final long _leaderboard_frame_checksum;
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
/**
* Constructs a new leaderboard (doesn't put it in DKV).
* @param projectName
* @param eventLog
* @param leaderboardFrame
* @param sortMetric
*/
public Leaderboard(String projectName, EventLog eventLog, Frame leaderboardFrame, String sortMetric) {
super(Key.make(idForProject(projectName)));
_project_name = projectName;
_eventlog_key = eventLog._key;
_leaderboard_frame_key = leaderboardFrame == null ? null : leaderboardFrame._key;
_leaderboard_frame_checksum = leaderboardFrame == null ? 0 : leaderboardFrame.checksum();
_sort_metric = sortMetric == null ? null : sortMetric.toLowerCase();
}
/**
* Assign a {@link LeaderboardExtensionsProvider} to this leaderboard instance.
* @param provider the provider used to generate the optional extension columns from the leaderboard.
* @see LeaderboardExtensionsProvider
*/
public void setExtensionsProvider(LeaderboardExtensionsProvider provider) {
_extensionsProvider = provider;
}
public String getProject() {
return _project_name;
}
/**
* If no sort metric is provided when creating the leaderboard,
* then a default sort metric will be automatically chosen based on the problem type:
*
* binomial classification: auc
* multinomial classification: logloss
* regression: mean_residual_deviance
*
* @return the metric used to sort the models in the leaderboard.
*/
public String getSortMetric() {
return _sort_metric;
}
/**
* The sort metric is always the first element in the list of metrics.
*
* @return the full list of metrics available in the leaderboard.
*/
public String[] getMetrics() {
return _metrics == null ? (_sort_metric == null ? new String[0] : new String[]{_sort_metric})
: _metrics;
}
/**
* Note: If no leaderboard was provided, then the models are sorted according to metrics obtained during training
* in the following priority order depending on availability:
*
* cross-validation metrics
* validation metrics
* training metrics
*
* @return the frame (if any) used to score the models in the leaderboard.
*/
public Frame leaderboardFrame() {
return _leaderboard_frame_key == null ? null : _leaderboard_frame_key.get();
}
/**
* @return list of keys of models sorted by the default metric for the model category, fetched from the DKV
*/
@Override
public Key[] getModelKeys() {
return _model_keys;
}
/** Return the number of models in this Leaderboard. */
@Override
public int getModelCount() { return getModelKeys() == null ? 0 : getModelKeys().length; }
/**
* @return list of models sorted by the default metric for the model category
*/
@Override
public Model[] getModels() {
if (getModelCount() == 0) return new Model[0];
return getModelsFromKeys(getModelKeys());
}
/**
* @return list of models sorted by the given metric
*/
public Model[] getModelsSortedByMetric(String metric) {
if (getModelCount() == 0) return new Model[0];
return getModelsFromKeys(sortModels(metric));
}
/**
* @return the model with the best sort metric value.
* @see #getSortMetric()
*/
public Model getLeader() {
if (getModelCount() == 0) return null;
return getModelKeys()[0].get();
}
/**
* @param modelKey
* @return the rank for the given model key, according to the sort metric ranking (leader has rank 1).
*/
public int getModelRank(Key modelKey) {
return ArrayUtils.find(getModelKeys(), modelKey) + 1;
}
/**
* @return the ordered values (asc or desc depending if sort metric is a loss function or not) for the sort metric.
* @see #getSortMetric()
* @see #isLossFunction(String)
*/
public double[] getSortMetricValues() {
return _sort_metric == null ? null : _metric_values.get(_sort_metric);
}
private EventLog eventLog() { return _eventlog_key.get(); }
private void setDefaultMetrics(Model m) {
write_lock();
String[] metrics = defaultMetricsForModel(m);
if (_sort_metric == null) {
_sort_metric = metrics.length > 0 ? metrics[0] : "mse"; // default to a metric "universally" available
}
// ensure metrics is ordered in such a way that sortMetric is the first metric, and without duplicates.
int sortMetricIdx = ArrayUtils.find(metrics, _sort_metric);
if (sortMetricIdx > 0) {
metrics = ArrayUtils.remove(metrics, sortMetricIdx);
metrics = ArrayUtils.prepend(metrics, _sort_metric);
} else if (sortMetricIdx < 0){
metrics = ArrayUtils.append(new String[]{_sort_metric}, metrics);
}
_metrics = metrics;
update();
unlock();
}
private ModelMetrics getOrCreateModelMetrics(Key modelKey) {
return getOrCreateModelMetrics(modelKey, getExtensionsAsMap());
}
private ModelMetrics getOrCreateModelMetrics(Key modelKey, Map, LeaderboardCell[]> extensions) {
final Frame leaderboardFrame = leaderboardFrame();
ModelMetrics mm;
Model model = modelKey.get();
if (leaderboardFrame == null) {
// If leaderboardFrame is null, use default model metrics instead
mm = ModelMetrics.defaultModelMetrics(model);
} else {
mm = ModelMetrics.getFromDKV(model, leaderboardFrame);
if (mm == null) { // metrics haven't been computed yet (should occur max once per model)
// optimization: as we need to score leaderboard, score from the scoring time extension if provided.
LeaderboardCell scoringTimePerRow = getExtension(modelKey, ScoringTimePerRow.COLUMN.getName(), extensions);
if (scoringTimePerRow != null && scoringTimePerRow.getValue() == null) {
scoringTimePerRow.fetch();
mm = ModelMetrics.getFromDKV(model, leaderboardFrame);
}
}
if (mm == null) { // last resort
//scores and magically stores the metrics where we're looking for it on the next line
model.score(leaderboardFrame).delete();
mm = ModelMetrics.getFromDKV(model, leaderboardFrame);
}
}
return mm;
}
/**
* Add the given models to the leaderboard.
* Note that to make this easier to use from Grid, which returns its models in random order,
* we allow the caller to add the same model multiple times and we eliminate the duplicates here.
* @param modelKeys
*/
public void addModels(final Key[] modelKeys) {
if (modelKeys == null || modelKeys.length == 0) return;
if (null == _key)
throw new H2OIllegalArgumentException("Can't add models to a Leaderboard which isn't in the DKV.");
final Key[] oldModelKeys = _model_keys;
final Key oldLeaderKey = (oldModelKeys == null || 0 == oldModelKeys.length) ? null : oldModelKeys[0];
// eliminate duplicates
final Set> uniques = new HashSet<>(Arrays.asList(ArrayUtils.append(oldModelKeys, modelKeys)));
final List> allModelKeys = new ArrayList<>(uniques);
final Set> newModelKeys = new HashSet<>(uniques);
newModelKeys.removeAll(Arrays.asList(oldModelKeys));
// In case we're just re-adding existing models
if (newModelKeys.isEmpty()) return;
allModelKeys.forEach(DKV::prefetch);
for (Key k : newModelKeys) {
Model m = k.get();
if (m == null) continue; // warning handled in next loop below
eventLog().debug(Stage.ModelTraining, "Adding model "+k+" to leaderboard "+_key+"."
+ " Training time: model=" + Math.round(m._output._run_time / 1000) + "s,"
+ " total=" + Math.round(m._output._total_run_time / 1000) + "s");
}
final List modelMetrics = new ArrayList<>();
final Map, LeaderboardCell[]> extensions = new HashMap<>();
final List> badKeys = new ArrayList<>();
for (Key modelKey : allModelKeys) { // fully rebuilding modelMetrics, so we loop through all keys, not only new ones
Model model = modelKey.get();
if (model == null) {
badKeys.add(modelKey);
eventLog().warn(Stage.ModelTraining, "Model `"+modelKey+"` has unexpectedly been deleted from H2O: ignoring the model and/or removing it from the leaderboard.");
continue;
}
if (_extensionsProvider != null) {
extensions.put(modelKey, _extensionsProvider.createExtensions(model));
}
ModelMetrics mm = getOrCreateModelMetrics(modelKey, extensions);
assert mm != null: "Missing metrics for model "+modelKey;
if (mm == null) {
badKeys.add(modelKey);
eventLog().warn(Stage.ModelTraining, "Metrics for model `"+modelKey+"` are missing: ignoring the model and/or removing it from the leaderboard.");
continue;
}
modelMetrics.add(mm);
}
if (_metrics == null) {
// lazily set to default for this model category
setDefaultMetrics(modelKeys[0].get());
}
for (Key key : badKeys) {
// keep everything clean for the update
allModelKeys.remove(key);
extensions.remove(key);
}
atomicUpdate(() -> {
_leaderboard_model_metrics.clear();
modelMetrics.forEach(this::addModelMetrics);
updateModels(allModelKeys.toArray(new Key[0]));
_extensions_cells = new LeaderboardCell[0];
extensions.forEach(this::addExtensions);
}, null);
if (oldLeaderKey == null || !oldLeaderKey.equals(_model_keys[0])) {
eventLog().info(Stage.ModelTraining,
"New leader: "+_model_keys[0]+", "+ _sort_metric +": "+ _metric_values.get(_sort_metric)[0]);
}
} // addModels
/**
* @param modelKeys the keys of the models to be removed from this leaderboard.
* @param cascade if true, the model itself and its dependencies will be completely removed from the backend.
*/
public void removeModels(final Key[] modelKeys, boolean cascade) {
if (modelKeys == null
|| modelKeys.length == 0
|| Arrays.stream(modelKeys).noneMatch(k -> ArrayUtils.contains(_model_keys, k))) return;
Arrays.stream(modelKeys).filter(k -> ArrayUtils.contains(_model_keys, k)).forEach(k -> {
eventLog().debug(Stage.ModelTraining, "Removing model "+k+" from leaderboard "+_key);
});
Key[] remainingKeys = Arrays.stream(_model_keys).filter(k -> !ArrayUtils.contains(modelKeys, k)).toArray(Key[]::new);
atomicUpdate(() -> {
_model_keys = new Key[0];
addModels(remainingKeys);
}, null);
if (cascade) {
for (Key key : modelKeys) {
Keyed.remove(key);
}
}
}
private void updateModels(Key[] modelKeys) {
final Key[] sortedModelKeys = sortModelKeys(modelKeys);
final Model[] sortedModels = getModelsFromKeys(sortedModelKeys);
final IcedHashMap metricValues = new IcedHashMap<>();
for (String metric : _metrics) {
metricValues.put(metric, getMetrics(metric, sortedModels));
}
_metric_values = metricValues;
_model_keys = sortedModelKeys;
}
private void atomicUpdate(Runnable update, Key jobKey) {
DKVUtils.atomicUpdate(this, update, jobKey, lock);
}
/**
* @see #addModels(Key[])
*/
@SuppressWarnings("unchecked")
public void addModel(final Key key) {
if (key == null) return;
addModels(new Key[] {key});
}
/**
* @param key the key of the model to be removed from the leaderboard.
* @param cascade if true, the model itself and it's dependencies will be completely removed from the backend.
*/
@SuppressWarnings("unchecked")
public void removeModel(final Key key, boolean cascade) {
if (key == null) return;
removeModels(new Key[] {key}, cascade);
}
private void addModelMetrics(ModelMetrics modelMetrics) {
if (modelMetrics != null) _leaderboard_model_metrics.put(modelMetrics._key, modelMetrics);
}
private void addExtensions(final Key key, LeaderboardCell... extensions) {
if (key == null) return;
assert ArrayUtils.contains(_model_keys, key);
LeaderboardCell[] toAdd = Stream.of(extensions)
.filter(lc -> getExtension(key, lc.getColumn().getName()) == null)
.toArray(LeaderboardCell[]::new);
_extensions_cells = ArrayUtils.append(_extensions_cells, toAdd);
}
private Map, LeaderboardCell[]> getExtensionsAsMap() {
return Arrays.stream(_extensions_cells).collect(Collectors.toMap(
c -> c.getModelId(),
c -> new LeaderboardCell[]{c},
(lhs, rhs) -> ArrayUtils.append(lhs, rhs)
));
}
private LeaderboardCell[] getExtensions(final Key key) {
return Stream.of(_extensions_cells)
.filter(c -> c.getModelId().equals(key))
.toArray(LeaderboardCell[]::new);
}
private LeaderboardCell getExtension(final Key key, String extName) {
return getExtension(key, extName, Collections.singletonMap((Key)key, getExtensions(key)));
}
private LeaderboardCell getExtension(final Key key, String extName, Map, LeaderboardCell[]> extensions) {
if (extensions != null && extensions.containsKey(key)) {
return Stream.of(extensions.get(key))
.filter(le -> le.getColumn().getName().equals(extName))
.findFirst()
.orElse(null);
}
return null;
}
private static Model[] getModelsFromKeys(Key[] modelKeys) {
Model[] models = new Model[modelKeys.length];
int i = 0;
for (Key modelKey : modelKeys)
models[i++] = DKV.getGet(modelKey);
return models;
}
/**
* @return list of keys of models sorted by the given metric, fetched from the DKV
*/
private Key[] sortModels(String metric) {
Key[] models = getModelKeys();
boolean decreasing = !isLossFunction(metric);
List> newModelsSorted = ModelMetrics.sortModelsByMetric(metric, decreasing, Arrays.asList(models));
return newModelsSorted.toArray(new Key[0]);
}
/**
* Sort by metric on the leaderboard/test set or default model metrics.
*/
private Key[] sortModelKeys(Key[] modelKeys) {
final List> sortedModelKeys;
boolean sortDecreasing = !isLossFunction(_sort_metric);
final Frame leaderboardFrame = leaderboardFrame();
try {
if (leaderboardFrame == null) {
sortedModelKeys = ModelMetrics.sortModelsByMetric(_sort_metric, sortDecreasing, Arrays.asList(modelKeys));
} else {
sortedModelKeys = ModelMetrics.sortModelsByMetric(leaderboardFrame, _sort_metric, sortDecreasing, Arrays.asList(modelKeys));
}
} catch (H2OIllegalArgumentException e) {
Log.warn("ModelMetrics.sortModelsByMetric failed: " + e);
throw e;
}
return sortedModelKeys.toArray(new Key[0]);
}
private double[] getMetrics(String metric, Model[] models) {
double[] metrics = new double[models.length];
int i = 0;
Frame leaderboardFrame = leaderboardFrame();
for (Model m : models) {
// If leaderboard frame exists, get metrics from there
if (leaderboardFrame != null) {
metrics[i++] = ModelMetrics.getMetricFromModelMetric(
_leaderboard_model_metrics.get(ModelMetrics.buildKey(m, leaderboardFrame)),
metric
);
} else {
// otherwise use default model metrics
Key model_key = m._key;
long model_checksum = m.checksum();
ModelMetrics mm = ModelMetrics.defaultModelMetrics(m);
metrics[i++] = ModelMetrics.getMetricFromModelMetric(
_leaderboard_model_metrics.get(ModelMetrics.buildKey(model_key, model_checksum, mm.frame()._key, mm.frame().checksum())),
metric
);
}
}
return metrics;
}
/**
* Delete object and its dependencies from DKV, including models.
*/
@Override
protected Futures remove_impl(Futures fs, boolean cascade) {
Log.debug("Cleaning up leaderboard from models "+Arrays.toString(_model_keys));
if (cascade) {
for (Key m : _model_keys) {
Keyed.remove(m, fs, true);
}
}
for (Key k : _leaderboard_model_metrics.keySet())
Keyed.remove(k, fs, true);
return super.remove_impl(fs, cascade);
}
private static String[] defaultMetricsForModel(Model m) {
if (m._output.isBinomialClassifier()) { //binomial
return new String[] {"auc", "logloss", "aucpr", "mean_per_class_error", "rmse", "mse"};
} else if (m._output.isMultinomialClassifier()) { // multinomial
return new String[] {"mean_per_class_error", "logloss", "rmse", "mse"};
} else if (m._output.isSupervised()) { // regression
return new String[] {"mean_residual_deviance", "rmse", "mse", "mae", "rmsle"};
}
return new String[0];
}
private double[] getModelMetricValues(int rank) {
assert rank >= 0 && rank < getModelKeys().length: "invalid rank";
if (_metrics == null) return new double[0];
final double[] values = new double[_metrics.length];
for (int i=0; i < _metrics.length; i++) {
values[i] = _metric_values.get(_metrics[i])[rank];
}
return values;
}
String rankTsv() {
String lineSeparator = "\n";
StringBuilder sb = new StringBuilder();
sb.append("Error").append(lineSeparator);
for (int i = getModelKeys().length - 1; i >= 0; i--) {
// TODO: allow the metric to be passed in. Note that this assumes the validation (or training) frame is the same.
sb.append(Arrays.toString(getModelMetricValues(i)));
sb.append(lineSeparator);
}
return sb.toString();
}
private TwoDimTable makeTwoDimTable(String tableHeader, int nrows, LeaderboardColumn... columns) {
assert columns.length > 0;
assert _sort_metric != null || nrows == 0 :
"sort_metrics needs to be always not-null for non-empty array!";
String description = nrows > 0 ? "models sorted in order of "+_sort_metric+", best first"
: "no models in this leaderboard";
String[] rowHeaders = new String[nrows];
for (int i = 0; i < nrows; i++) rowHeaders[i] = ""+i;
String[] colHeaders = Stream.of(columns).map(LeaderboardColumn::getName).toArray(String[]::new);
String[] colTypes = Stream.of(columns).map(LeaderboardColumn::getType).toArray(String[]::new);
String[] colFormats = Stream.of(columns).map(LeaderboardColumn::getFormat).toArray(String[]::new);
String colHeaderForRowHeader = nrows > 0 ? "#" : "-";
return new TwoDimTable(
tableHeader,
description,
rowHeaders,
colHeaders,
colTypes,
colFormats,
colHeaderForRowHeader
);
}
private void addTwoDimTableRow(TwoDimTable table, int row, String modelID, String[] metrics, LeaderboardCell[] extensions) {
int col = 0;
table.set(row, col++, modelID);
for (String metric : metrics) {
double value = _metric_values.get(metric)[row];
table.set(row, col++, value);
}
for (LeaderboardCell extension: extensions) {
if (extension != null) {
Object value = extension.getValue() == null ? extension.fetch() : extension.getValue(); // for costly extensions, only fetch value on-demand
if (!extension.isNA()) {
table.set(row, col, value);
}
}
col++;
}
}
/**
* Creates a {@link TwoDimTable} representation of the leaderboard.
* If no extensions are provided, then the representation will only contain the model ids and the scoring metrics.
* Each extension name will be represented in the table
* if and only if it was also made available to the leaderboard by the {@link LeaderboardExtensionsProvider},
* otherwise it will just be ignored.
* @param extensions optional columns for the leaderboard representation.
* @return a {@link TwoDimTable} representation of the current leaderboard.
* @see LeaderboardExtensionsProvider
* @see LeaderboardColumn
*/
public TwoDimTable toTwoDimTable(String... extensions) {
return toTwoDimTable("Leaderboard for project " + _project_name, false, extensions);
}
private TwoDimTable toTwoDimTable(String tableHeader, boolean leftJustifyModelIds, String... extensions) {
final Lock readLock = lock.readLock();
if (readLock.tryLock()) {
try {
final Key[] modelKeys = _model_keys.clone(); // leaderboard can be retrieved when AutoML is still running: freezing current models state.
final List columns = getDefaultTableColumns();
final List extColumns = new ArrayList<>();
if (getModelCount() > 0) {
final Key leader = getModelKeys()[0];
LeaderboardCell[] extCells = (extensions.length > 0 && LeaderboardExtensionsProvider.ALL.equalsIgnoreCase(extensions[0]))
? Stream.of(getExtensions(leader)).filter(cell -> !cell.getColumn().isHidden()).toArray(LeaderboardCell[]::new)
: Stream.of(extensions).map(e -> getExtension(leader, e)).toArray(LeaderboardCell[]::new);
Stream.of(extCells).filter(Objects::nonNull).forEach(e -> extColumns.add(e.getColumn()));
}
columns.addAll(extColumns);
TwoDimTable table = makeTwoDimTable(tableHeader, modelKeys.length, columns.toArray(new LeaderboardColumn[0]));
int maxModelIdLen = Stream.of(modelKeys).mapToInt(k -> k.toString().length()).max().orElse(0);
final String[] modelIDsFormatted = new String[modelKeys.length];
for (int i = 0; i < modelKeys.length; i++) {
Key key = modelKeys[i];
if (leftJustifyModelIds) {
// %-s doesn't work in TwoDimTable.toString(), so fake it here:
modelIDsFormatted[i] = org.apache.commons.lang.StringUtils.rightPad(key.toString(), maxModelIdLen);
} else {
modelIDsFormatted[i] = key.toString();
}
addTwoDimTableRow(table, i,
modelIDsFormatted[i],
getMetrics(),
extColumns.stream().map(ext -> getExtension(key, ext.getName())).toArray(LeaderboardCell[]::new)
);
}
return table;
} finally {
readLock.unlock();
}
} else {
return makeTwoDimTable(tableHeader, 0, getDefaultTableColumns().toArray(new LeaderboardColumn[0]));
}
}
private List getDefaultTableColumns() {
final List columns = new ArrayList<>();
columns.add(ModelId.COLUMN);
for (String metric : getMetrics()) {
columns.add(MetricScore.getColumn(metric));
}
return columns;
}
private String toString(String fieldSeparator, String lineSeparator, boolean includeTitle, boolean includeHeader) {
final StringBuilder sb = new StringBuilder();
if (includeTitle) {
sb.append("Leaderboard for project \"")
.append(_project_name)
.append("\": ");
if (_model_keys.length == 0) {
sb.append("");
return sb.toString();
}
sb.append(lineSeparator);
}
boolean printedHeader = false;
for (int i = 0; i < _model_keys.length; i++) {
final Key key = _model_keys[i];
if (includeHeader && ! printedHeader) {
sb.append("model_id");
sb.append(fieldSeparator);
String [] metrics = _metrics != null ? _metrics : new String[0];
sb.append(Arrays.toString(metrics));
sb.append(lineSeparator);
printedHeader = true;
}
sb.append(key.toString());
sb.append(fieldSeparator);
double[] values = _metrics != null ? getModelMetricValues(i) : new double[0];
sb.append(Arrays.toString(values));
sb.append(lineSeparator);
}
return sb.toString();
}
@Override
public String toString() {
return toString(" ; ", " | ", true, true);
}
public String toLogString() {
return toTwoDimTable("Leaderboard for project "+_project_name, true).toString();
}
}