All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.glrm.GLRMModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.glrm;

import hex.*;
import hex.svd.SVDModel.SVDParameters;
import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.*;
import water.util.TwoDimTable;

import java.util.Random;

import static hex.glrm.GLRMModel.GLRMParameters.Loss.*;

public class GLRMModel extends Model implements Model.GLRMArchetypes {

  public static class GLRMParameters extends Model.Parameters {
    public String algoName() { return "GLRM"; }
    public String fullName() { return "Generalized Low Rank Modeling"; }
    public String javaName() { return GLRMModel.class.getName(); }
    @Override public long progressUnits() { return 2 + _max_iterations; }
    public DataInfo.TransformType _transform = DataInfo.TransformType.NONE; // Data transformation (demean to compare with PCA)
    public int _k = 1;                            // Rank of resulting XY matrix
    public GLRM.Initialization _init = GLRM.Initialization.PlusPlus;  // Initialization of Y matrix
    public SVDParameters.Method _svd_method = SVDParameters.Method.Randomized;  // SVD initialization method (for _init = SVD)
    public Key _user_y;               // User-specified Y matrix (for _init = User)
    public Key _user_x;               // User-specified X matrix (for _init = User)
    public boolean _expand_user_y = true;    // Should categorical columns in _user_y be expanded via one-hot encoding? (for _init = User)

    // Loss functions
    public Loss _loss = Loss.Quadratic;                  // Default loss function for numeric cols
    public Loss _multi_loss = Loss.Categorical;  // Default loss function for categorical cols
    public int _period = 1;                       // Length of the period when _loss = Periodic
    public Loss[] _loss_by_col;                // Override default loss function for specific columns
    public int[] _loss_by_col_idx;

    // Regularization functions
    public Regularizer _regularization_x = Regularizer.None;   // Regularization function for X matrix
    public Regularizer _regularization_y = Regularizer.None;   // Regularization function for Y matrix
    public double _gamma_x = 0;                   // Regularization weight on X matrix
    public double _gamma_y = 0;                   // Regularization weight on Y matrix

    // Optional parameters
    public int _max_iterations = 1000;            // Max iterations
    public int _max_updates = 2*_max_iterations;  // Max number of updates (X or Y)
    public double _init_step_size = 1.0;          // Initial step size (decrease until we hit min_step_size)
    public double _min_step_size = 1e-4;          // Min step size
    public long _seed = System.nanoTime();        // RNG seed
    @Override protected long nFoldSeed() { return _seed; }

    // public Key _representation_key;     // Key to save X matrix
    public String _representation_name;
    public boolean _recover_svd = false;          // Recover singular values and eigenvectors of XY at the end?
    public boolean _impute_original = false;      // Reconstruct original training data by reversing _transform?
    public boolean _verbose = true;               // Log when objective increases each iteration?

    // Quadratic -> Gaussian distribution ~ exp(-(a-u)^2)
    // Absolute -> Laplace distribution ~ exp(-|a-u|)
    public enum Loss {
      Quadratic(true), Absolute(true), Huber(true), Poisson(true), Periodic(true),  // One-dimensional loss (numeric)
      Logistic(true, true), Hinge(true, true),  // Boolean loss (categorical)
      Categorical(false), Ordinal(false);    // Multi-dimensional loss (categorical)

      private boolean forNumeric;
      private boolean forBinary;
      Loss(boolean forNumeric) { this(forNumeric, false); }
      Loss(boolean forNumeric, boolean forBinary) {
        this.forNumeric = forNumeric;
        this.forBinary = forBinary;
      }

      public boolean isForNumeric() { return forNumeric; }
      public boolean isForCategorical() { return !forNumeric; }
      public boolean isForBinary() { return forBinary; }
    }

    // Non-negative matrix factorization (NNMF): r_x = r_y = NonNegative
    // Orthogonal NNMF: r_x = OneSparse, r_y = NonNegative
    // K-means clustering: r_x = UnitOneSparse, r_y = 0 (\gamma_y = 0)
    // Quadratic mixture: r_x = Simplex, r_y = 0 (\gamma_y = 0)
    public enum Regularizer {
      None, Quadratic, L2, L1, NonNegative, OneSparse, UnitOneSparse, Simplex
    }

    // Check if all elements of _loss_by_col are equal to a specific loss function
    private final boolean allLossEquals(Loss loss) {
      if (null == _loss_by_col) return false;

      boolean res = true;
      for(int i = 0; i < _loss_by_col.length; i++) {
        if(_loss_by_col[i] != loss) {
          res = false;
          break;
        }
      }
      return res;
    }

    // Closed form solution only if quadratic loss, no regularization or quadratic regularization (same for X and Y), and no missing values
    public final boolean hasClosedForm() {
      long na_cnt = 0;
      Frame train = _train.get();
      for(int i = 0; i < train.numCols(); i++)
        na_cnt += train.vec(i).naCnt();
      return hasClosedForm(na_cnt);
    }

    public final boolean hasClosedForm(long na_cnt) {
      boolean loss_quad = (null == _loss_by_col && _loss == Quadratic) ||
              (null != _loss_by_col && allLossEquals(Quadratic) && (_loss_by_col.length == _train.get().numCols() || _loss == Quadratic));

      return na_cnt == 0 && ((loss_quad && (_gamma_x == 0 || _regularization_x == Regularizer.None || _regularization_x == GLRMParameters.Regularizer.Quadratic)
              && (_gamma_y == 0 || _regularization_y == Regularizer.None || _regularization_y == GLRMParameters.Regularizer.Quadratic)));
    }

    // L(u,a): Loss function
    public final double loss(double u, double a) {
      return loss(u, a, _loss);
    }
    public final double loss(double u, double a, Loss loss) {
      assert loss.isForNumeric() : "Loss function " + loss + " not applicable to numerics";
      switch(loss) {
        case Quadratic:
          return (u-a)*(u-a);
        case Absolute:
          return Math.abs(u-a);
        case Huber:
          return Math.abs(u-a) <= 1 ? 0.5*(u-a)*(u-a) : Math.abs(u-a)-0.5;
        case Poisson:
          assert a >= 0 : "Poisson loss L(u,a) requires variable a >= 0";
          return Math.exp(u) + (a == 0 ? 0 : -a*u + a*Math.log(a) - a);   // Since \lim_{a->0} a*log(a) = 0
        case Hinge:
          // return Math.max(1-a*u,0);
          return Math.max(1 - (a == 0 ? -u : u), 0);   // Booleans are coded {0,1} instead of {-1,1}
        case Logistic:
          // return Math.log(1 + Math.exp(-a * u));
          return Math.log(1 + Math.exp(a == 0 ? u : -u));    // Booleans are coded {0,1} instead of {-1,1}
        case Periodic:
          return 1-Math.cos((a-u)*(2*Math.PI)/_period);
        default:
          throw new RuntimeException("Unknown loss function " + loss);
      }
    }

    // \grad_u L(u,a): Gradient of loss function with respect to u
    public final double lgrad(double u, double a) {
      return lgrad(u, a, _loss);
    }
    public final double lgrad(double u, double a, Loss loss) {
      assert loss.isForNumeric() : "Loss function " + loss + " not applicable to numerics";
      switch(loss) {
        case Quadratic:
          return 2*(u-a);
        case Absolute:
          return Math.signum(u - a);
        case Huber:
          return Math.abs(u-a) <= 1 ? u-a : Math.signum(u-a);
        case Poisson:
          assert a >= 0 : "Poisson loss L(u,a) requires variable a >= 0";
          return Math.exp(u)-a;
        case Hinge:
          // return a*u <= 1 ? -a : 0;
          return a == 0 ? (-u <= 1 ? 1 : 0) : (u <= 1 ? -1 : 0);  // Booleans are coded as {0,1} instead of {-1,1}
        case Logistic:
          // return -a/(1+Math.exp(a*u));
          return a == 0 ? 1/(1+Math.exp(-u)) : -1/(1+Math.exp(u));    // Booleans are coded as {0,1} instead of {-1,1}
        case Periodic:
          return ((2*Math.PI)/_period) * Math.sin((a-u) * (2*Math.PI)/_period);
        default:
          throw new RuntimeException("Unknown loss function " + loss);
      }
    }

    // L(u,a): Multidimensional loss function
    public final double mloss(double[] u, int a) {
      return mloss(u, a, _multi_loss);
    }
    public static double mloss(double[] u, int a, Loss multi_loss) {
      assert multi_loss.isForCategorical() : "Loss function " + multi_loss + " not applicable to categoricals";
      if(a < 0 || a > u.length-1)
        throw new IllegalArgumentException("Index must be between 0 and " + String.valueOf(u.length-1));

      double sum = 0;
      switch(multi_loss) {
        case Categorical:
          for (int i = 0; i < u.length; i++) sum += Math.max(1 + u[i], 0);
          sum += Math.max(1 - u[a], 0) - Math.max(1 + u[a], 0);
          return sum;
        case Ordinal:
          for (int i = 0; i < u.length-1; i++) sum += Math.max(a>i ? 1-u[i]:1, 0);
          return sum;
        default:
          throw new RuntimeException("Unknown multidimensional loss function " + multi_loss);
      }
    }

    // \grad_u L(u,a): Gradient of multidimensional loss function with respect to u
    public final double[] mlgrad(double[] u, int a) {
      return mlgrad(u, a, _multi_loss);
    }
    public static double[] mlgrad(double[] u, int a, Loss multi_loss) {
      assert multi_loss.isForCategorical() : "Loss function " + multi_loss + " not applicable to categoricals";
      if(a < 0 || a > u.length-1)
        throw new IllegalArgumentException("Index must be between 0 and " + String.valueOf(u.length-1));

      double[] grad = new double[u.length];
      switch(multi_loss) {
        case Categorical:
          for (int i = 0; i < u.length; i++) grad[i] = (1+u[i] > 0) ? 1:0;
          grad[a] = (1-u[a] > 0) ? -1:0;
          return grad;
        case Ordinal:
          for (int i = 0; i < u.length-1; i++) grad[i] = (a>i && 1-u[i] > 0) ? -1:0;
          return grad;
        default:
          throw new RuntimeException("Unknown multidimensional loss function " + multi_loss);
      }
    }

    // r_i(x_i), r_j(y_j): Regularization function for single row x_i or column y_j
    public final double regularize_x(double[] u) { return regularize(u, _regularization_x); }
    public final double regularize_y(double[] u) { return regularize(u, _regularization_y); }
    public final double regularize(double[] u, Regularizer regularization) {
      if(u == null) return 0;
      double ureg = 0;

      switch(regularization) {
        case None:
          return 0;
        case Quadratic:
          for(int i = 0; i < u.length; i++)
            ureg += u[i] * u[i];
          return ureg;
        case L2:
          for(int i = 0; i < u.length; i++)
            ureg += u[i] * u[i];
          return Math.sqrt(ureg);
        case L1:
          for(int i = 0; i < u.length; i++)
            ureg += Math.abs(u[i]);
          return ureg;
        case NonNegative:
          for(int i = 0; i < u.length; i++) {
            if(u[i] < 0) return Double.POSITIVE_INFINITY;
          }
          return 0;
        case OneSparse:
          int card = 0;
          for(int i = 0; i < u.length; i++) {
            if(u[i] < 0) return Double.POSITIVE_INFINITY;
            else if(u[i] > 0) card++;
          }
          return card == 1 ? 0 : Double.POSITIVE_INFINITY;
        case UnitOneSparse:
          int ones = 0, zeros = 0;
          for(int i = 0; i < u.length; i++) {
            if(u[i] == 1) ones++;
            else if(u[i] == 0) zeros++;
            else return Double.POSITIVE_INFINITY;
          }
          return ones == 1 && zeros == u.length-1 ? 0 : Double.POSITIVE_INFINITY;
        case Simplex:
          double sum = 0, absum = 0;
          for(int i = 0; i < u.length; i++) {
            if(u[i] < 0) return Double.POSITIVE_INFINITY;
            else {
              sum += u[i];
              absum += Math.abs(u[i]);
            }
          }
          return MathUtils.equalsWithinRecSumErr(sum, 1.0, u.length, absum) ? 0 : Double.POSITIVE_INFINITY;
        default:
          throw new RuntimeException("Unknown regularization function " + regularization);
      }
    }

    // \sum_i r_i(x_i): Sum of regularization function for all entries of X
    public final double regularize_x(double[][] u) { return regularize(u, _regularization_x); }
    public final double regularize_y(double[][] u) { return regularize(u, _regularization_y); }
    public final double regularize(double[][] u, Regularizer regularization) {
      if(u == null || regularization == Regularizer.None) return 0;

      double ureg = 0;
      for(int i = 0; i < u.length; i++) {
        ureg += regularize(u[i], regularization);
        if(Double.isInfinite(ureg)) return ureg;
      }
      return ureg;
    }

    // \prox_{\alpha_k*r}(u): Proximal gradient of (step size) * (regularization function) evaluated at vector u
    public final double[] rproxgrad_x(double[] u, double alpha, Random rand) { return rproxgrad(u, alpha, _gamma_x, _regularization_x, rand); }
    public final double[] rproxgrad_y(double[] u, double alpha, Random rand) { return rproxgrad(u, alpha, _gamma_y, _regularization_y, rand); }
    // public final double[] rproxgrad_x(double[] u, double alpha) { return rproxgrad(u, alpha, _gamma_x, _regularization_x, RandomUtils.getRNG(_seed)); }
    // public final double[] rproxgrad_y(double[] u, double alpha) { return rproxgrad(u, alpha, _gamma_y, _regularization_y, RandomUtils.getRNG(_seed)); }
    static double[] rproxgrad(double[] u, double alpha, double gamma, Regularizer regularization, Random rand) {
      if(u == null || alpha == 0 || gamma == 0) return u;
      double[] v = new double[u.length];

      switch(regularization) {
        case None:
          return u;
        case Quadratic:
          for(int i = 0; i < u.length; i++)
            v[i] = u[i]/(1+2*alpha*gamma);
          return v;
        case L2:
          // Proof uses Moreau decomposition; see section 6.5.1 of Parikh and Boyd https://web.stanford.edu/~boyd/papers/pdf/prox_algs.pdf
          double weight = 1 - alpha*gamma/ArrayUtils.l2norm(u);
          if(weight < 0) return v;   // Zero vector
          for(int i = 0; i < u.length; i++)
            v[i] = weight * u[i];
          return v;
        case L1:
          for(int i = 0; i < u.length; i++)
            v[i] = Math.max(u[i]-alpha*gamma,0) + Math.min(u[i]+alpha*gamma,0);
          return v;
        case NonNegative:
          for(int i = 0; i < u.length; i++)
            v[i] = Math.max(u[i],0);
          return v;
        case OneSparse:
          int idx = ArrayUtils.maxIndex(u, rand);
          v[idx] = u[idx] > 0 ? u[idx] : 1e-6;
          return v;
        case UnitOneSparse:
          idx = ArrayUtils.maxIndex(u, rand);
          v[idx] = 1;
          return v;
        case Simplex:
          // Proximal gradient algorithm by Chen and Ye in http://arxiv.org/pdf/1101.6081v2.pdf
          // 1) Sort input vector u in ascending order: u[1] <= ... <= u[n]
          int n = u.length;
          int[] idxs = new int[n];
          for(int i = 0; i < n; i++) idxs[i] = i;
          ArrayUtils.sort(idxs, u);

          // 2) Calculate cumulative sum of u in descending order
          // cumsum(u) = (..., u[n-2]+u[n-1]+u[n], u[n-1]+u[n], u[n])
          double[] ucsum = new double[n];
          ucsum[n-1] = u[idxs[n-1]];
          for(int i = n-2; i >= 0; i--)
            ucsum[i] = ucsum[i+1] + u[idxs[i]];

          // 3) Let t_i = (\sum_{j=i+1}^n u[j] - 1)/(n - i)
          // For i = n-1,...,1, set optimal t* to first t_i >= u[i]
          double t = (ucsum[0] - 1)/n;    // Default t* = (\sum_{j=1}^n u[j] - 1)/n
          for(int i = n-1; i >= 1; i--) {
            double tmp = (ucsum[i] - 1)/(n - i);
            if(tmp >= u[idxs[i-1]]) {
              t = tmp; break;
            }
          }

          // 4) Return max(u - t*, 0) as projection of u onto simplex
          double[] x = new double[u.length];
          for(int i = 0; i < u.length; i++)
            x[i] = Math.max(u[i] - t, 0);
          return x;
        default:
          throw new RuntimeException("Unknown regularization function " + regularization);
      }
    }

    // Project X,Y matrices into appropriate subspace so regularizer is finite. Used during initialization.
    public final double[] project_x(double[] u, Random rand) { return project(u, _regularization_x, rand); }
    public final double[] project_y(double[] u, Random rand) { return project(u, _regularization_y, rand); }
    public final double[] project(double[] u, Regularizer regularization, Random rand) {
      if(u == null) return u;

      switch(regularization) {
        // Domain is all real numbers
        case None:
        case Quadratic:
        case L2:
        case L1:
          return u;
        // Proximal operator of indicator function for a set C is (Euclidean) projection onto C
        case NonNegative:
        case OneSparse:
        case UnitOneSparse:
          return rproxgrad(u, 1, 1, regularization, rand);
        case Simplex:
          double reg = regularize(u, regularization);   // Check if inside simplex before projecting since algo is complicated
          if (reg == 0) return u;
          return rproxgrad(u, 1, 1, regularization, rand);
        default:
          throw new RuntimeException("Unknown regularization function " + regularization);
      }
    }

    // \hat A_{i,j} = \argmin_a L_{i,j}(x_iy_j, a): Data imputation for real numeric values
    public final double impute(double u) {
      return impute(u, _loss);
    }
    public static double impute(double u, Loss loss) {
      assert loss.isForNumeric() : "Loss function " + loss + " not applicable to numerics";
      switch(loss) {
        case Quadratic:
        case Absolute:
        case Huber:
        case Periodic:
          return u;
        case Poisson:
          return Math.exp(u)-1;
        case Hinge:
        case Logistic:
          return u > 0 ? 1 : 0;   // Booleans are coded as {0,1} instead of {-1,1}
        default:
          throw new RuntimeException("Unknown loss function " + loss);
      }
    }

    // \hat A_{i,j} = \argmin_a L_{i,j}(x_iy_j, a): Data imputation for categorical values {0,1,2,...}
    // TODO: Is there a faster way to find the loss minimizer?
    public final int mimpute(double[] u) {
      return mimpute(u, _multi_loss);
    }
    public static int mimpute(double[] u, Loss multi_loss) {
      assert multi_loss.isForCategorical() : "Loss function " + multi_loss + " not applicable to categoricals";
      switch(multi_loss) {
        case Categorical:
        case Ordinal:
          double[] cand = new double[u.length];
          for (int a = 0; a < cand.length; a++)
            cand[a] = mloss(u, a, multi_loss);
          return ArrayUtils.minIndex(cand);
        default:
          throw new RuntimeException("Unknown multidimensional loss function " + multi_loss);
      }
    }
  }

  public static class GLRMOutput extends Model.Output {
    // Iterations executed
    public int _iterations;

    // Updates executed
    public int _updates;

    // Current value of objective function
    public double _objective;

    // Average change in objective function this iteration
    public double _avg_change_obj;
    public double[/*iterations*/] _history_objective = new double[0];

    // Mapping from lower dimensional k-space to training features (Y)
    public TwoDimTable _archetypes;
    public GLRM.Archetypes _archetypes_raw;   // Needed for indexing into Y for scoring

    // Step size each iteration
    public double[/*iterations*/] _history_step_size = new double[0];

    // SVD of output XY
    public double[/*feature*/][/*k*/] _eigenvectors_raw;
    public TwoDimTable _eigenvectors;
    public double[] _singular_vals;

    // Frame key of X matrix
    public String _representation_name;
    public Key _representation_key;
    public Key _init_key;

    // Number of categorical and numeric columns
    public int _ncats;
    public int _nnums;

    // Number of good rows in training frame (not skipped)
    public long _nobs;

    // Categorical offset vector
    public int[] _catOffsets;

    // Standardization arrays for numeric data columns
    public double[] _normSub;   // Mean of each numeric column
    public double[] _normMul;   // One over standard deviation of each numeric column

    // Permutation array mapping adapted to original training col indices
    public int[] _permutation;  // _permutation[i] = j means col i in _adaptedFrame maps to col j of _train

    // Expanded column names of adapted training frame
    public String[] _names_expanded;

    // Loss function for every column in adapted training frame
    public GLRMParameters.Loss[] _lossFunc;

    // Training time
    public long[/*iterations*/] _training_time_ms = new long[0];

    public GLRMOutput(GLRM b) { super(b); }

    /** Override because base class implements ncols-1 for features with the
     *  last column as a response variable; for GLRM all the columns are features. */
    @Override public int nfeatures() { return _names.length; }

    @Override public ModelCategory getModelCategory() {
      return ModelCategory.DimReduction;
    }
  }

  public GLRMModel(Key selfKey, GLRMParameters parms, GLRMOutput output) { super(selfKey, parms, output); }

  @Override protected Futures remove_impl( Futures fs ) {
    if(null != _output._init_key)
      _output._init_key.remove(fs);
    if(null != _output._representation_key)
      _output._representation_key.remove(fs);
    return super.remove_impl(fs);
  }

  /** Write out K/V pairs */
  @Override protected AutoBuffer writeAll_impl(AutoBuffer ab) { 
    ab.putKey(_output._init_key);
    ab.putKey(_output._representation_key);
    return super.writeAll_impl(ab);
  }
  @Override protected Keyed readAll_impl(AutoBuffer ab, Futures fs) { 
    ab.getKey(_output._init_key,fs);
    ab.getKey(_output._representation_key,fs);
    return super.readAll_impl(ab,fs);
  }

  // GLRM scoring is data imputation based on feature domains using reconstructed XY (see Udell (2015), Section 5.3)
  private Frame reconstruct(Frame orig, Frame adaptedFr, Key destination_key, boolean save_imputed, boolean reverse_transform) {
    final int ncols = _output._names.length;
    assert ncols == adaptedFr.numCols();
    String prefix = "reconstr_";

    // Need [A,X,P] where A = adaptedFr, X = loading frame, P = imputed frame
    // Note: A is adapted to original training frame, P has columns shuffled so cats come before nums!
    Frame fullFrm = new Frame(adaptedFr);
    Frame loadingFrm = DKV.get(_output._representation_key).get();
    fullFrm.add(loadingFrm);
    String[][] adaptedDomme = adaptedFr.domains();
    for(int i = 0; i < ncols; i++) {
      Vec v = fullFrm.anyVec().makeZero();
      v.setDomain(adaptedDomme[i]);
      fullFrm.add(prefix + _output._names[i], v);
    }
    GLRMScore gs = new GLRMScore(ncols, _parms._k, save_imputed, reverse_transform).doAll(fullFrm);

    // Return the imputed training frame
    int x = ncols + _parms._k, y = fullFrm.numCols();
    Frame f = fullFrm.extractFrame(x, y);  // this will call vec_impl() and we cannot call the delete() below just yet

    f = new Frame((null == destination_key ? Key.make() : destination_key), f.names(), f.vecs());
    DKV.put(f);
    gs._mb.makeModelMetrics(GLRMModel.this, orig, null, null);   // save error metrics based on imputed data
    return f;
  }

  @Override protected Frame predictScoreImpl(Frame orig, Frame adaptedFr, String destination_key, Job j) {
    return reconstruct(orig, adaptedFr, (null == destination_key ? Key.make() : Key.make(destination_key)), true, _parms._impute_original);
  }

  public Frame scoreReconstruction(Frame frame, Key destination_key, boolean reverse_transform) {
    Frame adaptedFr = new Frame(frame);
    adaptTestForTrain(adaptedFr, true, false);
    return reconstruct(frame, adaptedFr, destination_key, true, reverse_transform);
  }

  /**
   * Project each archetype into original feature space
   * @param frame Original training data with m rows and n columns
   * @param destination_key Frame Id for output
   * @return Frame containing k rows and n columns, where each row corresponds to an archetype
   */
  public Frame scoreArchetypes(Frame frame, Key destination_key, boolean reverse_transform) {
    final int ncols = _output._names.length;
    Frame adaptedFr = new Frame(frame);
    adaptTestForTrain(adaptedFr, true, false);
    assert ncols == adaptedFr.numCols();
    String[][] adaptedDomme = adaptedFr.domains();
    double[][] proj = new double[_parms._k][_output._nnums + _output._ncats];

    // Categorical columns
    for (int d = 0; d < _output._ncats; d++) {
      double[][] block = _output._archetypes_raw.getCatBlock(d);
      for (int k = 0; k < _parms._k; k++)
        proj[k][_output._permutation[d]] = _parms.mimpute(block[k], _output._lossFunc[d]);
    }

    // Numeric columns
    for (int d = _output._ncats; d < (_output._ncats + _output._nnums); d++) {
      int ds = d - _output._ncats;
      for (int k = 0; k < _parms._k; k++) {
        double num = _output._archetypes_raw.getNum(ds, k);
        proj[k][_output._permutation[d]] = _parms.impute(num, _output._lossFunc[d]);
        if (reverse_transform)
          proj[k][_output._permutation[d]] = proj[k][_output._permutation[d]] / _output._normMul[ds] + _output._normSub[ds];
      }
    }

    // Convert projection of archetypes into a frame with correct domains
    Frame f = ArrayUtils.frame((null == destination_key ? Key.make() : destination_key), adaptedFr.names(), proj);
    for(int i = 0; i < ncols; i++) f.vec(i).setDomain(adaptedDomme[i]);
    return f;
  }

  private class GLRMScore extends MRTask {
    final int _ncolA;   // Number of cols in original data A
    final int _ncolX;   // Number of cols in X (rank k)
    final boolean _save_imputed;  // Save imputed data into new vecs?
    final boolean _reverse_transform;   // Reconstruct original training data by reversing transform?
    ModelMetrics.MetricBuilder _mb;

    GLRMScore( int ncolA, int ncolX, boolean save_imputed ) {
      this(ncolA, ncolX, save_imputed, _parms._impute_original);
    }

    GLRMScore( int ncolA, int ncolX, boolean save_imputed, boolean reverse_transform ) {
      _ncolA = ncolA; _ncolX = ncolX;
      _save_imputed = save_imputed;
      _reverse_transform = reverse_transform;
    }

    @Override public void map( Chunk chks[] ) {
      float atmp [] = new float[_ncolA];
      double xtmp [] = new double[_ncolX];
      double preds[] = new double[_ncolA];
      _mb = GLRMModel.this.makeMetricBuilder(null);

      if (_save_imputed) {
        for (int row = 0; row < chks[0]._len; row++) {
          double p[] = impute_data(chks, row, xtmp, preds);
          compute_metrics(chks, row, atmp, p);
          for (int c = 0; c < preds.length; c++)
            chks[_ncolA + _ncolX + c].set(row, p[c]);
        }
      } else {
        for (int row = 0; row < chks[0]._len; row++) {
          double p[] = impute_data(chks, row, xtmp, preds);
          compute_metrics(chks, row, atmp, p);
        }
      }
    }

    @Override public void reduce( GLRMScore other ) { if(_mb != null) _mb.reduce(other._mb); }

    @Override protected void postGlobal() { if(_mb != null) _mb.postGlobal(); }

    private float[] compute_metrics(Chunk[] chks, int row_in_chunk, float[] tmp, double[] preds) {
      for( int i=0; i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy