
hex.ScoreKeeper Maven / Gradle / Ivy
package hex;
import water.H2O;
import water.Iced;
import water.exceptions.H2OIllegalArgumentException;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;
import java.util.Arrays;
import java.util.Comparator;
/**
* Low-weight keeper of scores
* solely intended for display (either direct or as helper to create scoring history TwoDimTable).
* Not intended to store large AUC object or ConfusionMatrices, etc.
*/
public class ScoreKeeper extends Iced {
public double _r2 = Double.NaN;
public double _mean_residual_deviance = Double.NaN;
public double _mse = Double.NaN;
public double _logloss = Double.NaN;
public double _AUC = Double.NaN;
public double _classError = Double.NaN;
public float[] _hitratio;
public double _lift = Double.NaN; //Lift in top group
public ScoreKeeper() {}
/**
* Keep score of mean squared error only.
* @param mse
*/
public ScoreKeeper(double mse) { _mse = mse; }
/**
* Keep score of a given ModelMetrics.
* @param mm ModelMetrics to keep track of.
*/
public ScoreKeeper(ModelMetrics mm) { fillFrom(mm); }
/**
* Keep score for a model using its validation_metrics if available and training_metrics if not.
* @param m model for which we should keep score
*/
public ScoreKeeper(Model m) {
if (null == m) throw new H2OIllegalArgumentException("model", "ScoreKeeper(Model model)", null);
if (null == m._output) throw new H2OIllegalArgumentException("model._output", "ScoreKeeper(Model model)", null);
if (null != m._output._cross_validation_metrics) {
fillFrom(m._output._cross_validation_metrics);
} else if (null != m._output._validation_metrics) {
fillFrom(m._output._validation_metrics);
} else {
fillFrom(m._output._training_metrics);
}
}
public boolean isEmpty() {
return Double.isNaN(_mse) && Double.isNaN(_logloss); // at least one of them should always be filled
}
public void fillFrom(ModelMetrics m) {
if (m == null) return;
_mse = m._MSE;
if (m instanceof ModelMetricsSupervised) {
_r2 = ((ModelMetricsSupervised)m).r2();
}
if (m instanceof ModelMetricsRegression) {
_mean_residual_deviance = ((ModelMetricsRegression)m)._mean_residual_deviance;
}
if (m instanceof ModelMetricsBinomial) {
_logloss = ((ModelMetricsBinomial)m)._logloss;
if (((ModelMetricsBinomial)m)._auc != null) {
_AUC = ((ModelMetricsBinomial) m)._auc._auc;
_classError = ((ModelMetricsBinomial) m)._auc.defaultErr();
}
GainsLift gl = ((ModelMetricsBinomial)m)._gainsLift;
if (gl != null && gl.response_rates != null && gl.response_rates.length > 0) {
_lift = gl.response_rates[0] / gl.avg_response_rate;
}
}
else if (m instanceof ModelMetricsMultinomial) {
_logloss = ((ModelMetricsMultinomial)m)._logloss;
_classError = ((ModelMetricsMultinomial)m)._cm.err();
_hitratio = ((ModelMetricsMultinomial)m)._hit_ratios;
}
}
public enum StoppingMetric { AUTO, deviance, logloss, MSE, AUC, lift_top_group, r2, misclassification}
public static boolean moreIsBetter(StoppingMetric criterion) {
return (criterion == StoppingMetric.AUC || criterion == StoppingMetric.r2 || criterion == StoppingMetric.lift_top_group);
}
/** Based on the given array of ScoreKeeper and stopping criteria should we stop early? */
public static boolean stopEarly(ScoreKeeper[] sk, int k, boolean classification, StoppingMetric criterion, double rel_improvement, String what, boolean verbose) {
if (k == 0) return false;
int len = sk.length - 1; //how many "full"/"conservative" scoring events we have (skip the first)
if (len < 2*k) return false; //need at least k for SMA and another k to tell whether the model got better or not
if (criterion==StoppingMetric.AUTO) {
criterion = classification ? StoppingMetric.logloss : StoppingMetric.deviance;
}
boolean moreIsBetter = moreIsBetter(criterion);
double movingAvg[] = new double[k+1]; //need one moving average value for the last k+1 scoring events
double lastBeforeK = moreIsBetter ? -Double.MAX_VALUE : Double.MAX_VALUE;
double bestInLastK = moreIsBetter ? -Double.MAX_VALUE : Double.MAX_VALUE;
for (int i=0;i mark it not yet converged, avoid division by 0 or weird relative improvements math below
if (Math.signum(ArrayUtils.maxValue(movingAvg)) != Math.signum(ArrayUtils.minValue(movingAvg))) return false;
if (Math.signum(bestInLastK) != Math.signum(lastBeforeK)) return false;
assert(lastBeforeK != Double.MAX_VALUE);
assert(bestInLastK != Double.MAX_VALUE);
if (verbose)
Log.info("Windowed averages (window size " + k + ") of " + what + " " + (k+1) + " " + criterion.toString() + " metrics: " + Arrays.toString(movingAvg));
double ratio = bestInLastK / lastBeforeK;
if (Double.isNaN(ratio)) return false;
boolean improved = moreIsBetter ? ratio > 1+rel_improvement : ratio < 1-rel_improvement;
if (verbose)
Log.info("Checking convergence with " + criterion.toString() + " metric: " + lastBeforeK + " --> " + bestInLastK + (improved ? " (still improving)." : " (converged)."));
return !improved;
} // stopEarly
/**
* Compare this ScoreKeeper with that ScoreKeeper
* @param that
* @return true if they are equal (up to 1e-6 absolute and relative error, or both contain NaN for the same values)
*/
@Override public boolean equals(Object that) {
if (! (that instanceof ScoreKeeper)) return false;
ScoreKeeper o = (ScoreKeeper)that;
if (_hitratio == null && ((ScoreKeeper) that)._hitratio != null) return false;
if (_hitratio != null && ((ScoreKeeper) that)._hitratio == null) return false;
if (_hitratio != null && ((ScoreKeeper) that)._hitratio != null) {
if (_hitratio.length != ((ScoreKeeper) that)._hitratio.length) return false;
for (int i=0; i<_hitratio.length; ++i) {
if (!MathUtils.compare(_hitratio[i], ((ScoreKeeper) that)._hitratio[i], 1e-6, 1e-6)) return false;
}
}
return MathUtils.compare(_r2, o._r2, 1e-6, 1e-6)
&& MathUtils.compare(_mean_residual_deviance, o._mean_residual_deviance, 1e-6, 1e-6)
&& MathUtils.compare(_mse, o._mse, 1e-6, 1e-6)
&& MathUtils.compare(_logloss, o._logloss, 1e-6, 1e-6)
&& MathUtils.compare(_classError, o._classError, 1e-6, 1e-6)
&& MathUtils.compare(_lift, o._lift, 1e-6, 1e-6);
}
public static Comparator comparator(StoppingMetric criterion) {
switch (criterion) {
case AUC:
return new Comparator() {
@Override
public int compare(ScoreKeeper o1, ScoreKeeper o2) {
return (int)Math.signum(o2._AUC - o1._AUC); // moreIsBetter
}
};
case MSE:
return new Comparator() {
@Override
public int compare(ScoreKeeper o1, ScoreKeeper o2) {
return (int)Math.signum(o1._mse - o2._mse); // lessIsBetter
}
};
case deviance:
return new Comparator() {
@Override
public int compare(ScoreKeeper o1, ScoreKeeper o2) {
return (int)Math.signum(o1._mean_residual_deviance - o2._mean_residual_deviance); // lessIsBetter
}
};
case logloss:
return new Comparator() {
@Override
public int compare(ScoreKeeper o1, ScoreKeeper o2) {
return (int)Math.signum(o1._logloss - o2._logloss); // lessIsBetter
}
};
case r2:
return new Comparator() {
@Override
public int compare(ScoreKeeper o1, ScoreKeeper o2) {
return (int)Math.signum(o2._r2 - o1._r2); // moreIsBetter
}
};
case misclassification:
return new Comparator() {
@Override
public int compare(ScoreKeeper o1, ScoreKeeper o2) {
return (int)Math.signum(o1._classError - o2._classError); // lessIsBetter
}
};
case lift_top_group:
return new Comparator() {
@Override
public int compare(ScoreKeeper o1, ScoreKeeper o2) {
return (int)Math.signum(o2._lift - o1._lift); // moreIsBetter
}
};
default:
throw H2O.unimpl("Undefined stopping criterion.");
} // switch
} // comparator
@Override
public String toString() {
return "ScoreKeeper{" +
"_r2=" + _r2 +
", _mean_residual_deviance=" + _mean_residual_deviance +
", _mse=" + _mse +
", _logloss=" + _logloss +
", _AUC=" + _AUC +
", _classError=" + _classError +
", _hitratio=" + Arrays.toString(_hitratio) +
", _lift=" + _lift +
'}';
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy