All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.glm.GlmMultinomialMojoModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.genmodel.algos.glm;

public class GlmMultinomialMojoModel extends GlmMojoModelBase {

  private int P;
  private int noff;

  GlmMultinomialMojoModel(String[] columns, String[][] domains, String responseColumn) {
    super(columns, domains, responseColumn);
  }

  @Override
  void init() {
    P = _beta.length / _nclasses;
    if (P * _nclasses != _beta.length)
      throw new IllegalStateException("Incorrect coding of Beta.");
    noff = _catOffsets[_cats];
  }

  public final double[] score0(double[] data, double offset, double[] preds) {
    if (_meanImputation)
      super.imputeMissingWithMeans(data);

    return glmScore0(data, offset, preds);
  }
  
  double[] glmScore0(double[] data, double offset, double[] preds) {
    preds[0] = 0;
    for (int c = 0; c < _nclasses; ++c) {
      preds[c + 1] = 0;
      if (_cats > 0) {
        if (! _useAllFactorLevels) { // skip level 0 of all factors
          for (int i = 0; i < _catOffsets.length-1; ++i) if(data[i] != 0) {
            int ival = (int) data[i] - 1;
            if (ival != data[i] - 1) throw new IllegalArgumentException("categorical value out of range");
            ival += _catOffsets[i];
            if (ival < _catOffsets[i + 1])
              preds[c + 1] += _beta[ival + c*P];
          }
        } else { // do not skip any levels
          for(int i = 0; i < _catOffsets.length-1; ++i) {
            int ival = (int) data[i];
            if (ival != data[i]) throw new IllegalArgumentException("categorical value out of range");
            ival += _catOffsets[i];
            if(ival < _catOffsets[i + 1])
              preds[c + 1] += _beta[ival + c*P];
          }
        }
      }
      for (int i = 0; i < _nums; ++i)
        preds[c+1] += _beta[noff+i + c*P]*data[i+_cats];
      preds[c+1] += _beta[(P-1) + c*P]; // reduce intercept
    }
    double max_row = 0;
    for (int c = 1; c < preds.length; ++c) if (preds[c] > max_row) max_row = preds[c];
    double sum_exp = 0;
    for (int c = 1; c < preds.length; ++c) { sum_exp += (preds[c] = Math.exp(preds[c]-max_row));}
    sum_exp = 1/sum_exp;
    double max_p = 0;
    for (int c = 1; c < preds.length; ++c) if ((preds[c] *= sum_exp) > max_p) { max_p = preds[c]; preds[0] = c-1; }
    return preds;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy