All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.pca.PCAModel Maven / Gradle / Ivy

package hex.pca;

import hex.*;
import water.DKV;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.JCodeGen;
import water.util.SB;
import water.util.TwoDimTable;

public class PCAModel extends Model {

  public static class PCAParameters extends Model.Parameters {
    public DataInfo.TransformType _transform = DataInfo.TransformType.NONE; // Data transformation (demean to compare with PCA)
    public int _k = 1;                // Number of principal components
    public int _max_iterations = 1000;     // Max iterations
    public long _seed = System.nanoTime(); // RNG seed
    public Key _loading_key;
    public boolean _keep_loading = true;
    public boolean _useAllFactorLevels = false;   // When expanding categoricals, should last level be dropped?
  }

  public static class PCAOutput extends Model.Output {
    // Principal components (eigenvectors)
    public double[/*feature*/][/*k*/] _eigenvectors_raw;
    public TwoDimTable _eigenvectors;

    // Standard deviation of each principal component
    public double[] _std_deviation;

    // Importance of principal components
    // Standard deviation, proportion of variance explained, and cumulative proportion of variance explained
    public TwoDimTable _pc_importance;

    // Number of categorical and numeric columns
    public int _ncats;
    public int _nnums;

    // Categorical offset vector
    public int[] _catOffsets;

    // If standardized, mean of each numeric data column
    public double[] _normSub;

    // If standardized, one over standard deviation of each numeric data column
    public double[] _normMul;

    // Permutation matrix mapping training col indices to adaptedFrame
    public int[] _permutation;

    // Frame key for projection into principal component space
    public Key _loading_key;

    public PCAOutput(PCA b) { super(b); }

    /** Override because base class implements ncols-1 for features with the
     *  last column as a response variable; for PCA all the columns are
     *  features. */
    @Override public int nfeatures() { return _names.length; }

    @Override public ModelCategory getModelCategory() {
      return ModelCategory.DimReduction;
    }
  }

  public PCAModel(Key selfKey, PCAParameters parms, PCAOutput output) { super(selfKey,parms,output); }

  @Override
  public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
    return new ModelMetricsPCA.PCAModelMetrics(_parms._k);
  }

  @Override
  protected Frame scoreImpl(Frame orig, Frame adaptedFr, String destination_key) {
    Frame adaptFrm = new Frame(adaptedFr);
    for(int i = 0; i < _parms._k; i++)
      adaptFrm.add("PC"+String.valueOf(i+1),adaptFrm.anyVec().makeZero());

    new MRTask() {
      @Override public void map( Chunk chks[] ) {
        double tmp [] = new double[_output._names.length];
        double preds[] = new double[_parms._k];
        for( int row = 0; row < chks[0]._len; row++) {
          double p[] = score0(chks, row, tmp, preds);
          for( int c=0; c= 0 && level < clen : "Categorical level x = " + level + " must be in 0 <= x < " + clen;
        if (level > _output._catOffsets[j+1]-_output._catOffsets[j]-1) continue;  // Skip categorical level in test set but not in train
        preds[i] += _output._eigenvectors_raw[_output._catOffsets[j]+level][i];
      }

      int dcol = _output._ncats;
      int vcol = numStart;
      for (int j = 0; j < _output._nnums; j++) {
        preds[i] += (data[_output._permutation[dcol]] - _output._normSub[j]) * _output._normMul[j] * _output._eigenvectors_raw[vcol][i];
        dcol++; vcol++;
      }
    }
    return preds;
  }

  @Override
  public Frame score(Frame fr, String destination_key) {
    Frame adaptFr = new Frame(fr);
    adaptTestForTrain(adaptFr, true);   // Adapt
    Frame output = scoreImpl(fr, adaptFr, destination_key); // Score
    cleanup_adapt( adaptFr, fr );
    return output;
  }

  @Override protected SB toJavaInit(SB sb, SB fileContextSB) {
    sb = super.toJavaInit(sb, fileContextSB);
    sb.ip("public boolean isSupervised() { return " + isSupervised() + "; }").nl();
    sb.ip("public int nfeatures() { return "+_output.nfeatures()+"; }").nl();
    sb.ip("public int nclasses() { return "+_parms._k+"; }").nl();

    if (_output._nnums > 0) {
      JCodeGen.toStaticVar(sb, "NORMMUL", _output._normMul, "Standardization/Normalization scaling factor for numerical variables.");
      JCodeGen.toStaticVar(sb, "NORMSUB", _output._normSub, "Standardization/Normalization offset for numerical variables.");
    }
    JCodeGen.toStaticVar(sb, "CATOFFS", _output._catOffsets, "Categorical column offsets.");
    JCodeGen.toStaticVar(sb, "PERMUTE", _output._permutation, "Permutation index vector.");
    JCodeGen.toStaticVar(sb, "EIGVECS", _output._eigenvectors_raw, "Eigenvector matrix.");
    return sb;
  }

  @Override protected void toJavaPredictBody( final SB bodySb, final SB classCtxSb, final SB fileCtxSb) {
    SB model = new SB();
    bodySb.i().p("java.util.Arrays.fill(preds,0);").nl();
    final int cats = _output._ncats;
    final int nums = _output._nnums;
    bodySb.i().p("final int nstart = CATOFFS[CATOFFS.length-1];").nl();
    bodySb.i().p("for(int i = 0; i < ").p(_parms._k).p("; i++) {").nl();
    // Categorical columns
    bodySb.i(1).p("for(int j = 0; j < ").p(cats).p("; j++) {").nl();
    bodySb.i(2).p("int c = (int) data[PERMUTE[j]];").nl();
    bodySb.i(2).p("if(c > CATOFFS[j+1]-CATOFFS[j]-1) continue;").nl();
    bodySb.i(2).p("preds[i] += EIGVECS[CATOFFS[j]+c][i];").nl();
    bodySb.i(1).p("}").nl();

    // Numeric columns
    bodySb.i(1).p("for(int j = 0; j < ").p(nums).p("; j++) {").nl();
    bodySb.i(2).p("preds[i] += (data[PERMUTE[j" + (cats > 0 ? "+" + cats : "") + "]]-NORMSUB[j])*NORMMUL[j]*EIGVECS[j" + (cats > 0 ? "+ nstart" : "") +"][i];").nl();
    bodySb.i(1).p("}").nl();
    bodySb.i().p("}").nl();
    fileCtxSb.p(model);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy