All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.gam.GAMModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.gam;

import hex.*;
import hex.deeplearning.DeepLearningModel;
import hex.gam.MatrixFrameUtils.AddCSGamColumns;
import hex.gam.MatrixFrameUtils.AddISGamColumns;
import hex.gam.MatrixFrameUtils.AddTPKnotsGamColumns;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.glm.GLMModel.GLMParameters.Family;
import hex.glm.GLMModel.GLMParameters.Link;
import hex.glm.GLMModel.GLMParameters.Solver;
import hex.util.EffectiveParametersUtils;
import water.*;
import water.exceptions.H2OColumnNotFoundArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Rapids;
import water.rapids.Val;
import water.udf.CFuncRef;
import water.util.*;

import java.io.Serializable;
import java.util.Arrays;

import static hex.gam.MatrixFrameUtils.GamUtils.*;
import static hex.glm.GLMModel.GLMParameters.MissingValuesHandling;
import static hex.util.DistributionUtils.distributionToFamily;
import static hex.util.DistributionUtils.familyToDistribution;

public class GAMModel extends Model {
  private static final String[] BINOMIAL_CLASS_NAMES = new String[]{"0", "1"};
  public String[][] _gamColNamesNoCentering; // store column names only for GAM columns
  public String[][] _gamColNames; // store column names only for GAM columns after decentering
  public int[] _gamPredSize;  // store size of predictors for gam smoother
  public int[] _m;  // parameter related to gamPredSize;
  public int[] _M;  // size of polynomial basis for thin plate regression smoothers
  public int _cubicSplineNum;
  public int _iSplineNum;
  public int _thinPlateSmoothersWithKnotsNum;
  public Key[] _gamFrameKeysCenter;
  public double[] _gamColMeans;
  public int _nclass; // 2 for binomial, > 2 for multinomial and ordinal
  public double[] _ymu;
  public long _nobs;
  public long _nullDOF;
  public int _rank;
  public IcedHashSet> _validKeys = null;

  @Override public String[] makeScoringNames() {
    String[] names = super.makeScoringNames();
    if (_output._glm_vcov != null)
      names = ArrayUtils.append(names, "StdErr");
    return names;
  }
  
  @Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
    if (domain==null && (_parms._family==Family.binomial || _parms._family==Family.quasibinomial || 
            _parms._family==Family.negativebinomial || _parms._family==Family.fractionalbinomial)) {
      if (_parms._family == Family.fractionalbinomial)
        domain = BINOMIAL_CLASS_NAMES;
      else
        domain = _output._responseDomains;
    }
    GLMModel.GLMWeightsFun glmf = new GLMModel.GLMWeightsFun(_parms._family, _parms._link, _parms._tweedie_variance_power,
            _parms._tweedie_link_power, _parms._theta);
    return new MetricBuilderGAM(domain, _ymu, glmf, _rank, true, _parms._intercept, _nclass, _parms._auc_type);
  }

  public GAMModel(Key selfKey, GAMParameters parms, GAMModelOutput output) {
    super(selfKey, parms, output);
    assert(Arrays.equals(_key._kb, selfKey._kb));
  }
  
  public void initActualParamValuesAfterGlmCreation(){
    EffectiveParametersUtils.initFoldAssignment(_parms);
  }

  public TwoDimTable genCoefficientMagTableMultinomial(String[] colHeaders, double[][] coefficients,
                                                       String[] coefficientNames, String tableHeader) {
    String[] colTypes = new String[]{ "double", "string"};
    String[] colFormat = new String[]{"%5f", ""};
    int nCoeff = coefficients[0].length;
    int nClass = coefficients.length;
    String[] coeffNames = new String[nCoeff - 1];
    String[] coeffNames2 = new String[coeffNames.length];
    double[] coeffMags = new double[coeffNames.length];
    double[] coeffMags2 = new double[coeffNames.length];
    String[] coeffSigns = new String[coeffNames.length];

    Log.info("genCoefficientMagTableMultinomial", String.format("coeffNames length: %d.  coeffMags " +
            "length: %d, coeffSigns length: %d", coeffNames.length, coeffMags.length, coeffSigns.length));
    
    int countIndex = 0;
    for (int index = 0; index < nCoeff; index++) {
      if (!coefficientNames[index].equals("Intercept")) {
        for (int classInd = 0; classInd < nClass; classInd++) {
          coeffMags[countIndex] += Math.abs(coefficients[classInd][index]); // add abs(coefficients) of diff classes
        }
        coeffNames[countIndex] = coefficientNames[index];
        coeffSigns[countIndex] = "POS";   // assign all signs to positive for multinomial
        countIndex++;
      }
    }
    // sort in descending order of the magnitudes
    Integer[] indices = sortCoeffMags(coeffMags.length, coeffMags);
    // reorder names and coeffMags with indices
    for (int index = 0; index < coeffMags.length; index++) {
      coeffMags2[index] = coeffMags[indices[index]];
      coeffNames2[index] = coeffNames[indices[index]];
    }
    
    Log.info("genCoefficientMagTableMultinomial", String.format("coeffNames2 length: %d.  coeffMags2 " +
            "length: %d, coeffSigns length: %d", coeffNames2.length, coeffMags2.length, coeffSigns.length));
    
    TwoDimTable table = new TwoDimTable(tableHeader, "Standardized Coefficient Magnitutes", coeffNames2, colHeaders, colTypes, colFormat,
            "names");
    fillUpCoeffsMag(coeffMags2, coeffSigns, table, 0);
    return table;
  }

  public TwoDimTable genCoefficientMagTable(String[] colHeaders, double[] coefficients,
                                            String[] coefficientNames, String tableHeader) {
    String[] colTypes = new String[]{ "double", "string"};
    String[] colFormat = new String[]{"%5f", ""};
    int nCoeff = coefficients.length;
    String[] coeffNames = new String[nCoeff-1];
    double[] coeffMags = new double[nCoeff-1]; // skip over intercepts
    String[] coeffSigns = new String[nCoeff-1];
    int countMagIndex = 0;
    for (int index = 0; index < nCoeff; index++) {
      if (!coefficientNames[index].equals("Intercept")) {
        coeffMags[countMagIndex] = Math.abs(coefficients[index]);
        coeffSigns[countMagIndex] = coefficients[index] > 0 ? "POS" : "NEG";
        coeffNames[countMagIndex++] = coefficientNames[index];
      }
    }
    Integer[] indices = sortCoeffMags(coeffMags.length, coeffMags); // sort magnitude indices in decreasing magnitude
    String[] names2 = new String[coeffNames.length];
    double[] mag2 = new double[coeffNames.length];
    String[] sign2 = new String[coeffNames.length];
    for (int i = 0; i < coeffNames.length; ++i) {
      names2[i] = coeffNames[indices[i]];
      mag2[i] = coeffMags[indices[i]];
      sign2[i] = coeffSigns[indices[i]];
    }
    Log.info("genCoefficientMagTableMultinomial", String.format("coeffNames length: %d.  coeffMags " +
            "length: %d, coeffSigns length: %d", coeffNames.length, coeffMags.length, coeffSigns.length));

    TwoDimTable table = new TwoDimTable(tableHeader, "", names2, colHeaders, colTypes, colFormat,
            "names");
    fillUpCoeffsMag( mag2, sign2, table, 0);
    return table;
  }
  
  private void fillUpCoeffsMag(double[] coeffMags, String[] coeffSigns, TwoDimTable tdt, int rowStart) {
    int arrLength = coeffMags.length+rowStart;
    int arrCounter=0;
    for (int i=rowStart; i _plug_values = null;
    // internal parameter, handle with care. GLM will stop when there is more than this number of active predictors (after strong rule screening)
    public int _max_active_predictors = -1; // not used in GAM, copied over to GLM params
    public boolean _generate_scoring_history = false; // if true, will generate GLM scoring history but will slow algo down
    
    // the following parameters are for GAM
    public int[] _num_knots; // array storing number of knots per smoother
    public int[] _spline_orders;  // storing I-spline orders for each predictor
    public int[] _spline_orders_sorted;
    public int[] _num_knots_sorted;
    public int[] _num_knots_tp; // store num_knots for thin plate regression
    public String[] _knot_ids;  // store frame keys that contain knots location for each smoother in gam_X;
    public String[][] _gam_columns; // array storing which predictor columns are specified
    public String[][] _gam_columns_sorted;  // move CS spline to the front and tp to the back in gam_columns
    public int[] _gamPredSize;  // store size of predictors for gam smoother
    public int[] _m;  // parameter related to gamPredSize;
    public int[] _M;  // size of polynomial basis for thin plate regression smoothers
    public int[] _bs; // choose spline function for gam column, 0 = cr, 1 = thin plate regression with knots, 2 = monotone I-spline
    public int[] _bs_sorted; // choose spline function for gam column, 0 = cr, 1 = thin plate regression with knots, 2 = monotone I-spline
    public double[] _scale;  // array storing scaling values to control wriggliness of fit
    public double[] _scale_sorted;
    public boolean _saveZMatrix = false;  // if asserted will save Z matrix
    public boolean _keep_gam_cols = false;  // if true will save the keys to gam Columns only
    public boolean _savePenaltyMat = false; // if true will save penalty matrices as triple array
    public String algoName() { return "GAM"; }
    public String fullName() { return "Generalized Additive Model"; }
    public String javaName() { return GAMModel.class.getName(); }
    public double _prior = -1;
    public boolean _cold_start = false; // start building GLM model from scratch if true
    public int _nlambdas = -1;
    public boolean _non_negative = false;
    public boolean _remove_collinear_columns = false;
    public double _gradient_epsilon = -1;
    public boolean _early_stopping = true;  // internal GLM early stopping.
    public Key _beta_constraints = null;
    public double _lambda_min_ratio = -1;
    public boolean _betaConstraintsOff = false; // used for cross-validations
    // internal parameters added to support client mode
    int _glmNFolds = 0;
    Model.Parameters.FoldAssignmentScheme _glmFoldAssignment = null;
    String _glmFoldColumn = null;
    boolean _glmCvOn = false;

    @Override
    public long progressUnits() {
      return 1;
    }

    public long _seed=-1;

    public InteractionSpec interactionSpec() {
      return InteractionSpec.create(_interactions, _interaction_pairs);
    }
    
    public MissingValuesHandling missingValuesHandling() {
      if (_missing_values_handling instanceof MissingValuesHandling)
        return (MissingValuesHandling) _missing_values_handling;
      assert _missing_values_handling instanceof DeepLearningModel.DeepLearningParameters.MissingValuesHandling;
      switch ((DeepLearningModel.DeepLearningParameters.MissingValuesHandling) _missing_values_handling) {
        case MeanImputation:
          return MissingValuesHandling.MeanImputation;
        case Skip:
          return MissingValuesHandling.Skip;
        default:
          throw new IllegalStateException("Unsupported missing values handling value: " + _missing_values_handling);
      }
    }

    public DataInfo.Imputer makeImputer() {
      if (missingValuesHandling() == MissingValuesHandling.PlugValues) {
        if (_plug_values == null || _plug_values.get() == null) {
          throw new IllegalStateException("Plug values frame needs to be specified when Missing Value Handling = PlugValues.");
        }
        return new GLM.PlugValuesImputer(_plug_values.get());
      } else { // mean/mode imputation and skip (even skip needs an imputer right now! PUBDEV-6809)
        return new DataInfo.MeanImputer();
      }
    }

    public double linkInv(double x) {
      switch(_link) {
        case identity:
          return x;
        case ologlog:
          return 1.0-Math.exp(-1.0*Math.exp(x));
        case ologit:
        case logit:
          return 1.0 / (Math.exp(-x) + 1.0);
        case log:
          return Math.exp(x);
        case inverse:
          double xx = (x < 0) ? Math.min(-1e-5, x) : Math.max(1e-5, x);
          return 1.0 / xx;
        case tweedie:
          return _tweedie_link_power == 0
                  ?Math.max(2e-16,Math.exp(x))
                  :Math.pow(x, 1/ _tweedie_link_power);
        default:
          throw new RuntimeException("unexpected link function  " + _link.toString());
      }
    }

    @Override
    public void setDistributionFamily(DistributionFamily distributionFamily) {
        _family = distributionToFamily(distributionFamily);
        _link = Link.family_default;
    }

    @Override
    public DistributionFamily getDistributionFamily() {
      return familyToDistribution(_family);
    }
  }

  @Override
  protected String[][] scoringDomains(){
    int responseColIdx = _output._dinfo.responseChunkId(0);
    String [][] domains = _output._domains;
    if ((_parms._family == Family.binomial || _parms._family == Family.quasibinomial ||
            _parms._family == Family.fractionalbinomial)
            && _output._domains[responseColIdx] == null) {
      domains = domains.clone();
      if (_parms._family == Family.fractionalbinomial)
        domains[responseColIdx] = BINOMIAL_CLASS_NAMES;
      else
        domains[responseColIdx] = _output._responseDomains;
    }
    return domains;
  }

  public static class GAMModelOutput extends Model.Output {
    public String[] _coefficient_names_no_centering;
    public String[] _coefficient_names;    
    public TwoDimTable _glm_model_summary;
    public ModelMetrics _glm_training_metrics;
    public ModelMetrics _glm_validation_metrics;
    public double _glm_dispersion;
    public double[] _glm_zvalues;
    public double[] _glm_pvalues;
    public double[][] _glm_vcov;
    public double[] _glm_stdErr;
    public double _glm_best_lamda_value;
    public TwoDimTable _glm_scoring_history;
    public TwoDimTable[] _glm_cv_scoring_history;
    public TwoDimTable _coefficients_table;
    public TwoDimTable _coefficients_table_no_centering;
    public TwoDimTable _standardized_coefficient_magnitudes;
    public TwoDimTable _variable_importances;
    public VarImp _varimp;  // should contain the same content as standardized coefficients
    public double[] _model_beta_no_centering; // coefficients generated during model training
    public double[] _standardized_model_beta_no_centering; // standardized coefficients generated during model training
    public double[] _model_beta; // coefficients generated during model training
    public double[] _standardized_model_beta; // standardized coefficients generated during model training
    public double[][] _model_beta_multinomial_no_centering;  // store multinomial coefficients during model training
    public double[][] _standardized_model_beta_multinomial_no_centering;  // store standardized multinomial coefficients during model training
    public double[][] _model_beta_multinomial;  // store multinomial coefficients during model training
    public double[][] _standardized_model_beta_multinomial;  // store standardized multinomial coefficients during model training
    public double _best_alpha;
    public double _best_lambda;
    public double _devianceValid = Double.NaN;
    public double _devianceTrain = Double.NaN;
    private double[] _zvalues;
    private double _dispersion;
    private boolean _dispersionEstimated;
    public String[][] _gamColNames; // store gam column names after transformation and centering
    public double[][][] _zTranspose; // Z matrix for centralization, can be null
    public double[][][] _penaltyMatricesCenter; // stores t(Z)*t(D)*Binv*D*Z and can be null
    public double[][][] _penaltyMatrices;          // store t(D)*Binv*D and can be null
    public double[][][] _binvD; // store BinvD for each gam column specified for scoring
    public double[][][] _knots; // store knots location for each gam smoother
    int[][][] _allPolyBasisList; // store polynomial basis function for all tp smoothers
    double[][][] _penaltyMatCS; // penalty matrix after removing optimization constraint, only for thin plate
    double[][][] _zTransposeCS; // store for each thin plate smoother for removing optimization constraint
    public int[] _numKnots;  // store number of knots per gam smoother
    public double[][][] _starT;
    public double[][] _gamColMeansRaw;
    public double[][] _oneOGamColStd;
    public double[] _penaltyScale;
    public Key _gamTransformedTrainCenter;  // contain key of predictors, all gamified columns centered
    public DataInfo _dinfo;
    public String[] _responseDomains;
    public String _gam_transformed_center_key;
    final Family _family;
    
    @Override
    public int nclasses() {
      if (_family == Family.multinomial || _family == Family.ordinal) {
        return super.nclasses();
      } else if (Family.binomial == _family || Family.quasibinomial == _family
              || Family.fractionalbinomial == _family) {
        return 2;
      } else {
        return 1;
      }
    }

    /** Names of levels for a categorical response column. */
    @Override
    public String[] classNames() {
      if (_family == Family.quasibinomial || _family == Family.binomial)
        return _responseDomains;
      else if (_family == Family.fractionalbinomial)
        return BINOMIAL_CLASS_NAMES;
      else
        return super.classNames();
    }

    public GAMModelOutput(GAM b, DataInfo dinfo) {
      super(b, dinfo._adaptedFrame);
      _dinfo = dinfo;
      _domains = dinfo._adaptedFrame.domains(); // get domain of dataset predictors
      _family = b._parms._family;
      if (_family.equals(Family.quasibinomial)) {
        _responseDomains = new VecUtils.CollectDoubleDomain(null, 2).doAll(dinfo._adaptedFrame.vec(b._parms._response_column)).stringDomain(dinfo._adaptedFrame.vec(b._parms._response_column).isInt());
      } else {
        _responseDomains = dinfo._adaptedFrame.lastVec().domain();
      }
    }

    @Override public ModelCategory getModelCategory() {
      switch (_family) {
        case quasibinomial:
        case fractionalbinomial:
        case binomial: return ModelCategory.Binomial;
        case multinomial: return ModelCategory.Multinomial;
        case ordinal: return ModelCategory.Ordinal;
        default: return ModelCategory.Regression;
      }
    }

    public void copyMetrics(GAMModel gamModel, Frame train, boolean forTrain, ModelMetrics glmMetrics) {
      ModelMetrics tmpMetrics = glmMetrics.deepCloneWithDifferentModelAndFrame(gamModel, train);
      if (forTrain)
        gamModel._output._training_metrics = tmpMetrics;
      else
        gamModel._output._validation_metrics = tmpMetrics;
    }
  }

  /**
   * This method will massage the input training frame such that it can be used for scoring for a GAM model.
   *
   * @param test Testing Frame, updated in-place
   * @param expensive Try hard to adapt; this might involve the creation of
   *  whole Vecs and thus get expensive.  If {@code false}, then only adapt if
   *  no warnings and errors; otherwise just the messages are produced.
   *  Created Vecs have to be deleted by the caller (e.g. Scope.enter/exit).
   * @param computeMetrics
   * @return
   */
  @Override
  public String[] adaptTestForTrain(Frame test, boolean expensive, boolean computeMetrics) {
    // compare column names with test frame.  If equal, call adaptTestForTrain.  Otherwise, need to adapt it first
    String[] testNames = test.names();
    if (!equalColNames(testNames, _output._dinfo._adaptedFrame.names(), _parms._response_column)) {  // shallow check: column number, column names only
      Frame adptedF = cleanUpInputFrame(test); // column names here need to be in same sequence of dinfo._adaptedFrame
      int testNumCols = test.numCols();
      for (int index = 0; index < testNumCols; index++)
        test.remove(0);
      int adaptNumCols = adptedF.numCols();
      for (int index = 0; index < adaptNumCols; index++)
        test.add(adptedF.name(index), adptedF.vec(index));
      return super.adaptTestForTrain(test, expensive, computeMetrics);
    }
    return super.adaptTestForTrain(test, expensive, computeMetrics);
  }

  public Frame cleanUpInputFrame(Frame test) {
    Frame adptedF = new Frame(Key.make(), test.names(), test.vecs().clone()); // clone test dataset
    return cleanUpInputFrame(adptedF, _parms, _gamColNames, _output._binvD, _output._zTranspose, 
            _output._knots, _output._zTransposeCS, _output._allPolyBasisList, _output._gamColMeansRaw, 
            _output._oneOGamColStd, _cubicSplineNum);
  }

  public static Frame cleanUpInputFrame(Frame adptedF, GAMParameters parms, String[][] gamColNames, double[][][] binvD,
                                        double[][][] zTranspose, double[][][] knots,
                                        double[][][] zTransposeCS, int[][][] polyBasisList, double[][] gamColMeansRaw, 
                                        double[][] oneOGamColStd, int numCSGamCols) {
    String[] testNames = adptedF.names(); // adptedF contains predictors, gam_columns and extras
    // add gam columns for CS smoothers
    Frame csAugmentedColumns = addSingleVariableGamColumns(adptedF, parms, gamColNames, binvD, zTranspose, knots, 
            numCSGamCols);
    // add gam columns for TP smoothers
    Frame tpAugmentedColumns = addTPGamColumns(adptedF, parms, zTransposeCS, zTranspose, polyBasisList, 
            knots, gamColMeansRaw, oneOGamColStd);
    
    if (csAugmentedColumns == null)
      csAugmentedColumns = tpAugmentedColumns;
    else if (tpAugmentedColumns != null)
      csAugmentedColumns.add(tpAugmentedColumns.names(), tpAugmentedColumns.removeAll());
    
    if (parms._ignored_columns != null) {  // remove ignored columns
      for (String iname:parms._ignored_columns) {
        if (ArrayUtils.contains(testNames, iname)) {
          adptedF.remove(iname);
        }
      }
    }

    Vec respV = null;
    Vec weightV = null;
    Vec offsetV = null;
    if (parms._weights_column != null && ArrayUtils.contains(testNames, parms._weights_column))  // move weight column to be last column before response column
      weightV = adptedF.remove(parms._weights_column);
    if (parms._offset_column != null && ArrayUtils.contains(testNames, parms._offset_column))
      offsetV = adptedF.remove(parms._offset_column);
    if (ArrayUtils.contains(testNames, parms._response_column))
      respV = adptedF.remove(parms._response_column);
    adptedF.add(csAugmentedColumns.names(), csAugmentedColumns.removeAll());
    Scope.track(csAugmentedColumns);
    
    if (weightV != null)
      adptedF.add(parms._weights_column, weightV);
    if (offsetV != null)
      adptedF.add(parms._offset_column, offsetV);
    if (respV != null)
      adptedF.add(parms._response_column, respV);
    return adptedF;
  }

  public static Frame adaptValidFrame(Frame adptedF, Frame valid, GAMParameters parms, String[][] gamColNames, double[][][] binvD,
                                        double[][][] zTranspose, double[][][] knots,
                                        double[][][] zTransposeCS, int[][][] polyBasisList, double[][] gamColMeansRaw,
                                        double[][] oneOGamColStd, int numCSGam) {
    // add gam columns for CS smoothers
    Frame csAugmentedColumns = addSingleVariableGamColumns(adptedF, parms, gamColNames, binvD, zTranspose, knots, 
            numCSGam);
    // add gam columns for TP smoothers
    Frame tpAugmentedColumns = addTPGamColumns(adptedF, parms, zTransposeCS, zTranspose, polyBasisList,
            knots, gamColMeansRaw, oneOGamColStd);

    if (csAugmentedColumns == null)
      csAugmentedColumns = tpAugmentedColumns;
    else if (tpAugmentedColumns != null)
      csAugmentedColumns.add(tpAugmentedColumns.names(), tpAugmentedColumns.removeAll());
    
    Vec respV = null;
    Vec weightV = null;
    Vec offsetV = null;
    if (parms._weights_column != null && ArrayUtils.contains(valid.names(), parms._weights_column))  // move weight column to be last column before response column
      weightV = valid.remove(parms._weights_column);
    if (parms._offset_column != null && ArrayUtils.contains(valid.names(), parms._offset_column))
      offsetV = valid.remove(parms._offset_column);
    if (ArrayUtils.contains(valid.names(), parms._response_column))
      respV = valid.remove(parms._response_column);
    valid.add(csAugmentedColumns.names(), csAugmentedColumns.removeAll());
    Scope.track(csAugmentedColumns);

    if (weightV != null)
      valid.add(parms._weights_column, weightV);
    if (offsetV != null)
      valid.add(parms._offset_column, offsetV);
    if (respV != null)
      valid.add(parms._response_column, respV);
    return valid;
  }
  
  public static Frame addTPGamColumns(Frame adaptedF, GAMParameters parms, double[][][] zTransposeCS, 
                                      double[][][] zTranspose, int[][][] polyBasisList, double[][][] knots,
                                      double[][] gamColMeansRaw, double[][] oneOColStd) {
    int numTPCols = parms._M==null?0:parms._M.length;
    if (numTPCols == 0)
      return null;
    AddTPKnotsGamColumns addTPCols = new AddTPKnotsGamColumns(parms, zTransposeCS, zTranspose, polyBasisList, knots,
            adaptedF);
    addTPCols.addTPGamCols(gamColMeansRaw, oneOColStd); // generate thin plate regression smoothers
    return concateGamVecs(addTPCols._gamFrameKeysCenter);
  }
  
  public static Frame addSingleVariableGamColumns(Frame adptedF, GAMParameters parms, String[][] gamColNames,
                                                  double[][][] binvD, double[][][] zTranspose, double[][][] knots, 
                                                  int numCSGamCol) {
    int numGamCols = parms._gam_columns.length;
    int numISplineGamCol = numGamCols - numCSGamCol - (parms._M==null?0:parms._M.length);
    int numSingleVariableGamCols = numISplineGamCol + numCSGamCol;
    if (numSingleVariableGamCols == 0)  // no single variable GAM columns
      return null;
    
    Vec[] gamColCSSplines = new Vec[numCSGamCol];
    Vec[] gamColISplines = new Vec[numISplineGamCol];
    String[] gamColCSNames = new String[numCSGamCol];
    String[] gamColISplineNames = new String[numISplineGamCol];
    int countCS = 0;
    int countIS = 0;
    for (int vind=0; vind 0) {
      onlyGamColsCS = new Frame(gamColCSNames, gamColCSSplines);
      genCSGamCols = new AddCSGamColumns(binvD, zTranspose, knots, parms._num_knots_sorted, onlyGamColsCS, 
              parms._bs_sorted);
      genCSGamCols.doAll(genCSGamCols._gamCols2Add, Vec.T_NUM, onlyGamColsCS);
      String[] gamColsNamesCS = new String[genCSGamCols._gamCols2Add];
      int offset = 0;
      for (int ind = 0; ind < numGamCols; ind++) {
        if (parms._bs_sorted[ind]==0) {
          System.arraycopy(gamColNames[ind], 0, gamColsNamesCS, offset, gamColNames[ind].length);
          offset += gamColNames[ind].length;
        }
      }
      gamifiedCSCols = genCSGamCols.outputFrame(Key.make(), gamColsNamesCS, null);
      //removeVec(gamColCSSplines);
    }
    
    if (numISplineGamCol > 0) {
      onlyGamColsIS = new Frame(gamColISplineNames, gamColISplines);
      genISGamCols = new AddISGamColumns(knots, parms._num_knots_sorted, parms._bs_sorted, parms._spline_orders_sorted,
              onlyGamColsIS);
      genISGamCols.doAll(genISGamCols._totGamifiedColCentered, Vec.T_NUM, onlyGamColsIS);
      String[] gamColsNamesIS = new String[genISGamCols._totGamifiedColCentered];
      int offset = 0;
      for (int index=0; index mb = null;
    Frame rawFrame = null;
    if (gs._computeMetrics) {
      mb = gs._mb;
      rawFrame = gs.outputFrame();
    }
    domains[0] = gs._predDomains;
    Frame outputFrame = gs.outputFrame(Key.make(destination_key), predictNames, domains);
    return new PredictScoreResult(mb, rawFrame, outputFrame);
  }
  
  private GAMScore makeScoringTask(Frame adaptFrm, boolean makePredictions, Job j, boolean computeMetrics) {
    int responseId = adaptFrm.find(_output.responseName());
    if(responseId > -1 && adaptFrm.vec(responseId).isBad()) { // remove inserted invalid response
      adaptFrm = new Frame(adaptFrm.names(),adaptFrm.vecs());
      adaptFrm.remove(responseId);
    }
    final boolean detectedComputeMetrics = computeMetrics && (adaptFrm.vec(_output.responseName()) != null && !adaptFrm.vec(_output.responseName()).isBad());
    String [] domain = _output.nclasses()<=1 ? null : (!detectedComputeMetrics ? _output._domains[_output._domains.length-1] : adaptFrm.lastVec().domain());
    if (_parms._family.equals(Family.quasibinomial))
      domain = _output._responseDomains;
    return new GAMScore(j, this, _output._dinfo.scoringInfo(_output._names,adaptFrm),domain,detectedComputeMetrics, 
            makePredictions);
  }

  private class GAMScore extends MRTask {
    private DataInfo _dinfo;
    private double[] _coeffs;
    private double[][] _coeffs_multinomial;
    private int _nclass;
    private boolean _computeMetrics;
    private final Job _j;
    private Family _family;
    private transient double[] _eta;  // store eta calculation
    private String[] _predDomains;
    private final GAMModel _m;
    private final double _defaultThreshold;
    private int _lastClass;
    private ModelMetrics.MetricBuilder _mb;
    final boolean _generatePredictions;
    private transient double[][] _vcov;
    private transient double[] _tmp;
    private boolean _classifier2class;


    private GAMScore(final Job j, final GAMModel m, DataInfo dinfo, final String[] domain, final boolean computeMetrics, 
                     final boolean makePredictions) {
      _j = j;
      _m = m;
      _computeMetrics = computeMetrics;
      _predDomains = domain;
      _nclass = m._output.nclasses();
      _generatePredictions = makePredictions;
      _classifier2class = _m._parms._family == GLMModel.GLMParameters.Family.binomial || 
              _m._parms._family == Family.quasibinomial || _m._parms._family == Family.fractionalbinomial;
      if(_m._parms._family == GLMModel.GLMParameters.Family.multinomial ||
              _m._parms._family == GLMModel.GLMParameters.Family.ordinal){
        _coeffs = null;
        _coeffs_multinomial = m._output._model_beta_multinomial;
      } else {
        double [] beta = m._output._model_beta;
        int [] ids = new int[beta.length-1];
        int k = 0;
        for(int i = 0; i < beta.length-1; ++i){ // pick out beta that is not zero in ids
          if(beta[i] != 0) ids[k++] = i;
        }
        if (k < beta.length - 1) {
          ids = Arrays.copyOf(ids, k);
          dinfo = dinfo.filterExpandedColumns(ids);
          double[] beta2 = MemoryManager.malloc8d(ids.length + 1);
          int l = 0;
          for (int x : ids) {
            beta2[l++] = beta[x];
          }
          beta2[l] = beta[beta.length - 1];
          beta = beta2;
        }
        _coeffs_multinomial = null;
        _coeffs = beta;
      }
      _dinfo = dinfo;
      _dinfo._valid = true; // marking dinfo as validation data set disables an assert on unseen levels (which should not happen in train)
      _defaultThreshold = m.defaultThreshold();
      _family = m._parms._family;
      _lastClass = _nclass-1;
    }

    @Override
    public void map(Chunk[]chks, NewChunk[] nc) {
      if (isCancelled() || _j != null && _j.stop_requested()) return;
      if (_family.equals(Family.ordinal)||_family.equals(Family.multinomial))
        _eta = MemoryManager.malloc8d(_nclass);
      _vcov = _m._output._glm_vcov;
      if (_vcov != null)
        _tmp = MemoryManager.malloc8d(_vcov.length);
      int numPredVals = _nclass<=1?1:_nclass+1; // number of predictor values expected.
      double[] predictVals = MemoryManager.malloc8d(numPredVals);
      float[] trueResponse = null;

      if (_computeMetrics) {
        _mb = _m.makeMetricBuilder(_predDomains);
        trueResponse = new float[1];
      }
      DataInfo.Row r = _dinfo.newDenseRow();
      int chkLen = chks[0]._len;
      for (int rid = 0; rid < chkLen; rid++) {  // extract each row
        _dinfo.extractDenseRow(chks, rid, r);
        processRow(r, predictVals, nc, numPredVals);
        if (_computeMetrics && !r.response_bad) {
          trueResponse[0] = (float) r.response[0];
          _mb.perRow(predictVals, trueResponse, r.weight, r.offset, _m);
        }
      }
      if (_j != null) _j.update(1);
    }

    private void processRow(DataInfo.Row r, double[] ps, NewChunk[] preds, int ncols) {
      if (r.predictors_bad)
        Arrays.fill(ps, Double.NaN);  // output NaN with bad predictor entries
      else if (r.weight == 0)
        Arrays.fill(ps, 0.0); // zero weight entries got 0 too
      switch (_family) {
        case multinomial: ps = scoreMultinomialRow(r, r.offset, ps); break;
        case ordinal: ps = scoreOrdinalRow(r, r.offset, ps); break;
        default: ps = scoreRow(r, r.offset, ps); break;
      }
      if (_generatePredictions) {
        for (int predCol = 0; predCol < ncols; predCol++) { // write prediction to NewChunk
          preds[predCol].addNum(ps[predCol]);
        }
        if (_vcov != null)
          preds[ncols].addNum(Math.sqrt(r.innerProduct(r.mtrxMul(_vcov, _tmp))));
      }
    }

    public double[] scoreRow(DataInfo.Row r, double offset, double[] preds) {
      double mu = _m._parms.linkInv(r.innerProduct(_coeffs) + offset);
      if (_classifier2class) { // threshold for prediction
        preds[0] = mu >= _defaultThreshold ? 1 : 0;
        preds[1] = 1.0 - mu; // class 0
        preds[2] = mu; // class 1
      } else
        preds[0] = mu;
      return preds;
    }

    public double[] scoreOrdinalRow(DataInfo.Row r, double offset, double[] preds) {
      final double[][] bm = _coeffs_multinomial;
      Arrays.fill(preds,0); // initialize to small number
      preds[0] = _lastClass;  // initialize to last class by default here
      double previousCDF = 0.0;
      for (int cInd = 0; cInd < _lastClass; cInd++) { // classify row and calculate PDF of each class
        double eta = r.innerProduct(bm[cInd]) + offset;
        double currCDF = 1.0 / (1 + Math.exp(-eta));
        preds[cInd + 1] = currCDF - previousCDF;
        previousCDF = currCDF;

        if (eta > 0) { // found the correct class
          preds[0] = cInd;
          break;
        }
      }
      for (int cInd = (int) preds[0] + 1; cInd < _lastClass; cInd++) {  // continue PDF calculation
        double currCDF = 1.0 / (1 + Math.exp(-r.innerProduct(bm[cInd]) + offset));
        preds[cInd + 1] = currCDF - previousCDF;
        previousCDF = currCDF;

      }
      preds[_nclass] = 1-previousCDF;
      return preds;
    }

    public double[] scoreMultinomialRow(DataInfo.Row r, double offset, double[] preds) {
      double[] eta = _eta;
      final double[][] bm = _coeffs_multinomial;
      double sumExp = 0;
      double maxRow = Double.NEGATIVE_INFINITY;
      for (int c = 0; c < bm.length; ++c) {
        eta[c] = r.innerProduct(bm[c]) + offset;
        if(eta[c] > maxRow)
          maxRow = eta[c];
      }
      for (int c = 0; c < bm.length; ++c)
        sumExp += eta[c] = Math.exp(eta[c]-maxRow); // intercept
      sumExp = 1.0 / sumExp;
      for (int c = 0; c < bm.length; ++c)
        preds[c + 1] = eta[c] * sumExp;
      preds[0] = ArrayUtils.maxIndex(eta);
      return preds;
    }
    
    @Override 
    public void reduce(GAMScore other) {
      if (_mb !=null)
        _mb.reduce(other._mb);
    }
    
    @Override
    protected void postGlobal() {
      if (_mb != null)
        _mb.postGlobal();
    }
  }

  @Override
  public double[] score0(double[] data, double[] preds) {
    throw new UnsupportedOperationException("GAMModel.score0 should never be called");
  }

  @Override
  public GAMMojoWriter getMojo() {
    return new GAMMojoWriter(this);
  }

  @Override
  protected Futures remove_impl(Futures fs, boolean cascade) {
    super.remove_impl(fs, cascade);
    Keyed.remove(_output._gamTransformedTrainCenter, fs, true);
    if (_validKeys != null)
      for (Key oneKey:_validKeys) {
          Keyed.remove(oneKey, fs, true);
      }
    if (_parms._keep_cross_validation_predictions)
      Keyed.remove(_output._cross_validation_holdout_predictions_frame_id, fs, true);
    if (_parms._keep_cross_validation_fold_assignment)
      Keyed.remove(_output._cross_validation_fold_assignment_frame_id, fs, true);
    if (_parms._keep_cross_validation_models && _output._cross_validation_models!=null) {
      for (Key oneModelKey : _output._cross_validation_models)
        Keyed.remove(oneModelKey, fs, true);
    }
    return fs;
  }

  @Override protected AutoBuffer writeAll_impl(AutoBuffer ab) {
    if (_output._gamTransformedTrainCenter!=null)
      ab.putKey(_output._gamTransformedTrainCenter);
    if (_parms._keep_cross_validation_predictions)
      ab.putKey(_output._cross_validation_holdout_predictions_frame_id);
    if (_parms._keep_cross_validation_fold_assignment)
      ab.putKey(_output._cross_validation_fold_assignment_frame_id);
    if (_parms._keep_cross_validation_models && _output._cross_validation_models!=null) {
      for (Key oneModelKey : _output._cross_validation_models)
        ab.putKey(oneModelKey);
    }
    return super.writeAll_impl(ab);
  }

  @Override protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
    if (_output._gamTransformedTrainCenter!=null)
      ab.getKey(_output._gamTransformedTrainCenter, fs);
    if (_parms._keep_cross_validation_predictions)
      ab.getKey(_output._cross_validation_holdout_predictions_frame_id, fs);
    if (_parms._keep_cross_validation_fold_assignment)
      ab.getKey(_output._cross_validation_fold_assignment_frame_id, fs);
    if (_parms._keep_cross_validation_models && _output._cross_validation_models!=null) {
      for (Key oneModelKey : _output._cross_validation_models)
      ab.getKey(oneModelKey, fs);
    }
    return super.readAll_impl(ab, fs);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy