All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.coxph.CoxPHModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.coxph;

import hex.*;
import hex.FrameTask.DataInfo;
import hex.schemas.CoxPHModelV2;
import water.*;
import water.api.ModelSchema;
import water.fvec.Frame;
import water.fvec.Vec;

import java.util.Arrays;

/**
 * The Deep Learning model
 * It contains a DeepLearningModelInfo with the most up-to-date model,
 * a scoring history, as well as some helpers to indicate the progress
 */

public class CoxPHModel extends SupervisedModel {
  public static class CoxPHParameters extends SupervisedModel.SupervisedParameters {
    // get destination_key  from SupervisedModel.SupervisedParameters from Model.Parameters
    // get training_frame   from SupervisedModel.SupervisedParameters from Model.Parameters
    // get validation_frame from SupervisedModel.SupervisedParameters from Model.Parameters
    // get "response_column" from SupervisedModel.SupervisedParameters
    // get "ignored_columns" from SupervisedModel.SupervisedParameters from Model.Parameters

    public Vec start_column;
    public Vec stop_column;
    public Vec event_column;
    public Vec weights_column;
    public Vec[] offset_columns;
    public static enum CoxPHTies { efron, breslow }
    public CoxPHTies ties = CoxPHTies.efron;
    public double init = 0;
    public double lre_min = 9;
    public int iter_max = 20;
  }

  public static class CoxPHOutput extends SupervisedModel.SupervisedOutput {
    public CoxPHOutput( CoxPH b ) { super(b); }

    DataInfo data_info;
    String[] coef_names;
    double[] coef;
    double[] exp_coef;
    double[] exp_neg_coef;
    double[] se_coef;
    double[] z_coef;
    double[][] var_coef;
    double null_loglik;
    double loglik;
    double loglik_test;
    double wald_test;
    double score_test;
    double rsq;
    double maxrsq;
    double[] gradient;
    double[][] hessian;
    double lre;
    int iter;
    double[] x_mean_cat;
    double[] x_mean_num;
    double[] mean_offset;
    String[] offset_names;
    long n;
    long n_missing;
    long total_event;
    long min_time;
    long max_time;
    long[] time;
    double[] n_risk;
    double[] n_event;
    double[] n_censor;
    double[] cumhaz_0;
    double[] var_cumhaz_1;
    double[][] var_cumhaz_2;
  }

  // Default publically visible Schema is V2
  public ModelSchema schema() { return new CoxPHModelV2(); }

  // @Override
  public final CoxPHParameters get_params() { return _parms; }

  public CoxPHModel(final Key destKey, final CoxPHParameters parms, final CoxPHOutput output) {
    super(destKey, parms, output);
  }

  @Override public String toString() {
    StringBuilder sb = new StringBuilder();
    sb.append("CoxPHModel toString() UNIMPLEMENTED");
    return sb.toString();
  }

  public String toStringAll() {
    StringBuilder sb = new StringBuilder();
    sb.append("CoxPHModel toStringAll() UNIMPLEMENTED");
    return sb.toString();
  }

  /**
   * Predict from raw double values representing the data
   * @param data raw array containing categorical values (horizontalized to 1,0,0,1,0,0 etc.) and numerical values (0.35,1.24,5.3234,etc), both can contain NaNs
   * @param preds predicted label and per-class probabilities (for classification), predicted target (regression), can contain NaNs
   * @return preds, can contain NaNs
   */
  @Override public float[] score0(double[] data, float[] preds) {
    final int n_offsets = (_parms.offset_columns == null) ? 0 : _parms.offset_columns.length;
    final int n_time    = _output.time.length;
    final int n_coef    = _output.coef.length;
    final int n_cats    = _output.data_info._cats;
    final int n_nums    = _output.data_info._nums;
    final int n_data    = n_cats + n_nums;
    final int n_full    = n_coef + n_offsets;
    final int numStart  = _output.data_info.numStart();
    boolean catsAllNA   = true;
    boolean catsHasNA   = false;
    boolean numsHasNA   = false;
    for (int j = 0; j < n_cats; ++j) {
      catsAllNA &= Double.isNaN(data[j]);
      catsHasNA |= Double.isNaN(data[j]);
    }
    for (int j = n_cats; j < n_data; ++j)
      numsHasNA |= Double.isNaN(data[j]);
    if (numsHasNA || (catsHasNA && !catsAllNA)) {
      for (int i = 1; i <= 2 * n_time; ++i)
        preds[i] = Float.NaN;
    } else {
      double[] full_data = MemoryManager.malloc8d(n_full);
      for (int j = 0; j < n_cats; ++j)
        if (Double.isNaN(data[j])) {
          final int kst = _output.data_info._catOffsets[j];
          final int klen = _output.data_info._catOffsets[j+1] - kst;
          System.arraycopy(_output.x_mean_cat, kst, full_data, kst, klen);
        } else if (data[j] != 0)
          full_data[_output.data_info._catOffsets[j] + (int) (data[j] - 1)] = 1;
      for (int j = 0; j < n_nums; ++j)
        full_data[numStart + j] = data[n_cats + j] - _output.data_info._normSub[j];
      double logRisk = 0;
      for (int j = 0; j < n_coef; ++j)
        logRisk += full_data[j] * _output.coef[j];
      for (int j = n_coef; j < full_data.length; ++j)
        logRisk += full_data[j];
      final double risk = Math.exp(logRisk);
      for (int t = 0; t < n_time; ++t)
        preds[t + 1] = (float) (risk * _output.cumhaz_0[t]);
      for (int t = 0; t < n_time; ++t) {
        final double cumhaz_0_t = _output.cumhaz_0[t];
        double var_cumhaz_2_t = 0;
        for (int j = 0; j < n_coef; ++j) {
          double sum = 0;
          for (int k = 0; k < n_coef; ++k)
            sum += _output.var_coef[j][k] * (full_data[k] * cumhaz_0_t - _output.var_cumhaz_2[t][k]);
          var_cumhaz_2_t += (full_data[j] * cumhaz_0_t - _output.var_cumhaz_2[t][j]) * sum;
        }
        preds[t + 1 + n_time] = (float) (risk * Math.sqrt(_output.var_cumhaz_1[t] + var_cumhaz_2_t));
      }
    }
    preds[0] = Float.NaN;
    return preds;
  }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy