hex.coxph.CoxPH Maven / Gradle / Ivy
package hex.coxph;
import Jama.Matrix;
import hex.FrameTask;
import hex.FrameTask.DataInfo;
import hex.SupervisedModelBuilder;
// import hex.schemas.CoxPHV2;
import hex.schemas.ModelBuilderSchema;
import jsr166y.ForkJoinTask;
import jsr166y.RecursiveAction;
import water.*;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.*;
import java.util.Arrays;
/**
* Deep Learning Neural Net implementation based on MRTask
*/
public class CoxPH extends SupervisedModelBuilder {
public CoxPH( CoxPHModel.CoxPHParameters parms ) { super("CoxPHLearning",parms); init(false); }
public ModelBuilderSchema schema() {
H2O.unimpl();
return null;
// return new CoxPHV2();
}
/** Start the Cox PH training Job on an F/J thread. */
@Override public Job trainModel() {
CoxPHDriver cd = new CoxPHDriver();
cd.setModelBuilderTrain(_train);
CoxPH cph = (CoxPH) start(cd, _parms.iter_max);
return cph;
}
/** Initialize the ModelBuilder, validating all arguments and preparing the
* training frame. This call is expected to be overridden in the subclasses
* and each subclass will start with "super.init();". This call is made
* by the front-end whenever the GUI is clicked, and needs to be fast;
* heavy-weight prep needs to wait for the trainModel() call.
*/
@Override public void init(boolean expensive) {
super.init(expensive);
if ((_parms.start_column != null) && !_parms.start_column.isInt())
error("start_column", "start time must be null or of type integer");
if (!_parms.stop_column.isInt())
error("stop_column", "stop time must be of type integer");
if (!_parms.event_column.isInt() && !_parms.event_column.isEnum())
error("event_column", "event must be of type integer or factor");
if (Double.isNaN(_parms.lre_min) || _parms.lre_min <= 0)
error("lre_min", "lre_min must be a positive number");
if (_parms.iter_max < 1)
error("iter_max", "iter_max must be a positive integer");
final int MAX_TIME_BINS = 10000;
final long min_time = (_parms.start_column == null) ? (long) _parms.stop_column.min() : (long) _parms.start_column.min() + 1;
final int n_time = (int) (_parms.stop_column.max() - min_time + 1);
if (n_time < 1)
error("start_column", "start times must be strictly less than stop times");
if (n_time > MAX_TIME_BINS)
error("stop_column", "number of distinct stop times is " + n_time + "; maximum number allowed is " + MAX_TIME_BINS);
}
public class CoxPHDriver extends H2O.H2OCountedCompleter {
private Frame _modelBuilderTrain = null;
public void setModelBuilderTrain(Frame v) {
_modelBuilderTrain = v;
}
private void applyScoringFrameSideEffects() {
final int offset_ncol = _parms.offset_columns == null ? 0 : _parms.offset_columns.length;
if (offset_ncol == 0) {
return;
}
int numCols = _modelBuilderTrain.numCols();
String responseVecName = _modelBuilderTrain.names()[numCols-1];
Vec responseVec = _modelBuilderTrain.remove(numCols-1);
for (int i = 0; i < offset_ncol; i++) {
Vec offsetVec = _parms.offset_columns[i];
int idxInRawFrame = _train.find(offsetVec);
if (idxInRawFrame < 0) {
throw new RuntimeException("CoxPHDriver failed to find offsetVec");
}
String offsetVecName = _parms.train().names()[idxInRawFrame];
_modelBuilderTrain.add(offsetVecName, offsetVec);
}
_modelBuilderTrain.add(responseVecName, responseVec);
}
private void applyTrainingFrameSideEffects() {
int numCols = _modelBuilderTrain.numCols();
String responseVecName = _modelBuilderTrain.names()[numCols-1];
Vec responseVec = _modelBuilderTrain.remove(numCols-1);
final boolean use_weights_column = (_parms.weights_column != null);
final boolean use_start_column = (_parms.start_column != null);
if (use_weights_column) {
Vec weightsVec = _parms.weights_column;
int idxInRawFrame = _train.find(weightsVec);
if (idxInRawFrame < 0) {
throw new RuntimeException("CoxPHDriver failed to find weightVec");
}
String weightsVecName = _parms.train().names()[idxInRawFrame];
_modelBuilderTrain.add(weightsVecName, weightsVec);
}
if (use_start_column) {
Vec startVec = _parms.start_column;
int idxInRawFrame = _train.find(startVec);
if (idxInRawFrame < 0) {
throw new RuntimeException("CoxPHDriver failed to find startVec");
}
String startVecName = _parms.train().names()[idxInRawFrame];
_modelBuilderTrain.add(startVecName, startVec);
}
{
Vec stopVec = _parms.stop_column;
int idxInRawFrame = _train.find(stopVec);
if (idxInRawFrame < 0) {
throw new RuntimeException("CoxPHDriver failed to find stopVec");
}
String stopVecName = _parms.train().names()[idxInRawFrame];
_modelBuilderTrain.add(stopVecName, stopVec);
}
_modelBuilderTrain.add(responseVecName, responseVec);
}
protected void initStats(final CoxPHModel model, final DataInfo dinfo) {
CoxPHModel.CoxPHParameters p = model._parms;
CoxPHModel.CoxPHOutput o = model._output;
o.n = p.stop_column.length();
o.data_info = dinfo;
final int n_offsets = (p.offset_columns == null) ? 0 : p.offset_columns.length;
final int n_coef = o.data_info.fullN() - n_offsets;
final String[] coefNames = o.data_info.coefNames();
o.coef_names = new String[n_coef];
System.arraycopy(coefNames, 0, o.coef_names, 0, n_coef);
o.coef = MemoryManager.malloc8d(n_coef);
o.exp_coef = MemoryManager.malloc8d(n_coef);
o.exp_neg_coef = MemoryManager.malloc8d(n_coef);
o.se_coef = MemoryManager.malloc8d(n_coef);
o.z_coef = MemoryManager.malloc8d(n_coef);
o.gradient = MemoryManager.malloc8d(n_coef);
o.hessian = malloc2DArray(n_coef, n_coef);
o.var_coef = malloc2DArray(n_coef, n_coef);
o.x_mean_cat = MemoryManager.malloc8d(n_coef - (o.data_info._nums - n_offsets));
o.x_mean_num = MemoryManager.malloc8d(o.data_info._nums - n_offsets);
o.mean_offset = MemoryManager.malloc8d(n_offsets);
o.offset_names = new String[n_offsets];
System.arraycopy(coefNames, n_coef, o.offset_names, 0, n_offsets);
final Vec start_column = p.start_column;
final Vec stop_column = p.stop_column;
o.min_time = p.start_column == null ? (long) stop_column.min():
(long) start_column.min() + 1;
o.max_time = (long) stop_column.max();
final int n_time = new Vec.CollectDomain().doAll(stop_column).domain().length;
o.time = MemoryManager.malloc8(n_time);
o.n_risk = MemoryManager.malloc8d(n_time);
o.n_event = MemoryManager.malloc8d(n_time);
o.n_censor = MemoryManager.malloc8d(n_time);
o.cumhaz_0 = MemoryManager.malloc8d(n_time);
o.var_cumhaz_1 = MemoryManager.malloc8d(n_time);
o.var_cumhaz_2 = malloc2DArray(n_time, n_coef);
}
protected void calcCounts(CoxPHModel model, final CoxPHTask coxMR) {
CoxPHModel.CoxPHParameters p = model._parms;
CoxPHModel.CoxPHOutput o = model._output;
o.n_missing = o.n - coxMR.n;
o.n = coxMR.n;
for (int j = 0; j < o.x_mean_cat.length; j++)
o.x_mean_cat[j] = coxMR.sumWeightedCatX[j] / coxMR.sumWeights;
for (int j = 0; j < o.x_mean_num.length; j++)
o.x_mean_num[j] = coxMR.dinfo()._normSub[j] + coxMR.sumWeightedNumX[j] / coxMR.sumWeights;
System.arraycopy(coxMR.dinfo()._normSub, o.x_mean_num.length, o.mean_offset, 0, o.mean_offset.length);
int nz = 0;
for (int t = 0; t < coxMR.countEvents.length; ++t) {
o.total_event += coxMR.countEvents[t];
if (coxMR.sizeEvents[t] > 0 || coxMR.sizeCensored[t] > 0) {
o.time[nz] = o.min_time + t;
o.n_risk[nz] = coxMR.sizeRiskSet[t];
o.n_event[nz] = coxMR.sizeEvents[t];
o.n_censor[nz] = coxMR.sizeCensored[t];
nz++;
}
}
if (p.start_column == null)
for (int t = o.n_risk.length - 2; t >= 0; --t)
o.n_risk[t] += o.n_risk[t + 1];
}
protected double calcLoglik(CoxPHModel model, final CoxPHTask coxMR) {
CoxPHModel.CoxPHParameters p = model._parms;
CoxPHModel.CoxPHOutput o = model._output;
final int n_coef = o.coef.length;
final int n_time = coxMR.sizeEvents.length;
double newLoglik = 0;
for (int j = 0; j < n_coef; ++j)
o.gradient[j] = 0;
for (int j = 0; j < n_coef; ++j)
for (int k = 0; k < n_coef; ++k)
o.hessian[j][k] = 0;
switch (p.ties) {
case efron:
final double[] newLoglik_t = MemoryManager.malloc8d(n_time);
final double[][] gradient_t = malloc2DArray(n_time, n_coef);
final double[][][] hessian_t = malloc3DArray(n_time, n_coef, n_coef);
ForkJoinTask[] fjts = new ForkJoinTask[n_time];
for (int t = n_time - 1; t >= 0; --t) {
final int _t = t;
fjts[t] = new RecursiveAction() {
@Override protected void compute() {
final double sizeEvents_t = coxMR.sizeEvents[_t];
if (sizeEvents_t > 0) {
final long countEvents_t = coxMR.countEvents[_t];
final double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[_t];
final double sumRiskEvents_t = coxMR.sumRiskEvents[_t];
final double rcumsumRisk_t = coxMR.rcumsumRisk[_t];
final double avgSize = sizeEvents_t / countEvents_t;
newLoglik_t[_t] = sumLogRiskEvents_t;
System.arraycopy(coxMR.sumXEvents[_t], 0, gradient_t[_t], 0, n_coef);
for (long e = 0; e < countEvents_t; ++e) {
final double frac = ((double) e) / ((double) countEvents_t);
final double term = rcumsumRisk_t - frac * sumRiskEvents_t;
newLoglik_t[_t] -= avgSize * Math.log(term);
for (int j = 0; j < n_coef; ++j) {
final double djTerm = coxMR.rcumsumXRisk[_t][j] - frac * coxMR.sumXRiskEvents[_t][j];
final double djLogTerm = djTerm / term;
gradient_t[_t][j] -= avgSize * djLogTerm;
for (int k = 0; k < n_coef; ++k) {
final double dkTerm = coxMR.rcumsumXRisk[_t][k] - frac * coxMR.sumXRiskEvents[_t][k];
final double djkTerm = coxMR.rcumsumXXRisk[_t][j][k] - frac * coxMR.sumXXRiskEvents[_t][j][k];
hessian_t[_t][j][k] -= avgSize * (djkTerm / term - (djLogTerm * (dkTerm / term)));
}
}
}
}
}
};
}
ForkJoinTask.invokeAll(fjts);
for (int t = 0; t < n_time; ++t)
newLoglik += newLoglik_t[t];
for (int t = 0; t < n_time; ++t)
for (int j = 0; j < n_coef; ++j)
o.gradient[j] += gradient_t[t][j];
for (int t = 0; t < n_time; ++t)
for (int j = 0; j < n_coef; ++j)
for (int k = 0; k < n_coef; ++k)
o.hessian[j][k] += hessian_t[t][j][k];
break;
case breslow:
for (int t = n_time - 1; t >= 0; --t) {
final double sizeEvents_t = coxMR.sizeEvents[t];
if (sizeEvents_t > 0) {
final double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[t];
final double rcumsumRisk_t = coxMR.rcumsumRisk[t];
newLoglik += sumLogRiskEvents_t;
newLoglik -= sizeEvents_t * Math.log(rcumsumRisk_t);
for (int j = 0; j < n_coef; ++j) {
final double dlogTerm = coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t;
o.gradient[j] += coxMR.sumXEvents[t][j];
o.gradient[j] -= sizeEvents_t * dlogTerm;
for (int k = 0; k < n_coef; ++k)
o.hessian[j][k] -= sizeEvents_t *
(((coxMR.rcumsumXXRisk[t][j][k] / rcumsumRisk_t) -
(dlogTerm * (coxMR.rcumsumXRisk[t][k] / rcumsumRisk_t))));
}
}
}
break;
default:
throw new IllegalArgumentException("ties method must be either efron or breslow");
}
return newLoglik;
}
protected void calcModelStats(CoxPHModel model, final double[] newCoef, final double newLoglik) {
CoxPHModel.CoxPHParameters p = model._parms;
CoxPHModel.CoxPHOutput o = model._output;
final int n_coef = o.coef.length;
final Matrix inv_hessian = new Matrix(o.hessian).inverse();
for (int j = 0; j < n_coef; ++j) {
for (int k = 0; k <= j; ++k) {
final double elem = -inv_hessian.get(j, k);
o.var_coef[j][k] = elem;
o.var_coef[k][j] = elem;
}
}
for (int j = 0; j < n_coef; ++j) {
o.coef[j] = newCoef[j];
o.exp_coef[j] = Math.exp(o.coef[j]);
o.exp_neg_coef[j] = Math.exp(- o.coef[j]);
o.se_coef[j] = Math.sqrt(o.var_coef[j][j]);
o.z_coef[j] = o.coef[j] / o.se_coef[j];
}
if (o.iter == 0) {
o.null_loglik = newLoglik;
o.maxrsq = 1 - Math.exp(2 * o.null_loglik / o.n);
o.score_test = 0;
for (int j = 0; j < n_coef; ++j) {
double sum = 0;
for (int k = 0; k < n_coef; ++k)
sum += o.var_coef[j][k] * o.gradient[k];
o.score_test += o.gradient[j] * sum;
}
}
o.loglik = newLoglik;
o.loglik_test = - 2 * (o.null_loglik - o.loglik);
o.rsq = 1 - Math.exp(- o.loglik_test / o.n);
o.wald_test = 0;
for (int j = 0; j < n_coef; ++j) {
double sum = 0;
for (int k = 0; k < n_coef; ++k)
sum -= o.hessian[j][k] * (o.coef[k] - p.init);
o.wald_test += (o.coef[j] - p.init) * sum;
}
}
protected void calcCumhaz_0(CoxPHModel model, final CoxPHTask coxMR) {
CoxPHModel.CoxPHParameters p = model._parms;
CoxPHModel.CoxPHOutput o = model._output;
final int n_coef = o.coef.length;
int nz = 0;
switch (p.ties) {
case efron:
for (int t = 0; t < coxMR.sizeEvents.length; ++t) {
final double sizeEvents_t = coxMR.sizeEvents[t];
final double sizeCensored_t = coxMR.sizeCensored[t];
if (sizeEvents_t > 0 || sizeCensored_t > 0) {
final long countEvents_t = coxMR.countEvents[t];
final double sumRiskEvents_t = coxMR.sumRiskEvents[t];
final double rcumsumRisk_t = coxMR.rcumsumRisk[t];
final double avgSize = sizeEvents_t / countEvents_t;
o.cumhaz_0[nz] = 0;
o.var_cumhaz_1[nz] = 0;
for (int j = 0; j < n_coef; ++j)
o.var_cumhaz_2[nz][j] = 0;
for (long e = 0; e < countEvents_t; ++e) {
final double frac = ((double) e) / ((double) countEvents_t);
final double haz = 1 / (rcumsumRisk_t - frac * sumRiskEvents_t);
final double haz_sq = haz * haz;
o.cumhaz_0[nz] += avgSize * haz;
o.var_cumhaz_1[nz] += avgSize * haz_sq;
for (int j = 0; j < n_coef; ++j)
o.var_cumhaz_2[nz][j] +=
avgSize * ((coxMR.rcumsumXRisk[t][j] - frac * coxMR.sumXRiskEvents[t][j]) * haz_sq);
}
nz++;
}
}
break;
case breslow:
for (int t = 0; t < coxMR.sizeEvents.length; ++t) {
final double sizeEvents_t = coxMR.sizeEvents[t];
final double sizeCensored_t = coxMR.sizeCensored[t];
if (sizeEvents_t > 0 || sizeCensored_t > 0) {
final double rcumsumRisk_t = coxMR.rcumsumRisk[t];
final double cumhaz_0_nz = sizeEvents_t / rcumsumRisk_t;
o.cumhaz_0[nz] = cumhaz_0_nz;
o.var_cumhaz_1[nz] = sizeEvents_t / (rcumsumRisk_t * rcumsumRisk_t);
for (int j = 0; j < n_coef; ++j)
o.var_cumhaz_2[nz][j] = (coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t) * cumhaz_0_nz;
nz++;
}
}
break;
default:
throw new IllegalArgumentException("ties method must be either efron or breslow");
}
for (int t = 1; t < o.cumhaz_0.length; ++t) {
o.cumhaz_0[t] = o.cumhaz_0[t - 1] + o.cumhaz_0[t];
o.var_cumhaz_1[t] = o.var_cumhaz_1[t - 1] + o.var_cumhaz_1[t];
for (int j = 0; j < n_coef; ++j)
o.var_cumhaz_2[t][j] = o.var_cumhaz_2[t - 1][j] + o.var_cumhaz_2[t][j];
}
}
@Override protected void compute2() {
CoxPHModel model = null;
try {
Scope.enter();
init(true);
_parms.lock_frames(CoxPH.this);
applyScoringFrameSideEffects();
// The model to be built
model = new CoxPHModel(dest(), _parms, new CoxPHModel.CoxPHOutput(CoxPH.this));
model.delete_and_lock(_key);
applyTrainingFrameSideEffects();
int nResponses = 1;
boolean useAllFactorLevels = false;
final DataInfo dinfo = new DataInfo(Key.make(), _modelBuilderTrain, null, nResponses, useAllFactorLevels, DataInfo.TransformType.DEMEAN);
initStats(model, dinfo);
final int n_offsets = (model._parms.offset_columns == null) ? 0 : model._parms.offset_columns.length;
final int n_coef = dinfo.fullN() - n_offsets;
final double[] step = MemoryManager.malloc8d(n_coef);
final double[] oldCoef = MemoryManager.malloc8d(n_coef);
final double[] newCoef = MemoryManager.malloc8d(n_coef);
Arrays.fill(step, Double.NaN);
Arrays.fill(oldCoef, Double.NaN);
for (int j = 0; j < n_coef; ++j)
newCoef[j] = model._parms.init;
double oldLoglik = - Double.MAX_VALUE;
final int n_time = (int) (model._output.max_time - model._output.min_time + 1);
final boolean has_start_column = (model._parms.start_column != null);
final boolean has_weights_column = (model._parms.weights_column != null);
for (int i = 0; i <= model._parms.iter_max; ++i) {
model._output.iter = i;
final CoxPHTask coxMR = new CoxPHTask(self(), dinfo, newCoef, model._output.min_time, n_time, n_offsets,
has_start_column, has_weights_column).doAll(dinfo._adaptedFrame);
final double newLoglik = calcLoglik(model, coxMR);
if (newLoglik > oldLoglik) {
if (i == 0)
calcCounts(model, coxMR);
calcModelStats(model, newCoef, newLoglik);
calcCumhaz_0(model, coxMR);
if (newLoglik == 0)
model._output.lre = - Math.log10(Math.abs(oldLoglik - newLoglik));
else
model._output.lre = - Math.log10(Math.abs((oldLoglik - newLoglik) / newLoglik));
if (model._output.lre >= model._parms.lre_min)
break;
Arrays.fill(step, 0);
for (int j = 0; j < n_coef; ++j)
for (int k = 0; k < n_coef; ++k)
step[j] -= model._output.var_coef[j][k] * model._output.gradient[k];
for (int j = 0; j < n_coef; ++j)
if (Double.isNaN(step[j]) || Double.isInfinite(step[j]))
break;
oldLoglik = newLoglik;
System.arraycopy(newCoef, 0, oldCoef, 0, oldCoef.length);
} else {
for (int j = 0; j < n_coef; ++j)
step[j] /= 2;
}
for (int j = 0; j < n_coef; ++j)
newCoef[j] = oldCoef[j] - step[j];
}
model.update(_key);
} catch( Throwable t ) {
t.printStackTrace();
cancel2(t);
throw t;
} finally {
_parms.unlock_frames(CoxPH.this);
Scope.exit();
done(); // Job done!
}
tryComplete();
}
Key self() { return _key; }
// /**
// * Report the relative progress of building a Deep Learning model (measured by how many epochs are done)
// * @return floating point number between 0 and 1
// */
// @Override public float progress(){
// if(UKV.get(dest()) == null)return 0;
// DeepLearningModel m = UKV.get(dest());
// if (m != null && m.model_info()!=null ) {
// final float p = (float) Math.min(1, (m.epoch_counter / m.model_info().get_params().epochs));
// return cv_progress(p);
// }
// return 0;
// }
}
private static double[][] malloc2DArray(final int d1, final int d2) {
final double[][] array = new double[d1][];
for (int j = 0; j < d1; ++j)
array[j] = MemoryManager.malloc8d(d2);
return array;
}
private static double[][][] malloc3DArray(final int d1, final int d2, final int d3) {
final double[][][] array = new double[d1][d2][];
for (int j = 0; j < d1; ++j)
for (int k = 0; k < d2; ++k)
array[j][k] = MemoryManager.malloc8d(d3);
return array;
}
protected static class CoxPHTask extends FrameTask {
private final double[] _beta;
private final int _n_time;
private final long _min_time;
private final int _n_offsets;
private final boolean _has_start_column;
private final boolean _has_weights_column;
protected long n;
protected long n_missing;
protected double sumWeights;
protected double[] sumWeightedCatX;
protected double[] sumWeightedNumX;
protected double[] sizeRiskSet;
protected double[] sizeCensored;
protected double[] sizeEvents;
protected long[] countEvents;
protected double[][] sumXEvents;
protected double[] sumRiskEvents;
protected double[][] sumXRiskEvents;
protected double[][][] sumXXRiskEvents;
protected double[] sumLogRiskEvents;
protected double[] rcumsumRisk;
protected double[][] rcumsumXRisk;
protected double[][][] rcumsumXXRisk;
CoxPHTask(Key jobKey, DataInfo dinfo, final double[] beta, final long min_time, final int n_time,
final int n_offsets, final boolean has_start_column, final boolean has_weights_column) {
super(jobKey, dinfo);
_beta = beta;
_n_time = n_time;
_min_time = min_time;
_n_offsets = n_offsets;
_has_start_column = has_start_column;
_has_weights_column = has_weights_column;
}
@Override
protected void chunkInit(){
final int n_coef = _beta.length;
sumWeightedCatX = MemoryManager.malloc8d(n_coef - (_dinfo._nums - _n_offsets));
sumWeightedNumX = MemoryManager.malloc8d(_dinfo._nums);
sizeRiskSet = MemoryManager.malloc8d(_n_time);
sizeCensored = MemoryManager.malloc8d(_n_time);
sizeEvents = MemoryManager.malloc8d(_n_time);
countEvents = MemoryManager.malloc8(_n_time);
sumRiskEvents = MemoryManager.malloc8d(_n_time);
sumLogRiskEvents = MemoryManager.malloc8d(_n_time);
rcumsumRisk = MemoryManager.malloc8d(_n_time);
sumXEvents = malloc2DArray(_n_time, n_coef);
sumXRiskEvents = malloc2DArray(_n_time, n_coef);
rcumsumXRisk = malloc2DArray(_n_time, n_coef);
sumXXRiskEvents = malloc3DArray(_n_time, n_coef, n_coef);
rcumsumXXRisk = malloc3DArray(_n_time, n_coef, n_coef);
}
@Override
protected void processRow(long gid, double [] nums, int ncats, int [] cats, double [] response) {
n++;
final double weight = _has_weights_column ? response[0] : 1.0;
if (weight <= 0)
throw new IllegalArgumentException("weights must be positive values");
final long event = (long) response[response.length - 1];
final int t1 = _has_start_column ? (int) (((long) response[response.length - 3] + 1) - _min_time) : -1;
final int t2 = (int) (((long) response[response.length - 2]) - _min_time);
if (t1 > t2)
throw new IllegalArgumentException("start times must be strictly less than stop times");
final int numStart = _dinfo.numStart();
sumWeights += weight;
for (int j = 0; j < ncats; ++j)
sumWeightedCatX[cats[j]] += weight;
for (int j = 0; j < nums.length; ++j)
sumWeightedNumX[j] += weight * nums[j];
double logRisk = 0;
for (int j = 0; j < ncats; ++j)
logRisk += _beta[cats[j]];
for (int j = 0; j < nums.length - _n_offsets; ++j)
logRisk += nums[j] * _beta[numStart + j];
for (int j = nums.length - _n_offsets; j < nums.length; ++j)
logRisk += nums[j];
final double risk = weight * Math.exp(logRisk);
logRisk *= weight;
if (event > 0) {
countEvents[t2]++;
sizeEvents[t2] += weight;
sumLogRiskEvents[t2] += logRisk;
sumRiskEvents[t2] += risk;
} else
sizeCensored[t2] += weight;
if (_has_start_column) {
for (int t = t1; t <= t2; ++t)
sizeRiskSet[t] += weight;
for (int t = t1; t <= t2; ++t)
rcumsumRisk[t] += risk;
} else {
sizeRiskSet[t2] += weight;
rcumsumRisk[t2] += risk;
}
final int ntotal = ncats + (nums.length - _n_offsets);
final int numStartIter = numStart - ncats;
for (int jit = 0; jit < ntotal; ++jit) {
final boolean jIsCat = jit < ncats;
final int j = jIsCat ? cats[jit] : numStartIter + jit;
final double x1 = jIsCat ? 1.0 : nums[jit - ncats];
final double xRisk = x1 * risk;
if (event > 0) {
sumXEvents[t2][j] += weight * x1;
sumXRiskEvents[t2][j] += xRisk;
}
if (_has_start_column) {
for (int t = t1; t <= t2; ++t)
rcumsumXRisk[t][j] += xRisk;
} else {
rcumsumXRisk[t2][j] += xRisk;
}
for (int kit = 0; kit < ntotal; ++kit) {
final boolean kIsCat = kit < ncats;
final int k = kIsCat ? cats[kit] : numStartIter + kit;
final double x2 = kIsCat ? 1.0 : nums[kit - ncats];
final double xxRisk = x2 * xRisk;
if (event > 0)
sumXXRiskEvents[t2][j][k] += xxRisk;
if (_has_start_column) {
for (int t = t1; t <= t2; ++t)
rcumsumXXRisk[t][j][k] += xxRisk;
} else {
rcumsumXXRisk[t2][j][k] += xxRisk;
}
}
}
}
@Override
public void reduce(CoxPHTask that) {
n += that.n;
sumWeights += that.sumWeights;
ArrayUtils.add(sumWeightedCatX, that.sumWeightedCatX);
ArrayUtils.add(sumWeightedNumX, that.sumWeightedNumX);
ArrayUtils.add(sizeRiskSet, that.sizeRiskSet);
ArrayUtils.add(sizeCensored, that.sizeCensored);
ArrayUtils.add(sizeEvents, that.sizeEvents);
ArrayUtils.add(countEvents, that.countEvents);
ArrayUtils.add(sumXEvents, that.sumXEvents);
ArrayUtils.add(sumRiskEvents, that.sumRiskEvents);
ArrayUtils.add(sumXRiskEvents, that.sumXRiskEvents);
ArrayUtils.add(sumXXRiskEvents, that.sumXXRiskEvents);
ArrayUtils.add(sumLogRiskEvents, that.sumLogRiskEvents);
ArrayUtils.add(rcumsumRisk, that.rcumsumRisk);
ArrayUtils.add(rcumsumXRisk, that.rcumsumXRisk);
ArrayUtils.add(rcumsumXXRisk, that.rcumsumXXRisk);
}
@Override
protected void postGlobal() {
if (!_has_start_column) {
for (int t = rcumsumRisk.length - 2; t >= 0; --t)
rcumsumRisk[t] += rcumsumRisk[t + 1];
for (int t = rcumsumXRisk.length - 2; t >= 0; --t)
for (int j = 0; j < rcumsumXRisk[t].length; ++j)
rcumsumXRisk[t][j] += rcumsumXRisk[t + 1][j];
for (int t = rcumsumXXRisk.length - 2; t >= 0; --t)
for (int j = 0; j < rcumsumXXRisk[t].length; ++j)
for (int k = 0; k < rcumsumXXRisk[t][j].length; ++k)
rcumsumXXRisk[t][j][k] += rcumsumXXRisk[t + 1][j][k];
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy