All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.gam.GamMojoReader Maven / Gradle / Ivy

There is a newer version: 3.46.0.5
Show newest version
package hex.genmodel.algos.gam;

import hex.genmodel.ModelMojoReader;
import hex.genmodel.utils.DistributionFamily;

import java.io.IOException;
import java.nio.ByteBuffer;

import static hex.genmodel.algos.gam.GamMojoModel.IS_SPLINE_TYPE;
import static hex.genmodel.algos.gam.GamMojoModel.MS_SPLINE_TYPE;
import static hex.genmodel.utils.ArrayUtils.subtract;
import static hex.genmodel.utils.DistributionFamily.ordinal;

public class GamMojoReader extends ModelMojoReader {

  @Override
  public String getModelName() {
    return "Generalized Additive Model";
  }

  @Override
  protected void readModelData() throws IOException {
    _model._useAllFactorLevels = readkv("use_all_factor_levels", false);
    _model._numExpandedGamCols = readkv("num_expanded_gam_columns",0);
    _model._numExpandedGamColsCenter = readkv("num_expanded_gam_columns_center",0);
    _model._family = DistributionFamily.valueOf((String)readkv("family"));
    _model._cats = readkv("cats", -1);
    _model._nums = readkv("num");
    _model._numsCenter = readkv("numsCenter");
    _model._catNAFills = readkv("catNAFills", new int[0]);
    _model._numNAFillsCenter = readkv("numNAFillsCenter", new double[0]);
    _model._meanImputation = readkv("mean_imputation", false);
    _model._betaSizePerClass = readkv("beta length per class",0);
    _model._catOffsets = readkv("cat_offsets", new int[0]);
    if (!_model._family.equals(DistributionFamily.multinomial))  // multinomial or ordinal have specific link functions not included in general link functions
      _model._link_function = readLinkFunction((String) readkv("link"), _model._family);
    _model._tweedieLinkPower = readkv("tweedie_link_power", 0.0);
    _model._betaCenterSizePerClass = readkv("beta center length per class", 0);
    if (_model._family.equals(DistributionFamily.multinomial) || _model._family.equals(ordinal)) {
      _model._beta_multinomial_no_center = readRectangularDoubleArray("beta_multinomial", _model._nclasses, _model._betaSizePerClass);
      _model._beta_multinomial_center = readRectangularDoubleArray("beta_multinomial_centering", _model._nclasses, 
              _model._betaCenterSizePerClass);
    } else {
      _model._beta_no_center = readkv("beta");
      _model._beta_center = readkv("beta_center");
    }
    // read in GAM specific parameters
    _model._num_knots = readkv("num_knots");
    _model._num_knots_sorted = readkv("num_knots_sorted");
    int[] gamColumnDim = readkv("gam_column_dim");
    _model._gam_columns = read2DStringArrays(gamColumnDim,"gam_columns");
    int[] gamColumnDimSorted = readkv("gam_column_dim_sorted");
    _model._gam_columns_sorted = read2DStringArrays(gamColumnDimSorted,"gam_columns_sorted");
    _model._num_gam_columns = _model._gam_columns.length;
    _model._numTPCol = readkv("num_TP_col");
    _model._numCSCol = readkv("num_CS_col");
    _model._numISCol = readkv("num_IS_col");
    _model._numMSCol = readkv("num_MS_col");
    if (_model._numISCol > 0 || _model._numMSCol > 0) {
      _model._spline_orders = readkv("spline_orders");
      _model._spline_orders_sorted = readkv("spline_orders_sorted");
      if (_model._numISCol > 0) {
        _model._numBasisSize = new int[_model._numISCol];
        int isCounter = 0;
        int offset = _model._numCSCol;
        for (int index = 0; index < _model._numISCol; index++) {
          int trueIndex = index+offset;
          _model._numBasisSize[isCounter++] = _model._num_knots_sorted[trueIndex] +
                  _model._spline_orders_sorted[trueIndex] - 2;
        }
      }
      if (_model._numMSCol > 0) {
        _model._numMSBasisSize = new int[_model._numMSCol];
        int msCounter = 0;
        int offset = _model._numISCol+_model._numCSCol;
        for (int index=0; index<_model._numMSCol; index++) {
          int trueIndex = offset + index;
          _model._numMSBasisSize[msCounter++] = _model._num_knots_sorted[trueIndex]+
                  _model._spline_orders_sorted[trueIndex]-2;
        }
      }
    }
    _model._totFeatureSize = readkv("total feature size");
    _model._names_no_centering = readStringArrays(_model._totFeatureSize, "_names_no_centering");
    _model._bs = readkv("bs");
    _model._bs_sorted = readkv("bs_sorted");
    _model._zTranspose = new double[_model._num_gam_columns][][];
    int[] gamColName_dim = readkv("gamColName_dim");
    _model._gamColNames = read2DStringArrays(gamColName_dim, "gamColNames");
    //_model._gamColNames = new String[_model._num_gam_columns][];
    //_model._gamColNamesCenter = new String[_model._num_gam_columns][];
    _model._gamPredSize = readkv("_d");
    if (_model._numTPCol > 0) {
      _model._standardize = readkv("standardize");
      _model._zTransposeCS = new double[_model._numTPCol][][];
      _model._num_knots_TP = readkv("num_knots_TP");
      _model._d = readkv("_d");
      _model._m = readkv("_m");
      _model._M = readkv("_M");
      int[] predSize = new int[_model._numTPCol];
      System.arraycopy(predSize, predSize.length-_model._numTPCol, predSize, 0, _model._numTPCol);
      _model._gamColMeansRaw = read2DDoubleArrays(predSize, "gamColMeansRaw");
      _model._oneOGamColStd = read2DDoubleArrays(predSize, "gamColStdRaw");
      int[] numKnotsMM = subtract(_model._num_knots_TP, _model._M);
      _model._zTransposeCS = read3DArray("zTransposeCS", _model._numTPCol, numKnotsMM, _model._num_knots_TP);
      int[] predNum = new int[_model._numTPCol];
      System.arraycopy(_model._d, _model._numCSCol, predNum, 0, _model._numTPCol);
      _model._allPolyBasisList = read3DIntArray("polynomialBasisList", _model._numTPCol, _model._M, predNum);
    }
    int[] numKnotsM1 = subtract(_model._num_knots_sorted, 1);
    int numKnotsLen = numKnotsM1.length;
    int isCounter=0;
    int msCounter = 0;
    int[] zSecondDim = new int[numKnotsLen];
    int[] zThirdDim = new int[numKnotsLen];
    for (int index=0; index 0) {
      int[] numKnotsM2 = subtract(_model._num_knots_sorted, 2);
      _model._binvD = read3DArray("_binvD", _model._numCSCol, numKnotsM2, _model._num_knots_sorted);
    }
    _model.init();
  }

  String[][] read2DStringArrays(int[] arrayDim, String title) throws IOException {
    int firstDim = arrayDim.length;
    String[][] stringArrays = new String[firstDim][];
    int indexDim1 = 0;
    int indexDim2 = 0;
    for (int index = 0; index < firstDim; index++)
      stringArrays[index] = new String[arrayDim[index]];
    for (String line : readtext(title)) {
      if (indexDim2 >= stringArrays[indexDim1].length) { // go to next dim
        indexDim1++;
        indexDim2 = 0;
      }
      stringArrays[indexDim1][indexDim2] = line;
      indexDim2++;
    }
    return stringArrays;
  }

  double[][] read2DDoubleArrays(int[] arrayDim, String title) throws IOException {
    int firstDim = arrayDim.length;
    double[][] doubleArrays = new double[firstDim][];
    ByteBuffer bb = ByteBuffer.wrap(readblob(title));
    for (int index = 0; index < firstDim; index++) {
      doubleArrays[index] = new double[arrayDim[index]];
      for (int index2nd = 0; index2nd < arrayDim[index]; index2nd++) {
        doubleArrays[index][index2nd] = bb.getDouble();
      }
    }
    return doubleArrays;
  }
  
  double[][] read2DArray(String title, int firstDSize, int secondDSize) throws IOException {
    double [][] row = new double[firstDSize][secondDSize];
    ByteBuffer bb = ByteBuffer.wrap(readblob(title));
    for (int i = 0; i < firstDSize; i++) {
      for (int j = 0; j < secondDSize; j++)
        row[i][j] = bb.getDouble();
    }
    return row;
  }

  int[][][] read3DIntArray(String title, int firstDimSize, int[] secondDim, int[] thirdDim) throws IOException {
    int [][][] row = new int[firstDimSize][][];
    ByteBuffer bb = ByteBuffer.wrap(readblob(title));
    for (int i = 0; i < firstDimSize; i++) {
      row[i] = new int[secondDim[i]][thirdDim[i]];
      for (int j = 0; j < secondDim[i]; j++) {
        for (int k = 0; k < thirdDim[i]; k++)
          row[i][j][k] = bb.getInt();
      }
    }
    return row;
  }

  double[][][] read3DArray(String title, int firstDimSize, int[] secondDim, int[] thirdDim) throws IOException {
    double [][][] row = new double[firstDimSize][][];
    ByteBuffer bb = ByteBuffer.wrap(readblob(title));
    for (int i = 0; i < firstDimSize; i++) {
      row[i] = new double[secondDim[i]][thirdDim[i]];
      for (int j = 0; j < secondDim[i]; j++) {
        for (int k = 0; k < thirdDim[i]; k++)
          row[i][j][k] = bb.getDouble();
      }
    }
    return row;
  }
  
  double[][] read2DArrayDiffLength(String title, double[][] row, int[] num_knots) throws IOException {
    int numGamColumns = num_knots.length;
    ByteBuffer bb = ByteBuffer.wrap(readblob(title));
    for (int i = 0; i < numGamColumns; i++) {
      row[i] = new double[num_knots[i]];
      for (int j = 0; j < row[i].length; j++)
      row[i][j] = bb.getDouble();
    }
    return row;
  }

  @Override
  protected GamMojoModelBase makeModel(String[] columns, String[][] domains, String responseColumn) {
    String family = readkv("family");
    if ("multinomial".equals(family) || "ordinal".equals(family))
      return new GamMojoMultinomialModel(columns, domains, responseColumn);
    else
      return new GamMojoModel(columns, domains, responseColumn);
  }

  @Override
  public String mojoVersion() {
    return "1.00";
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy