hex.genmodel.algos.glm.GlmMultinomialMojoModel Maven / Gradle / Ivy
package hex.genmodel.algos.glm;
public class GlmMultinomialMojoModel extends GlmMojoModelBase {
private int P;
private int noff;
GlmMultinomialMojoModel(String[] columns, String[][] domains, String responseColumn) {
super(columns, domains, responseColumn);
}
@Override
void init() {
P = _beta.length / _nclasses;
if (P * _nclasses != _beta.length)
throw new IllegalStateException("Incorrect coding of Beta.");
noff = _catOffsets[_cats];
}
public final double[] score0(double[] data, double offset, double[] preds) {
if (_meanImputation)
super.imputeMissingWithMeans(data);
return glmScore0(data, offset, preds);
}
double[] glmScore0(double[] data, double offset, double[] preds) {
preds[0] = 0;
for (int c = 0; c < _nclasses; ++c) {
preds[c + 1] = 0;
if (_cats > 0) {
if (! _useAllFactorLevels) { // skip level 0 of all factors
for (int i = 0; i < _catOffsets.length-1; ++i) if(data[i] != 0) {
int ival = (int) data[i] - 1;
if (ival != data[i] - 1) throw new IllegalArgumentException("categorical value out of range");
ival += _catOffsets[i];
if (ival < _catOffsets[i + 1])
preds[c + 1] += _beta[ival + c*P];
}
} else { // do not skip any levels
for(int i = 0; i < _catOffsets.length-1; ++i) {
int ival = (int) data[i];
if (ival != data[i]) throw new IllegalArgumentException("categorical value out of range");
ival += _catOffsets[i];
if(ival < _catOffsets[i + 1])
preds[c + 1] += _beta[ival + c*P];
}
}
}
for (int i = 0; i < _nums; ++i)
preds[c+1] += _beta[noff+i + c*P]*data[i+_cats];
preds[c+1] += _beta[(P-1) + c*P]; // reduce intercept
}
double max_row = 0;
for (int c = 1; c < preds.length; ++c) if (preds[c] > max_row) max_row = preds[c];
double sum_exp = 0;
for (int c = 1; c < preds.length; ++c) { sum_exp += (preds[c] = Math.exp(preds[c]-max_row));}
sum_exp = 1/sum_exp;
double max_p = 0;
for (int c = 1; c < preds.length; ++c) if ((preds[c] *= sum_exp) > max_p) { max_p = preds[c]; preds[0] = c-1; }
return preds;
}
}