All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.glm.GlmMojoModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.5
Show newest version
package hex.genmodel.algos.glm;

import hex.genmodel.GenModel;

import java.io.Serializable;

public class GlmMojoModel extends GlmMojoModelBase {

  String _link;
  double _tweedieLinkPower;

  // set by init()
  private Function1 _linkFn;
  private boolean _binomial;

  GlmMojoModel(String[] columns, String[][] domains, String responseColumn) {
    super(columns, domains, responseColumn);
  }

  @Override
  void init() {
    _linkFn = createLinkFunction();
    _binomial = "binomial".equals(_family) || "fractionalbinomial".equals(_family) || "quasibinomial".equals(_family);
  }

  public final double[] score0(double[] data, double offset, double[] preds) {
    if (_meanImputation)
      super.imputeMissingWithMeans(data);

    return glmScore0(data, offset, preds);
  }
  
  double[] glmScore0(double[] data, double offset, double[] preds) {
    double eta = 0.0;

    if (!_useAllFactorLevels) { // skip level 0 of all factors
      for(int i = 0; i < _catOffsets.length-1; ++i) {
        if(data[i] != 0) {
          int ival = (int) data[i] - 1;
          if (ival != data[i] - 1) {
            throw new IllegalArgumentException("categorical value out of range");
          }
          ival += _catOffsets[i];
          if (ival < _catOffsets[i + 1]) {
            eta += _beta[ival];
          }
        }
      }
    } else { // do not skip any levels
      for(int i = 0; i < _catOffsets.length-1; ++i) {
        int ival = (int) data[i];
        if (ival != data[i]) {
          throw new IllegalArgumentException("categorical value out of range");
        }
        ival += _catOffsets[i];
        if (ival < _catOffsets[i + 1]) {
          eta += _beta[ival];
        }
      }
    }

    int noff = _catOffsets[_cats] - _cats;
    for(int i = _cats; i < _beta.length - 1 - noff; ++i)
      eta += _beta[noff + i] * data[i];
    eta += _beta[_beta.length - 1]; // reduce intercept
    eta += offset;

    double mu = _linkFn.eval(eta);

    if (_binomial) {
      preds[0] = (mu >= _defaultThreshold) ? 1 : 0; // threshold given by ROC
      preds[1] = 1.0 - mu; // class 0
      preds[2] =       mu; // class 1
    } else {
      preds[0] = mu;
    }

    return preds;
  }

  /**
   * Applies GLM coefficients to a given row of data to calculate
   * feature contributions.
   *
   * Note: for internal purposes only (k-LIME)
   *
   * @param data input row of data (same input as to glmScore0)
   * @param output target output array
   * @param destPos index to the output array where the result should start
   * @return feature contributions, prediction = linkFunction(sum(output) + intercept)
   */
  public double[] applyCoefficients(double[] data, double[] output, int destPos) {
    final int offset = _useAllFactorLevels ? 0 : -1;
    for (int i = 0; i < _catOffsets.length - 1; i++) {
      int ival = (int) data[i] - offset;
      if (ival < 0) continue;
      ival += _catOffsets[i];
      if (ival < _catOffsets[i + 1])
        output[i + destPos] = _beta[ival];
    }
    int p = destPos + _catOffsets.length - 1;
    int noff = _catOffsets[_cats] - _cats;
    for (int i = _cats; i < _beta.length - 1 - noff; i++)
      output[p++] = _beta[noff + i] * data[i];
    return output;
  }

  public double getIntercept() {
    return _beta[_beta.length - 1];
  }

  private interface Function1 extends Serializable {
    double eval(double x);
  }

  private Function1 createLinkFunction() {
    if ("identity".equals(_link))
      return new GLM_identityInv();
    else if ("logit".equals(_link))
      return new GLM_logitInv();
    else if ("log".equals(_link))
      return new GLM_logInv();
    else if ("inverse".equals(_link))
      return new GLM_inverseInv();
    else if ("tweedie".equals(_link))
      return new GLM_tweedieInv(_tweedieLinkPower);
    else
      throw new UnsupportedOperationException("Unexpected link function " + _link);
  }

  private static class GLM_identityInv implements Function1 {
    @Override public double eval(double x) { return GenModel.GLM_identityInv(x); }
  }
  private static class GLM_logitInv implements Function1 {
    @Override public double eval(double x) { return GenModel.GLM_logitInv(x); }
  }
  private static class GLM_logInv implements Function1 {
    @Override public double eval(double x) { return GenModel.GLM_logInv(x); }
  }
  private static class GLM_inverseInv implements Function1 {
    @Override
    public double eval(double x) {
      return GenModel.GLM_inverseInv(x);
    }
  }

    private static class GLM_ologitInv implements Function1 {
      @Override public double eval(double x) { return GenModel.GLM_ologitInv(x); }
  }
  private static class GLM_tweedieInv implements Function1 {
    private final double _tweedie_link_power;
    GLM_tweedieInv(double tweedie_link_power) { this._tweedie_link_power = tweedie_link_power; }
    @Override public double eval(double x) { return GenModel.GLM_tweedieInv(x, _tweedie_link_power); }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy