scripts.algorithms.Cox-predict.dml Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
#
# THIS SCRIPT APPLIES THE ESTIMATED PARAMETERS OF A COX PROPORTIONAL HAZARD REGRESSION MODEL TO A NEW (TEST) DATASET
#
# INPUT PARAMETERS:
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# X String --- Location to read the input matrix containing the survival data with the following schema:
# - X[,1]: timestamps
# - X[,2]: whether an event occurred (1) or data is censored (0)
# - X[,3:]: feature vectors (excluding the baseline columns) used for model fitting
# RT String --- Location to read column matrix RT containing the (order preserving) recoded timestamps from X
# M String --- Location to read matrix M containing the fitted Cox model with the following schema:
# - M[,1]: betas
# - M[,2]: exp(betas)
# - M[,3]: standard error of betas
# - M[,4]: Z
# - M[,5]: p-value
# - M[,6]: lower 100*(1-alpha)% confidence interval of betas
# - M[,7]: upper 100*(1-alpha)% confidence interval of betas
# Y String --- Location to read matrix Y used for prediction
# COV String --- Location to read the variance-covariance matrix of the betas
# MF String --- Location to read column indices of X excluding the baseline factors if available
# P String --- Location to store matrix P containing the results of prediction
# fmt String "text" Matrix output format, usually "text" or "csv" (for matrices only)
# ---------------------------------------------------------------------------------------------
# OUTPUT:
# 1- A matrix P with the following schema:
# P[,1]: linear predictors relative to a baseline which contains the mean values for each feature
# i.e., (Y[3:] - colMeans (X[3:])) %*% b
# P[,2]: standard error of linear predictors
# P[,3]: risk relative to a baseline which contains the mean values for each feature
# i.e., exp ((Y[3:] - colMeans (X[3:])) %*% b)
# P[,4]: standard error of risk
# P[,5]: estimates of cumulative hazard
# P[,6]: standard error of the estimates of cumulative hazard
# -------------------------------------------------------------------------------------------
# HOW TO INVOKE THIS SCRIPT - EXAMPLE:
# hadoop jar SystemML.jar -f cox-predict.dml -nvargs X=INPUT_DIR/X RT=INPUT_DIR/RT M=INPUT_DIR/M Y=INPUT_DIR/Y
# COV=INTPUT_DIR/COV MF=INPUT_DIR/MF P=OUTPUT_DIR/P fmt=csv
fileX = $X;
fileRT = $RT;
fileMF = $MF;
fileY = $Y;
fileM = $M;
fileCOV = $COV;
fileP = $P;
# Default values of some parameters
fmtO = ifdef ($fmt, "text"); # $fmt="text"
X_orig = read (fileX);
RT_X = read (fileRT);
Y_orig = read (fileY);
M = read (fileM);
b = M[,1];
COV = read (fileCOV);
col_ind = read (fileMF);
tab = table (col_ind, seq (1, nrow (col_ind)), ncol (Y_orig), nrow (col_ind));
Y_orig = Y_orig %*% tab;
# Y and X have the same dimensions and schema
if (ncol (Y_orig) != ncol (X_orig)) {
stop ("Y has a wrong number of columns!");
}
X = X_orig[,3:ncol (X_orig)];
T_X = X_orig[,1];
E_X = X_orig[,2];
D = ncol (X);
N = nrow (X);
Y_orig = order (target = Y_orig, by = 1);
Y = Y_orig[,3:ncol (X_orig)];
T_Y = Y_orig[,1];
col_means = colMeans (X);
ones = matrix (1, rows = nrow (Y), cols = 1);
Y_rel = Y - (ones %*% col_means);
##### compute linear predictors
LP = Y_rel %*% b;
# compute standard error of linear predictors using the Delta method
se_LP = diag(sqrt (Y_rel %*% COV %*% t(Y_rel)));
##### compute risk
R = exp (Y_rel %*% b);
# compute standard error of risk using the Delta method
se_R = diag(sqrt ((Y_rel * R) %*% COV %*% t(Y_rel * R))) / sqrt (exp (LP));
##### compute estimates of cumulative hazard together with their standard errors:
# 1. col contains cumulative hazard estimates
# 2. col contains standard errors for cumulative hazard estimates
d_r = aggregate (target = E_X, groups = RT_X, fn = "sum");
e_r = aggregate (target = RT_X, groups = RT_X, fn = "count");
Idx = cumsum (e_r);
all_times = table (seq (1, nrow (Idx), 1), Idx) %*% T_X; # distinct event times
event_times = removeEmpty (target = (d_r > 0) * all_times, margin = "rows");
num_distinct_event = nrow (event_times);
num_distinct = nrow (all_times); # no. of distinct timestamps censored or uncensored
I_rev = table (seq (1, num_distinct, 1), seq (num_distinct, 1, -1));
e_r_rev_agg = cumsum (I_rev %*% e_r);
select = t (colSums (table (seq (1, num_distinct), e_r_rev_agg)));
min_event_time = min (event_times);
max_event_time = max (event_times);
T_Y = T_Y + (min_event_time * (T_Y < min_event_time));
T_Y = T_Y + (max_event_time * (T_Y > max_event_time));
Ind = outer (T_Y, t (event_times), ">=");
Ind = table (seq (1, nrow (T_Y)), rowIndexMax (Ind), nrow (T_Y), num_distinct_event);
exp_Xb = exp (X %*% b);
exp_Xb_agg = aggregate (target = exp_Xb, groups = RT_X, fn = "sum");
exp_Xb_cum = I_rev %*% cumsum (I_rev %*% exp_Xb_agg);
H0 = cumsum (removeEmpty (target = d_r / exp_Xb_cum, margin = "rows"));
P1 = cumsum (removeEmpty (target = d_r / exp_Xb_cum ^ 2, margin = "rows"));
X_exp_Xb = X * exp (X %*% b);
I_rev_all = table (seq (1, N, 1), seq (N, 1, -1));
X_exp_Xb_rev_agg = cumsum (I_rev_all %*% X_exp_Xb);
X_exp_Xb_rev_agg = removeEmpty (target = X_exp_Xb_rev_agg * select, margin = "rows");
X_exp_Xb_cum = I_rev %*% X_exp_Xb_rev_agg;
P2 = cumsum (removeEmpty (target = (X_exp_Xb_cum * d_r) / exp_Xb_cum ^ 2, margin = "rows"));
exp_Yb = exp (Y %*% b);
exp_Yb_2 = exp_Yb ^ 2;
Y_exp_Yb = Y * exp (Y %*% b);
# estimates of cumulative hazard
H = exp_Yb * (Ind %*% H0);
# term1
term1 = exp_Yb_2 * (Ind %*% P1);
# term2
P3 = cumsum (removeEmpty (target = (exp_Xb_cum * d_r) / exp_Xb_cum ^ 2, margin = "rows"));
P4 = (Ind %*% P2) * exp_Yb;
P5 = Y_exp_Yb * (Ind %*% P3);
term2 = P4 - P5;
# standard error of the estimates of cumulative hazard
se_H = sqrt (term1 + rowSums((term2 %*% COV) * term2));
# prepare output matrix
P = matrix (0, rows = nrow (Y), cols = 6);
P[,1] = LP;
P[,2] = se_LP;
P[,3] = R;
P[,4] = se_R;
P[,5] = H;
P[,6] = se_H;
write (P, fileP, format=fmtO);