All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.glrm.ModelMetricsGLRM Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.glrm;

import hex.CustomMetric;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsUnsupervised;
import water.fvec.Frame;

public class ModelMetricsGLRM extends ModelMetricsUnsupervised {
  public double _numerr;
  public double _caterr;
  public long   _numcnt;
  public long   _catcnt;

  public ModelMetricsGLRM(Model model, Frame frame, double numerr, double caterr, CustomMetric customMetric) {
    super(model, frame, 0, Double.NaN, customMetric);
    _numerr = numerr;
    _caterr = caterr;
  }

  public ModelMetricsGLRM(Model model, Frame frame, double numerr, double caterr, long numcnt, long catcnt, CustomMetric customMetric) {
    this(model, frame, numerr, caterr, customMetric);
    _numcnt = numcnt;
    _catcnt = catcnt;
  }

  public static class GlrmModelMetricsBuilder extends MetricBuilderUnsupervised {
    public double _miscls;     // Number of misclassified categorical values
    public long _numcnt;      // Number of observed numeric entries
    public long _catcnt;     // Number of observed categorical entries
    public int[] _permutation;  // Permutation array for shuffling cols
    public boolean _impute_original;

    public GlrmModelMetricsBuilder(int dims, int[] permutation) { this(dims, permutation, false); }
    public GlrmModelMetricsBuilder(int dims, int[] permutation, boolean impute_original) {
      _work = new double[dims];
      _miscls = _numcnt = _catcnt = 0;
      _permutation = permutation;
      _impute_original = impute_original;
    }

    @Override
    public double[] perRow(double[] preds, float[] dataRow, Model m) {
      assert m instanceof GLRMModel;
      GLRMModel gm = (GLRMModel) m;
      assert gm._output._ncats + gm._output._nnums == dataRow.length;
      int ncats = gm._output._ncats;
      double[] sub = gm._output._normSub;
      double[] mul = gm._output._normMul;

      // Permute cols so categorical before numeric since error metric different
      for (int i = 0; i < ncats; i++) {
        int idx = _permutation[i];
        if (Double.isNaN(dataRow[idx])) continue;
        if (dataRow[idx] != preds[idx]) _miscls++;
        _catcnt++;
      }

      int c = 0;
      for (int i = ncats; i < dataRow.length; i++) {
        int idx = _permutation[i];
        if (Double.isNaN(dataRow[idx])) { c++; continue; }
        double diff = (_impute_original ? dataRow[idx] : (dataRow[idx] - sub[c]) * mul[c]) - preds[idx];
        _sumsqe += diff * diff;
        _numcnt++;
        c++;
      }
      assert c == gm._output._nnums;
      return preds;
    }

    @Override
    public void reduce(GlrmModelMetricsBuilder mm) {
      super.reduce(mm);
      _miscls += mm._miscls;
      _numcnt += mm._numcnt;
      _catcnt += mm._catcnt;
    }

    @Override
    public ModelMetrics makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds) {
      // double numerr = _numcnt > 0 ? _sumsqe / _numcnt : Double.NaN;
      // double caterr = _catcnt > 0 ? _miscls / _catcnt : Double.NaN;
      // return m._output.addModelMetrics(new ModelMetricsGLRM(m, f, numerr, caterr));
      return m.addModelMetrics(new ModelMetricsGLRM(m, f, _sumsqe, _miscls, _numcnt, _catcnt, _customMetric));
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy