All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.glm.GlmOrdinalMojoModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.5
Show newest version
package hex.genmodel.algos.glm;

import hex.genmodel.utils.ArrayUtils;

import java.util.Arrays;

public class GlmOrdinalMojoModel extends GlmMojoModelBase {

  private int P;
  private int noff;
  private int lastClass;
  private int[] icptIndices;

  GlmOrdinalMojoModel(String[] columns, String[][] domains, String responseColumn) {
    super(columns, domains, responseColumn);
  }

  @Override
  void init() {
    P = _beta.length / _nclasses;
    lastClass = _nclasses-1;
    icptIndices = new int[lastClass];
    for (int c = 0; c < lastClass; c++) {
      icptIndices[c] = P-1+c*P;
    }
    if (P * _nclasses != _beta.length)
      throw new IllegalStateException("Incorrect coding of Beta.");
    noff = _catOffsets[_cats];
  }

  public final double[] score0(double[] data, double offset, double[] preds) {
    if (_meanImputation)
      super.imputeMissingWithMeans(data);

    return glmScore0(data, offset, preds);
  }
  
  double[] glmScore0(double[] data, double offset, double[] preds) {
    Arrays.fill(preds, 0);
    preds[0]=lastClass;

    for (int c = 0; c < lastClass; ++c) { // preds contains the etas for each class
      if (_cats > 0) {
        if (! _useAllFactorLevels) { // skip level 0 of all factors
          for (int i = 0; i < _catOffsets.length-1; ++i) if(data[i] != 0) {
            int ival = (int) data[i] - 1;
            if (ival != data[i] - 1) throw new IllegalArgumentException("categorical value out of range");
            ival += _catOffsets[i];
            if (ival < _catOffsets[i + 1])
              preds[c + 1] += _beta[ival + c*P];
          }
        } else { // do not skip any levels
          for(int i = 0; i < _catOffsets.length-1; ++i) {
            int ival = (int) data[i];
            if (ival != data[i]) throw new IllegalArgumentException("categorical value out of range");
            ival += _catOffsets[i];
            if(ival < _catOffsets[i + 1])
              preds[c + 1] += _beta[ival + c*P];
          }
        }
      }
      for (int i = 0; i < _nums; ++i) {
        preds[c + 1] += _beta[i+noff + c * P] * data[i+_cats];
      }
      preds[c+1] += _beta[icptIndices[c]];
    }

    double previousCDF = 0.0;
    for (int cInd = 0; cInd < lastClass; cInd++) { // classify row and calculate PDF of each class
      double eta = preds[cInd + 1]+offset;
      double currCDF = 1.0 / (1 + Math.exp(-eta));
      preds[cInd + 1] = currCDF - previousCDF;
      previousCDF = currCDF;
    }
    preds[_nclasses] = 1-previousCDF;
    preds[0] = 0;
    preds[0] = ArrayUtils.maxIndex(preds)-1;
    return preds;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy