All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.coxph.CoxPH Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.coxph;

import Jama.Matrix;
import hex.*;
import hex.DataInfo.Row;
import hex.DataInfo.TransformType;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.ast.prims.mungers.AstGroup;
import water.util.*;
import water.util.Timer;

import static java.util.stream.Collectors.toList;
import static water.util.ArrayUtils.constAry;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * Cox Proportional Hazards Model
 */
public class CoxPH extends ModelBuilder {

  private static final int MAX_TIME_BINS = 100000;

  @Override public ModelCategory[] can_build() { return new ModelCategory[] { ModelCategory.CoxPH }; }
  @Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Stable; }
  @Override public boolean isSupervised() { return true; }

  public CoxPH(boolean startup_once) {
    super(new CoxPHModel.CoxPHParameters(), startup_once);
  }

  public CoxPH( CoxPHModel.CoxPHParameters parms ) { super(parms); init(false); }
  @Override protected CoxPHDriver trainModelImpl() { return new CoxPHDriver(); }

  @Override
  public boolean haveMojo() {
    return true;
  }

  /** Initialize the ModelBuilder, validating all arguments and preparing the
   *  training frame.  This call is expected to be overridden in the subclasses
   *  and each subclass will start with "super.init();".  This call is made
   *  by the front-end whenever the GUI is clicked, and needs to be fast;
   *  heavy-weight prep needs to wait for the trainModel() call.
  */
  @Override public void init(boolean expensive) {
    super.init(expensive);

    if (_parms._train != null && _parms.train() == null) {
      error("train", "Invalid training frame (Frame key = " + _parms._train + " not found)");
    }

    if (_parms._train != null && _parms.train() != null) {
      if (_parms._start_column != null) {
        Vec startVec  = _parms.startVec();
        if (startVec == null) {
          error("start_column", "start_column " + _parms._start_column + " not found in the training frame");
        } else if (!startVec.isNumeric()) {
          error("start_column", "start time must be undefined or of type numeric");
        }
      }

      if (_parms._stop_column != null) {
        Vec stopVec  = _parms.stopVec();
        if (stopVec == null) {
          error("stop_column", "stop_column " + _parms._stop_column + " not found in the training frame");
        } else if (!stopVec.isNumeric()) {
          error("stop_column", "stop time must be of type numeric");
        } else if (expensive) {
          try {
            CollectTimes.collect(_parms.stopVec(), _parms._single_node_mode);
          } catch (CollectTimesException e) {
            error("stop_column", e.getMessage());
          }
        }
      }

      if ((_parms._response_column != null) && ! _response.isInt() && (! _response.isCategorical()))
        error("response_column", "response/event column must be of type integer or factor");

      if (_parms.startVec() != null && _parms.stopVec() != null) {
        if (_parms.startVec().min() >= _parms.stopVec().max())
          error("start_column", "start times must be strictly less than stop times");
      }

      if (_parms._interactions != null) {
        for (String col : _parms._interactions) {
          if (col != null && !col.isEmpty() && _train.vec(col) == null) {
            error("interactions", col + " not found in the training frame");
          }
        }
      }

      if (_parms._interactions_only != null) {
        for (String col : _parms._interactions_only) {
          if (col != null && !col.isEmpty() && _train.vec(col) == null) {
            error("interactions_only", col + " not found in the training frame");
          }
        }
      }

      if (_parms._interaction_pairs != null) {
        for (StringPair pair : _parms._interaction_pairs) {
          if (pair._a != null && !pair._a.isEmpty() && _train.vec(pair._a) == null) {
            error("interaction_pairs", pair._a + " not found in the training frame with columns" 
                    + Arrays.toString(_train.names()));
          }
          if (pair._b != null && !pair._b.isEmpty() && _train.vec(pair._b) == null) {
            error("interaction_pairs", pair._b + " not found in the training frame with columns"
                    + Arrays.toString(_train.names()));
          }
        }
      }

      if( _train != null ) {
        int nonFeatureColCount = (_parms._start_column!=null?1:0) + (_parms._stop_column!=null?1:0);
        if (_train.numCols() < (2 + nonFeatureColCount))
          error("_train", "Training data must have at least 2 features (incl. response).");
        if (null != _parms._stratify_by) {
          int stratifyColCount = _parms._stratify_by.length;
          if (_train.numCols() < (2 + nonFeatureColCount + stratifyColCount))
            error("_train", "Training data must have at least 1 feature that is not a response and is not used for stratification."); 
          }
      }

      if (_parms.isStratified()) {
        for (String col : _parms._stratify_by) {
          Vec v = _parms.train().vec(col);
          if (v == null) {
            error("stratify_by", "column '" + col + "' not found");
          } else if (v.get_type() != Vec.T_CAT) {
            error("stratify_by", "non-categorical column '" + col + "' cannot be used for stratification");
          }
          if (_parms._interactions != null) {
            for (String inter : _parms._interactions) {
              if (col.equals(inter)) {
                // Makes implementation simpler and should not have an actual impact anyway
                error("stratify_by", "stratification column '" + col + "' cannot be used in an implicit interaction. " +
                        "Use explicit (pair-wise) interactions instead");
                break;
              }
            }
          }
        }
       
      }
    }

    if (Double.isNaN(_parms._lre_min) || _parms._lre_min <= 0)
      error("lre_min", "lre_min must be a positive number");

    if (_parms._max_iterations < 1)
      error("max_iterations", "max_iterations must be a positive integer");
  }

  @Override
  protected int init_getNClass() {
    return 1;
  }

  static class DiscretizeTimeTask extends MRTask {
    final double[] _time;
    final boolean _has_start_column;

    private DiscretizeTimeTask(double[] time, boolean has_start_column) {
      _time = time;
      _has_start_column = has_start_column;
    }

    @Override
    public void map(Chunk[] cs, NewChunk[] ncs) {
      assert cs.length == (_has_start_column ? 2 : 1);
      for (int i = 0; i < cs[0].len(); i++)
        discretizeTime(i, cs, ncs, 0);
    }

    void discretizeTime(int i, Chunk[] cs, NewChunk[] ncs, int offset) {
      final double stopTime = cs[cs.length - 1].atd(i);
      final int t2 = Arrays.binarySearch(_time, stopTime);
      if (t2 < 0)
        throw new IllegalStateException("Encountered unexpected stop time");
      ncs[ncs.length - 1].addNum(t2 + offset);
      if (_has_start_column) {
        final double startTime = cs[0].atd(i);
        if (startTime >= stopTime)
          throw new IllegalArgumentException("start times must be strictly less than stop times");
        final int t1c = Arrays.binarySearch(_time, startTime);
        final int t1 = t1c >= 0 ? t1c + 1 : -t1c - 1;
        ncs[0].addNum(t1 + offset);
      }
    }

    static Frame discretizeTime(double[] time, Vec startVec, Vec stopVec, boolean runLocal) {
      final boolean hasStartColumn = startVec != null;
      final Frame f = new Frame();
      if (hasStartColumn)
        f.add("__startCol", startVec);
      f.add("__stopCol", stopVec);
      byte[] outputTypes = hasStartColumn ? new byte[]{Vec.T_NUM, Vec.T_NUM} : new byte[]{Vec.T_NUM}; 
      return new DiscretizeTimeTask(time, startVec != null)
              .doAll(outputTypes, f, runLocal).outputFrame();
    }

  }

  static class StrataTask extends DiscretizeTimeTask {
    private final IcedHashMap _strataMap;

    private StrataTask(IcedHashMap strata) {
      this(strata, new double[0], false);
    }

    private StrataTask(IcedHashMap strata, double[] time, boolean has_start_column) {
      super(time, has_start_column);
      _strataMap = strata;
    }

    @Override
    public void map(Chunk[] cs, NewChunk[] ncs) {
      Chunk[] scs; // strata chunks
      Chunk[] tcs; // time chunks
      NewChunk[] tncs; // time new chunks

      if (ncs.length > 1) {
        // split chunks into 2 groups: strata chunks and time chunks
        scs = new Chunk[cs.length - ncs.length + 1];
        System.arraycopy(cs, 0, scs, 0, scs.length);
        tcs = new Chunk[ncs.length - 1];
        System.arraycopy(cs, scs.length, tcs, 0, tcs.length);
        tncs = new NewChunk[ncs.length - 1];
        System.arraycopy(ncs, 1, tncs, 0, tncs.length);
      } else {
        scs = cs;
        tcs = null;
        tncs = null;
      }

      AstGroup.G g = new AstGroup.G(scs.length, null);
      for (int i = 0; i < cs[0].len(); i++) {
        g.fill(i, scs);
        IcedInt strataId = _strataMap.get(g);
        if (strataId == null) {
          for (NewChunk nc : ncs)
            nc.addNA();
        } else {
          ncs[0].addNum(strataId._val);
          if (tcs != null) {
            final int strataOffset = _time.length * strataId._val;
            discretizeTime(i, tcs, tncs, strataOffset);
          }
        }
      }
    }

    static Vec makeStrataVec(Frame f, String[] stratifyBy, IcedHashMap mapping, boolean runLocal) {
      final Frame sf = f.subframe(stratifyBy);
      return new StrataTask(mapping).doAll(new byte[]{Vec.T_NUM}, sf, runLocal).outputFrame().anyVec();
    }

    static Frame stratifyTime(Frame f, double[] time, String[] stratifyBy, IcedHashMap mapping,
                              Vec startVec, Vec stopVec, boolean runLocal) {
      final Frame sf = f.subframe(stratifyBy);
      final boolean hasStartColumn = startVec != null;
      if (hasStartColumn)
        sf.add("__startVec", startVec);
      sf.add("__stopVec", stopVec);
      return new StrataTask(mapping, time, hasStartColumn)
              .doAll(constAry(hasStartColumn ? 3 : 2, Vec.T_NUM), sf, runLocal).outputFrame();
    }

    static void setupStrataMapping(Frame f, String[] stratifyBy, IcedHashMap outMapping) {
      final Frame sf = f.subframe(stratifyBy);
      int[] idxs = MemoryManager.malloc4(stratifyBy.length);
      for (int i = 0; i < idxs.length; i++)
        idxs[i] = i;
      Collection groups = AstGroup.doGroups(sf, idxs, AstGroup.aggNRows());
      groups: for (AstGroup.G g : groups) {
        for (double val : g._gs)
          if (Double.isNaN(val))
            continue groups;
        outMapping.put(g, new IcedInt(outMapping.size()));
      }
    }

  }

  public boolean hasStartColumn() {
    return _parms._start_column != null;
  }

  @Override
  protected boolean validateBinaryResponse() {
    return false; // CoxPH can handle numerical 0-1 response, no warnings needed
  }

  public class CoxPHDriver extends Driver {

    private Frame reorderTrainFrameColumns(IcedHashMap outStrataMap, double time[]) {
      Frame f = new Frame();

      Vec weightVec = null;
      Vec startVec = null;
      Vec stopVec = null;
      Vec eventVec = null;

      Vec[] vecs = train().vecs();
      String[] names = train().names();

      for (int i = 0; i < names.length; i++) {
        if (names[i].equals(_parms._weights_column))
          weightVec = vecs[i];
        else if (names[i].equals(_parms._start_column))
          startVec = vecs[i];
        else if (names[i].equals(_parms._stop_column))
          stopVec = vecs[i];
        else if (names[i].equals(_parms._response_column))
          eventVec = vecs[i];
        else
          f.add(names[i], vecs[i]);
      }

      Vec strataVec = null;
      Frame discretizedFr;
      if (_parms.isStratified()) {
        StrataTask.setupStrataMapping(f, _parms._stratify_by, outStrataMap);
        discretizedFr = Scope.track(
                StrataTask.stratifyTime(f, time, _parms._stratify_by, outStrataMap, startVec, stopVec, _parms._single_node_mode)
        );
        strataVec = discretizedFr.remove(0);
        if (_parms.interactionSpec() == null) {
          // no interactions => we can drop the columns earlier
          f.remove(_parms._stratify_by);
        }
      } else {
        discretizedFr = Scope.track(DiscretizeTimeTask.discretizeTime(time, startVec, stopVec, _parms._single_node_mode));
      }
      // swap time columns for their discretized versions
      if (startVec != null) {
        startVec = discretizedFr.vec(0);
        stopVec = discretizedFr.vec(1);
      } else
        stopVec = discretizedFr.vec(0);

      if (weightVec != null)
        f.add(_parms._weights_column, weightVec);
      if (strataVec != null)
        f.add(_parms._strata_column, strataVec);
      if (startVec != null)
        f.add(_parms._start_column, startVec);
      if (stopVec != null)
        f.add(_parms._stop_column, stopVec);
      if (eventVec != null)
        f.add(_parms._response_column, eventVec);

      return f;
    }

    protected void initStats(final CoxPHModel model, final DataInfo dinfo, final double[] time) {
      CoxPHModel.CoxPHParameters p = model._parms;
      CoxPHModel.CoxPHOutput o = model._output;

      o._n = p.stopVec().length();
      o.data_info = dinfo;
      final int n_offsets = _offset == null ? 0 : 1;
      final int n_coef    = o.data_info.fullN() - n_offsets;
      final String[] coefNames = o.data_info.coefNames();
      o._coef_names = new String[n_coef];
      System.arraycopy(coefNames, 0, o._coef_names, 0, n_coef);
      o._coef = MemoryManager.malloc8d(n_coef);
      o._exp_coef = MemoryManager.malloc8d(n_coef);
      o._exp_neg_coef = MemoryManager.malloc8d(n_coef);
      o._se_coef = MemoryManager.malloc8d(n_coef);
      o._z_coef = MemoryManager.malloc8d(n_coef);
      o._var_coef = MemoryManager.malloc8d(n_coef, n_coef);
      o._mean_offset = MemoryManager.malloc8d(n_offsets);
      o._offset_names = new String[n_offsets];
      System.arraycopy(coefNames, n_coef, o._offset_names, 0, n_offsets);

      final int n_time = (int) dinfo._adaptedFrame.vec(p._stop_column).max() + 1;
      o._time = time;
      o._n_risk = MemoryManager.malloc8d(n_time);
      o._n_event = MemoryManager.malloc8d(n_time);
      o._n_censor = MemoryManager.malloc8d(n_time);
    }

    protected void calcCounts(CoxPHModel model, final CoxPHTask coxMR) {
      CoxPHModel.CoxPHParameters p = model._parms;
      CoxPHModel.CoxPHOutput o = model._output;

      o._n_missing = o._n - coxMR.n;
      o._n = coxMR.n;
      o._x_mean_cat = MemoryManager.malloc8d(coxMR.sumWeights.length, o.data_info.numCats());
      o._x_mean_num = MemoryManager.malloc8d(coxMR.sumWeights.length, o.data_info.numNums() - o._mean_offset.length);
      for (int s = 0; s < coxMR.sumWeights.length; s++) {
        System.arraycopy(coxMR.sumWeightedCatX[s], 0, o._x_mean_cat[s], 0, o._x_mean_cat[s].length);
        for (int j = 0; j < o._x_mean_cat[s].length; j++)
          o._x_mean_cat[s][j] /= coxMR.sumWeights[s];
        System.arraycopy(coxMR.sumWeightedNumX[s], 0, o._x_mean_num[s], 0, o._x_mean_num[s].length);
        for (int j = 0; j < o._x_mean_num[s].length; j++)
          o._x_mean_num[s][j] = o.data_info._normSub[j] + o._x_mean_num[s][j] / coxMR.sumWeights[s];
      }
      System.arraycopy(o.data_info._normSub, o.data_info.numNums() - o._mean_offset.length, o._mean_offset, 0, o._mean_offset.length);
      for (int t = 0; t < coxMR.countEvents.length; ++t) {
        o._total_event += coxMR.countEvents[t];
        if (coxMR.sizeEvents[t] > 0 || coxMR.sizeCensored[t] > 0) {
          o._n_risk[t]   = coxMR.sizeRiskSet[t];
          o._n_event[t]  = coxMR.sizeEvents[t];
          o._n_censor[t] = coxMR.sizeCensored[t];
        }
      }
      if (p._start_column == null)
        for (int t = o._n_risk.length - 2; t >= 0; --t)
          o._n_risk[t] += o._n_risk[t + 1];
    }

    protected ComputationState calcLoglik(DataInfo dinfo, ComputationState cs, CoxPHModel.CoxPHParameters p, CoxPHTask coxMR) {

      cs.reset();
      switch (p._ties) {
        case efron:
          return EfronMethod.calcLoglik(dinfo, coxMR, cs, _parms._single_node_mode);
        case breslow:
          final int n_coef = cs._n_coef;
          final int n_time = coxMR.sizeEvents.length;
          double newLoglik = 0;
          for (int i = 0; i < n_coef; i++)
            cs._gradient[i] = coxMR.sumXEvents[i];
          for (int t = n_time - 1; t >= 0; --t) {
            final double sizeEvents_t = coxMR.sizeEvents[t];
            if (sizeEvents_t > 0) {
              final double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[t];
              final double rcumsumRisk_t      = coxMR.rcumsumRisk[t];
              newLoglik += sumLogRiskEvents_t;
              newLoglik -= sizeEvents_t * Math.log(rcumsumRisk_t);
              for (int j = 0; j < n_coef; ++j) {
                final double dlogTerm = coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t;
                cs._gradient[j] -= sizeEvents_t * dlogTerm;
                for (int k = 0; k < n_coef; ++k)
                  cs._hessian[j][k] -= sizeEvents_t *
                          (((coxMR.rcumsumXXRisk[t][j][k] / rcumsumRisk_t) -
                                  (dlogTerm * (coxMR.rcumsumXRisk[t][k] / rcumsumRisk_t))));
              }
            }
          }
          cs._logLik =  newLoglik;
          return cs;
        default:
          throw new IllegalArgumentException("_ties method must be either efron or breslow");
      }
    }

    protected void calcModelStats(CoxPHModel model, final double[] newCoef, final ComputationState cs) {
      CoxPHModel.CoxPHParameters p = model._parms;
      CoxPHModel.CoxPHOutput o = model._output;

      final int n_coef = o._coef.length;
      final Matrix inv_hessian = new Matrix(cs._hessian).inverse();
      for (int j = 0; j < n_coef; ++j) {
        for (int k = 0; k <= j; ++k) {
          final double elem = -inv_hessian.get(j, k);
          o._var_coef[j][k] = elem;
          o._var_coef[k][j] = elem;
        }
      }
      for (int j = 0; j < n_coef; ++j) {
        o._coef[j]         = newCoef[j];
        o._exp_coef[j]     = Math.exp(o._coef[j]);
        o._exp_neg_coef[j] = Math.exp(- o._coef[j]);
        o._se_coef[j]      = Math.sqrt(o._var_coef[j][j]);
        o._z_coef[j]       = o._coef[j] / o._se_coef[j];
      }
      if (o._iter == 0) {
        o._null_loglik = cs._logLik;
        o._maxrsq = 1 - Math.exp(2 * o._null_loglik / o._n);
        o._score_test = 0;
        for (int j = 0; j < n_coef; ++j) {
          double sum = 0;
          for (int k = 0; k < n_coef; ++k)
            sum += o._var_coef[j][k] * cs._gradient[k];
          o._score_test += cs._gradient[j] * sum;
        }
      }
      o._loglik = cs._logLik;
      o._loglik_test = - 2 * (o._null_loglik - o._loglik);
      o._rsq = 1 - Math.exp(- o._loglik_test / o._n);
      o._wald_test = 0;
      for (int j = 0; j < n_coef; ++j) {
        double sum = 0;
        for (int k = 0; k < n_coef; ++k)
          sum -= cs._hessian[j][k] * (o._coef[k] - p._init);
        o._wald_test += (o._coef[j] - p._init) * sum;
      }
    }

    protected void calcCumhaz_0(CoxPHModel model, final CoxPHTask coxMR) {
      CoxPHModel.CoxPHParameters p = model._parms;
      CoxPHModel.CoxPHOutput o = model._output;

      final int n_time = coxMR.sizeEvents.length;

      o._cumhaz_0 = MemoryManager.malloc8d(n_time);
      o._var_cumhaz_1 = MemoryManager.malloc8d(n_time);
      o._var_cumhaz_2 = Key.make(model._key + "_var_cumhaz_2");
      o._var_cumhaz_2_matrix = new CoxPHModel.FrameMatrix(o._var_cumhaz_2, n_time, o._coef.length);

      final int num_strata = coxMR._num_strata;
      o._baseline_hazard = Key.make(model._key + "_baseline_hazard");
      o._baseline_hazard_matrix = new CoxPHModel.FrameMatrix(o._baseline_hazard, n_time / num_strata, num_strata + 1);
      o._baseline_survival = Key.make(model._key + "_baseline_survival");
      o._baseline_survival_matrix = new CoxPHModel.FrameMatrix(o._baseline_survival, coxMR.sizeEvents.length / num_strata, num_strata + 1);

      final int n_coef = o._coef.length;
      int nz = 0;
      switch (p._ties) {
        case efron:
          for (int t = 0; t < coxMR.sizeEvents.length; ++t) {
            final double sizeEvents_t   = coxMR.sizeEvents[t];
            final double sizeCensored_t = coxMR.sizeCensored[t];
            if (sizeEvents_t > 0 || sizeCensored_t > 0) {
              final long   countEvents_t   = coxMR.countEvents[t];
              final double sumRiskEvents_t = coxMR.sumRiskEvents[t];
              final double rcumsumRisk_t   = coxMR.rcumsumRisk[t];
              final double avgSize = sizeEvents_t / countEvents_t;
              o._cumhaz_0[nz]     = 0;
              o._var_cumhaz_1[nz] = 0;
              for (int j = 0; j < n_coef; ++j)
                o._var_cumhaz_2_matrix.set(nz, j, 0);
              for (long e = 0; e < countEvents_t; ++e) {
                final double frac   = ((double) e) / ((double) countEvents_t);
                final double haz    = 1 / (rcumsumRisk_t - frac * sumRiskEvents_t);
                final double haz_sq = haz * haz;
                o._cumhaz_0[nz]     += avgSize * haz;
                o._var_cumhaz_1[nz] += avgSize * haz_sq;
                for (int j = 0; j < n_coef; ++j)
                  o._var_cumhaz_2_matrix.add(nz, j, avgSize * ((coxMR.rcumsumXRisk[t][j] - frac * coxMR.sumXRiskEvents[t][j]) * haz_sq));
              }
              nz++;
            }
          }
          break;
        case breslow:
          for (int t = 0; t < coxMR.sizeEvents.length; ++t) {
            final double sizeEvents_t   = coxMR.sizeEvents[t];
            final double sizeCensored_t = coxMR.sizeCensored[t];
            if (sizeEvents_t > 0 || sizeCensored_t > 0) {
              final double rcumsumRisk_t = coxMR.rcumsumRisk[t];
              final double cumhaz_0_nz   = sizeEvents_t / rcumsumRisk_t;
              o._cumhaz_0[nz]     = cumhaz_0_nz;
              o._var_cumhaz_1[nz] = sizeEvents_t / (rcumsumRisk_t * rcumsumRisk_t);
              for (int j = 0; j < n_coef; ++j)
                o._var_cumhaz_2_matrix.set(nz, j, (coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t) * cumhaz_0_nz);
              nz++;
            }
          }
          break;
        default:
          throw new IllegalArgumentException("_ties method must be either efron or breslow");
      }

      double[] totalRisks = coxMR.totalRisk.clone();
      double[] sumHaz = new double[totalRisks.length];

      for (int i = sumHaz.length - 1; i >= 0; i--) {
        sumHaz[i] = 0d;
      }

      for (int t = 0; t < coxMR._time.length; ++t) {
        o._baseline_hazard_matrix.set(t,0, coxMR._time[t]);
        o._baseline_survival_matrix.set(t,0, coxMR._time[t]);

        for (int strata = 0; strata < num_strata; strata++) {
          final double weightEvent = coxMR.sizeEvents[t + coxMR._time.length * strata];
          final double sumRiskEvent = coxMR.sumRiskAllEvents[t + coxMR._time.length * strata];

          final double eventRisk = weightEvent / totalRisks[strata];

          totalRisks[strata] -= sumRiskEvent;
          sumHaz[strata] += eventRisk;
          
          o._baseline_hazard_matrix.set(t, strata + 1, eventRisk);
          o._baseline_survival_matrix.set(t, strata + 1, Math.exp(-sumHaz[strata]));
        }
      }
      
      for (int t = 1; t < o._cumhaz_0.length; ++t) {
        o._cumhaz_0[t]     = o._cumhaz_0[t - 1]     + o._cumhaz_0[t];
        o._var_cumhaz_1[t] = o._var_cumhaz_1[t - 1] + o._var_cumhaz_1[t];
        for (int j = 0; j < n_coef; ++j)
          o._var_cumhaz_2_matrix.set(t, j, o._var_cumhaz_2_matrix.get(t - 1, j) + o._var_cumhaz_2_matrix.get(t, j));
      }

      // install MatricFrames into DKV
      o._var_cumhaz_2_matrix.toFrame(o._var_cumhaz_2);
      final Frame baselineHazardAsFrame = o._baseline_hazard_matrix.toFrame(o._baseline_hazard);
      final Frame baselineSurvivalAsFrame = o._baseline_survival_matrix.toFrame(o._baseline_survival);

      if (null == o._strataMap || 0 == o._strataMap.size()) {
        baselineHazardAsFrame.setNames(new String[]{"t", "baseline hazard"});
        baselineSurvivalAsFrame.setNames(new String[]{"t", "baseline survival"});
      } else {
        final Vec[] strataCols = train().vecs(_input_parms._stratify_by);

        List names = o._strataMap.entrySet().stream()
                .sorted(Comparator.comparingInt(e -> e.getValue()._val))
                .map(Map.Entry::getKey)
                .map(i -> i._gs)
                .map(a -> IntStream.range(0, strataCols.length)
                                   .mapToObj(i -> strataCols[i].factor((int) a[i]))
                )
                .map(s -> s.collect(Collectors.joining(", ", "(", ")")))
                .collect(toList());
        names.add(0, "t");
        baselineHazardAsFrame.setNames(names.toArray(new String[0]));
        baselineSurvivalAsFrame.setNames(names.toArray(new String[0]));
      }
    }

    @Override
    public void computeImpl() {
      CoxPHModel model = null;
      try {
        init(true);

        final double[] time = CollectTimes.collect(_parms.stopVec(), _parms._single_node_mode);

        _job.update(0, "Initializing model training");

        IcedHashMap strataMap = new IcedHashMap<>();
        Frame f = reorderTrainFrameColumns(strataMap, time);

        int nResponses = (_parms.startVec() == null ? 2 : 3) + (_parms.isStratified() ? 1 : 0);
        final DataInfo dinfo = new DataInfo(f, null, nResponses, _parms._use_all_factor_levels, 
                TransformType.DEMEAN, TransformType.NONE, true, false, false, 
                hasWeightCol(), false, false, _parms.interactionSpec()).disableIntercept();
        Scope.track_generic(dinfo);
        DKV.put(dinfo);

        // The model to be built
        CoxPHModel.CoxPHOutput output = new CoxPHModel.CoxPHOutput(CoxPH.this, dinfo._adaptedFrame, train(), strataMap);
        model = new CoxPHModel(_result, _parms, output);
        model.delete_and_lock(_job);

        initStats(model, dinfo, time);
        ScoringHistory sc = new ScoringHistory(_parms._max_iterations + 1);

        final int n_offsets = (_offset == null) ? 0 : 1;
        final int n_coef = dinfo.fullN() - n_offsets;
        final double[] step = MemoryManager.malloc8d(n_coef);
        final double[] oldCoef = MemoryManager.malloc8d(n_coef);
        final double[] newCoef = MemoryManager.malloc8d(n_coef);
        Arrays.fill(step, Double.NaN);
        Arrays.fill(oldCoef, Double.NaN);
        for (int j = 0; j < n_coef; ++j)
          newCoef[j] = model._parms._init;
        double logLik = -Double.MAX_VALUE;
        final boolean has_start_column = (model._parms.startVec() != null);
        final boolean has_weights_column = (_weights != null);
        final ComputationState cs = new ComputationState(n_coef);
        Timer iterTimer = null;
        CoxPHTask coxMR = null;
        _job.update(1, "Running iteration 0");
        for (int i = 0; i <= model._parms._max_iterations; ++i) {
          iterTimer = new Timer();
          model._output._iter = i;

          Timer aggregTimer = new Timer();
          coxMR = new CoxPHTask(dinfo, newCoef, time, (long) response().min() /* min event */,
                  n_offsets, has_start_column, dinfo._adaptedFrame.vec(_parms._strata_column), has_weights_column,
                  _parms._ties).doAll(dinfo._adaptedFrame, _parms._single_node_mode);
          Log.info("CoxPHTask: iter=" + i + ", time=" + aggregTimer.toString());
          _job.update(1);

          Timer loglikTimer = new Timer();
          final double newLoglik = calcLoglik(dinfo, cs, _parms, coxMR)._logLik;
          Log.info("LogLik: iter=" + i + ", time=" + loglikTimer.toString() + ", logLik=" + newLoglik);
          model._output._scoring_history = sc.addIterationScore(i, newLoglik).to2dTable(i+1);
          
          if (newLoglik > logLik) {
            if (i == 0)
              calcCounts(model, coxMR);

            calcModelStats(model, newCoef, cs);

            if (newLoglik == 0)
              model._output._lre = -Math.log10(Math.abs(logLik - newLoglik));
            else
              model._output._lre = -Math.log10(Math.abs((logLik - newLoglik) / newLoglik));
            if (model._output._lre >= model._parms._lre_min)
              break;

            Arrays.fill(step, 0);
            for (int j = 0; j < n_coef; ++j)
              for (int k = 0; k < n_coef; ++k)
                step[j] -= model._output._var_coef[j][k] * cs._gradient[k];
            for (int j = 0; j < n_coef; ++j)
              if (Double.isNaN(step[j]) || Double.isInfinite(step[j]))
                break;

            logLik = newLoglik;
            System.arraycopy(newCoef, 0, oldCoef, 0, oldCoef.length);
          } else {
            for (int j = 0; j < n_coef; ++j)
              step[j] /= 2;
          }

          for (int j = 0; j < n_coef; ++j)
            newCoef[j] = oldCoef[j] - step[j];

          model.update(_job);
          _job.update(1, "Iteration = " + i + "/" + model._parms._max_iterations + ", logLik = " + logLik);
          if (i != model._parms._max_iterations)
            Log.info("CoxPH Iteration: iter=" + i + ", " + iterTimer.toString());
        }

        if (_parms._calc_cumhaz && coxMR != null) {
          calcCumhaz_0(model, coxMR);
        }

        if (iterTimer != null) {
          Log.info("CoxPH Last Iteration: " + iterTimer.toString());
        }
        
        final boolean _skip_scoring = H2O.getSysBoolProperty("debug.skipScoring", false); 
        
        if (!_skip_scoring) {
          model.update(_job);
          model.score(_parms.train()).delete();
          model._output._training_metrics = ModelMetrics.getFromDKV(model, _parms.train());
          model._output._concordance = ((ModelMetricsRegressionCoxPH) model._output._training_metrics).concordance();
        }
        
        model._output._model_summary = generateSummary(model._output);
        Log.info(model._output._model_summary);

        model.update(_job);
      } finally {
        if (model != null) model.unlock(_job);
      }
    }

  }

  private TwoDimTable generateSummary(CoxPHModel.CoxPHOutput output) {
    String[] names = new String[]{"Formula", "Likelihood ratio test", "Concordance", "Number of Observations", "Number of Events"};
    String[] types = new String[]{"string", "double", "double", "long", "long"};
    String[] formats = new String[]{"%s", "%.5f", "%.5f", "%d", "%d"};
    TwoDimTable summary = new TwoDimTable("CoxPH Model", "summary", new String[]{""}, names, types, formats, "");
    summary.set(0, 0, output._formula);
    summary.set(0, 1, output._loglik_test);
    summary.set(0, 2, output._concordance);
    summary.set(0, 3, output._n);
    summary.set(0, 4, output._total_event);
    return summary;
  }

  protected static class CoxPHTask extends CPHBaseTask {
    final double[] _beta;
    final double[] _time;
    final int      _n_offsets;
    final boolean  _has_start_column;
    final boolean  _has_strata_column;
    final boolean  _has_weights_column;
    final long     _min_event;
    final int      _num_strata; // = 1 if the model is not stratified
    final boolean  _isBreslow;

    // OUT
    long         n;
    double[]     sumWeights;
    double[][]   sumWeightedCatX;
    double[][]   sumWeightedNumX;
    double[]     sizeRiskSet;
    double[]     sizeCensored;
    double[]     sizeEvents;
    long[]       countEvents;
    double[]     sumXEvents;
    double[]     sumRiskEvents;
    double[]     sumRiskAllEvents;
    double[][]   sumXRiskEvents;
    double[]     sumLogRiskEvents;
    double[]     rcumsumRisk;
    double[][]   rcumsumXRisk;
    double[]     totalRisk;

    // Breslow only
    double[][][] rcumsumXXRisk;

    CoxPHTask(DataInfo dinfo, final double[] beta, final double[] time, final long min_event,
              final int n_offsets, final boolean has_start_column, Vec strata_column, final boolean has_weights_column,
              final CoxPHModel.CoxPHParameters.CoxPHTies ties) {
      super(dinfo);
      _beta               = beta;
      _time = time;
      _min_event          = min_event;
      _n_offsets          = n_offsets;
      _has_start_column   = has_start_column;
      _has_strata_column  = strata_column != null;
      _has_weights_column = has_weights_column;
      _num_strata         = _has_strata_column ? 1 + (int) strata_column.max() : 1;
      _isBreslow          = CoxPHModel.CoxPHParameters.CoxPHTies.breslow.equals(ties);
    }

    @Override
    protected void chunkInit(){
      final int n_time = _time.length * _num_strata;
      final int n_coef = _beta.length;

      sumWeights       = MemoryManager.malloc8d(_num_strata);
      sumWeightedCatX  = MemoryManager.malloc8d(_num_strata, _dinfo.numCats());
      sumWeightedNumX  = MemoryManager.malloc8d(_num_strata, _dinfo.numNums());
      sizeRiskSet      = MemoryManager.malloc8d(n_time);
      sizeCensored     = MemoryManager.malloc8d(n_time);
      sizeEvents       = MemoryManager.malloc8d(n_time);
      countEvents      = MemoryManager.malloc8(n_time);
      sumRiskEvents    = MemoryManager.malloc8d(n_time);
      sumRiskAllEvents = MemoryManager.malloc8d(n_time);
      sumLogRiskEvents = MemoryManager.malloc8d(n_time);
      rcumsumRisk      = MemoryManager.malloc8d(n_time);
      sumXEvents       = MemoryManager.malloc8d(n_coef);
      sumXRiskEvents   = MemoryManager.malloc8d(n_time, n_coef);
      rcumsumXRisk     = MemoryManager.malloc8d(n_time, n_coef);
      totalRisk        = MemoryManager.malloc8d(_num_strata);

      if (_isBreslow) { // Breslow only
        rcumsumXXRisk = MemoryManager.malloc8d(n_time, n_coef, n_coef);
      }
    }

    @Override
    protected void processRow(Row row) {
      n++;
      double [] response = row.response;
      int ncats = row.nBins;
      int [] cats = row.binIds;
      double [] nums = row.numVals;
      final double weight = _has_weights_column ? row.weight : 1.0;
      if (weight <= 0) {
        throw new IllegalArgumentException("weights must be positive values");
      }
      int respIdx = response.length - 1;
      final long event = (long) (response[respIdx--] - _min_event);
      final int t2 = (int) response[respIdx--];
      final int t1 = _has_start_column ? (int) response[respIdx--] : -1;
      final double strata = _has_strata_column ? response[respIdx--] : 0;
      assert respIdx == -1 : "expected to use all response data";
      if (Double.isNaN(strata))
        return; // skip this row

      final int strataId = (int) strata;
      final int numStart = _dinfo.numStart();
      sumWeights[strataId] += weight;
      for (int j = 0; j < ncats; ++j) {
        sumWeightedCatX[strataId][cats[j]] += weight;
      }
      for (int j = 0; j < nums.length; ++j) {
        sumWeightedNumX[strataId][j] += weight * nums[j];
      }
      double logRisk = 0;
      for (int j = 0; j < ncats; ++j)
        logRisk += _beta[cats[j]];
      for (int j = 0; j < nums.length - _n_offsets; ++j)
        logRisk += nums[j] * _beta[numStart + j];
      for (int j = nums.length - _n_offsets; j < nums.length; ++j)
        logRisk += nums[j];
      final double risk = weight * Math.exp(logRisk);
      logRisk *= weight;
      totalRisk[strataId] += risk;
      sumRiskAllEvents[t2] += risk;
      if (event > 0) {
        countEvents[t2]++;
        sizeEvents[t2]       += weight;
        sumLogRiskEvents[t2] += logRisk;
        sumRiskEvents[t2]    += risk;
      } else
        sizeCensored[t2] += weight;
      if (_has_start_column) {
        for (int t = t1; t <= t2; ++t)
          sizeRiskSet[t] += weight;
        for (int t = t1; t <= t2; ++t)
          rcumsumRisk[t] += risk;
      } else {
        sizeRiskSet[t2]  += weight;
        rcumsumRisk[t2]  += risk;
      }

      final int ntotal = ncats + (nums.length - _n_offsets);
      final int numStartIter = numStart - ncats;
      for (int jit = 0; jit < ntotal; ++jit) {
        final boolean jIsCat = jit < ncats;
        final int j = jIsCat ? cats[jit] : numStartIter + jit;
        final double x1 = jIsCat ? 1.0 : nums[jit - ncats];
        final double xRisk = x1 * risk;
        if (event > 0) {
          sumXEvents[j] += weight * x1;
          sumXRiskEvents[t2][j] += xRisk;
        }
        rcumsumXRisk[t2][j] += xRisk;
        if (_has_start_column && (t1 % _time.length > 0)) {
          rcumsumXRisk[t1 - 1][j] -= xRisk;
        }
        if (_isBreslow) { // Breslow only
          for (int kit = 0; kit < ntotal; ++kit) {
            final boolean kIsCat = kit < ncats;
            final int k = kIsCat ? cats[kit] : numStartIter + kit;
            final double x2 = kIsCat ? 1.0 : nums[kit - ncats];
            final double xxRisk = x2 * xRisk;
            if (_has_start_column) {
              for (int t = t1; t <= t2; ++t)
                rcumsumXXRisk[t][j][k] += xxRisk;
            } else {
              rcumsumXXRisk[t2][j][k] += xxRisk;
            }
          }
        }
      }
    }

    @Override
    public void reduce(CoxPHTask that) {
      n += that.n;
      ArrayUtils.add(sumWeights,       that.sumWeights);
      ArrayUtils.add(sumWeightedCatX,  that.sumWeightedCatX);
      ArrayUtils.add(sumWeightedNumX,  that.sumWeightedNumX);
      ArrayUtils.add(sizeRiskSet,      that.sizeRiskSet);
      ArrayUtils.add(sizeCensored,     that.sizeCensored);
      ArrayUtils.add(sizeEvents,       that.sizeEvents);
      ArrayUtils.add(countEvents,      that.countEvents);
      ArrayUtils.add(sumXEvents,       that.sumXEvents);
      ArrayUtils.add(sumRiskEvents,    that.sumRiskEvents);
      ArrayUtils.add(sumRiskAllEvents, that.sumRiskAllEvents);
      ArrayUtils.add(sumXRiskEvents,   that.sumXRiskEvents);
      ArrayUtils.add(sumLogRiskEvents, that.sumLogRiskEvents);
      ArrayUtils.add(rcumsumRisk,      that.rcumsumRisk);
      ArrayUtils.add(rcumsumXRisk,     that.rcumsumXRisk);
      ArrayUtils.add(totalRisk,        that.totalRisk);
      if (_isBreslow) { // Breslow only
        ArrayUtils.add(rcumsumXXRisk,    that.rcumsumXXRisk);
      }
    }

    @Override
    protected void postGlobal() {
      for (int t = rcumsumXRisk.length - 2; t >= 0; --t)
        for (int j = 0; j < rcumsumXRisk[t].length; ++j)
          rcumsumXRisk[t][j] += ((t + 1) % _time.length) == 0 ? 0 : rcumsumXRisk[t + 1][j];

      if (! _has_start_column) {
        for (int t = rcumsumRisk.length - 2; t >= 0; --t)
          rcumsumRisk[t] += ((t + 1) % _time.length) == 0 ? 0 : rcumsumRisk[t + 1];

        if (_isBreslow) { // Breslow only
          for (int t = rcumsumXXRisk.length - 2; t >= 0; --t)
            for (int j = 0; j < rcumsumXXRisk[t].length; ++j)
              for (int k = 0; k < rcumsumXXRisk[t][j].length; ++k)
                rcumsumXXRisk[t][j][k] += ((t + 1) % _time.length) == 0 ? 0 : rcumsumXXRisk[t + 1][j][k];
        }
      }
    }
  }

  private static class CollectTimes extends VecUtils.CollectDoubleDomain {
    private CollectTimes() {
      super(new double[0], MAX_TIME_BINS);
    }
    static double[] collect(Vec timeVec, boolean runLocal) {
      return new CollectTimes().doAll(timeVec, runLocal).domain();
    }
    @Override
    protected void onMaxDomainExceeded(int maxDomainSize, int currentSize) {
      throw new CollectTimesException("number of distinct stop times is at least " + currentSize + "; maximum number allowed is " + maxDomainSize);
    }
  }

  private static class CollectTimesException extends RuntimeException {
    private CollectTimesException(String message) {
      super(message);
    }
  }

  static class ComputationState {
    final int _n_coef;
    double _logLik;
    double[] _gradient;
    double[][] _hessian;

    ComputationState(int n_coef) {
      _n_coef = n_coef;
      _logLik = 0;
      _gradient = MemoryManager.malloc8d(n_coef);
      _hessian = MemoryManager.malloc8d(n_coef, n_coef);
    }

    void reset() {
      _logLik = 0;
      for (int j = 0; j < _n_coef; ++j)
        _gradient[j] = 0;
      for (int j = 0; j < _n_coef; ++j)
        for (int k = 0; k < _n_coef; ++k)
          _hessian[j][k] = 0;
    }

  }

  private static class ScoringHistory {
    private long[]_scoringTimes;
    private double[] _logLiks;

    public ScoringHistory(int iterCnt) {
      _scoringTimes = new long[iterCnt];
      _logLiks = new double[iterCnt];
    }

    public ScoringHistory addIterationScore(int iter, double logLik) {
      _scoringTimes[iter] = System.currentTimeMillis();
      _logLiks[iter] = logLik;
      return this;
    }

    public TwoDimTable to2dTable(int iterCnt) {
      String[] cnames = new String[]{"timestamp", "duration", "iterations", "logLik"};
      String[] ctypes = new String[]{"string", "string", "int", "double"};
      String[] cformats = new String[]{"%s", "%s", "%d", "%.5f"};
      TwoDimTable res = new TwoDimTable("Scoring History", "", new String[iterCnt], cnames, ctypes, cformats, "");
      DateTimeFormatter fmt = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss");
      for (int i = 0; i < iterCnt; i++) {
        int col = 0;
        res.set(i, col++, fmt.print(_scoringTimes[i]));
        res.set(i, col++, PrettyPrint.msecs(_scoringTimes[i] - _scoringTimes[0], true));
        res.set(i, col++, i);
        res.set(i, col++, _logLiks[i]);
      }
      return res;
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy