1 Introduction

Machine learning has been one of the driving forces for the huge progress in medical imaging analysis over the last years. Of key importance for learning-based techniques is the training dataset that is used for estimating the model parameters. Including all available data in a training set is becoming increasingly impractical, since processing the data to create training models can be very time consuming on huge datasets. In addition, most processing may be unnecessary because it does not help the model estimation for a given task. In this work, we propose a method to select a subset of the data for training that is most relevant for a specific task. Foreshadowing some of our results, such a guided selection of a subset for training can lead to a higher performance than using all the available data while requiring only a fraction of the processing time.

The task of selecting a subset of the data for training is challenging because at the time of making the decision, we do not yet have processed the data and we do therefore not know how the inclusion of the sample would affect the prediction. On the other hand, in many scenarios each image is assigned metadata about the subject (sex, diagnosis, age, etc.) or the image acquisition (dataset of origin, location, imaging device, etc.). We hypothetize that some of this information can be useful to guide the selection of samples but it is a priori not clear which information is most relevant and how it should be used to guide the selection process. To address this, we formulate the selection of the samples to be included in a training set as reinforcement learning problem, where a trade-off must be reached between the exploration of new sources of data and the exploitation of sources that have been shown to lead to informative data points in the past. More specifically, we model this as a multi-armed bandit problem solved with Thompson sampling, where each arm of the bandit corresponds to a cluster of samples generated using meta information.

In this paper, we apply our sample selection method to brain age estimation [7] from MRI T1 images. The estimated age serves as a proxy for biological age, whose difference to the chronological age can be used as indicator of disease [6]. The age estimation is a well-suited application for testing our algorithm as it allows us to work with a large number of datasets, since the subject’s age is one of the few variables that is included in every neuroimaging dataset.

1.1 Related Work

Our work is mostly related to active learning approaches, whose aim is to select samples to be labeled out of a pool of unlabeled data. Examples of active learning approaches applied to medical imaging tasks include the work by Hoi et al. [9], where a batch mode active learning approach was presented for selecting medical images for manually labeling the image category. Another active learning approach was proposed for the selection of histopathological slices for manual annotation in [21]. The problem was formulated as constrained submodular optimization problem and solved with a greedy algorithm. To select a diverse set of slices, the patient identity was used as meta information. From a methodological point of view, our work relates to the work of Bouneffouf et al. [1], where an active learning strategy based on contextual multi-armed bandits is proposed. The main difference between all these active learning approaches and our method is that image features are not available a priori in our application, and therefore can not be used in the sample selection process. Our work also relates to domain adaptation [15, 20]. In instance weighting, the training samples are assigned weights according to the distribution of the labels (class imbalance) [10] and the distribution of the observations (covariate shift) [16]. Again these methods are not directly applicable in our scenario because the distribution of the metadata is not always defined on the target dataset.

2 Method

2.1 Incremental Sample Selection

In supervised learning, we model a predictive function \({f}:(\mathbf{x}, {\mathbf {p}})\mapsto y\) depending on a parameter vector \({\mathbf {p}}\), relating an observation \(\mathbf {x}\) to its label y. In our application, \(\mathbf{x}\in \mathbb {R}^{m}\) is a vector with m quantitative brain measurements from the image and \(y \in \mathbb {R}\) is the age of the subject. The parameters \({\mathbf {p}}\) are estimated by using a training set \({S^T}=\{{s}_1,{s}_2, \ldots , {s}_{N_{train}}\}\), where each sample \({s}= ({\mathbf {x}}, {y})\) is a pair of a feature vector and its associated true label. Once the parameters are estimated, we can predict the label \(\tilde{y}\) for a new observation \(\tilde{\mathbf{x}}\) with \(\tilde{y} = f(\tilde{\mathbf{x}},{\mathbf {p}}^*)\), where the prediction depends on the estimated parameters and therefore the training dataset. In our scenario, the samples to be included in the training set \({S^T}\) are selected from a large source set \({S}= \{h_1,h_2,..,h_{N_{total}}\}\) containing hidden samples of the form \(h= \{\hat{{\mathbf {x}}}, \hat{{y}}, \mathbf {m}\}\). Each h contains hidden features \(\hat{{\mathbf {x}}}\) and label \(\hat{{y}}\) that can only be revealed after processing the sample. In addition, each hidden sample possesses a d-dimensional vector of metadata \(\mathbf {m}\in \mathbb {Z}^d\) that encodes characteristics of the patient or the image such as sex, diagnosis, and dataset of origin. In contrast to \(\hat{{\mathbf {x}}}\) and \(\hat{{y}}\), \(\mathbf {m}\) is known a priori and can be observed at no cost. To include a sample h from set \({S}\) into \({S^T}\), first its features and labels have to be revealed, which comes at a high cost. Consequently, we would like to find a sampling strategy that minimizes the cost by selecting only the most relevant samples according to the metadata \(\mathbf {m}\).

2.2 Multiple Partitions of the Source Data

In order to guide our sample selection algorithm, we create multiple partitions of the source dataset, where each one considers different information from the metadata \(\mathbf {m}\). Considering the j-th meta information (\(1 \le j \le d\)), we create the j-th partition \({S}= \cup _{i=1}^{\eta _j} C_i^j\) with \(\eta _j\) a predefined number of bins for \(\mathbf {m}[j]\). As a concrete example, sex could be used for partitioning the data, so \({S}= C_{\text {female}}^\text {sex} \cup C_{\text {male}}^\text {sex}\) and \(\eta _\text {sex} = 2\). In the case of continuous variables such as age, partitions can be done by quantizing the variable into bins. All the clusters generated using different meta information are merged into a set of clusters \(\mathcal {C} = \{ C_\iota ^j \}\). Since partitions can be done using different elements of \(\mathbf {m}\) a sample can be assigned to more than one cluster.

We hypothesize that given this partitioning, there exist clusters \(C_i \in \mathcal {C}\) that contain more relevant samples than others for a specific task. Intuitively, we would like to draw samples h from clusters with a higher probability of returning a relevant sample. However, since the relationship between the metadata and the task is uncertain, the utility of each cluster for a specific task is unknown beforehand. We will now describe a strategy that simultaneously explores the clusters to find out which ones contain more relevant information and exploits them by extracting as many samples from relevant clusters as possible.

2.3 Sample Selection as a Multi-armed Bandit Problem

We model the task of sequential sample selection as a multi-armed bandit problem. At each iteration \({t}\), a new sample is added to the training dataset \(S^T\). For adding a sample, the algorithm decides which cluster \(C_i \in \mathcal {C}\) to exploit and randomly draws a training sample \({s}_{t}\) from cluster \( C_i\). The corresponding feature vector \({\mathbf {x}}_{t}\) and label \({y}_{t}\) are revealed and the usefulness of the sample \({s}_t\) for the given task is evaluated, yielding a reward \({r}_{t}\in \{-1,1\}\). A reward \(r_t=1\) is given if adding the sample improves the prediction accuracy of the model and \(r_t =-1\) otherwise.

At \(t=0\), we do not possess knowledge about the utility of any cluster. This knowledge is incrementally built as more and more samples are drawn and their rewards are revealed. To this end, each cluster is assigned a distribution of rewards \({\varPi }_i\). With every sample the distribution better approximates the true expected reward of the cluster, but every new sample also incurs a cost. Therefore, a strategy needs to be designed that explores the distribution for each of the clusters, while at the same time exploiting as often as possible the most rewarding sources.

To solve the problem of selecting from which \(C_i\) to sample at every iteration t, we follow a strategy based on Thompson sampling [17] with binary rewards. In this setting, the expected rewards are modeled using a probability \(P_i\) following a Bernoulli distribution with parameter \(\pi _i \in [0,1]\). We maintain an estimate of the likelihood of each \(\pi _i\) given the number of successes \(\alpha _i\) and failures \(\beta _i\) observed for the cluster \({C}_i\) so far. Successes (\(r=1\)) and failures (\(r=-1\)) are defined based on the reward of the current iteration. It can be shown that this likelihood follows the conjugate distribution of a Bernoulli law, i.e., a Beta distribution \(Beta(\alpha _i ,\beta _i)\) so that

$$\begin{aligned} P(\pi _i | \alpha _i, \beta _i) = \frac{\varGamma (\alpha _i + \beta _i)}{\varGamma (\alpha _i) \varGamma (\beta _i)}(1-\pi _i)^{\beta _i - 1} \pi _i^{\alpha _i - 1}. \end{aligned}$$
(1)

with the gamma function \(\varGamma \). At each iteration, \(\hat{\pi }_i\) is drawn from each cluster distribution \(P_i\) and the cluster with the maximum \(\hat{\pi }_i\) is chosen. The procedure is summarized in Algorithm 1.

figure a

3 Results

In order to showcase the advantages of the multi-armed bandit sampling algorithm (MABS), we evaluate our method in estimating the biological age of a subject given a set of volume and thickness features of the brain. We choose this task because of the big number of available brain scans in public databases and the relevance of age estimation as a diagnostic tool for neurodegenerative diseases [18]. For predicting the age, we reconstruct brain scans with FreeSurfer [5] and extract volume and thickness measurements to create feature vectors \({\mathbf {x}}\). Based on these features, we train a regression model for predicting the age of previously unseen subjects.

3.1 Data

We work on MRI T1 brain scans from 10 large-scale public datasets: ABIDE [3], ADHD200 [14], AIBL [4], COBRE [13], IXIFootnote 1, GSP [2], HCP [19], MCIC [8], PPMI [12] and OASIS [11]. From all of these datasets, we obtain a total number of 7,250 images, which is to the best of our knowledge the largest dataset ever used for brain age prediction. Since each one of these datasets is targeted towards different applications, the selected population is heterogeneous in terms of age, sex, and health status. For the extraction of thickness and volume measurements, we process the images with FreeSurfer. Even though this is a fully automatic tool, the feature extraction is a computationally intensive task, which is by far the bottleneck of our age prediction regression model.

3.2 Age Estimation

We perform age estimation on two different testing scenarios. In the first, we create a testing dataset by randomly selecting subsets from all the datasets. The aim of this experiment is to show that our method is capable of selecting samples that will create a model that can generalize well to a heterogeneous population. In the second scenario, the testing dataset corresponds to a single dataset. In this scenario, we show that the sample selection permits tailoring the training dataset to a specific target dataset.

Experiment 1. For the first experiment we take all the images in the dataset and we divide them randomly into three sets: (1) a small validation set of 2% of all samples to compute the rewards given to MABS, (2) a large testing set of 48% to measure the performance of our age regression task, and (3) a large hidden training set of 50%, from which samples are taken sequentially using MABS. We perform the sequential sample selection described in Algorithm 1 using the following metadata to construct the clusters \(\mathcal {C}\): age, dataset, diagnosis, and sex. We experiment with considering all of the metadata separately, to investigate the importance of each one, and the joint modeling considering all partitions at once. We opted to use ridge regression as our learning algorithm because of its fast training and good performance for our task, but other regression models can be easily plugged into our method. Rewards \({r}\) are given to each bandit by estimating and observing if the \(r^2\) score of the prediction in the validation set increases. It is important to emphasize that the testing set is not observed by the bandits in the process of giving rewards. Every experiment is repeated 20 times using different random splits and the mean results are shown. We compare with two baselines: the first one (RANDOM) consists of obtaining samples at random from the hidden set and adding them sequentially to the training set. As a second baseline (AGE PRIOR), we add samples sequentially by following the age distribution of the testing set. The results of this first experiment are shown in Fig. 1 (top left). In almost all of the cases, using MABS as a selection strategy performed better than the baselines. Notably, an increase in performance is obtained not only when the relationships between the metadata and the task are direct, like in the case of the clusters constructed by age, but also when this relationship is not clear, like in the case of clustering the images using only dataset or diagnostic information. Another important aspect is that even when the meta information is not informative, like in the case of the clusters generated by sex, the prediction using MABS is not affected.

Fig. 1.
figure 1

Results of our age prediction experiments in terms of \(r^2\) score. A comparison is made between MABS using different strategies to build the clusters \(\mathcal {C}\), a random selection of samples, and a random selection based on the age distribution of the test data. To improve the presentation of the results, we limit the plot to 4,000 samples.

Experiment 2. For our second experiment, we perform age estimation with the test data being a specific dataset. This experiment follows the same methodology as the previous one with the important difference of how the datasets are split. This time the split is done by choosing: (1) a small validation set, taken only from the target dataset, (2) a testing set, which corresponds to the remaining samples in the target dataset not included in the validation set, and (3) a hidden dataset containing all the samples from the remaining datasets. The goal of this experiment is to show that our approach can be applied to selecting samples according to a specific population and prediction task. Figure 1 shows the results for three different target datasets. We observe that bandits operating on single metadata like diagnosis or dataset can perform very well for the sample selection. However, the best metadata is different for each of the presented datasets. We also observe that MABS using all available metadata extracts informative samples more efficiently than the baselines and always close to the best performing single metadata MABS. This strengthens our hypothesis that it is difficult to define an a priori relationship between the metadata and the task. Consequently, it is a better strategy to pass all the metadata from multiple sources to MABS and let it select the most relevant information.

4 Conclusion

We have proposed a method for efficiently and intelligently sampling a training dataset from a large pool of data. The problem was formulated as reinforcement learning, where the training dataset was sequentially built after evaluating a reward function at every step. Concretely, we used a multi-armed bandit model that was solved with Thompson sampling. The intelligent selection considered metadata of the scan to construct a distribution about the expected reward of a training sample. Our results showed that the selective sampling approach leads to higher accuracy than using all the data, while requiring less time for processing the data. We demonstrated that our technique can either be used to build a general model or to adapt to a specific target dataset, depending on the composition of the test dataset. Since our method does not require to observe the information contained in the images, it could also be applied to predict useful samples even before the images are acquired, guiding the recruitment of subjects.