Introduction

Survival analysis under competing risks describes the time of occurrence of the first of several possible outcomes. This can be done by predicting the cause-specific hazards from a set of explanatory variables, also called covariates. Competing risks have countless applications in a system’s failure time, which includes client churn and probability of a borrower defaulting on a loan [1, 2]. In medicine, modelling competing events can be used to measure the time-to-event on several possible outcomes such as treatment effects on a patient or the prediction of the time of death after colon cancer diagnosis [3, 4].

Previous work was done on the prediction of cause-specific hazards under competing risks. Firstly, the semi-parametric Cox proportional hazards (CoxPH) model was introduced for survival analysis under the assumption of proportional hazards, namely a linear relationship between the log-hazard ratio and the covariates [5]. Because the original CoxPH model failed in the context of variable collinearity when applied to highly dimensional data, the Regularized CoxPH (RCoxPH) was introduced. This model minimizes CoxPH’s partial likelihood function with an additional elastic net penality [6]. This model had numerous uses, such as the identification of breast cancer prognosis markers [7]. Secondly, a collaspsed log-likelihood approach was developed and applied to colon cancer data [4]. This method does not rely on the proportional hazards assumptions of the CoxPH model, which improved its applicability to real-world data. It was recently implemented as a Python package in PyDTS [8]. Lastly, several studies used deep learning models to minimize a loss function adapted to datasets with censored data [9]. Multi layer perceptron models outperformed previous models in both continuous (DeepSurv) and discrete time (DeepHit) [10, 11]. These deep learning models are able to learn without strong assumptions on the predicted hazard rates; however, they were not initially designed to handle temporal covariates or produce temporal predictions, which limits their performance in survival analysis on longitudinal cohorts.

Additionally, several studies reported on the failure of the proportional hazard assumption in survival analysis, notably for treatment response and oncology [12,13,14,15]. This highlights the need for modelling competing risks with non-proportional hazards.

In various tasks involving sequential data, such as natural language processing and time series forecasting, Transformer-based models demonstrated excellent performance in learning complex dynamics from sequential data [16, 17]. Transformer models are particularly suited for sequence generation, which motivated their use in time series predictions of discrete time cause-specific hazards. A Transformer model was recently used for survival analysis with a single event [18]. In this study, we introduce a Transformer-based deep learning model for the prediction of the cause-specific hazards in discrete time under competing risks.

Because the true data-generating mechanisms that entail targeted cause-specific hazards are unknown in practice, we used synthetic data to compare our model against three state-of-the-art models [19]. We followed the ADEMP guidelines (Aims, Data-generating mechanisms, Estimands, Methods, and Performance Measures) for simulation and reporting of results [20]. We then validated our model on the English longitudinal study of ageing (ELSA) dataset for the prediction of death, dementia and psychiatric conditions [21]. To our knowledge, this is the first study to use a Transformer-based model for the prediction of the cause-specific hazards in discrete-time under competing risks.

This article is organized as follows: in “Methodology” section describes our Transformer-based model, the benchmark models, as well as the simulated and ELSA datasets; in “Discussion” section presents the predictive performance of each model on the synthetic and ELSA datasets; finally in “Conclusions” section, we discuss the present conclusions of this study.

Our codes are openly available at https://github.com/USM-CHU-FGuyon/cause_specific_hazard_transformer.

Methodology

Notations

Competing risks analysis considers a patient described by a vector of covariates X, that may experience one of E separate events on a [0, T] period of time. A patient may be censored at \(t_0 \le T\), in which case it is only known that no event occurred before \(t_0\). For convenience, competing events were denoted \(\{1, \dots , E\}\). If event e occurred at time t, the outcome is written (et) with \(e \in \{0, 1, \dots , E\}\), \(t \in [0, T]\), and \(e=0\) indicating censoring.

The cause-specific hazard \(\lambda _{e, X}(t)\), for \(e\ge 1\), defined by (1) is the instantaneous rate of occurrence of event e at time t, given that the patient remained event-free until t. A model of cause-specific hazard explores the relation between covariates X and the cause-specific hazard \(\lambda _{e, X}\) for each event e [22].

$$\begin{aligned} \lambda _{e, X}(t) = \lim _{\delta \rightarrow 0}\frac{P_{e, X}(t \le T < t+ \delta \mid T > t)}{\delta } \end{aligned}$$
(1)

Note that in discrete-time competing risks, the cause-specific hazard is defined as a probability and not as an unbounded positive number [23]. We also introduce the cumulative incidence function (2). This is a function of the cause-specific hazard that describes the proportion of patients that experienced event e up until time t.

$$\begin{aligned} I_{e, X}(t) = \sum _{\tau = 0}^{t} i_{e, X}(\tau ) \end{aligned}$$
(2)

where \(i_{e, X}\) is the incidence function defined by:

$$\begin{aligned} i_{e, X}(\tau ) = \lambda _{e,X}(\tau ) \prod _{k = 0}^\tau \left( 1 - \sum _{e \in \{1, \dots , E\}} \lambda _{e,X}(k)\right) \end{aligned}$$
(3)

The goal of this study is to build a prediction model for the cause-specific hazards \((\lambda _{e, X})_{e \in \{1, \ldots , E\}}\) from a set of covariates X. This study focused on the cause-specific hazard but did not explore the prediction of the sub-distribution hazard. In the following, X may be constant or longitudinal data.

A transformer-based model for cause-specific hazard prediction in discrete time

We used a Transformer-based deep learning model to predict the cause-specific hazard \(\lambda _{e, X}\) of each event e from covariates X. This section describes the input and output data, the loss function that was minimized and the model architecture.

Input and output data

In real-world applications, the cause-specific hazards are unknown. The available data are the covariates X and outcomes (et) where e is the experienced event—or censoring—and t the time-to-event. Our model predicts the cause-specific hazards \(\lambda _{e, X}\) of events e from the covariates X as a time series of length T. The output of the model may be written as matrix (4).

$$\begin{aligned} \lambda _X = \begin{bmatrix} \lambda _{1,1}&{}\quad \ldots &{}\quad \lambda _{1,T}\\ \vdots &{}\quad \ddots &{}\quad \vdots \\ \lambda _{E,1}&{}\quad \ldots &{}\quad \lambda _{E,T}\\ \end{bmatrix}_{E \times T} \end{aligned}$$
(4)

Loss function

The collapsed log-likelihood (5) from the PyDTS package was used as a loss function [8]. This function evaluates the consistency between the predicted cause-specific hazards \(\lambda _{X=x}\) and the observed outcome \((e_x,t_x)\).

$$\begin{aligned} L(\lambda _{X=x}, e_x, t_x) = \sum _{j = 1}^E\sum _{k = 0}^{t} \delta _{jk}^{et} \log {\lambda _{j,k}(x)} +\left( 1- \delta _{jk}^{et}\right) \log {(1-\lambda _{j,k}(x))} \end{aligned}$$
(5)

where

$$\begin{aligned} \delta _{jk}^{et} = 1 \text{ if } (j,k) = (e_x, t_x) \text{ else } 0 \end{aligned}$$

Minimizing this loss encourages:

  • A high value of \(\lambda _{e,t}(x)\); which represents the predicted hazard for the observed outcome \((e_x, t_x)\)

  • Low values of \(\lambda _{j,k}(x)\) for \((j,k) \ne (e_x,t_x)\); which represent the predicted hazard for outcomes that were not observed

Note that a patient censored at \(t_{x}\) will contribute to low values of \(\lambda _{j,k}(x)\) for each event j and each time \(k < t_{x}\).

Transformer-based model architecture

The Transformer model is a sequence-to-sequence architecture that was introduced as a response to the vanishing-gradients problem that faced long short-term memory (LSTM) and other recurrent neural networks [24]. It utilizes the self-attention mechanism in an encoder–decoder architecture to learn complex temporal features of input and/or output data. They are especially suited for producing meaningful sequential output, which initially motivated their use for NLP tasks. A gentle introduction to the Transformer architecture is provided in Appendix 1. Consequently, the Transformer architecture also proved to be efficient for time series prediction from sequential or constant input data.

Our model architecture is presented in Fig. 1. It is based on a Transformer encoder, and a linear decoder to predict cause-specific hazards as a time series for each event. An input vector of covariates X is encoded by a linear layer and concatenated with an embedding of time. A positional encoding is summed to the obtained tensor, and fed to the Transformer encoder that outputs a single time series of length \(E \times T\). This time series is then decoded into a matrix of shape (ET) by a single linear layer. The loss function (5) ensures that the model learns to predict cause-specific hazards. This model was implemented using the Pytorch framework.

Fig. 1
figure 1

Architecture of our transformer-based model. Each part of the architecture is described in detail in “Appendix 1

Performance evaluation

Benchmark models

The performance of our Transformer-based model in predicting cause-specific hazards was compared to three existing models.

Firstly, we used the semi-parametric RCoxPH model from the lifelines package in Python [25]. Secondly, we used the PyDTS model from Lee and al. [4, 8]. Finally, we implemented a model equivalent to the original DeepHit model using the Pytorch framework [11]. This contains a feed forward subnetwork with one hidden linear layer for each competing event and minimizes the loss function (5). All models predicted a time-discretized cause-specific hazard for each competing event in the form of a \(E\times T\) matrix, as presented in (4).

Benchmark designs

We evaluated all models using the same experimental setup, for both the synthetic and ELSA data. Data was split as 80% for training and 20% for validation. As described in “Loss function” section, models learned to predict patients’ cause-specific hazard for each competing event by learning from observed events in the training data. Both deep learning models had 64-neurons hidden layers and no dropout.

Additional implementation details are available in our code repository.

Synthetic data benchmark

We simulated populations of 2000—50,000 patients described by five covariates and susceptible to experience three competing events. Their covariates were independent and uniformly distributed between 0 and 1. Events were drawn using cause-specific hazard functions defined in Table 5 from Appendix. Cumulative incidences of each event, and the number of patients at risk at each time step are illustrated in Fig. 2a. Note that one of the simulated events’ hazard was proportional and the other two were non-proportional. Departure from proportional hazard hypothesis is common in clinical data, but represents a strong limitation for most survival analysis models [12].

Fig. 2
figure 2

Description this study’s data. a and b respectively illustrate underlying cause-specific hazards and the cumulative incidence of each simulated event. c illustrates the cumulative incidence function of events in the ELSA cohort

Finally, censoring times were drawn uniformly between 1 and 49. A patient was censored if the drawn censoring time was anterior to the drawn event. Events (and censoring) were drawn 10 times separately, training and evaluation were done on each drawn dataset to measure performance variability.

In this synthetic experiment, ground truth cause-specific hazards are known. For this reason, model predictions were evaluated on the mean absolute error of the cause-specific hazard prediction. We also evaluated the models’ predictive performance along simulated time, and with varying training sample size.

ELSA data benchmark The ELSA dataset is a representative cohort of the English population older than 50. It features economic, social, psychological, cognitive, health, biological and genetic data [21]. This longitudinal study currently features 9 waves of data acquired over 18 years and includes various diagnoses of cardiovascular, ocular, and psychiatric diseases.

We used this longitudinal cohort to evaluate the models’ prediction of dementia and psychiatric conditions. The ELSA dataset refers to a psychiatric condition for any of the following psychiatric disorders: hallucinations, anxiety, depression, emotional problems, schizophrenia, psychosis, mood swings, and manic depression. Our study population was the cohort from wave 2 that started in 2004. Patients already diagnosed for a psychiatric condition or dementia were excluded. Because mortality data was last updated in 2012, the study period was 2004–2012. We evaluated the models on the following competing events:

  • Dementia new diagnosis of dementia

  • Psychiatric condition new diagnosis of a psychiatric condition

  • Death

Contrary to our synthetic dataset, the ground-truth for the cause-specific hazard is unknown; hence, models were evaluated on the Integrated Brier Score and Time-dependent Concordance Index for each event [26, 27]. The Brier Score is a generalization of the mean absolute error applied to the comparison of predicted probabilities and observed event. The Concordance Index is a generalization of the area under receiver operating characteristic (AUROC), it evaluates the ranking of failure times from the predicted probabilities [28]. The Integrated Brier Score and Time-dependent concordance index are respective variants of the brier score and concordance index adapted to the prediction of time series. The mean error and \(95\%\) confidence intervals were computed by bootstrapping on the test dataset. Finally, the assumption of proportional hazards was evaluated by computing the p values of the Schoenfeld residuals from the RCoxPH model [29].

We used the Integrated Gradients method on both deep learning models to provide an importance score for the input features [30]. This method provides importance scores with a lower computational cost than Shapley values when applied with a large number of input variables and time series output. In this work, we present the total importance scores over the whole ELSA dataset; however, these scores are available at each prediction. Such importance scores were shown to improve to the usability of artificial intelligence in clinical practice [31].

Results

Evaluation on synthetic data

Simulated data

We simulated datasets of sample sizes of 2000, 5000, 10,000, 20,000, and 50,000 patients each described by 5 covariates and susceptible to experience one of 3 competing events during a period of 30 timesteps. In total, approximately \(40\%\) of patients were censored.

A sample of simulated cause-specific hazards for each event are shown on Fig. 2a. We introduced three simulated events: a Proportional hazard event that had constant hazard in time, and two non-proportional hazard events: denoted the Increasing hazard and Non-monotonic hazard events which featured a temporal evolution with a non-linear dependence on the covariates. The Non-monotonic hazard event had a bell-curve distribution where parameters of mean and standard deviation depended on patients’ covariates (see Table 5 from Appendix).

Figure 2b shows the cumulative incidence of each of the three events over the simulated time. We noted that fewer events were observed at the later timesteps of the simulated time due to a smaller number at risk.

Performance comparison

The mean absolute error of the cause-specific hazard prediction for several sizes of synthetic datasets is presented in Table 1. The Transformer-based model outperformed or equalled other models on non-proportional hazard events for all dataset sizes, and was better or equivalent to other models on the Proportional hazard event with training data \(> {5000}\) patients. These results highlights a strong performance improvement when using deep learning models on non-proportional events, moreover, the benefit of the Transformer compared to the DeepHit model was more pronounced on smaller dataset sizes. Additionally, Fig. 3 shows the mean absolute error of the cause-specific hazard predictions as a function of time. Our Transformer model had better performance on Proportional hazard event despite a lower precision at early time steps of this hazards predictions. We observed that our Transformer-based model always had a large benefit towards the end of the simulated time-frame, which indicates a better ability to extrapolate cause-specific hazards from the set of observed events. We also noted that the PyDTS and RCoxPH models had extremely poor performance on the later part of the simulated time where fewer events were observed. This was true for the Proportional hazard event, but even more pronounced for non-proportional hazard events.

Table 1 Mean absolute error of the cause-specific hazard prediction for datasets of 2000—50,000 patients
Fig. 3
figure 3

Time-dependance of the models’ performance. Performance was computed using the mean absolute error for the prediction of the cause-specific hazard for each simulated event. The Transformer model surpassed other models by a large margin on non-proportional hazard events, thanks especially to a major performance gap on the second half of the simulated time. It was also better than the DeepHit model at every single time step. This error was computed with each models being trained on a dataset of 10,000 simulated patients

Evaluation on the ELSA dataset

Collected data

The cohort size was 3564 patients. We selected 74 variables of which 54 were binary. Over the 8-year study period, there were 542 diagnoses of psychiatric conditions, 150 diagnoses of dementia, and 499 recordings of death. Cumulative incidences of each event are illustrated in Fig. 2c. The list of selected variables is shown in Table 6 from Appendix. Some variables had a large number of missing values—up to 45%—and 22 variables had more than 10% missing values. The missing values were imputed using the median value for the continuous variables, and the most frequent value for binary variables. Because evaluated models other than the Transformer and RCoxPH models do not inherently support sequential input data, we used singleton-length input data to provide a fair comparison between all models. All models learnt from input singleton-length sequences and produced cause-specific hazard predictions as a fixed-length time series.

Performance comparison

Integrated Brier scores and Time-dependent Concordance Index for each model are presented in Table 2. The mean value and 95% confidence interval were obtained by bootstrapping on the test dataset. Our Transformer-based model had the best Integrated Brier Score and Time-dependent Concordance Index. Moreover, the PyDTS model was slighlty better than the RCoxPH model, but in comparison, the Transformer model allowed for a major improvement on both metrics. Finally, despite a strong Integrated Brier Score, the DeepHit model showed a poor Concordance index on the ELSA dataset.

Table 2 Integrated brier score and time-dependent concordance index (\(C_{td}\) index) for the prediction of three competing events on the English longitudinal study of ageing dataset

Feature importance

The most important features on average for the prediction of each event by the DeepHit and Transformer models are shown on Fig. 4. See Table 6 from Appendix for details on each feature. The age feature was the most important feature for the Transformer model’s predictions. In the prediction of death, the Transformer model notably used the binary features limiting illness and cancer, which stated, respectively, ”Whether limited by longtime illness” and ”Ever diagnosed with cancer”. In the Transformer model predictions, happy mood only appeared among the important features of psychiatric condition and dementia predictions.

Fig. 4
figure 4

Seven most important features obtained from the mean integrated gradients from the Deephit (a) and Transformer (b) model using the ELSA dataset

Proportional hazard assumption

Variables that broke the proportional hazard assumption are shown in Table 3. This table lists the variables of each dataset where Schoenfeld residuals of the fitted RCoxPH model had p values lower than 0.05. In the synthetic dataset none of the five variables broke the proportional hazard assumption for the Proportional hazard event, whereas the Increasing hazard event and Non-monotonic hazard event had respectively five and four variables breaking the proportional hazard assumption. Events from the ELSA dataset had four to six Schoenfeld residuals with p values lower than 0.05. This indicates that the Death, Psychiatric condition, and Dementia events had non-proportional hazard rates.

Table 3 Variables from the English longitudinal study of ageing and synthetic datasets for which the p value of the Schoenfeld residual from the RCoxPH model was lower than 0.05

Discussion

We introduced a Transformer-based deep learning model for the prediction of cause-specific hazards in the context of discrete-time competing risks. This model provides state-of-the-art hazard prediction without strong assumption on the relation between covariates and cause-specific hazard. It strongly outperformed current models even with relatively small training datasets, and was especially successful on events with highly non-proportional hazards or few observed outcomes. We noted that basic models could perform better in a simplistic setting of time-independent proportional hazard with a small training sample; however our Transformer model was generally the best for proportional hazards too.

Our Transformer-based model had the best predictive performance of the cause-specific hazard for sizes simulated datasets ranging from 5000 to 50,000. It also had the best Integrated Brier score and Time-dependent Concordance index on the prediction of three competing events from the ELSA dataset. The experiment on simulated data showed that our model notably outperformed other models in predicting the cause-specific hazards at later time steps where fewer outcomes were observed. This resulted in improved performance on the hazard prediction of rare events, a key benefit of our model. Such behaviour could be expected because of the ability of the Transformer architecture to learn and extrapolate complex temporal features from input data and generate coherent time-series.

The analysis of the proportional hazard assumption on the synthetic data showed that only the Proportional hazard event had a proportional hazard rate. This was consistent with the definition of each event. The same analysis on the ELSA dataset indicated that all three events had non-proportional hazards, which is consistent with other findings of departure from the proportional hazard assumption in clinical data [12,13,14,15]. As a result, in both the synthetic and ELSA datasets, our model strongly outperformed current models on all events featuring non-proportional hazard rates.

Moreover, our model outperformed the DeepHit model on non-proportional hazard by a larger margin for synthetic datasets with sample sizes of 2000–10,000. This indicates that the Transformer model has a better generalization from limited data. Such results greatly increase the usability of our model on relatively small datasets such as ELSA and most longitudinal cohorts. Additionally, the interpretability through integrated gradients provided the main features that affected the result of a prediction. This can be used by clinicians to ensure trust in the model’s prediction, and focus their attention on features that it deemed most relevant. This is critical for clinical use of any machine learning model as no decision-making ought to be based on a non-explainable prediction.

Some limitations remain in our study. Firstly, our model has a large number of parameters unlike the RCoxPH and PyDTS models. While non-optimized parameters already outperform other models, fine-tuning the network size and training parameters could improve performance. Secondly, our Transformer-based model was consistently better than the simpler architecture of the DeepHit model. However, the gain in performance came with a higher computational cost. This was not limiting in our study as the training times did not exceed several minutes. Finally, to provide a fair comparison between models, only singleton-length input sequences were utilized in the data examples, as models other than the RCoxPH and Transformer were not designed for handling sequential input. This experiment did demonstrate the ability of the Transformer model to generate meaningful sequences, but did not take benefit from its ability to understand complex dynamics of input sequences.

Conclusions

This study introduces a Transformer-based deep learning model with state-of-the-art performance on the cause-specific hazard prediction in the context of discrete-time competing risks. Our model outperformed current models in cause-specific hazard prediction especially for non-proportional hazard rates and few observed outcomes. It had an increased benefit compared to current models for datasets of 2000–50,000 patients. The designs where our model shows greater benefits encompass those of most clinical survival analysis studies on longitudinal cohorts. Our Transformer-based model is ready to be used for improving current hazard predictions on longitudinal cohorts with complex covariate-to-outcome dynamics.