All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.glrm.GlrmMojoReader Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.genmodel.algos.glrm;

import hex.genmodel.ModelMojoReader;

import java.io.IOException;
import java.nio.ByteBuffer;

/**
 */
public class GlrmMojoReader extends ModelMojoReader {

  @Override
  public String getModelName() {
    return "Generalized Low Rank Model";
  }

  @Override
  protected void readModelData() throws IOException {
    _model._ncolA = readkv("ncolA");
    _model._ncolY = readkv("ncolY");
    _model._nrowY = readkv("nrowY");
    _model._ncolX = readkv("ncolX");
    _model._regx = GlrmRegularizer.valueOf((String) readkv("regularizationX"));
    _model._gammax = readkv("gammaX");
    _model._init = GlrmInitialization.valueOf((String) readkv("initialization"));

    _model._ncats = readkv("num_categories");
    _model._nnums = readkv("num_numeric");
    _model._normSub = readkv("norm_sub");
    _model._normMul = readkv("norm_mul");
    _model._permutation = readkv("cols_permutation");

    // loss functions
    _model._losses = new GlrmLoss[_model._ncolA];
    int li = 0;
    for (String line : readtext("losses")) {
      _model._losses[li++] = GlrmLoss.valueOf(line);
    }
    // archetypes
    _model._numLevels = readkv("num_levels_per_category");
    _model._archetypes = new double[_model._nrowY][];
    ByteBuffer bb = ByteBuffer.wrap(readblob("archetypes"));
    for (int i = 0; i < _model._nrowY; i++) {
      double[] row = new double[_model._ncolY];
      _model._archetypes[i] = row;
      for (int j = 0; j < _model._ncolY; j++)
        row[j] = bb.getDouble();
    }

    // new fields added after version 1.00
    try {
      _model._seed = readkv("seed", 0l);
      _model._reverse_transform = readkv("reverse_transform");
      _model._transposed = readkv("transposed");
      _model._catOffsets = readkv("catOffsets");
      // load in archetypes raw
      if (_model._transposed) {
        _model._archetypes_raw = new double[_model._archetypes[0].length][_model._archetypes.length];
        for (int row = 0; row < _model._archetypes.length; row++) {
          for (int col = 0; col < _model._archetypes[0].length; col++) {
            _model._archetypes_raw[col][row] = _model._archetypes[row][col];
          }
        }
      } else
        _model._archetypes_raw = _model._archetypes;
    } catch (NullPointerException re) { // only expect null pointer exception.
      _model._seed = System.currentTimeMillis();  // randomly initialize seed
      _model._reverse_transform = true;
      _model._transposed = true;
      _model._catOffsets = null;
      _model._archetypes_raw = null;
    }
  }

  @Override
  protected GlrmMojoModel makeModel(String[] columns, String[][] domains, String responseColumn) {
    GlrmMojoModel glrmModel = new GlrmMojoModel(columns, domains, responseColumn);
    glrmModel._allAlphas = GlrmMojoModel.initializeAlphas(glrmModel._numAlphaFactors);  // set _allAlphas array
    return glrmModel;
  }

  @Override public String mojoVersion() {
    return "1.10";
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy