All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.ModelMetrics Maven / Gradle / Ivy

There is a newer version: 3.8.2.9
Show newest version
package hex;

import water.*;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.util.TwoDimTable;

import java.lang.reflect.Method;
import java.util.*;

/** Container to hold the metric for a model as scored on a specific frame.
 *
 *  The MetricBuilder class is used in a hot inner-loop of a Big Data pass, and
 *  when given a class-distribution, can be used to compute CM's, and AUC's "on
 *  the fly" during ModelBuilding - or after-the-fact with a Model and a new
 *  Frame to be scored.
 */
public class ModelMetrics extends Keyed {
  public String _description;
  final Key _modelKey;
  final Key _frameKey;
  final ModelCategory _model_category;
  final long _model_checksum;
  long _frame_checksum;
  public final long _scoring_time;

  // Cached fields - cached them when needed
  private transient Model _model;
  private transient Frame _frame;

  public final double _MSE;     // Mean Squared Error (Every model is assumed to have this, otherwise leave at NaN)

  public ModelMetrics(Model model, Frame frame, double MSE, String desc) {
    super(buildKey(model, frame));
    _description = desc;
    // Do not cache fields now
    _modelKey = model._key;
    _frameKey = frame._key;
    _model_category = model._output.getModelCategory();
    _model_checksum = model.checksum();
    try { _frame_checksum = frame.checksum(); } catch (Throwable t) { }
    _MSE = MSE;
    _scoring_time = System.currentTimeMillis();
  }

  public long residual_degrees_of_freedom(){throw new UnsupportedOperationException("residual degrees of freedom is not supported for this metric class");}
  @Override public String toString() {
    StringBuilder sb = new StringBuilder();
    sb.append("Model Metrics Type: " + this.getClass().getSimpleName().substring(12) + "\n");
    sb.append(" Description: " + (_description == null ? "N/A" : _description) + "\n");
    sb.append(" model id: " + _modelKey + "\n");
    sb.append(" frame id: " + _frameKey + "\n");
    sb.append(" MSE: " + (float)_MSE + "\n");
    return sb.toString();
  }

  public Model model() { return _model==null ? (_model=DKV.getGet(_modelKey)) : _model; }
  public Frame frame() { return _frame==null ? (_frame=DKV.getGet(_frameKey)) : _frame; }

  public double mse() { return _MSE; }
  public ConfusionMatrix cm() { return null; }
  public float[] hr() { return null; }
  public AUC2 auc_obj() { return null; }

  static public double getMetricFromModel(Key key, String criterion) {
    Model model = DKV.getGet(key);
    if (null == model) throw new H2OIllegalArgumentException("Cannot find model " + key);
    if (null == criterion || criterion.equals("")) throw new H2OIllegalArgumentException("Need a valid criterion, but got '" + criterion + "'.");
    ModelMetrics m =
            model._output._cross_validation_metrics != null ?
                    model._output._cross_validation_metrics :
                    model._output._validation_metrics != null ?
                            model._output._validation_metrics :
                            model._output._training_metrics;
    Method method = null;
    ConfusionMatrix cm = m.cm();
    try {
      method = m.getClass().getMethod(criterion);
    }
    catch (Exception e) {
      // fall through
    }

    if (null == method && null != cm) {
      try {
        method = cm.getClass().getMethod(criterion);
      }
      catch (Exception e) {
        // fall through
      }
    }
    if (null == method)
      throw new H2OIllegalArgumentException("Failed to find ModelMetrics for criterion: " + criterion);

    double c;
    try {
      c = (double) method.invoke(m);
    } catch(Exception fallthru) {
      try {
        c = (double)method.invoke(cm);
      } catch (Exception e) {
        throw new H2OIllegalArgumentException(
                "Failed to get metric: " + criterion + " from ModelMetrics object: " + m,
                "Failed to get metric: " + criterion + " from ModelMetrics object: " + m + ", criterion: " + method + ", exception: " + e
        );
      }
    }
    return c;
  }


  private static class MetricsComparator implements Comparator> {
    String _sort_by = null;
    boolean decreasing = false;

    public MetricsComparator(String sort_by, boolean decreasing) {
      this._sort_by = sort_by;
      this.decreasing = decreasing;
    }

    public int compare(Key key1, Key key2) {
      double c1 = getMetricFromModel(key1, _sort_by);
      double c2 = getMetricFromModel(key2, _sort_by);
      return decreasing ? Double.compare(c2, c1) : Double.compare(c1, c2);
    }
  }

  //
  public static Set getAllowedMetrics(Key key) {
    Set res = new HashSet<>();
    Model model = DKV.getGet(key);
    if (null == model) throw new H2OIllegalArgumentException("Cannot find model " + key);
    ModelMetrics m =
            model._output._cross_validation_metrics != null ?
                    model._output._cross_validation_metrics :
                    model._output._validation_metrics != null ?
                            model._output._validation_metrics :
                            model._output._training_metrics;
    ConfusionMatrix cm = m.cm();
    Set excluded = new HashSet<>();
    excluded.add("makeSchema");
    excluded.add("hr");
    excluded.add("cm");
    excluded.add("auc_obj");
    excluded.add("remove");
    if (m!=null) {
      for (Method meth : m.getClass().getMethods()) {
        if (excluded.contains(meth.getName())) continue;
        try {
          double c = (double) meth.invoke(m);
          res.add(meth.getName());
        } catch (Exception e) {
          // fall through
        }
      }
    }
    if (cm!=null) {
      for (Method meth : cm.getClass().getMethods()) {
        if (excluded.contains(meth.getName())) continue;
        try {
          double c = (double) meth.invoke(cm);
          res.add(meth.getName());
        } catch (Exception e) {
          // fall through
        }
      }
    }
    return res;
  }

  /**
   * Return a new list of models sorted by the named criterion, such as "auc", mse", "hr", "err", "errCount",
   * "accuracy", "specificity", "recall", "precision", "mcc", "max_per_class_error", "F1", "F2", "F0point5". . .
   * @param sort_by criterion by which we should sort
   * @param decreasing sort by decreasing metrics or not
   * @param modelKeys keys of models to sortm
   * @return keys of the models, sorted by the criterion
   */
  public static List> sortModelsByMetric(String sort_by, boolean decreasing, List>modelKeys) {
    List> sorted = new ArrayList<>();
    sorted.addAll(modelKeys);

    Comparator> c = new MetricsComparator(sort_by, decreasing);

    Collections.sort(sorted, c);
    return sorted;
  }

  public static TwoDimTable calcVarImp(VarImp vi) {
    if (vi == null) return null;
    double[] dbl_rel_imp = new double[vi._varimp.length];
    for (int i=0; i() {
      public int compare(Integer idx1, Integer idx2) {
        return Double.compare(-rel_imp[idx1], -rel_imp[idx2]);
      }
    });

    double total = 0;
    double max = rel_imp[sorted_idx[0]];
    String[] sorted_names = new String[rel_imp.length];
    double[][] sorted_imp = new double[rel_imp.length][3];

    // First pass to sum up relative importance measures
    int j = 0;
    for(int i : sorted_idx) {
      total += rel_imp[i];
      sorted_names[j] = coef_names[i];
      sorted_imp[j][0] = rel_imp[i];         // Relative importance
      sorted_imp[j++][1] = rel_imp[i] / max;   // Scaled importance
    }
    // Second pass to calculate percentages
    j = 0;
    for(int i : sorted_idx)
      sorted_imp[j++][2] = rel_imp[i] / total; // Percentage

    String [] col_types = new String[3];
    String [] col_formats = new String[3];
    Arrays.fill(col_types, "double");
    Arrays.fill(col_formats, "%5f");
    return new TwoDimTable(table_header, null, sorted_names, col_headers, col_types, col_formats, "Variable",
            new String[rel_imp.length][], sorted_imp);
  }

  private static Key buildKey(Key model_key, long model_checksum, Key frame_key, long frame_checksum) {
    return Key.make("modelmetrics_" + model_key + "@" + model_checksum + "_on_" + frame_key + "@" + frame_checksum);
  }

  public static Key buildKey(Model model, Frame frame) {
    return frame==null ? null : buildKey(model._key, model.checksum(), frame._key, frame.checksum());
  }

  public boolean isForModel(Model m) { return _model_checksum == m.checksum(); }
  public boolean isForFrame(Frame f) { return _frame_checksum == f.checksum(); }

  public static ModelMetrics getFromDKV(Model model, Frame frame) {
    Value v = DKV.get(buildKey(model, frame));
    return null == v ? null : (ModelMetrics)v.get();
  }

  @Override protected long checksum_impl() { return _frame_checksum * 13 + _model_checksum * 17; }

  /** Class used to compute AUCs, CMs & HRs "on the fly" during other passes
   *  over Big Data.  This class is intended to be embedded in other MRTask
   *  objects.  The {@code perRow} method is called once-per-scored-row, and
   *  the {@code reduce} method called once per MRTask.reduce, and the {@code
   *  } called once per MRTask.map.
   */
  public static abstract class MetricBuilder> extends Iced {
    transient public double[] _work;
    public double _sumsqe;      // Sum-squared-error
    public long _count;
    public double _wcount;
    public double _wY; // (Weighted) sum of the response
    public double _wYY; // (Weighted) sum of the squared response

    public  double weightedSigma() {
//      double sampleCorrection = _count/(_count-1); //sample variance -> depends on the number of ACTUAL ROWS (not the weighted count)
      double sampleCorrection = 1; //this will make the result (and R^2) invariant to globally scaling the weights
      return _count <= 1 ? 0 : Math.sqrt(sampleCorrection*(_wYY/_wcount - (_wY*_wY)/(_wcount*_wcount)));
    }
    abstract public double[] perRow(double ds[], float yact[], Model m);
    public double[] perRow(double ds[], float yact[],double weight, double offset,  Model m) {
      assert(weight==1 && offset == 0);
      return perRow(ds, yact, m);
    }
    public void reduce( T mb ) {
      _sumsqe += mb._sumsqe;
      _count += mb._count;
      _wcount += mb._wcount;
      _wY += mb._wY;
      _wYY += mb._wYY;
    }

    public void postGlobal() {}

    /**
     * Having computed a MetricBuilder, this method fills in a ModelMetrics
     * @param m Model
     * @param f Scored Frame
     * @param adaptedFrame Adapted Frame
     *@param preds Predictions of m on f (optional)  @return Filled Model Metrics object
     */
    public abstract ModelMetrics makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy