Abstract
Precision medicine is a framework that adapts treatment strategies to a patient’s individual characteristics and provides helpful clinical decision support. Existing research has been extended to various situations but high-dimensional data have not yet been fully incorporated into the paradigm. We propose a new precision medicine approach called deep doubly robust outcome weighted learning (DDROWL) that can handle big and complex data. This is a machine learning tool that directly estimates the optimal decision rule and achieves the best of three worlds: deep learning, double robustness, and residual weighted learning. Two architectures have been implemented in the proposed method, a fully-connected feedforward neural network and the Deep Kernel Learning model, a Gaussian process with deep learning-filtered inputs. We compare and discuss the performance and limitation of different methods through a range of simulations. Using longitudinal and brain imaging data from patients with Alzheimer’s disease, we demonstrate the application of the proposed method in real-world clinical practice. With the implementation of deep learning, the proposed method can expand the influence of precision medicine to high-dimensional abundant data with greater flexibility and computational power.
Similar content being viewed by others
Data availability
The code used to generate simulated data can be shared upon request. The National Alzheimer’s Coordinating Center database needs to be requested at https://naccdata.org/.
Code availability
Code can be shared upon request and will be shared on GitHub.
Notes
The NACC database can be accessed at https://www.alz.washington.edu/WEB/researcher _home.html.
References
Angrist, J. D., Imbens, G. W., & Rubin, D. B. (1996). Identification of causal effects using instrumental variables. Journal of the American Statistical Association, 91(434), 444–455.
Athey, S., & Wager, S. (2021). Policy learning with observational data. Econometrica, 89(1), 133–161.
Barthold, D., Joyce, G., Diaz Brinton, R., Wharton, W., Kehoe, P. G., & Zissimopoulos, J. (2020). Association of combination statin and antihypertensive therapy with reduced Alzheimer’s disease and related dementia risk. PLoS ONE, 15(3), e0229541.
Beekly, D. L., Ramos, E. M., van Belle, G., Deitrich, W., Clark, A. D., Jacka, M. E., Kukull, W. A., et al. (2004). The national Alzheimer’s coordinating center (nacc) database: An Alzheimer disease database. Alzheimer Disease & Associated Disorders, 18(4), 270–277.
Bennett, A. & Kallus, N. (2020). Efficient policy learning from surrogate-loss classification reductions. In International conference on machine learning, pp. 788–798. PMLR.
Bergstra, J., Yamins, D., & Cox, D. D. (2013). Making a science of model search: Hyperparameter optimization in hundreds of dimensions for vision architectures. Jmlr.
Bergstra, J.S., R. Bardenet, Y. Bengio, & Kégl, B. (2011). Algorithms for hyper-parameter optimization. In Advances in neural information processing systems, pp. 2546–2554.
Besser, L., Kukull, W., Knopman, D. S., Chui, H., Galasko, D., Weintraub, S., Jicha, G., Carlsson, C., Burns, J., Quinn, J., et al. (2018). Version 3 of the national Alzheimer’s coordinating center’s uniform data set. Alzheimer Disease and Associated Disorders, 32(4), 351.
Dudík, M., J. Langford, & Li, L. (2011). Doubly robust policy evaluation and learning. arXiv:1103.4601.
Duron, E., Rigaud, A. S., Dubail, D., Mehrabian, S., Latour, F., Seux, M. L., & Hanon, O. (2009). Effects of antihypertensive therapy on cognitive decline in Alzheimer’s disease. American Journal of Hypertension, 22(9), 1020–1024.
Friedman, J., Hastie, T., & Tibshirani, R. (2010). Regularization paths for generalized linear models via coordinate descent. Journal of Statistical Software, 33(1), 1.
Gardner, J., Pleiss, G., Weinberger, K. Q., Bindel, D., & Wilson A. G. (2018). Gpytorch: Blackbox matrix-matrix gaussian process inference with gpu acceleration. In Advances in neural information processing systems, pp. 7587–7597.
Gorgolewski, K., Burns, C. D., Madison, C., Clark, D., Halchenko, Y. O., Waskom, M. L., & Ghosh, S. S. (2011). 08 Nipype: A flexible, lightweight and extensible neuroimaging data processing framework in python. Frontiers in Neuroinformatics. https://doi.org/10.3389/fninf.2011.00013
Greenland, S., Pearl, J., & Robins, J. M. (1999). Confounding and collapsibility in causal inference. Statistical Science, 14(1), 29–46.
Guo, W., Zhou, X. H., & Ma, S. (2021). Estimation of optimal individualized treatment rules using a covariate-specific treatment effect curve with high-dimensional covariates. Journal of the American Statistical Association, 116(533), 309–321.
Hajjar, I., Hart, M., Chen, Y. L., Mack, W., Milberg, W., Chui, H., & Lipsitz, L. (2012). Effect of antihypertensive therapy on cognitive function in early executive cognitive impairment: A double-blind randomized clinical trial. Archives of Internal Medicine, 172(5), 442–444.
He, K., X. Zhang, S. Ren & Sun J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778.
Hernán, M. A., & Robins, J. M. (2020). Causal inference: What if. Chapman & Hall/CRC.
Holloway, S. T., Laber, E. B., Linn, K. A., Zhang, B., Davidian, M., & Tsiatis, A. A. (2018). DynTxRegime: Methods for estimating optimal dynamic treatment regimes. R package version, 3, 2.
Janocha, K. & Czarnecki, W. M. (2017). On loss functions for deep neural networks in classification. arXiv:1702.05659 .
Jiang, X., Nelson, A. E., Cleveland, R. J., Beavers, D. P., Schwartz, T. A., Arbeeva, L., Alvarez, C., Callahan, L. F., Messier, S., Loeser, R., et al. (2021). Precision medicine approach to develop and internally validate optimal exercise and weight-loss treatments for overweight and obese adults with knee osteoarthritis: Data from a single-center randomized trial. Arthritis Care & research, 73(5), 693–701.
Kallus, N. (2018). Balanced policy evaluation and learning. In Advances in neural information processing systems, pp. 8909–8920.
Kallus, N. (2020). Generalized optimal matching methods for causal inference. Journal of Machine Learning Research, 21, 62–1.
Kang, J.D. & Schafer, J. L. (2007). Demystifying double robustness: A comparison of alternative strategies for estimating a population mean from incomplete data.
Kosorok, M. R., & Laber, E. B. (2019). Precision medicine. Annual review of statistics and its application, 6, 263–286.
Leete, O.E., N. Kallus, M.G. Hudgens, S. Napravnik, and M.R. Kosorok. 2019. Balanced policy evaluation and learning for right censored data. arXiv:1911.05728.
Liang, M., Ye, T., & Fu, H. (2018). Estimating individualized optimal combination therapies through outcome weighted deep learning algorithms. Statistics in Medicine, 37(27), 3869–3886.
Liaw, A., Wiener, M., et al. (2002). Classification and regression by randomforest. R news, 2(3), 18–22.
Liaw, R., Liang, E., Nishihara, R., Moritz, P., Gonzalez, J. E. & Stoica , I. (2018). Tune: A research platform for distributed model selection and training. arXiv:1807.05118.
Liu, Y., Wang, Y., Kosorok, M. R., Zhao, Y. & D. Zeng. 2016. Robust hybrid learning for estimating personalized dynamic treatment regimens. arXiv:1611.02314.
Michel, J. (2020). March. ENNUI: An elegant neural network user interface.
Nie, X., Brunskill, E., & Wager, S. (2021). Learning when-to-treat policies. Journal of the American Statistical Association, 116(533), 392–409.
Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E., DeVito, Z., Lin, Z., Desmaison, A., Antiga, L. & Lerer, A. (2017). Automatic differentiation in pytorch. NIPS.
Plis, S. M., Hjelm, D. R., Salakhutdinov, R., Allen, E. A., Bockholt, H. J., Long, J. D., Johnson, H. J., Paulsen, J. S., Turner, J. A., & Calhoun, V. D. (2014). Deep learning for neuroimaging: A validation study. Frontiers in Neuroscience, 8, 229.
Provost, F (2000). Machine learning from imbalanced data sets 101. In Proceedings of the AAAI’2000 workshop on imbalanced data sets, Vol. 68, pp. 1–3. AAAI Press.
Qian, M., & Murphy, S. A. (2011). Performance guarantees for individualized treatment rules. The Annals of Statistics, 39(2), 1180.
Robins, J. M., Rotnitzky, A., & Zhao, L. P. (1994). Estimation of regression coefficients when some regressors are not always observed. Journal of the American Statistical Association, 89(427), 846–866.
Scharfstein, D. O., Rotnitzky, A., & Robins, J. M. (1999). Adjusting for nonignorable drop-out using semiparametric nonresponse models. Journal of the American Statistical Association, 94(448), 1096–1120.
Shah, K., Qureshi, S. U., Johnson, M., Parikh, N., Schulz, P. E., & Kunik, M. E. (2009). Does use of antihypertensive drugs affect the incidence or progression of dementia? a systematic review. The American Journal of Geriatric Pharmacotherapy, 7(5), 250–261.
Shi, C., Blei, D. M. & Veitch V. (2019). Adapting neural networks for the estimation of treatment effects. arXiv:1906.02120.
Sverdrup, E., Kanodia, A., Zhou, Z., Athey, S., & Wager, S. (2020). policytree: Policy learning via doubly robust empirical welfare maximization over trees. Journal of Open Source Software, 5(50), 2232.
Talo, M., Baloglu, U. B., Yıldırım, Ö., & Acharya, U. R. (2019). Application of deep transfer learning for automated brain abnormality classification using mr images. Cognitive Systems Research, 54, 176–188.
Tibshirani, J., Athey, S., Sverdrup, E. & Wager S. (2023). grf: Generalized random forests. R package version 2.3.0.
Wainberg, M., Merico, D., Delong, A., & Frey, B. J. (2018). Deep learning in biomedicine. Nature Biotechnology, 36(9), 829–838.
Wilson, A.G., Hu, Z., Salakhutdinov, R. & Xing E.P. (2016). Deep kernel learning. In Artificial intelligence and statistics, pp. 370–378.
Zhang, B., Tsiatis, A. A., Laber, E. B., & Davidian, M. (2012). A robust method for estimating optimal treatment regimes. Biometrics, 68(4), 1010–1018.
Zhao, Y., Zeng, D., Rush, A. J., & Kosorok, M. R. (2012). Estimating individualized treatment rules using outcome weighted learning. Journal of the American Statistical Association, 107(499), 1106–1118.
Zhao, Y. Q., Laber, E. B., Ning, Y., Saha, S., & Sands, B. E. (2019). Efficient augmentation and relaxation learning for individualized treatment rules using observational data. Journal of Machine Learning Research, 20(48), 1–23.
Zhou, X. & Kosorok, M. R. (2017). Augmented outcome-weighted learning for optimal treatment regimes. arXiv:1711.10654.
Zhou, X., Mayer-Hamblett, N., Khan, U., & Kosorok, M. R. (2017). Residual weighted learning for estimating individualized treatment rules. Journal of the American Statistical Association, 112(517), 169–187.
Zhou, Z., Athey, S., & Wager, S. (2023). Offline multi-action policy learning: Generalization and optimization. Operations Research, 71(1), 148–183.
Funding
There is no funding source to disclose.
Author information
Authors and Affiliations
Contributions
Concept and design: XJ, MRK. Acquisition, analysis, or interpretation of data: XJ, XZ, MRK. Drafting of the manuscript: XJ. Critical revision of the manuscript: XJ, XZ, MRK.
Corresponding author
Ethics declarations
Conflict of interest
There are no conflicts of interests to disclose.
Additional information
Editor: María Óskarsdóttir.
Publisher's Note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Supplementary Information
Below is the link to the electronic supplementary material.
Appendix
Appendix
1.1 Appendix A: More on data extraction and preprocessing of the clinical data
Different centers have different data collection methods and policies during different time periods which might not conform to each other. Given that such conformity issues could create noisy heterogeneity in a problematic way, some stringent inclusion and exclusion criteria were applied to keep the multi-center, multi-stage data relatively clean. We found that there were more missing data in earlier form versions and later visits, and the most information was collected at the initial visits. Since we were interested in baseline information, only observations of the first visit of each subject with the latest form version were included. For categorical variables, categories such as unknown (e.g., 9, 99, 999, 8888, 9999 values), not applicable (form submitted did not collect such data or a skip pattern precludes such responses), or left blank were considered missing. For continuous variables, indicators of unknown, not assessed, and not available (e.g., − 4, 888.8) and extreme values outside of the normal range (e.g., height more than 80 inches) were considered missing.
Because the UDS dataset has many forms/variables that contribute to the assessment of the subject’s cognitive status, covariates that have moderate to high estimated Pearson correlations (\(> 0.5\)) with the outcome variable were excluded to avoid multicollinearity. Covariates with high estimated Pearson correlation (\(> 0.8\)) with other non-outcome covariates were excluded as well (e.g., height/weight and body mass index, and various CDR® scores). The ID data, data containing information that links with UDS and MRI, were processed similarly as the UDS data, such as removing severe missing data and multicollinearity.
There were 48 clinical variables included in the analysis (Table 8). They are: demographics (age, gender, smoking status, research center), vitals (blood pressure, resting heart rate), previous and current medication/therapy (antiadrenergic, anxiolytic, anticoagulant, antidepressant, angiotensin converting enzyme inhibitor, antipsychotic, diuretic, estrogen hormone therapy, lipid lowering medication, nonsteroidal anti-inflammatory, vasodilator), previous medical history (anxiety, diabetes, angioplasty, cardiac bypass procedure, vascular brain injury, heart valve replacement, microhemorrhage, incontinence, insomnia, obsessive-compulsive disorder, Parkinson’s disease), number of visits, and time to complete trail making test.
Variables related to smoking (e.g., consumption of tobacco in the past 30 days, number of years smoking, whether or not the subject quit smoking, etc.) were combined into two variables: an indicator of current smoker (has not quit smoking, or has consumed tobacco in the last 30 days) and an indicator of ever smoked (smoked cigarettes in the last 30 days, more than 100 cigarettes in life, non-zero years of smoking, or at least 1 cigarette smoked per day on average). Diabetes was redefined as a binary variables with 1 representing Type 1, Type 2, or other type of diabetes such as diabetes insipidus, latent autoimmune diabetes or Type 1.5, and gestational diabetes, and 0 representing no diabetes reported.
There were a total of 5,616 MRI sessions available. Each session contains multiple DICOM files, with one file representing one MRI slice. The DICOM format was preferred because it contains image information such as slice position and sequence type in the headers. We extracted the MRI slices from each subject’s MRI session in a compressed folder and removed slices without series description or image position because series description informs the sequence type of the MRI scan and image position helps sort the slices in the right order. Since the consecutive slices differed by a matter of milliseconds, we selected every 5th slice from the 150 middle slices to save space and maintain the same reasonable image dimension for every subject. End slices were discarded as they contain less useful information about the brain. We restricted to one MRI per subject because we were only interested in baseline covariates and wanted to keep the input dimension consistent.
The preprocessed ID and preprocessed UDS were merged by subject ID, their center, and visit year. Only complete cases were used because addressing data incompleteness was not our main research focus here, and imputation on such multi-center observations was often unreliable or needed extremely careful manipulation. In the preprocessed data (before merging with MRI scans but after merging UDS with ID), there were 424 observations collected from 12 centers spanning from 2015 to 2019. Among the 424 subjects, 48% had better or maintained normal cognitive status and 48% currently used antihypertensive or blood pressure medication at the initial visit. The preprocessed MRI data were merged with the medical data by file locator information so only subjects who had qualified UDS, ID, and MRI data were included, resulting in a sample size of 186. We further excluded 24 subjects whose MRI date was more than 200 days away from the initial visit date. The final sample size for the preprocessed data was 162. As mentioned above, we applied a relatively strict inclusion criteria to make sure the input data for the DL models were relatively clean and similar. The MRI data were only lightly processed to preserve the original values but could be piped through more systematic image processing tools.
1.2 Appendix B: More on transfer learning and ResNet34
Transfer learning can be regarded as a feature selection tool because images often contain a large amount of nuisance pixels. The outputs of ResNet34 prior to the dense layers have a lowered dimension of 1000, much smaller and more extracted than the original dimension. If we fed them together with medical data directly into a DL architecture, the MRI data would dominate the dimension. An alternative to transfer learning is applying unsupervised learning such as autoencoder (AE) to the MRI data. AE is a good dimension reduction method but the encoded outputs are sometimes not necessarily good representations of the original input.
In addition to the reasons mentioned in the main text, ResNet34 was chosen as the pretrained model because it has lower model complexity and relatively low top-1 and top-5 errors on the ImageNet data compared wth other famous deep learning architectures such as AlexNet or VGG. Top-1 error means the proportion of test images whose true label does not match with the prediction class with the highest estimated probability. Top-5 error means the proportion of test images whose true label is not among the 5 prediction classes with the top 5 highest estimated probabilities. We applied a pretrained model instead of training our own structure because the lower-level representations extracted from the earlier layers of existing models are generally transferrable across images.
1.3 Appendix C: Definition of Value Functions
We used the same CV definition of value function estimator and its variance estimator as in Jiang et al. (2021). Let \(j=1,\ldots , MK\) denote all MK tuning folds regardless of repetition across M repetitions and K CV folds and \(i = 1, \ldots , n_j\) be the ith observation in the jth overall fold. We applied cross validation on a dataset of size \(n_{tr}\) which will be split into training and validation sets. The CV estimated value function was used to compare tuning performance:
where \(W_{ji} = \frac{1\{A_{ji} = \hat{d}_{n_{tr}}^{(-j)} ({\varvec{X}}_{ji}) \}}{\hat{P}^{(-j)}(A_{ji} \vert {\varvec{X}}_{ji})}\), \(U_{ji} = Y_{ji}W_{ji}\), \(\hat{d}_{n_{tr}}^{(-j)}\) is the decision rule estimated from the dataset of size \(n_{tr}\) with the jth fold left out, and \(\hat{P}^{(-j)}(A_{ji} \vert {\varvec{X}}_{ji})\) is the estimated propensity score. Its variance is used to measure the estimation uncertainty
where \(R_{ji} = \frac{1}{\bar{W}_j} U_{ji} - \frac{\bar{U}_j}{\bar{W}_j^2} W_{ji}\) is an influence function-inspired form of the value function with \(\bar{U}_j = \sum _{i=1}^{n_j}U_{ji}\) and \(\bar{W}_j = \sum _{i=1}^{n_j} W_{ji}\). By definition, \(\sum _{j=1}^{MK}\sum _{i=1}^{n_j} R_{ji} = 0\).
For testing results, there is no CV so \(j = 1, \ldots , M\). The estimated value function becomes
where \(n_{te}\) is sample size of the testing set, and \(U_i, W_i\) are defined similarly as \(U_{ji}\) and \(W_{ji}\) but with \(i=1,\ldots , n_{te}\) and decision rule \(\hat{d}_{n_{te}}\). Its variance is given as
where \(\hat{d}_{n_{te}, j}\) is the single estimated decision rule from the jth repetition and \(\bar{V}(\hat{d}_{n_{te}, M}) = \sum _{j=1}^{M} \widehat{V}(\hat{d}_{n_{te}, j})\) is the average estimated value functions over M single estimated decision rules. The SD is the square root of the variance.
The performance of the model during tuning and testing was determined by higher estimated value functions and lower SDs.
Rights and permissions
Springer Nature or its licensor (e.g. a society or other partner) holds exclusive rights to this article under a publishing agreement with the author(s) or other rightsholder(s); author self-archiving of the accepted manuscript version of this article is solely governed by the terms of such publishing agreement and applicable law.
About this article
Cite this article
Jiang, X., Zhou, X. & Kosorok, M.R. Deep doubly robust outcome weighted learning. Mach Learn 113, 815–842 (2024). https://doi.org/10.1007/s10994-023-06484-w
Received:
Revised:
Accepted:
Published:
Issue Date:
DOI: https://doi.org/10.1007/s10994-023-06484-w