Deep Survival : The Deep Cox Proportional Hazards Network

Different neural network architectures have been effective in modeling survival datasets in which patients’ death times are unknown, i.e. right-censored. However, these neural networks have rarely been shown to outperform their linear counterparts, such as the Cox proportional hazards model. In this paper, we run simulated experiments and use real survival data to build upon the risk-regression architecture proposed by Faraggi-Simon. We demonstrate that a Deep Cox proportional hazards model not only works as well, but also outperforms the standard linear Cox proportional hazards model in predictive ability on survival datasets with non-linear proportional risks. We then show that the neural network can also function as a recommender system by including a categorical variable representing a treatment group. The results of our study suggest, that with modern techniques, neural networks can successfully model Cox regression and, furthermore, explain the interactions of a patient’s factors with their calculated risk.


Introduction
Medical researchers use survival models to evaluate the significance of prognostic variables in outcomes such as death or cancer recurrence and subsequently inform patients of their treatment options [1,2,3,4]. One standard survival model is the Cox proportional hazards model (CPH) [5], a semiparametric model that calculates the effects of observed covariates on the risk of an event occurring (henceforth defined as 'death'). The CPH assumes that a patient's risk of death is a linear combination of their covariates. This assumption is referred to as the linear proportional hazards condition. In many real world datasets, the assumption that the risk function is linear may be too simplistic. As such, a richer family of survival models is needed to better fit survival data with nonlinear risk functions. Since neural networks (NNs) can learn highly complex and nonlinear functions, researchers have attempted to use NNs to model the nonlinear proportional hazards of real survival datasets. However, studies have demonstrated mixed results, for example, see [6] and [7]. To the best of our knowledge, NNs have not outperformed standard methods for survival analysis such as the CPH.
There are three main approaches in the field of neural networks and survival analysis. These include variants of: (i) classification methods [see details in 8,9], (ii) time-encoded methods [see details in 10,11], (iii) and the Faraggi-Simon network [12], which implements a feed-forward neural network that estimates an individual's risk of death. The Faraggi-Simon network is seen as a nonlinear extension of the Cox proportional hazards model. Researchers have attempted to apply the Faraggi-Simon network with various extensions. However, perhaps because the practice of NNs was not as developed as it is today, they failed to demonstrate improvements beyond the linear Cox model, see [7] and [13].
An advantage of the Faraggi-Simon network is its ability to provide prognosis based on multiple prognostic features without prior selection. However, Schwarzder et al. [14] and others have raised concerns about using NNs in prognostic applications due to their tendency to overfit implausible biological functions. Therefore, further validation is needed to evaluate the prognostic abilities of the Faraggi-Simon network.
The goals of this paper are: (i) to show that the application of deep learning to survival analysis often outperforms the standard CPH; and (ii) to demonstrate how the deep neural network can be viewed as a personalized treatment recommender system and a useful framework for medical applications.
We propose a modern deep learning generalization of the Faraggi-Simon network, henceforth referred to as DeepSurv. We make the following contributions. First, we show that DeepSurv outperforms the CPH on survival data with both linear and nonlinear risk functions. Second, we include an additional categorical variable representing a patient's treatment group to illustrate how to view the network as a treatment recommender system. This, in turn, provides personalized treatments tailored to a patient's observed features. Our experimental results demonstrate that the network accurately models the risk function of the population. We validate our results on real survival data, which further demonstrates the power of the DeepSurv model. Additionally, we show that the recommender system can guide us in making decisions on personalized treatment recommendations and can potentially increase the median survival time for a set of breast cancer patients.
The organization of the manuscript is as follows: in Section 2, we provide a brief background on survival analysis. In Section 3, we present our contributions and explain the implementation of DeepSurv and our proposed recommender system. In Section 4, we describe the experimental design and results. Section 5 concludes the manuscript.

Background
In this section, we define survival data and the approaches for modeling a population's survival and death rates. Additionally, we discuss linear and nonlinear survival models and their limitations.

Survival Data
Survival data is comprised of three elements: baseline data x, an event time T , and an event indicator E. If an event (e.g. death) is observed, the time interval T corresponds to the time elapsed between the time in which the baseline data was collected and the time of the event occurring, and the event indicator is E = 1. If an event is not observed, the time interval T corresponds to the time elapsed between the collection of the baseline data and the last contact with the patient (e.g. end of study), and the event indicator is E = 0. In this case, the patient is said to be right-censored. Modeling right-censored survival data requires special consideration; if one opts to use standard regression methods, the right-censored data must be discarded.
Survival and hazard functions are the two fundamental functions in survival analysis. The survival function is denoted by S(t) = Pr(T > t), which signifies the probability that an individual has 'survived' up to time t. The hazard function corresponds to the probability that an individual dies at time t given that he or she has survived up to that point. The hazard function λ(t) is defined as: The hazard function is a measure of risk at time t. A greater hazard signifies a greater risk of failure.
A proportional hazards model is a common method for modeling an individual's survival given their baseline data x. The model assumes that the hazard function is composed of two functions: a baseline hazard function, λ 0 (t), and a risk function, h(x), denoting the effects of an individual's covariates. The hazard function is assumed to have the form λ(t|x) = λ 0 (t) · e h(x) .

Linear Survival Models
The CPH is a proportional hazards model that estimates the risk function h(x) by a linear function h β (x) = β T x. To perform Cox regression, one tunes the weights β to optimize the Cox partial likelihood. The partial likelihood is the product of the probability at each event time T i that the event has occurred to individual i, given the set of individuals still at risk at time T i . The Cox partial likelihood is parameterized by β and defined as where the values T i , E i , and x i are the respective event time, event indicator, and baseline data for the i th observation. The risk set (t) = {i : T i ≥ t} is the set of patients still at risk of death at time t.
The CPH model may be too simplistic for fitting real world datasets since they commonly do not satisfy the linear proportional hazards condition. In this case, a more complex set of features that includes higher-level interaction terms between covariates is needed.

Nonlinear Survival Models
The Faraggi-Simon method [12] is a feed-forward neural network that provides the basis for a nonlinear proportional hazards model. They experimented with a single hidden layer network with two or three nodes. Their model requires no prior assumption of the risk function h(x) other than continuity. Instead, the NN computes nonlinear features from the training data and calculates their linear combination to estimate the risk function. Similar to Cox regression, Faraggi and Simon optimize a modified Cox partial likelihood. They replace the linear combination of featuresĥ β (x) in Equation 2 with the output of the networkĥ θ (x).
As previous research suggests, the Faraggi-Simon network has not been shown to outperform the linear CPH [12,7,13]. This could be due to the fact that neural network practice at the time was not mature. Furthermore, to the best of our knowledge, there have been no attempts at applying modern deep learning techniques to the Faraggi-Simon network.

Deep Survival
In this section, we describe our contributions: first, we leverage recent deep learning techniques unavailable to Faraggi and Simon. We improve upon their network by adding additional hidden layers to the network such that the covariates of the uppermost hidden layer of the deep network are used as an input to a CPH model. Second, we introduce an open source Python module that provides an interface to train DeepSurv. Third, we demonstrate how to view the network as a prognostic model and how the network's predicted risk function can provide personalized treatment recommendations.

DeepSurv
DeepSurv is a multi-layer perceptron similar to the Faraggi-Simon network. However, we allow a deep architecture (i.e., more than one hidden layer) and apply modern techniques such as weight decay regularization, Rectified Linear Units (ReLU) [15], Batch Normalization [16], dropout [17], stochastic gradient descent with Nesterov momentum [18], gradient clipping [19], and learning rate scheduling [20]. The output of the network is a single node, which estimates the risk functionĥ θ (x) parameterized by the weights of the network θ. We set the loss function to be the negative log partial likelihood of Equation 2: In addition, we perform a Random hyper-parameter optimization search [21]; see Section 4.4 for more details.

Treatment Recommender System
It is a common practice in medical applications to determine the relationship between a patient's observable covariates and his or her risk of an event [13,3,22]. Survival models based on NNs are rarely used in clinical research because NNs tend to overfit implausible biological functions [14]. However, our results show that DeepSurv is able to accurately generalize biologically significant relationships between a patient's covariates and his or her risk of death. As a result, the network is able to provide guidance to physicians in terms of personalized treatment recommendations.
Let all patients in a given study be assigned to one of n treatment groups τ ∈ {0, 1, ..., n − 1}. We assume each treatment i to have an independent risk function h i (x). Collectively, the hazard function becomes: (4) For any patient, the network should be able to accurately predict the risk of being prescribed a given treatment. In addition, the network can compare the risk of undergoing any two treatments. For example, if we pass a patient through the network once in treatment group i and again in treatment group j, we can take the log of their hazards ratio to calculate the personal risk of prescribing one treatment option over another: We define this difference of log hazards as the recommender function or rec ij (x). In practice, when a patient receives a positive recommendation rec ij (x), treatment j is more effective than treatment i and leads to a lower risk of death. Conversely, a negative recommendation indicates that treatment j leads to a higher risk of death than treatment i. Hence, the patient should be prescribed treatment i.
DeepSurv is a tool for researchers to easily investigate more complex interactions between covariates and treatment success. DeepSurv's architecture has an advantage over the CPH because it does not require an a priori specification of treatment interaction terms, but learns these adaptively. In contrast, the CPH model is only able to provide personalized treatment recommendations when treatment interaction terms are added to the model. This requires extensive experimentation or prior biological knowledge of treatment outcomes.

Experiments
We perform three sets of experiments on: (i) simulated survival data, (ii) real survival data, and (iii) clinical treatment data. For the first set of experiments, we simulate both a linear and nonlinear risk function and show DeepSurv's superior modeling capabilities. Then, we train DeepSurv on real survival data and demonstrate the network's improved predictive ability. In addition, we verify that the network can model multiple risk functions within a population. Lastly, we demonstrate how DeepSurv's treatment recommendations can improve a population's survival rate.
To evaluate the predictive accuracy of DeepSurv, we measure the concordance-index (C-index) c as outlined by Harrell et al. [23]. For all possible pairs of patients with comparable event times (a non-comparable event, for example, is two censored patients or one patient who is censored before another's death time), a pair is concordant with the true outcomes if the patient with a higher predicted risk dies first. The C-index is the ratio of the number of concordant predictions and the set of all possible pairs. For context, a c = 0.5 is the average C-index of a random model whereas c = 1 is a perfect ranking of event times. We perform bootstrapping [24] and sample the test set with replacement to obtain confidence intervals. We report the confidence intervals (CI) of the C-indices of the bootstrapped samples for each model.

Simulated Survival Data
In this section, we perform two experiments with simulated survival data: one with a linear risk function and one with a nonlinear (Gaussian) risk function. In addition to training DeepSurv on each dataset, we run a linear CPH regression for a baseline comparison. The advantage of the simulated datasets is that we can ascertain whether DeepSurv or CPH successfully models the true risk function.
For each experiment, we generate a training, validation, and testing set of N = 10000 observations, such that an observation represents a patient vector with d = 10 covariates, each drawn from a uniform distribution on [−1, 1). We generate the death time T according to an exponential Cox model [25]: In both experiments, the risk function h(x) only depends on two of the ten covariates, and we demonstrate that DeepSurv is able to discern the relevant covariates from the noise. We then choose a censoring time to represent the 'end of study,' such that an average of 30-40 percent of the patients have an observed event in the dataset.

Linear Risk Experiment
We first simulate patients to have a linear risk function for x ∈ R d so that the linear proportional hazards assumption holds true: Because the linear proportional hazards assumption holds true, we expect the linear CPH to accurately model the risk function in Equation 7.
Our results demonstrate that DeepSurv performs as well as the standard linear Cox regression in predictive ability. However, DeepSurv reconstructs the true risk function for all patients more accurately than the linear CPH.   Figure 1(a) plots the true risk function h(x) for all patients in the test set. As shown in Figure 1(b), the CPH's estimated risk functionĥ β (x) does not perfectly model the true risk for a patient. In contrast, as shown in Figure 1(c), DeepSurv estimates the true risk function. As depicted in Figures 1(d) and 1(e), the CPH's estimated risk has a significantly larger error than that of DeepSurv, especially for patients with a high positive risk. To quantify these differences, we calculate the

Nonlinear Risk Experiment
We set the risk function to be a Gaussian with λ max = 5.0 and a scale factor of r = 0.5: The surface of the risk function is depicted in 2(a). Because this risk function is nonlinear, we do not expect the CPH to predict the risk function properly without adding quadratic terms of the covariates to the model. We expect DeepSurv to be successful in reconstructing the Gaussian risk function and predicting a patient's risk.
As shown in Figure 2, DeepSurv is more successful than the linear CPH in modeling the true risk function. Figure 2(b) demonstrates that the linear CPH regression fails to determine the first two covariates as significant. The CPH has a C-index of 0.490 (95% CI: 0.490 -0.491), which is equivalent to the performance of randomly ranking death times. DeepSurv has a higher predictive accuracy of 0.612 (95% CI: 0.611 -0.612). Furthermore, Figure 2(c) shows that DeepSurv reconstructs the Gaussian relationship between the first two covariates and a patient's risk. DeepSurv clearly outperforms the linear CPH in predictive ability and is able to learn nonlinear relationships between a patient's covariates and their risk.

Real Survival Data Experiments
We compare the performance of the CPH and DeepSurv on two datasets from real clinical studies: the Worcester Heart Attack Study (WHAS) [26] and the Molecular Taxonomy of Breast Cancer International Consortium (METABRIC) [27]. Our goal is to demonstrate that DeepSurv has superior predictive ability in medical application and practice compared to the linear CPH.

Worcester Heart Attack Study (WHAS)
The Worcester Heart Attack Study (WHAS) investigates the effects of a patient's factors on acute myocardial infraction (MI) survival [28]. To further explore the advantages of DeepSurv, we rerun the network on a reduced dataset consisting of the four factors (age, BMI, CHF, and MIORD) that the CPH found significant (p-value < 10 −6 ). We find that eliminating sex as an input feature decreases the C-index of DeepSurv to 0.748 (95% CI: 0.745 -0.750). This signifies that DeepSurv found sex to be significant in the calculation of a patient's risk. This result is expected, as Vaccarino et al. [29] have shown strong evidence that the interaction between age and sex affect MI survival.

Molecular Taxonomy of Breast Cancer International Consortium (METABRIC)
The Molecular Taxonomy of Breast Cancer International Consortium (METABRIC) uses gene and protein expression profiles to determine new breast cancer subgroups in order to help physicians provide better treatment recommendations.
The METABRIC dataset consists of gene expression data and clinical features for 1,981 patients, and 43.85 percent have an observed death due to breast cancer with a median survival time of 1,907 days [27]. We partition 20 percent of the observations for both the validation and test set.
Each gene expression profile includes 49,576 probes, representing the genes in the transcriptome. We reduce the dimension of the dataset to 14 using manual selection of features by human expert. The first four features (ERBB2, MKI67, PGR, ESR1) are probes corresponding to genes that are common indicators of breast cancer and are known to influence treatment outcomes. The other ten features are inspired by the winners of the Sage Bionetworks-DREAM Breast Cancer Prognosis Challenge (BCC), which was a competition to assess the accuracy of computational models trained on the METABRIC dataset. The competition winners found four metagene factors (CIN, MES, LYM, FGD3-SUSD3) to be high predictors of survival rates [30,4]. A metagene is the average of a set of probes representing a particular biological-pathway. We also supplement the gene expression data with the clinical variables (age at diagnosis, number of positive nodes, tumor size, ER status, HER2 status, and treatment) that the winning model showed to improve predictive performance [4].
DeepSurv outperforms the CPH model on predicting a patient's risk. DeepSurv has a C-index of 0.695 (95% CI: 0.693 -0.697). The linear CPH has a C-index of 0.688 (95% CI: 0.686 -0.690). Although DeepSurv's C-index does not seem significant in absolute terms, these results have important medical implications. Studies have shown evidence that while different commercial assays provide equivalent prognostic information at the population level, these tests differ in risk stratification on the individual level, which has a direct impact on the quality of patient care [31]. Evident from the greater Cindex, DeepSurv is able to model nonlinear interactions between a patient's covariates and his or her predicted risk. Therefore, while DeepSurv and CPH have similar prognostic abilities, we expect the two methods to differ in risk stratification. Thus, DeepSurv has significant implications for clinical application.

Treatment Recommender System Experiments
In this section, we perform two experiments to demonstrate the effectiveness of DeepSurv's treatment recommender system. First, we simulate clinical treatment data by including an additional covariate to the simulated data from Section 4.1.2. After demonstrating DeepSurv's modeling capabilities, we apply the recommender system to datasets from real clinical trials that study the effects of hormone treatment on breast cancer patients. We show that if all patients follow the network's recommended treatment option, we gain a significant increase in patient lifespan.

Simulated Treatment Data
We uniformly assign a treatment group τ ∈ {0, 1} to each simulated patient in the dataset. All of the patients in group τ = 0 were 'unaffected' by the treatment (e.g. given a placebo) and have a constant risk function h 0 (x). The other group τ = 1 is prescribed a treatment with Gaussian effects (Equation 8) and has a risk function h 1 (x) with λ max = 10 and r = 0.5. Figure 3 illustrates the network's success in predicting the risk function for patients in the test set. Figure 3(a) plots the true risk distribution h 1 (x). As expected, Figure 3(b) shows that the network models a constant risk for a patient in treatment 0, independent of a patient's covariates. Figure 3(c) shows how DeepSurv models the Gaussian effects of a patient's covariates on their treatment risk. Because the network accurately reconstructs the risk function, we expect it will provide accurate treatment recommendations for new patients.

Hormone Treatment Recommendations for Breast Cancer
We first train DeepSurv on breast cancer data from the Rotterdam tumor bank [32] and construct a recommender system to provide treatment recommendations to patients from a study by the German Breast Cancer Study Group (GBSG) [33]. We then plot the two survival curves: the survival times of those who followed the recommended treatment and those who did not. If the recommender system is effective, we expect the population with the recommended treatments to survive longer than those who did not take the recommended treatment.
The Rotterdam tumour bank dataset contains records for 1,546 patients with node-positive breast cancer, and nearly 90 percent of the patients have an observed death time. The testing data from the GBSG contains complete data for 686 patients (56 percent are censored) in a randomized clinical trial that studied the effects of chemotherapy and hormone treatment on survival rate. We preprocess the data as outlined by Royston and Altman [34].
We then validate and compare the network against a linear CPH regression baseline. The C-indices of DeepSurv and linear CPH are 0.668 (95% CI 0.667 -0.669) and 0.655 (95% CI: 0.654 -0.656), respectively. Thus, DeepSurv provides an improvement relative to the CPH.
Next, we calculate the recommender function (Equation 5) for all patients in the GBSG test set and determine the recommended treatment for each patient. We then identify two subset of patients: those whose treatment group aligns with the network's recommended treatment (Recommendation) and those who did not undergo recommended treatment (Anti-Recommendation). In Figure 4, we plot the Kaplan-Meier survival curves for both the Recommendation subset and the Anti-Recommendation subset. The survival curve for the Recommendation subset is shifted to the right, which signifies an increase in survival time for the population following DeepSurv's recommendations. The median death time of the Recommendation population versus the Anti-Recommendation population is 40.099 and 31.770 months, respectively. The majority of patients were recommended to undergo hormone therapy (tamoxifen), which is in alignment with standard medical practice [35].  : Kaplan-Meier estimated survival curves with confidence intervals (α = .05) for the patients whom were given the treatment concordant with DeepSurv's recommended treatment (Recommendation) and the subset of patients who were not (Anti-Recommendation). We perform a log-rank test to validate the significance between the two curves (p = 0.003427).

Experimental Details
We run all linear CPH regression, Kaplan-Meier estimations, concordance-index statistics, and logrank tests using the Lifelines Python package [36]. DeepSurv is implemented in Theano [37] with the Python package Lasagne [38].
The hyper-parameters of the network include: 2 regularization coefficient, learning rate, learning rate decay constant, dropout rate, momentum, and the size and depth of the network. We run the Random hyper-parameter optimization search as proposed in [21] using the Python package Optunity [39]. We perform random sampling on each hyper-parameter from a predefined range and evaluate the performance of the configuration on a validation set. We then choose the configuration with the largest validation C-index and with the smallest difference between validation C-index and training C-index to avoid models that overfit.

Summary
In conclusion, we demonstrated how deep learning can be applied to survival analysis and showed that DeepSurv is superior to the linear Cox proportional hazards model in predictive ability on survival datasets with linear and nonlinear risk functions. We illustrated that the network can provide personalized treatment recommendations for patients and can be used by physicians to guide their treatment decisions in order to improve patient lifespan. We also released a Python module that implements DeepSurv, see https://github.com/jaredleekatzman/DeepSurv for more details. With future research and development, this approach has the potential to replace traditional linear Cox regression and become a standard practice in biomedical applications.