1 Introduction

Segmentation of brain MR images is a critical step in many diagnostic and surgical applications. Accordingly, several approaches have been proposed for tackling this problem such as atlas-based segmentation [1], methods based on machine learning techniques such as CNNs [2], among many others as detailed in this recent survey [3]. One of the important challenges in many MRI analysis tasks, including segmentation, is robustness to differences in statistical characteristics of image intensities. These differences might arise due to using different scanners in which factors like drift in scanner SNR over time [4], gradient non-linearities [5] and others play an important role. Intensity variations may even arise when scanning protocol parameters (flip angle, echo or repetition time, etc.) are slightly changed on the same scanner. Figure 1(a, b) shows 2D slices from two T1-weighted MRI datasets from different scanners, along with their intensity histograms which show the aforementioned variations. Segmentation algorithms are often very sensitive to such changes. Furthermore, images acquired with different MR modalities, such as T1 and T2-weighted images, may have considerable high-level similarities in image content (see Fig. 1). While analyzing these images, humans can leverage such commonalities easily and it would be highly desirable if learning-based algorithms could mimic this trait.

Fig. 1.
figure 1

Image slices (top) and corresponding histograms (bottom) of normalized T1w (a,b) and T2w (c,d) MRIs from different scanners. Despite high-level information similarity, there exists considerable intensity and contrast differences, which segmentation algorithms are often sensitive to.

In the parlance of transfer machine learning, images acquired from different scanners, protocols or similar MR modalities may be viewed as data points sampled from different domains, with the degree of domain shift potentially indicated by the differences in their intensity statistics. This perspective motivates us to employ ideas from the literature of domain adaptation [6], multi-domain learning [7] and lifelong learning [8] to the problem of brain segmentation across scanners/protocols. Domain adaptation/transfer learning refers to a situation where a learner trained on a source domain is able to perform well on a target domain, of which only a few labelled examples are available. However, in this case, the performance on the source domain may not be necessarily maintained after adaptation. Multi-domain learning aims to train a learner that can simultaneously perform well on multiple domains. Finally, in lifelong learning, a multi-domain learner is able to incorporate new domains with only few labelled examples, while preserving performance on previous domains.

Variants of image intensity standardization [9, 10] and atlas intensity renormalization [11] have been proposed as pre-processing steps to insure conventional segmentation methods from inter-scanner differences. Among learning methods based on hand-crafted features, transfer learning approaches have been employed for multi-site segmentation [12] and classification [13]. While adaptive support vector machines used by [12] may be adapted for new scanners in a lifelong learning sense, they are likely to be limited by the quality of the hand-crafted features. Using CNNs, [14] propose to deal with inter-protocol differences by learning domain invariant representations. This approach may be limited to work with the least common denominator between the domains, while, as shown in [15], providing a few separate parameters for each domain allows for learning of domain specific nuances. Further, it is unclear how [14] can be extended to deal with new domains that may be encountered after the initial training. In the computer vision literature, several adaptations of batch normalization (BN) [16] have been suggested for domain adaptation [17, 18] and multi-domain learning [15, 19] for object recognition using CNNs. Broadly, these works employ BN for domain-specific scaling to account for domain shifts, while sharing the bulk of the CNN parameters to leverage the similarity between the domains.

In this work, we extend approaches based on adaptive BN layers for segmentation across scanning protocols in a lifelong learning setting. In particular, we train a CNN with common convolutional filters and specific BN parameters for each protocol/scanner. The network is initially trained with images from a few scanners to learn appropriate convolutional filters. It can then be adapted to new protocols/scanners by fine-tuning the BN parameters with only a few labelled images. Crucially, this is achieved without performance degradation on the older scanners, whose training data is not available after the initial training.

2 Method

Batch normalization (BN) was introduced in [16] to enable faster training of deep neural networks by preventing saturated gradients via normalization of inputs before each non-linear activation layer. In a BN layer, each batch \(x_B\) is normalized as shown in Eq. 1. During training, \(\mu _B\) and \(\sigma ^2_B\) are the mean and variance of \(x_B\), while at test time, they are the estimated population mean and variance as approximated by a moving average over training batches. \(\gamma \), \(\beta \) are learnable parameters that allow the network to undo the normalization, if required. Inspired by [15], we propose to use separate batch normalization for each protocol/scanner.

$$\begin{aligned} ~ BN(x_B) = \gamma \times \dfrac{x_B - \mu _B}{\sqrt{\sigma ^2_B+\epsilon }} + \beta \end{aligned}$$
(1)

Notwithstanding variations in image statistics due to inter-scanner differences, a segmentation network would be confronted with images of the same organ, acquired with the same modality (MR). Hence, it is reasonable to postulate common characteristics between the domains and thus, shared support in an appropriate representation space. Following [15], we hypothesize that such a representation space can be found by using domain-agnostic convolutional filters and that the inter-domain differences can be handled by appropriate normalization via domain-specific BN modules. This approach is not only in line with the previous domain adaptation works [18], but also embodies the normalization idea of conventional proposals for dealing with inter-scanner variations [9,10,11]. Further, like [19], once suitable shared convolutional filters have been learned, we adapt the domain-specific BN layers to new related domains.

The training procedure in our framework is as follows. We use superscript \(^{bn}\) to indicate a network with domain-specific BN layers. We initially train a network, \(N_{12\cdots d}^{bn}\) on d domains, with shared convolutional filters and separate BN parameters, \(bn_{k}\), for each domain \(D_k\). During training, each batch consists of only one domain, with all domains covered successively. In a training iteration when the batch consists of domain \(D_k\), \(bn_{k'}\) for \(\mathrm{k}^\prime \) \(\ne \)k are frozen. Now, consider a new domain \(D_{d+1}\), with a few labelled images \(I_{D_{d+1}}\). We split this small dataset into two halves, using one for training, \(I_{D_{d+1}}^{tr}\) and the other for validation, \(I_{D_{d+1}}^{vl}\). We evaluate the performance of \(N_{12\ldots d}^{bn}\) on \(I_{D_{d+1}}^{tr}\), using each learned \(bn_{k}\), \(k=1,2,\cdots d\). If \(bn_{k^*}\) leads to the best accuracy, we infer that among the already learned domains, \(D_{k^*}\) is the closest to \(D_{d+1}\). Then, keeping the convolutional filter weights fixed, an additional set of BN parameters \(bn_{d+1}\) is initialized with \(bn_{k^*}\) and fine-tuned using \(I_{D_{d+1}}^{tr}\) with standard stochastic gradient descent minimization. The optimization is stopped when the performance on \(I_{D_{d+1}}^{vl}\) stops improving. Now, the network can segment all domains \(D_k\), for \(k=1,2,\ldots d, d+1\) using their respective \(bn_k\).

In the spirit of lifelong learning, this approach allows learning on new domains with only a few labelled examples. This is enabled by utilizing the knowledge obtained from learning on the old domains, in the form of the trained domain-agnostic parameters. The fact that the number of domain-specific parameters is small comes with two advantages. One, that they can be tuned for a new domain by training with a few labelled images quickly and with minimal risk of overfitting. Secondly, they can be saved for each domain without significant memory footprint. Finally, catastrophic forgetting [20] by performance degradation on previous domains does not arise in this approach by construction because of the explicit separate modeling of shared and private parameters.

3 Experiments and Results

Datasets: Brain MR datasets from several scanners, hospitals, or acquisition protocols are required to test the applicability of the proposed method for lifelong multi-domain learning. To the best of our knowledge, there are only a few publicly available brain MRI datasets with ground truth segmentation labels from human experts. Therefore, we use FreeSurfer [1] to generate pseudo ground truth annotations. While annotations from human experts would be ideal, we believe that FreeSurfer annotations can serve as a reasonable proxy to test our approach to lifelong multi-scanner learning.

We use images from 4 publicly available datasets: Human Connectome Project (HCP) [21], Alzheimer’s Disease Neuroimaging Initiative (ADNI)Footnote 1, Autism Brain Imaging Data Exchange (ABIDE) [22] and Information eXtraction from Images (IXI)Footnote 2. The datasets are split into different domains, as shown in Table 1. Domains \(D_1\), \(D_2\), \(D_3\) are treated as initially available, and \(D_4\), \(D_5\) as new. The number of training and test images for each domain indicated in the table are explained later while describing the experiments.

Training Details: While the domain-specific BN layers can be incorporated in any standard CNN, we work with the widely used U-Net [2] architecture with minor alterations. Namely, our network has a reduced depth with three max-pooling layers and a reduced number of kernels: 32, 64, 128, 256 in the convolutional blocks on the contracting path and 128, 64, 32 on the upscaling path. Also, bilinear interpolation is preferred to deconvolutional layers for upscaling in view of potential checkerboard artifacts [23]. The network is trained to minimize the dice loss, as introduced in [24] to reduce sensitivity to imbalanced classes. Per image volume, the intensities are normalized by dividing by their 98th percentile. The initial network trains in about 6 h, while the domain-specific BN modules can be updated for a new domain in about 1 h, on a Nvidia Titan Xp GPU.

Table 1. Details of the datasets used for our experiments.
Table 2. Segmentation dice scores for different domains for the three different types of networks, trained as explained in the experiments section.

Experiments: We train three types of networks, as described below.

  • Individual networks \(N_d\): Trained for each domain d, with \(n_{train}^{scratch}\) training images (see Table 1). For the known domains (\(D_1\), \(D_2\), \(D_3\)), the accuracy of \(N_d\) serves as a baseline that the other networks with shared parameters must preserve. For the new domains (\(D_4\), \(D_5\)), the performance of \(N_d\) is the benchmark that we seek to achieve by training on much fewer training examples (\(n_{train}\)) and using the knowledge of the previously learned domains.

  • A shared network \(N_{123}\): Trained on \(D_1\), \(D_2\), \(D_3\) with \(n_{train}\) images, with all parameters shared including the BN layers, bn\(_s\). In contrast to the training regime of \(N^{bn}_{1,2,\ldots d}\) described in Sect. 2, while training \(N_{123}\) each batch randomly contains images from all domains to ensure that the shared BN parameters can be tuned for all domains. Histogram equalization [25] is applied to a new domain \(D_d\) before being tested \(N_{123}\). For adapting \(N_{123}\) to \(D_d\), all parameters are fine-tuned with \(n_{train}\) images of the new domain and the modified network is referred to as \(N_{123\rightarrow d}\).

  • A lifelong multi-domain learning network \(N_{123}^{bn}\): Trained on \(D_1\), \(D_2\), \(D_3\), with shared convolutional layers and domain-specific BN layers. The updated network after extending \(N_{123}^{bn}\) for a new domain \(D_d\) according to the procedure described in Sec. 2 is called \(N_{123,k^*\rightarrow d}\), where \(k^*\) is the closest domain to \(D_d\).

Results: All networks are evaluated based on their mean Dice score for \(n_{test}\) images from the appropriate domain (see Table 1). Quantitative results of our experiments are shown in Table 2. The findings can be summarized as follows:

  • \(N_{123}\) preserves the performance of \(N_{1}\), \(N_{2}\), \(N_{3}\). Thus, a single network can learn to segment multiple domains, provided sufficient training data is available from all the domains at once. However, its performance severely degrades for unseen domains \(D_4\) and \(D_5\). Histogram equalization (denoted by \(D_{d,HistEq}\)) to the closest domain is unable to improve performance significantly, while fine-tuning the network for the new domains causes catastrophic forgetting [20], that is, degradation in performance on the old domains.

  • \(N_{123}^{bn}\) also preserves the performance of \(N_{1}\), \(N_{2}\), \(N_{3}\). For a new domain \(D_4\), using the \(bn_3\) parameters of the trained \(N_{123}^{bn}\) leads to the best performance. Thus, we infer that \(D_3\) is the closest to \(D_4\) among \(D_1\), \(D_2\), \(D_3\). After fine-tuning the parameters of BN\(_3\) to obtain those of BN\(_4\), the dice scores for all the structures improve dramatically and are comparable to the performance of \(N_4\). Crucially, as the original \(bn_k\) for k = 1, 2, 3 are saved, the performance on \(D_1\), \(D_2\), \(D_3\) in the updated network \(N_{123,3-4}^{bn}\) is exactly the same as in \(N_{123}^{bn}\). Similar results can be seen for the other new domain, \(D_5\). The improvement in the segmentations for new domains after fine-tuning the BN parameters can also be observed qualitatively in Fig. 2.

Fig. 2.
figure 2

Qualitative results: (a) images from domains D\(_d\), segmentations predicted by (b) N\(_{123}^{bn}\), bn\(_{k^*}\), (c) N\(_{123,k^*\rightarrow d}^{bn}\), bn\(_d\), (d) N\(_{d}\) and (e) ground truth annotations, with {d, \(k^*\)} as {4, 3} (top) and {5, 2} (bottom).

4 Conclusion

In this article, we presented a lifelong multi-domain learning approach to learn a segmentation CNN that can be used for related MR modalities and across scanners/protocols. Further, it can be adapted to new scanners or protocols with only a few labelled images and without degrading performance on the previous scanners. This was achieved by learning batch normalization parameters for each scanner, while sharing the convolutional filters between all scanners. In future work, we intend to investigate the possibility of extending this approach to MR modalities that were not present during the initial training.

To the best of our knowledge, this is the first work to tackle the lifelong machine learning problem for CNNs in the context of medical image analysis. We believe that this may set an important precedent for more research in this vein to handle data distribution changes which are ubiquitous in clinical data.