hex.coxph.CoxPHModel Maven / Gradle / Ivy
package hex.coxph;
import hex.*;
import hex.coxph.CoxPHModel.CoxPHOutput;
import hex.coxph.CoxPHModel.CoxPHParameters;
import hex.genmodel.descriptor.ModelDescriptor;
import hex.schemas.CoxPHModelV3;
import water.*;
import water.api.schemas3.ModelSchemaV3;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.ast.prims.mungers.AstGroup;
import water.udf.CFuncRef;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.IcedInt;
import water.util.Log;
import java.util.*;
import java.util.function.Predicate;
import java.util.stream.Stream;
public class CoxPHModel extends Model {
public static class CoxPHParameters extends Model.Parameters {
public String algoName() { return "CoxPH"; }
public String fullName() { return "Cox Proportional Hazards"; }
public String javaName() { return CoxPHModel.class.getName(); }
@Override public long progressUnits() { return ((_max_iterations + 1) * 2) + 1; }
public String _start_column;
public String _stop_column;
final String _strata_column = "__strata";
public String[] _stratify_by;
public enum CoxPHTies { efron, breslow}
public CoxPHTies _ties = CoxPHTies.efron;
public double _init = 0;
public double _lre_min = 9;
public int _max_iterations = 20;
public boolean _use_all_factor_levels;
public String[] _interactions_only;
public String[] _interactions = null;
public StringPair[] _interaction_pairs = null;
public boolean _calc_cumhaz = true; // support survfit
/**
* If true, computation is performed with local jobs.
* {@link MRTask#doAll(Vec, boolean)} and other overloaded variants are during the computation called with runLocal
* set as true.
*
* Thus setting effects the main CoxPH computation only. Model metrics computation doesn't honour this setting -
* {@link ModelMetricsRegressionCoxPH#concordance()} computation ignores it.
*/
public boolean _single_node_mode = false;
String[] responseCols() {
String[] cols = _start_column != null ? new String[]{_start_column} : new String[0];
if (isStratified())
cols = ArrayUtils.append(cols, _start_column);
return ArrayUtils.append(cols, _stop_column, _response_column);
}
Vec startVec() { return train().vec(_start_column); }
Vec stopVec() { return train().vec(_stop_column); }
InteractionSpec interactionSpec() {
// add "stratify by" columns to "interaction only"
final String[] interOnly;
if (getInteractionsOnly() != null && _stratify_by != null) {
String[] io = getInteractionsOnly().clone();
Arrays.sort(io);
String[] sb = _stratify_by.clone();
Arrays.sort(sb);
interOnly = ArrayUtils.union(io, sb, true);
} else {
interOnly = getInteractionsOnly() != null ? getInteractionsOnly() : _stratify_by;
}
return InteractionSpec.create(_interactions, _interaction_pairs, interOnly, _stratify_by);
}
private String[] getInteractionsOnly() {
// clients sometimes represent empty interactions as [""] - sanitize this
if (_interactions_only != null && _interactions_only.length == 1 && "".equals(_interactions_only[0])) {
return null;
} else {
return _interactions_only;
}
}
boolean isStratified() { return _stratify_by != null && _stratify_by.length > 0; }
String toFormula(Frame f) {
StringBuilder sb = new StringBuilder();
sb.append("Surv(");
if (_start_column != null) {
sb.append(_start_column).append(", ");
}
sb.append(_stop_column).append(", ").append(_response_column);
sb.append(") ~ ");
Set stratifyBy = _stratify_by != null ? new HashSet<>(Arrays.asList(_stratify_by)) : Collections.emptySet();
Set interactionsOnly = _interactions_only != null ? new HashSet<>(Arrays.asList(_interactions_only)) : Collections.emptySet();
Set specialCols = new HashSet() {{
add(_start_column);
if (_stop_column != null)
add(_stop_column);
add(_response_column);
add(_strata_column);
if (_weights_column != null)
add(_weights_column);
if (_ignored_columns != null)
addAll(Arrays.asList(_ignored_columns));
}};
String sep = "";
for (String col : f._names) {
if (_offset_column != null && _offset_column.equals(col))
continue;
if (stratifyBy.contains(col) || interactionsOnly.contains(col) || specialCols.contains(col))
continue;
sb.append(sep).append(col);
sep = " + ";
}
if (_offset_column != null)
sb.append(sep).append("offset(").append(_offset_column).append(")");
InteractionSpec interactionSpec = interactionSpec();
if (interactionSpec != null) {
InteractionPair[] interactionPairs = interactionSpec().makeInteractionPairs(f);
for (InteractionPair ip : interactionPairs) {
sb.append(sep);
String v1 = f._names[ip.getV1()];
String v2 = f._names[ip.getV2()];
if (stratifyBy.contains(v1))
sb.append("strata(").append(v1).append(")");
else
sb.append(v1);
sb.append(":");
if (stratifyBy.contains(v2))
sb.append("strata(").append(v2).append(")");
else
sb.append(v2);
sep = " + ";
}
}
if (_stratify_by != null) {
final String tmp = sb.toString();
for (String col : _stratify_by) {
String strataCol = "strata(" + col + ")";
if (! tmp.contains(strataCol)) {
sb.append(sep).append(strataCol);
sep = " + ";
}
}
}
return sb.toString();
}
}
public static class CoxPHOutput extends Model.Output {
public CoxPHOutput(CoxPH coxPH, Frame adaptFr, Frame train, IcedHashMap strataMap) {
super(coxPH, fullFrame(coxPH, adaptFr, train));
_strataOnlyCols = new String[_names.length - adaptFr._names.length];
for (int i = 0; i < _strataOnlyCols.length; i++)
_strataOnlyCols[i] = _names[i];
_ties = coxPH._parms._ties;
_formula = coxPH._parms.toFormula(train);
_interactionSpec = coxPH._parms.interactionSpec();
_strataMap = strataMap;
_hasStartColumn = coxPH.hasStartColumn();
_hasStrataColumn = coxPH._parms.isStratified();
}
@Override
public int nclasses() {
return 1;
}
@Override
protected int lastSpecialColumnIdx() {
return super.lastSpecialColumnIdx() - 1 - (_hasStartColumn ? 1 : 0) - (_hasStrataColumn ? 1 : 0);
}
public int weightsIdx() {
if (!_hasWeights)
return -1;
return lastSpecialColumnIdx() - (hasFold() ? 1 : 0);
}
public int offsetIdx() {
if (!_hasOffset)
return -1;
return lastSpecialColumnIdx() - (hasWeights() ? 1 : 0) - (hasFold() ? 1 : 0);
}
private static Frame fullFrame(CoxPH coxPH, Frame adaptFr, Frame train) {
if (! coxPH._parms.isStratified())
return adaptFr;
Frame ff = new Frame();
for (String col : coxPH._parms._stratify_by)
if (adaptFr.vec(col) == null)
ff.add(col, train.vec(col));
ff.add(adaptFr);
return ff;
}
@Override
public ModelCategory getModelCategory() { return ModelCategory.CoxPH; }
@Override
public InteractionBuilder interactionBuilder() {
return _interactionSpec != null ? new CoxPHInteractionBuilder() : null;
}
private class CoxPHInteractionBuilder implements InteractionBuilder {
@Override
public Frame makeInteractions(Frame f) {
Model.InteractionPair[] interactions = _interactionSpec.makeInteractionPairs(f);
f.add(Model.makeInteractions(f, false, interactions, data_info._useAllFactorLevels, data_info._skipMissing, data_info._predictor_transform == DataInfo.TransformType.STANDARDIZE));
return f;
}
}
InteractionSpec _interactionSpec;
DataInfo data_info;
IcedHashMap _strataMap;
String[] _strataOnlyCols;
private final boolean _hasStartColumn;
private final boolean _hasStrataColumn;
public String[] _coef_names;
public double[] _coef;
public double[] _exp_coef;
public double[] _exp_neg_coef;
public double[] _se_coef;
public double[] _z_coef;
double[][] _var_coef;
double _null_loglik;
double _loglik;
double _loglik_test;
double _wald_test;
double _score_test;
double _rsq;
double _maxrsq;
double _lre;
int _iter;
double[][] _x_mean_cat;
double[][] _x_mean_num;
double[] _mean_offset;
String[] _offset_names;
long _n;
long _n_missing;
long _total_event;
double[] _time;
double[] _n_risk;
double[] _n_event;
double[] _n_censor;
double[] _cumhaz_0;
double[] _var_cumhaz_1;
FrameMatrix _var_cumhaz_2_matrix;
Key _var_cumhaz_2;
Key _baseline_hazard;
FrameMatrix _baseline_hazard_matrix;
Key _baseline_survival;
FrameMatrix _baseline_survival_matrix;
CoxPHParameters.CoxPHTies _ties;
String _formula;
double _concordance;
}
public static class FrameMatrix extends Storage.DenseRowMatrix {
Key _frame_key;
FrameMatrix(Key frame_key, int rows, int cols) {
super(rows, cols);
_frame_key = frame_key;
}
@SuppressWarnings("unused")
public final AutoBuffer write_impl(AutoBuffer ab) {
Key.write_impl(_frame_key, ab);
return ab;
}
@SuppressWarnings({"unused", "unchecked"})
public final FrameMatrix read_impl(AutoBuffer ab) {
_frame_key = (Key) Key.read_impl(null, ab);
// install in DKV if not already there
if (DKV.getGet(_frame_key) == null)
toFrame(_frame_key);
return this;
}
}
@Override
public ModelMetricsRegressionCoxPH.MetricBuilderRegressionCoxPH makeMetricBuilder(String[] domain) {
return new ModelMetricsRegressionCoxPH.MetricBuilderRegressionCoxPH(_parms._start_column, _parms._stop_column, _parms.isStratified(), _parms._stratify_by);
}
public ModelSchemaV3 schema() { return new CoxPHModelV3(); }
public CoxPHModel(final Key destKey, final CoxPHParameters parms, final CoxPHOutput output) {
super(destKey, parms, output);
}
@Override
protected PredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job job, boolean computeMetrics, CFuncRef customMetricFunc) {
int nResponses = 0;
for (String col : _parms.responseCols())
if (adaptFrm.find(col) != -1)
nResponses++;
DataInfo scoringInfo = _output.data_info.scoringInfo(_output._names, adaptFrm, nResponses, false);
CoxPHScore score = new CoxPHScore(scoringInfo, _output, _parms.isStratified(), null != _parms._offset_column);
final Frame scored = score
.doAll(Vec.T_NUM, scoringInfo._adaptedFrame)
.outputFrame(Key.make(destination_key), new String[]{"lp"}, null);
ModelMetrics.MetricBuilder> mb = null;
if (computeMetrics) {
mb = makeMetricBuilder(null);
}
return new PredictScoreResult(mb, scored, scored);
}
@Override
public String[] adaptTestForTrain(Frame test, boolean expensive, boolean computeMetrics) {
boolean createStrataVec = _parms.isStratified() && (test.vec(_parms._strata_column) == null);
if (createStrataVec) {
Vec strataVec = test.anyVec().makeCon(Double.NaN);
_toDelete.put(strataVec._key, "adapted missing strata vector");
test.add(_parms._strata_column, strataVec);
}
String[] msgs = super.adaptTestForTrain(test, expensive, computeMetrics);
if (createStrataVec) {
Vec strataVec = CoxPH.StrataTask.makeStrataVec(test, _parms._stratify_by, _output._strataMap, _parms._single_node_mode);
_toDelete.put(strataVec._key, "adapted missing strata vector");
test.replace(test.find(_parms._strata_column), strataVec);
if (_output._strataOnlyCols != null)
test.remove(_output._strataOnlyCols);
}
return msgs;
}
@Override
protected String[] adaptTestForJavaScoring(Frame test, boolean computeMetrics) {
return super.adaptTestForTrain(test, true, computeMetrics);
}
private static class CoxPHScore extends MRTask {
private DataInfo _dinfo;
private double[] _coef;
private double[] _lpBase;
private int _numStart;
private boolean _hasStrata;
private CoxPHScore(DataInfo dinfo, CoxPHOutput o, boolean hasStrata, boolean hasOffsets) {
final int strataCount = o._x_mean_cat.length;
_dinfo = dinfo;
_hasStrata = hasStrata;
_coef = hasOffsets ? ArrayUtils.append(o._coef, 1.0) : o._coef;
_numStart = o._x_mean_cat[0].length;
_lpBase = new double[strataCount];
for (int s = 0; s < strataCount; s++) {
for (int i = 0; i < o._x_mean_cat[s].length; i++)
_lpBase[s] += o._x_mean_cat[s][i] * _coef[i];
for (int i = 0; i < o._x_mean_num[s].length; i++)
_lpBase[s] += o._x_mean_num[s][i] * _coef[i + _numStart];
}
}
@Override
public void map(Chunk[] chks, NewChunk nc) {
DataInfo.Row r = _dinfo.newDenseRow();
for (int rid = 0; rid < chks[0]._len; ++rid) {
_dinfo.extractDenseRow(chks, rid, r);
if (r.predictors_bad) {
nc.addNA();
continue;
} else if (r.weight == 0) {
nc.addNum(0);
continue;
}
final double s = _hasStrata ? chks[_dinfo.responseChunkId(0)].atd(rid) : 0;
final boolean unknownStrata = Double.isNaN(s);
if (unknownStrata) {
nc.addNA();
} else {
final double lp = r.innerProduct(_coef) - _lpBase[(int) s];
nc.addNum(lp);
}
}
}
}
@Override public double[] score0(double[] data, double[] preds) {
throw new UnsupportedOperationException("CoxPHModel.score0 should never be called");
}
protected Futures remove_impl(Futures fs, boolean cascade) {
remove(fs, _output._var_cumhaz_2);
remove(fs, _output._baseline_hazard);
remove(fs, _output._baseline_survival);
super.remove_impl(fs, cascade);
return fs;
}
private void remove(Futures fs, Key key) {
Frame fr = key != null ? key.get() : null;
if (fr != null) {
fr.remove(fs);
}
}
@Override
public CoxPHMojoWriter getMojo() {
return new CoxPHMojoWriter(this);
}
@Override
public boolean haveMojo() {
final boolean enabled = super.haveMojo() && hasOnlyNumericInteractions();
if (! enabled) {
boolean forceEnabled = H2O.getSysBoolProperty("coxph.mojo.forceEnable", false);
if (forceEnabled) {
Log.warn("Model " + this._key + " doesn't technically support MOJO, but MOJO support was force-enabled.");
return true;
}
}
return enabled;
}
boolean hasOnlyNumericInteractions() {
if (_output._interactionSpec == null) {
return true;
}
return Stream.of(_output.data_info._interactions)
.allMatch(InteractionPair::isNumeric);
}
@Override
public ModelDescriptor modelDescriptor() {
return new CoxPHModelDescriptor(extraMojoFeatures());
}
public String[] extraMojoFeatures() {
InteractionSpec interactionSpec = _parms.interactionSpec();
if (interactionSpec == null) {
return new String[0];
}
String[] interactionsOnly = interactionSpec.getInteractionsOnly();
if (interactionsOnly == null) {
return new String[0];
}
Set alreadyExported = new HashSet<>(Arrays.asList(_output._names));
return Stream.of(interactionsOnly)
.filter(((Predicate) alreadyExported::contains).negate())
.toArray(String[]::new);
}
class CoxPHModelDescriptor extends H2OModelDescriptor {
private final String[] _extraMojoFeatures;
private CoxPHModelDescriptor(String[] extraMojoFeatures) {
_extraMojoFeatures = extraMojoFeatures;
}
@Override
public int nfeatures() {
return super.nfeatures() + _extraMojoFeatures.length;
}
@Override
public String[] features() {
return ArrayUtils.append(super.features(), _extraMojoFeatures);
}
@Override
public String[] columnNames() {
return ArrayUtils.insert(super.columnNames(), _extraMojoFeatures, super.nfeatures());
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy