hex.genmodel.easy.error.CountingErrorConsumer Maven / Gradle / Ivy
package hex.genmodel.easy.error;
import hex.genmodel.GenModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;
/**
* An implementation of {@link hex.genmodel.easy.EasyPredictModelWrapper.ErrorConsumer}
* counting number of each kind of error even received
*/
public class CountingErrorConsumer extends EasyPredictModelWrapper.ErrorConsumer {
private Map dataTransformationErrorsCountPerColumn;
private Map unknownCategoricalsPerColumn;
private Map> unseenCategoricalsCollector;
private final boolean collectUnseenCategoricals;
/**
* @param model An instance of {@link GenModel}
*/
public CountingErrorConsumer(GenModel model) {
this(model, DEFAULT_CONFIG);
}
/**
* @param model An instance of {@link GenModel}
* @param config An instance of {@link Config}
*/
public CountingErrorConsumer(GenModel model, Config config) {
collectUnseenCategoricals = config.isCollectUnseenCategoricals();
initializeDataTransformationErrorsCount(model);
initializeUnknownCategoricals(model);
}
/**
* Initializes the map of data transformation errors for each column that is not related to response variable,
* excluding response column. The map is initialized as unmodifiable and thread-safe.
*
* @param model {@link GenModel} the data trasnformation errors count map is initialized for
*/
private void initializeDataTransformationErrorsCount(GenModel model) {
String responseColumnName = model.isSupervised() ? model.getResponseName() : null;
dataTransformationErrorsCountPerColumn = new ConcurrentHashMap<>();
for (String column : model.getNames()) {
// Do not perform check for response column if the model is unsupervised
if (!model.isSupervised() || !column.equals(responseColumnName)) {
dataTransformationErrorsCountPerColumn.put(column, new AtomicLong());
}
}
dataTransformationErrorsCountPerColumn = Collections.unmodifiableMap(dataTransformationErrorsCountPerColumn);
}
/**
* Initializes the map of unknown categoricals per column with an unmodifiable and thread-safe implementation of {@link Map}.
*
* @param model {@link GenModel} the unknown categorical per column map is initialized for
*/
private void initializeUnknownCategoricals(GenModel model) {
unknownCategoricalsPerColumn = new ConcurrentHashMap<>();
unseenCategoricalsCollector = new ConcurrentHashMap<>();
for (int i = 0; i < model.getNumCols(); i++) {
String[] domainValues = model.getDomainValues(i);
if (domainValues != null) {
unknownCategoricalsPerColumn.put(model.getNames()[i], new AtomicLong());
if (collectUnseenCategoricals)
unseenCategoricalsCollector.put(model.getNames()[i], new ConcurrentHashMap