hex.gam.GAMModel Maven / Gradle / Ivy
package hex.gam;
import hex.*;
import hex.deeplearning.DeepLearningModel;
import hex.gam.MatrixFrameUtils.AddCSGamColumns;
import hex.gam.MatrixFrameUtils.AddISGamColumns;
import hex.gam.MatrixFrameUtils.AddMSGamColumns;
import hex.gam.MatrixFrameUtils.AddTPKnotsGamColumns;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.glm.GLMModel.GLMParameters.Family;
import hex.glm.GLMModel.GLMParameters.Link;
import hex.glm.GLMModel.GLMParameters.Solver;
import hex.util.EffectiveParametersUtils;
import water.*;
import water.exceptions.H2OColumnNotFoundArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.util.*;
import java.io.Serializable;
import java.util.Arrays;
import static hex.gam.MatrixFrameUtils.GamUtils.*;
import static hex.genmodel.algos.gam.GamMojoModel.*;
import static hex.glm.GLMModel.GLMParameters.MissingValuesHandling;
import static hex.util.DistributionUtils.distributionToFamily;
import static hex.util.DistributionUtils.familyToDistribution;
public class GAMModel extends Model {
private static final String[] BINOMIAL_CLASS_NAMES = new String[]{"0", "1"};
private static final int CS_NUM_INDEX = 0;
private static final int IS_NUM_INDEX = 1;
private static final int MS_NUM_INDEX = 2;
private static final int NUM_SINGLE_SPLINE_TYPES = 3;
public String[][] _gamColNamesNoCentering; // store column names only for GAM columns
public String[][] _gamColNames; // store column names only for GAM columns after decentering
public int[] _gamPredSize; // store size of predictors for gam smoother
public int[] _m; // parameter related to gamPredSize;
public int[] _M; // size of polynomial basis for thin plate regression smoothers
public int _cubicSplineNum;
public int _iSplineNum;
public int _mSplineNum;
public int _thinPlateSmoothersWithKnotsNum;
public Key[] _gamFrameKeysCenter;
public double[] _gamColMeans;
public int _nclass; // 2 for binomial, > 2 for multinomial and ordinal
public double[] _ymu;
public long _nobs;
public long _nullDOF;
public int _rank;
public IcedHashSet> _validKeys = null;
@Override public String[] makeScoringNames() {
String[] names = super.makeScoringNames();
if (_output._glm_vcov != null)
names = ArrayUtils.append(names, "StdErr");
return names;
}
@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
if (domain==null && (_parms._family==Family.binomial || _parms._family==Family.quasibinomial ||
_parms._family==Family.negativebinomial || _parms._family==Family.fractionalbinomial)) {
if (_parms._family == Family.fractionalbinomial)
domain = BINOMIAL_CLASS_NAMES;
else
domain = _output._responseDomains;
}
GLMModel.GLMWeightsFun glmf = new GLMModel.GLMWeightsFun(_parms._family, _parms._link, _parms._tweedie_variance_power,
_parms._tweedie_link_power, _parms._theta);
return new MetricBuilderGAM(domain, _ymu, glmf, _rank, true, _parms._intercept, _nclass, _parms._auc_type);
}
public GAMModel(Key selfKey, GAMParameters parms, GAMModelOutput output) {
super(selfKey, parms, output);
assert(Arrays.equals(_key._kb, selfKey._kb));
}
public void initActualParamValuesAfterGlmCreation(){
EffectiveParametersUtils.initFoldAssignment(_parms);
}
public TwoDimTable genCoefficientMagTableMultinomial(String[] colHeaders, double[][] coefficients,
String[] coefficientNames, String tableHeader) {
String[] colTypes = new String[]{ "double", "string"};
String[] colFormat = new String[]{"%5f", ""};
int nCoeff = coefficients[0].length;
int nClass = coefficients.length;
String[] coeffNames = new String[nCoeff - 1];
String[] coeffNames2 = new String[coeffNames.length];
double[] coeffMags = new double[coeffNames.length];
double[] coeffMags2 = new double[coeffNames.length];
String[] coeffSigns = new String[coeffNames.length];
Log.info("genCoefficientMagTableMultinomial", String.format("coeffNames length: %d. coeffMags " +
"length: %d, coeffSigns length: %d", coeffNames.length, coeffMags.length, coeffSigns.length));
int countIndex = 0;
for (int index = 0; index < nCoeff; index++) {
if (!coefficientNames[index].equals("Intercept")) {
for (int classInd = 0; classInd < nClass; classInd++) {
coeffMags[countIndex] += Math.abs(coefficients[classInd][index]); // add abs(coefficients) of diff classes
}
coeffNames[countIndex] = coefficientNames[index];
coeffSigns[countIndex] = "POS"; // assign all signs to positive for multinomial
countIndex++;
}
}
// sort in descending order of the magnitudes
Integer[] indices = sortCoeffMags(coeffMags.length, coeffMags);
// reorder names and coeffMags with indices
for (int index = 0; index < coeffMags.length; index++) {
coeffMags2[index] = coeffMags[indices[index]];
coeffNames2[index] = coeffNames[indices[index]];
}
Log.info("genCoefficientMagTableMultinomial", String.format("coeffNames2 length: %d. coeffMags2 " +
"length: %d, coeffSigns length: %d", coeffNames2.length, coeffMags2.length, coeffSigns.length));
TwoDimTable table = new TwoDimTable(tableHeader, "Standardized Coefficient Magnitutes", coeffNames2, colHeaders, colTypes, colFormat,
"names");
fillUpCoeffsMag(coeffMags2, coeffSigns, table, 0);
return table;
}
public TwoDimTable genCoefficientMagTable(String[] colHeaders, double[] coefficients,
String[] coefficientNames, String tableHeader) {
String[] colTypes = new String[]{ "double", "string"};
String[] colFormat = new String[]{"%5f", ""};
int nCoeff = coefficients.length;
String[] coeffNames = new String[nCoeff-1];
double[] coeffMags = new double[nCoeff-1]; // skip over intercepts
String[] coeffSigns = new String[nCoeff-1];
int countMagIndex = 0;
for (int index = 0; index < nCoeff; index++) {
if (!coefficientNames[index].equals("Intercept")) {
coeffMags[countMagIndex] = Math.abs(coefficients[index]);
coeffSigns[countMagIndex] = coefficients[index] > 0 ? "POS" : "NEG";
coeffNames[countMagIndex++] = coefficientNames[index];
}
}
Integer[] indices = sortCoeffMags(coeffMags.length, coeffMags); // sort magnitude indices in decreasing magnitude
String[] names2 = new String[coeffNames.length];
double[] mag2 = new double[coeffNames.length];
String[] sign2 = new String[coeffNames.length];
for (int i = 0; i < coeffNames.length; ++i) {
names2[i] = coeffNames[indices[i]];
mag2[i] = coeffMags[indices[i]];
sign2[i] = coeffSigns[indices[i]];
}
Log.info("genCoefficientMagTableMultinomial", String.format("coeffNames length: %d. coeffMags " +
"length: %d, coeffSigns length: %d", coeffNames.length, coeffMags.length, coeffSigns.length));
TwoDimTable table = new TwoDimTable(tableHeader, "", names2, colHeaders, colTypes, colFormat,
"names");
fillUpCoeffsMag( mag2, sign2, table, 0);
return table;
}
private void fillUpCoeffsMag(double[] coeffMags, String[] coeffSigns, TwoDimTable tdt, int rowStart) {
int arrLength = coeffMags.length+rowStart;
int arrCounter=0;
for (int i=rowStart; i _plug_values = null;
// internal parameter, handle with care. GLM will stop when there is more than this number of active predictors (after strong rule screening)
public int _max_active_predictors = -1; // not used in GAM, copied over to GLM params
public boolean _generate_scoring_history = false; // if true, will generate GLM scoring history but will slow algo down
// the following parameters are for GAM
public int[] _num_knots; // array storing number of knots per smoother
public int[] _spline_orders; // storing I-spline orders for each predictor
public int[] _spline_orders_sorted;
public int[] _num_knots_sorted;
public int[] _num_knots_tp; // store num_knots for thin plate regression
public String[] _knot_ids; // store frame keys that contain knots location for each smoother in gam_X;
public String[][] _gam_columns; // array storing which predictor columns are specified
public String[][] _gam_columns_sorted; // move CS spline to the front and tp to the back in gam_columns
public int[] _gamPredSize; // store size of predictors for gam smoother
public int[] _m; // parameter related to gamPredSize;
public int[] _M; // size of polynomial basis for thin plate regression smoothers
public int[] _bs; // choose spline function for gam column, 0 = cr, 1 = thin plate regression with knots,
// 2 = monotone I-spline, 3 = NBSplineTypeI M-splines
public int[] _bs_sorted; // choose spline function for gam column, 0 = cr, 1 = thin plate regression with knots,
// 2 = monotone I-spline, 3 = NBSplineTypeI M-splines
public double[] _scale; // array storing scaling values to control wriggliness of fit
public double[] _scale_sorted;
public boolean _saveZMatrix = false; // if asserted will save Z matrix
public boolean _keep_gam_cols = false; // if true will save the keys to gam Columns only
public boolean _savePenaltyMat = false; // if true will save penalty matrices as triple array
public String algoName() { return "GAM"; }
public String fullName() { return "Generalized Additive Model"; }
public String javaName() { return GAMModel.class.getName(); }
public double _prior = -1;
public boolean _cold_start = false; // start building GLM model from scratch if true
public int _nlambdas = -1;
public boolean _non_negative = false;
public boolean _remove_collinear_columns = false;
public double _gradient_epsilon = -1;
public boolean _early_stopping = true; // internal GLM early stopping.
public Key _beta_constraints = null;
public double _lambda_min_ratio = -1;
public boolean _betaConstraintsOff = false; // used for cross-validations
// internal parameters added to support client mode
int _glmNFolds = 0;
Model.Parameters.FoldAssignmentScheme _glmFoldAssignment = null;
String _glmFoldColumn = null;
boolean _glmCvOn = false;
public boolean[] _splines_non_negative;
public boolean[] _splines_non_negative_sorted;
public boolean _store_knot_locations = false;
@Override
public long progressUnits() {
return 1;
}
public InteractionSpec interactionSpec() {
return InteractionSpec.create(_interactions, _interaction_pairs);
}
public MissingValuesHandling missingValuesHandling() {
if (_missing_values_handling instanceof MissingValuesHandling)
return (MissingValuesHandling) _missing_values_handling;
assert _missing_values_handling instanceof DeepLearningModel.DeepLearningParameters.MissingValuesHandling;
switch ((DeepLearningModel.DeepLearningParameters.MissingValuesHandling) _missing_values_handling) {
case MeanImputation:
return MissingValuesHandling.MeanImputation;
case Skip:
return MissingValuesHandling.Skip;
default:
throw new IllegalStateException("Unsupported missing values handling value: " + _missing_values_handling);
}
}
public DataInfo.Imputer makeImputer() {
if (missingValuesHandling() == MissingValuesHandling.PlugValues) {
if (_plug_values == null || _plug_values.get() == null) {
throw new IllegalStateException("Plug values frame needs to be specified when Missing Value Handling = PlugValues.");
}
return new GLM.PlugValuesImputer(_plug_values.get());
} else { // mean/mode imputation and skip (even skip needs an imputer right now! PUBDEV-6809)
return new DataInfo.MeanImputer();
}
}
public double linkInv(double x) {
switch(_link) {
case identity:
return x;
case ologlog:
return 1.0-Math.exp(-1.0*Math.exp(x));
case ologit:
case logit:
return 1.0 / (Math.exp(-x) + 1.0);
case log:
return Math.exp(x);
case inverse:
double xx = (x < 0) ? Math.min(-1e-5, x) : Math.max(1e-5, x);
return 1.0 / xx;
case tweedie:
return _tweedie_link_power == 0
?Math.max(2e-16,Math.exp(x))
:Math.pow(x, 1/ _tweedie_link_power);
default:
throw new RuntimeException("unexpected link function " + _link.toString());
}
}
@Override
public void setDistributionFamily(DistributionFamily distributionFamily) {
_family = distributionToFamily(distributionFamily);
_link = Link.family_default;
}
@Override
public DistributionFamily getDistributionFamily() {
return familyToDistribution(_family);
}
}
@Override
protected String[][] scoringDomains(){
int responseColIdx = _output._dinfo.responseChunkId(0);
String [][] domains = _output._domains;
if ((_parms._family == Family.binomial || _parms._family == Family.quasibinomial ||
_parms._family == Family.fractionalbinomial)
&& _output._domains[responseColIdx] == null) {
domains = domains.clone();
if (_parms._family == Family.fractionalbinomial)
domains[responseColIdx] = BINOMIAL_CLASS_NAMES;
else
domains[responseColIdx] = _output._responseDomains;
}
return domains;
}
public static class GAMModelOutput extends Model.Output {
public String[] _coefficient_names_no_centering;
public String[] _coefficient_names;
public TwoDimTable _glm_model_summary;
public ModelMetrics _glm_training_metrics;
public ModelMetrics _glm_validation_metrics;
public double _glm_dispersion;
public double[] _glm_zvalues;
public double[] _glm_pvalues;
public double[][] _glm_vcov;
public double[] _glm_stdErr;
public double _glm_best_lamda_value;
public TwoDimTable _glm_scoring_history;
public TwoDimTable[] _glm_cv_scoring_history;
public TwoDimTable _coefficients_table;
public TwoDimTable _coefficients_table_no_centering;
public TwoDimTable _standardized_coefficient_magnitudes;
public TwoDimTable _variable_importances;
public VarImp _varimp; // should contain the same content as standardized coefficients
public double[] _model_beta_no_centering; // coefficients generated during model training
public double[] _standardized_model_beta_no_centering; // standardized coefficients generated during model training
public double[] _model_beta; // coefficients generated during model training
public double[] _standardized_model_beta; // standardized coefficients generated during model training
public double[][] _model_beta_multinomial_no_centering; // store multinomial coefficients during model training
public double[][] _standardized_model_beta_multinomial_no_centering; // store standardized multinomial coefficients during model training
public double[][] _model_beta_multinomial; // store multinomial coefficients during model training
public double[][] _standardized_model_beta_multinomial; // store standardized multinomial coefficients during model training
public double _best_alpha;
public double _best_lambda;
public double _devianceValid = Double.NaN;
public double _devianceTrain = Double.NaN;
private double[] _zvalues;
private double _dispersion;
private boolean _dispersionEstimated;
public String[][] _gamColNames; // store gam column names after transformation and centering
public double[][][] _zTranspose; // Z matrix for centralization, can be null
public double[][][] _penaltyMatricesCenter; // stores t(Z)*t(D)*Binv*D*Z and can be null
public double[][][] _penaltyMatrices; // store t(D)*Binv*D and can be null
public double[][][] _binvD; // store BinvD for each gam column specified for scoring
public double[][][] _knots; // store knots location for each gam smoother
int[][][] _allPolyBasisList; // store polynomial basis function for all tp smoothers
double[][][] _penaltyMatCS; // penalty matrix after removing optimization constraint, only for thin plate
double[][][] _zTransposeCS; // store for each thin plate smoother for removing optimization constraint
public int[] _numKnots; // store number of knots per gam smoother
public double[][][] _starT;
public double[][] _gamColMeansRaw;
public double[][] _oneOGamColStd;
public double[] _penaltyScale;
public Key _gamTransformedTrainCenter; // contain key of predictors, all gamified columns centered
public DataInfo _dinfo;
public String[] _responseDomains;
public String _gam_transformed_center_key;
final Family _family;
public String[] _gam_knot_column_names;
public double[][] _knot_locations;
/***
* The function will copy over the knot locations into _knot_locations and the gam column names corresponding to
* the knot locations into _gam_knot_column_names.
*/
public void copyKnots(double[][][] knots, String[][] gam_columns_sorted) {
int numGam = gam_columns_sorted.length;
int trueGamNum = 0;
for (int index=0; index 0)
gamifiedCSCols = gamifiedSinglePredictors(gamColCSNames, gamColCSSplines, binvD, CS_SPLINE_TYPE, zTranspose, knots,
parms, gamColNames);
if (numISGamCol > 0)
gamifiedISCols = gamifiedSinglePredictors(gamColISNames, gamColISplines, null, IS_SPLINE_TYPE, null,
knots, parms, gamColNames);
if (numMSGamCol > 0)
gamifiedMSCols = gamifiedSinglePredictors(gamColMSNames, gamColMSplines, null, MS_SPLINE_TYPE, zTranspose, knots,
parms, gamColNames);
return mergedGamifiedCols(new Frame[]{gamifiedCSCols, gamifiedISCols, gamifiedMSCols});
}
private static Frame mergedGamifiedCols(Frame[] allGamifiedCols) {
Frame mergedFrame = null;
int numGams = allGamifiedCols.length;
for (int index = 0; index < numGams; index++) {
if (allGamifiedCols[index] != null) {
if (mergedFrame == null) {
mergedFrame = allGamifiedCols[index];
} else {
mergedFrame.add(allGamifiedCols[index].names(), allGamifiedCols[index].removeAll());
Scope.track(allGamifiedCols[index]);
}
}
}
Scope.track(mergedFrame);
return mergedFrame;
}
private static Frame gamifiedSinglePredictors(String[] gamifiedColNames, Vec[] gamColCSSplines, double[][][] binvD, int bsType,
double[][][] zTranspose, double[][][] knots, GAMParameters parms, String[][] gamColNames) {
Frame onlyGamifiedPredictors = new Frame(gamifiedColNames, gamColCSSplines);
int numGamCentered = 0;
AddCSGamColumns genCSGamCols = null;
AddISGamColumns genISGamCols = null;
AddMSGamColumns genMSGamCols = null;
if (bsType == CS_SPLINE_TYPE) {
genCSGamCols = new AddCSGamColumns(binvD, zTranspose, knots, parms._num_knots_sorted, onlyGamifiedPredictors,
parms._bs_sorted);
genCSGamCols.doAll(genCSGamCols._gamCols2Add, Vec.T_NUM, onlyGamifiedPredictors);
numGamCentered = genCSGamCols._gamCols2Add;
} else if (bsType == IS_SPLINE_TYPE) {
genISGamCols = new AddISGamColumns(knots, parms._num_knots_sorted, parms._bs_sorted, parms._spline_orders_sorted,
onlyGamifiedPredictors);
genISGamCols.doAll(genISGamCols._totGamifiedColCentered, Vec.T_NUM, onlyGamifiedPredictors);
numGamCentered = genISGamCols._totGamifiedColCentered;
} else if (bsType == MS_SPLINE_TYPE) {
genMSGamCols = new AddMSGamColumns(knots, zTranspose, parms._num_knots_sorted, parms._bs_sorted,
parms._spline_orders_sorted, onlyGamifiedPredictors);
genMSGamCols.doAll(genMSGamCols._totGamifiedColCentered, Vec.T_NUM, onlyGamifiedPredictors);
numGamCentered = genMSGamCols._totGamifiedColCentered;
}
String[] gamColsNamesCentered = new String[numGamCentered];
int offset = 0;
int numGamCols = parms._gam_columns.length;
for (int ind = 0; ind < numGamCols; ind++) {
if (bsType == parms._bs_sorted[ind]) {
System.arraycopy(gamColNames[ind], 0, gamColsNamesCentered, offset, gamColNames[ind].length);
offset += gamColNames[ind].length;
}
}
if (bsType == CS_SPLINE_TYPE)
return genCSGamCols.outputFrame(Key.make(), gamColsNamesCentered, null);
else if (bsType == IS_SPLINE_TYPE)
return genISGamCols.outputFrame(Key.make(), gamColsNamesCentered, null);
else if (bsType == MS_SPLINE_TYPE)
return genMSGamCols.outputFrame(Key.make(), gamColsNamesCentered, null);
else
return null;
}
@Override
protected PredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job j,
boolean computeMetrics, CFuncRef customMetricFunc) {
String[] predictNames = makeScoringNames();
String[][] domains = new String[predictNames.length][];
GAMScore gs = makeScoringTask(adaptFrm, true, j, computeMetrics);
gs.doAll(predictNames.length, Vec.T_NUM, gs._dinfo._adaptedFrame);
ModelMetrics.MetricBuilder> mb = null;
Frame rawFrame = null;
if (gs._computeMetrics) {
mb = gs._mb;
rawFrame = gs.outputFrame();
}
domains[0] = gs._predDomains;
Frame outputFrame = gs.outputFrame(Key.make(destination_key), predictNames, domains);
return new PredictScoreResult(mb, rawFrame, outputFrame);
}
private GAMScore makeScoringTask(Frame adaptFrm, boolean makePredictions, Job j, boolean computeMetrics) {
int responseId = adaptFrm.find(_output.responseName());
if(responseId > -1 && adaptFrm.vec(responseId).isBad()) { // remove inserted invalid response
adaptFrm = new Frame(adaptFrm.names(),adaptFrm.vecs());
adaptFrm.remove(responseId);
}
final boolean detectedComputeMetrics = computeMetrics && (adaptFrm.vec(_output.responseName()) != null && !adaptFrm.vec(_output.responseName()).isBad());
String [] domain = _output.nclasses()<=1 ? null : (!detectedComputeMetrics ? _output._domains[_output._domains.length-1] : adaptFrm.lastVec().domain());
if (_parms._family.equals(Family.quasibinomial))
domain = _output._responseDomains;
return new GAMScore(j, this, _output._dinfo.scoringInfo(_output._names,adaptFrm),domain,detectedComputeMetrics,
makePredictions);
}
private class GAMScore extends MRTask {
private DataInfo _dinfo;
private double[] _coeffs;
private double[][] _coeffs_multinomial;
private int _nclass;
private boolean _computeMetrics;
private final Job _j;
private Family _family;
private transient double[] _eta; // store eta calculation
private String[] _predDomains;
private final GAMModel _m;
private final double _defaultThreshold;
private int _lastClass;
private ModelMetrics.MetricBuilder _mb;
final boolean _generatePredictions;
private transient double[][] _vcov;
private transient double[] _tmp;
private boolean _classifier2class;
private GAMScore(final Job j, final GAMModel m, DataInfo dinfo, final String[] domain, final boolean computeMetrics,
final boolean makePredictions) {
_j = j;
_m = m;
_computeMetrics = computeMetrics;
_predDomains = domain;
_nclass = m._output.nclasses();
_generatePredictions = makePredictions;
_classifier2class = _m._parms._family == GLMModel.GLMParameters.Family.binomial ||
_m._parms._family == Family.quasibinomial || _m._parms._family == Family.fractionalbinomial;
if(_m._parms._family == GLMModel.GLMParameters.Family.multinomial ||
_m._parms._family == GLMModel.GLMParameters.Family.ordinal){
_coeffs = null;
_coeffs_multinomial = m._output._model_beta_multinomial;
} else {
double [] beta = m._output._model_beta;
int [] ids = new int[beta.length-1];
int k = 0;
for(int i = 0; i < beta.length-1; ++i){ // pick out beta that is not zero in ids
if(beta[i] != 0) ids[k++] = i;
}
if (k < beta.length - 1) {
ids = Arrays.copyOf(ids, k);
dinfo = dinfo.filterExpandedColumns(ids);
double[] beta2 = MemoryManager.malloc8d(ids.length + 1);
int l = 0;
for (int x : ids) {
beta2[l++] = beta[x];
}
beta2[l] = beta[beta.length - 1];
beta = beta2;
}
_coeffs_multinomial = null;
_coeffs = beta;
}
_dinfo = dinfo;
_dinfo._valid = true; // marking dinfo as validation data set disables an assert on unseen levels (which should not happen in train)
_defaultThreshold = m.defaultThreshold();
_family = m._parms._family;
_lastClass = _nclass-1;
}
@Override
public void map(Chunk[]chks, NewChunk[] nc) {
if (isCancelled() || _j != null && _j.stop_requested()) return;
if (_family.equals(Family.ordinal)||_family.equals(Family.multinomial))
_eta = MemoryManager.malloc8d(_nclass);
_vcov = _m._output._glm_vcov;
if (_vcov != null)
_tmp = MemoryManager.malloc8d(_vcov.length);
int numPredVals = _nclass<=1?1:_nclass+1; // number of predictor values expected.
double[] predictVals = MemoryManager.malloc8d(numPredVals);
float[] trueResponse = null;
if (_computeMetrics) {
_mb = _m.makeMetricBuilder(_predDomains);
trueResponse = new float[1];
}
DataInfo.Row r = _dinfo.newDenseRow();
int chkLen = chks[0]._len;
for (int rid = 0; rid < chkLen; rid++) { // extract each row
_dinfo.extractDenseRow(chks, rid, r);
processRow(r, predictVals, nc, numPredVals);
if (_computeMetrics && !r.response_bad) {
trueResponse[0] = (float) r.response[0];
_mb.perRow(predictVals, trueResponse, r.weight, r.offset, _m);
}
}
if (_j != null) _j.update(1);
}
private void processRow(DataInfo.Row r, double[] ps, NewChunk[] preds, int ncols) {
if (r.predictors_bad)
Arrays.fill(ps, Double.NaN); // output NaN with bad predictor entries
else if (r.weight == 0)
Arrays.fill(ps, 0.0); // zero weight entries got 0 too
switch (_family) {
case multinomial: ps = scoreMultinomialRow(r, r.offset, ps); break;
case ordinal: ps = scoreOrdinalRow(r, r.offset, ps); break;
default: ps = scoreRow(r, r.offset, ps); break;
}
if (_generatePredictions) {
for (int predCol = 0; predCol < ncols; predCol++) { // write prediction to NewChunk
preds[predCol].addNum(ps[predCol]);
}
if (_vcov != null)
preds[ncols].addNum(Math.sqrt(r.innerProduct(r.mtrxMul(_vcov, _tmp))));
}
}
public double[] scoreRow(DataInfo.Row r, double offset, double[] preds) {
double mu = _m._parms.linkInv(r.innerProduct(_coeffs) + offset);
if (_classifier2class) { // threshold for prediction
preds[0] = mu >= _defaultThreshold ? 1 : 0;
preds[1] = 1.0 - mu; // class 0
preds[2] = mu; // class 1
} else
preds[0] = mu;
return preds;
}
public double[] scoreOrdinalRow(DataInfo.Row r, double offset, double[] preds) {
final double[][] bm = _coeffs_multinomial;
Arrays.fill(preds,0); // initialize to small number
preds[0] = _lastClass; // initialize to last class by default here
double previousCDF = 0.0;
for (int cInd = 0; cInd < _lastClass; cInd++) { // classify row and calculate PDF of each class
double eta = r.innerProduct(bm[cInd]) + offset;
double currCDF = 1.0 / (1 + Math.exp(-eta));
preds[cInd + 1] = currCDF - previousCDF;
previousCDF = currCDF;
if (eta > 0) { // found the correct class
preds[0] = cInd;
break;
}
}
for (int cInd = (int) preds[0] + 1; cInd < _lastClass; cInd++) { // continue PDF calculation
double currCDF = 1.0 / (1 + Math.exp(-r.innerProduct(bm[cInd]) + offset));
preds[cInd + 1] = currCDF - previousCDF;
previousCDF = currCDF;
}
preds[_nclass] = 1-previousCDF;
return preds;
}
public double[] scoreMultinomialRow(DataInfo.Row r, double offset, double[] preds) {
double[] eta = _eta;
final double[][] bm = _coeffs_multinomial;
double sumExp = 0;
double maxRow = Double.NEGATIVE_INFINITY;
for (int c = 0; c < bm.length; ++c) {
eta[c] = r.innerProduct(bm[c]) + offset;
if(eta[c] > maxRow)
maxRow = eta[c];
}
for (int c = 0; c < bm.length; ++c)
sumExp += eta[c] = Math.exp(eta[c]-maxRow); // intercept
sumExp = 1.0 / sumExp;
for (int c = 0; c < bm.length; ++c)
preds[c + 1] = eta[c] * sumExp;
preds[0] = ArrayUtils.maxIndex(eta);
return preds;
}
@Override
public void reduce(GAMScore other) {
if (_mb !=null)
_mb.reduce(other._mb);
}
@Override
protected void postGlobal() {
if (_mb != null)
_mb.postGlobal();
}
}
@Override
public double[] score0(double[] data, double[] preds) {
throw new UnsupportedOperationException("GAMModel.score0 should never be called");
}
@Override
public GAMMojoWriter getMojo() {
return new GAMMojoWriter(this);
}
@Override
protected Futures remove_impl(Futures fs, boolean cascade) {
super.remove_impl(fs, cascade);
Keyed.remove(_output._gamTransformedTrainCenter, fs, true);
if (_validKeys != null)
for (Key oneKey:_validKeys) {
Keyed.remove(oneKey, fs, true);
}
if (_parms._keep_cross_validation_predictions)
Keyed.remove(_output._cross_validation_holdout_predictions_frame_id, fs, true);
if (_parms._keep_cross_validation_fold_assignment)
Keyed.remove(_output._cross_validation_fold_assignment_frame_id, fs, true);
if (_parms._keep_cross_validation_models && _output._cross_validation_models!=null) {
for (Key oneModelKey : _output._cross_validation_models)
Keyed.remove(oneModelKey, fs, true);
}
return fs;
}
@Override protected AutoBuffer writeAll_impl(AutoBuffer ab) {
if (_output._gamTransformedTrainCenter!=null)
ab.putKey(_output._gamTransformedTrainCenter);
if (_parms._keep_cross_validation_predictions)
ab.putKey(_output._cross_validation_holdout_predictions_frame_id);
if (_parms._keep_cross_validation_fold_assignment)
ab.putKey(_output._cross_validation_fold_assignment_frame_id);
if (_parms._keep_cross_validation_models && _output._cross_validation_models!=null) {
for (Key oneModelKey : _output._cross_validation_models)
ab.putKey(oneModelKey);
}
return super.writeAll_impl(ab);
}
@Override protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
if (_output._gamTransformedTrainCenter!=null)
ab.getKey(_output._gamTransformedTrainCenter, fs);
if (_parms._keep_cross_validation_predictions)
ab.getKey(_output._cross_validation_holdout_predictions_frame_id, fs);
if (_parms._keep_cross_validation_fold_assignment)
ab.getKey(_output._cross_validation_fold_assignment_frame_id, fs);
if (_parms._keep_cross_validation_models && _output._cross_validation_models!=null) {
for (Key oneModelKey : _output._cross_validation_models)
ab.getKey(oneModelKey, fs);
}
return super.readAll_impl(ab, fs);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy