1 Introduction

Deep learning has achieved highly competitive performance on test data drawn from the same distribution as large training data. However, in practice, it is almost impossible to ensure that test data strictly follow source distributions. Domain generalization (DG) investigates how to generalize a hypothesis learned from source domains to unseen target domains (Blanchard et al., 2011; Muandet et al., 2013).

As a seminal method, the empirical risk minimization (ERM) algorithm (Vapnik, 1999) aims to learn a hypothesis that achieves the minimum empirical risk on all the source domains (Gulrajani & Lopez-Paz, 2020). Although the ERM algorithm has achieved promising results on DG (Gulrajani & Lopez-Paz, 2020), previous work have shown from both theoretical and empirical perspectives that the performance of ERM can be largely relayed on the number of source domains and the diversity of source samples (Li et al., 2022; Gulrajani & Lopez-Paz, 2020).

Recent DG work explores an invariance learning approach to alleviate the prediction gap that arises from the distributional diversity across different domains (Li et al., 2018a, b; Zhang et al., 2021). Such approach aims to obtain an invariant representation by training the feature embedding using discrepancy-based losses, which estimate discrepancy metrics on covariate shifts w.r.t. marginal feature distributions (Albuquerque et al., 2020) or conditional shifts w.r.t. conditional feature distributions (Zhang et al., 2021; Shui et al., 2022). Further, previous work has shown that an invariance of the excess risk across domains is equivalent to the invariance of representation (Zhang et al., 2021). Although the invariant feature learning can ensure prediction invariance across domains, the intrinsic distribution gap between the source and target domains and the possibility of overfitting in source domains can badly affect the generalization performance, as shown in Fig. 1b.

Fig. 1
figure 1

Illustration of our approach. Compared with the ERM baseline (a), domain invariance learning (b) reduces the discrepancy across source domains and performs well on source-domain classification, but it may still have big error on the target domain. Our approach (c) uses bilevel meta-learning to further reduce the discrepancy between the target domain and source domains, such that a hypothesis learned from the source domains can generalize to the target domain

We improve the out-of-domain robustness for invariance learning via a bilevel meta-learning algorithm to learn more robust invariant representation across different domains. In particular, we follow the previous work to use an episodic training process (Li et al., 2018c), i.e., randomly extracting some meta-source domains for training and a meta-target domain for test from all the source domains as a meta-task to simulate domain shift.

1.1 Approach

We consider a learning algorithm for the feature embedding with meta-parameters, denoted as \(A^f_{\varvec{\phi }}(\cdot )\). Then, a bilevel meta-learning algorithm (Finn et al., 2017) is used to learn the parameter initialization \({\phi }\), where the inner-loop objective aims to minimize the discrepancy across different meta-source domains while the outer-loop objective aims to minimize the discrepancy between the meta-target and meta-source domains. Intuitively, the effectiveness of such bilevel meta-learning algorithm is shown in Fig. 1c.

1.2 Results

We formulate a geometric understanding for the bilevel meta-learning algorithm and show its effectiveness to minimize the intrinsic domain discrepancy, which is formulated as the \({\mathcal {Y}}\)-discrepancy (Zhang et al., 2012) between the target domain and a convex hull of source domains. Empirically, we follow the training and evaluation protocol by Gulrajani and Lopez-Paz (2020) and conduct experiments on five datasets. Results show that our approach can effectively learn the domain invariance and achieve the best performance compared with a range of ERM, invariance learning and meta-learning algorithms. The code is released at https://github.com/jiachenwestlake/MLIR.

2 Related work

Domain generalization (DG) has become a popular field and achieved promising results in recent years. We review the most related DG work as follows.

2.1 Domain-invariance learning

Early DG work performs kernel-based approaches to learn an invariant feature mapping to the reproducing kernel Hilbert space (RKHS) (Muandet et al., 2013). Neural methods have achieved promising results in recent years, and invariant representation learning has become a strong approach for DG. Roughly speaking, such approach uses an additional loss w.r.t. a discrepancy measure across different source domains, which can employ maximum mean discrepancy (Li et al., 2018a), \({\mathcal {H}}\)-divergence (Li et al., 2018b; Albuquerque et al., 2020), KL-divergence (Xiao et al., 2021), \({\mathcal {Y}}\)-discrepancy (Zhang et al., 2021) and total variation distance (Shui et al., 2022). Furthermore, DMG (Chattopadhyay et al., 2020) learns a balance between invariant and specific representation; REG (Shui et al., 2022) uses regularization to improve the smoothness of representation. In contrast to these work, we aim to improve the robustness of invariant learning via meta-learning. Our work can be seen as an extension to the line of work (Zhang et al., 2021) with a meta-learning approach, which has shown the equivalence between transferability and \({\mathcal {Y}}\)-discrepancy across different domains. Other invariance learning approaches such as IRM (Arjovsky et al., 2019) learns the labeling invariance across different domains, which is orthogonal to this work.

2.2 Meta-learning

Meta-learning provides a framework to gain experience for future tasks over multiple training episodes, which has been introduced to address DG via simulating domain shift (Li et al., 2018c; Balaji et al., 2018; Dou et al., 2019). An early approach is MLDG (Li et al., 2018c), which uses bilevel meta-learning (Finn et al., 2017) to train a model on source domains such that it generalizes to the target domain. MetaReg Balaji et al. (2018) learns a regularization on the classifier such that a classifier trained on source domains can generalize to target domain. These work have a common limitation that uses task objectives directly as the inner-loop and outer-loop objectives, which can be suboptimal, since it is highly abstracted from the feature representation. To address this problem, we focus on a meta-learning approach to reduce the discrepancy between the target domain and sources domains. In particular, we build a bilevel meta-learning procedure on the first-order MAML framework (Finn et al., 2017), which achieves highly computational efficiency while also preserving the accuracy. To our knowledge, we are the first to use meta-learning for invariance learning.

2.3 Convex domain combination

A closely related problem is in multiple-source domain adaptation, where the target domain is assumed to be a convex combination of source domains, but the weights can be unknown. Previous work (Mansour et al., 2008; Hoffman et al., 2018; Shao et al., 2021) assume that there exists pretrained hypothesis for each source domain and have well-studied how to combine the source hypotheses to derive a target hypothesis. Such work also indicate that simple linear combinations face difficulties due to the discrepancy across different source domains. In contrast to these work, DG often assumes that source-domain data are available for training, which can be used to learn an invariant representation to break the limitation of domain discrepancy for convex combination (Shao et al., 2021). Furthermore, we study a more general setting, where the target domain can be outside the convex hull of source domains. Accordingly, we propose a meta-learning approach to reduce the discrepancy between the target domain and the convex hull of source domains.

3 Preliminaries

3.1 Notations

Let \({\mathcal {X}}\) be the input space and \({\mathcal {Y}}\) be the output space. Following previous work (Blanchard et al., 2011; Muandet et al., 2013), we define a domain as a joint distribution on Cartesian product of the input and output space \({\mathcal {Z}} = {\mathcal {X}} \times {\mathcal {Y}}\) and let \({\mathfrak {P}}\) denote the set of all domains. We denote the set of N source domains as \({\mathcal {S}} = \{{\mathbb {S}}^{i}\}_{1\le i \le N}\). The corresponding set of training samples is denoted as \(\hat{{\mathcal {S}}} = \{\hat{{\mathbb {S}}}^i\}_{1 \le i\le N}\), where the training sample for the i-th domain is denoted as \(\hat{{\mathbb {S}}}^{i}=\{ (x^i_k, y^i_k) \}_{1 \le k \le n_i}\) with cardinality \(n_i\) and assuming that \((x^i_k, y^i_k) \overset{i.i.d.}{\sim } {\mathbb {S}}^i\). For brevity, we assume that all domains have the equal sample size, i.e., \(n_1 = \ldots = n_N = n\).

A hypothesis \(h \in {\mathcal {H}}:{\mathcal {X}} \rightarrow {\mathcal {Y}}\) is defined as a mapping from the input space to the output space. The associated error of a hypothesis h at a data point (xy) is defined as \(\ell (h(x),y)\). Given a domain \({\mathbb {S}}\) and its corresponding sample \(\hat{{\mathbb {S}}}=\{(x_i,y_i)\}_{1 \le i \le n}\), the expected error and the empirical error are defined as \(\epsilon _{{\mathbb {S}}}(h) = {\mathbb {E}}_{(x,y)\sim {\mathbb {S}}} \ell (h(x),y)\) and \({\hat{\epsilon }}_{\hat{{\mathbb {S}}}}(h) = \frac{1}{n}\sum _{i=1}^n \ell (h(x_i),y_i)\), respectively. In this work, we consider h to be a neural network and decompose h into a feature embedding \(f_{\varvec{\psi }} \in {\mathcal {F}}: {\mathcal {X}} \rightarrow {\mathbb {R}}^d\), parameterized by \({{ \psi }}\) (or f for brevity) and a task classifier \(g_{\varvec{\theta }} \in {\mathcal {G}}: {\mathbb {R}}^d \rightarrow {\mathcal {Y}}\), parameterized by \(\varvec{\theta }\) (or g for brevity), i.e., \(h = g_{\varvec{\theta }} \circ f_{\varvec{\psi }}\). Furthermore, this work is interested in a learning algorithm for the feature embedding \(A^f_{\varvec{\phi }}: \bigcup _{N=1}^{\infty } {\mathcal {Z}}^{N \times n} \rightarrow {\mathcal {F}}\), with the meta-parameter \(\varvec{\phi } \in \varvec{\Phi }\), mapping from source-domain training samples to a feature embedding. Given source-domain training samples \(\hat{{\mathcal {S}}}\), the hypothesis can therefore be represented as \(g \circ A_{\varvec{\phi }}^f(\hat{{\mathcal {S}}})\).

3.2 Meta-learning for domain generalization

The main idea is to use a sequence of M pairs of meta-training and meta-test samples \(\{(\hat{{\mathcal {D}}}^{tr}_i, \hat{{\mathcal {D}}}^{te}_i)\}_{1 \le i \le M}\) to improve the ability of an algorithm for tackling domain shift. To make connections with the standard meta-learning formulations (Baxter, 2000; Chen et al., 2020), each meta-sample \((\hat{{\mathcal {D}}}^{tr}_i, \hat{{\mathcal {D}}}^{te}_i)\) can be seen as a pair of Qury/Support sets of a DG task, where for each \(i \in [M]\), \(\hat{{\mathcal {D}}}^{tr}_i\) denotes meta-training samples from a set of meta-source domains and \(\hat{{\mathcal {D}}}^{te}_i\) denotes the meta-test sample from a meta-target domain which should not belong to any meta-source domain. In practice, an episodic training process is used to construct the meta-sample with training samples from N source domains. In each training iteration, each domain can become the meta-target domain and the rest are served as the meta-source domains. Thus, the meta-sample \(\{(\hat{{\mathcal {D}}}^{tr}_i, \hat{{\mathcal {D}}}^{te}_i)\}_{1 \le i \le M}\) is defined as:

$$\begin{aligned} \left \{(\hat{{\mathcal {D}}}^{tr}_i, \hat{{\mathcal {D}}}^{te}_i)\right \}_{1 \le i \le M} := \left \{\left (\{ \hat{{\mathbb {S}}}^j \}_{j \ne i}, \hat{{\mathbb {S}}}^i\right ): 1 \le i \le N\right \} \end{aligned}$$
(1)

3.3 Domain discrepancy

\({\mathcal {Y}}\)-discrepancy has been used for domain invariance learning (Zhang et al., 2012, 2021). For convenience in presentation, we extend the hypothesis in the original definition (Zhang et al., 2012) to a learning algorithm for feature embedding.

Definition 1

(\({\mathcal {Y}}\) -discrepancy): Let \(g \in {\mathcal {G}}\) be the classifier and \(A_{\varvec{\phi }}^f(\hat{{\mathcal {S}}})\) be the feature embedding learned from source samples \(\hat{{\mathcal {S}}}\), then the \({\mathcal {Y}}\)-discrepancy \({\text {disc}}({\mathbb {S}}, {\mathbb {T}})\) between two domains \({\mathbb {S}}\) and \({\mathbb {T}}\) and its empirical version \(\hat{{\text {disc}}}(\hat{{\mathbb {S}}}, \hat{{\mathbb {T}}})\) w.r.t. the corresponding samples \(\hat{{\mathbb {S}}}\) and \(\hat{{\mathbb {T}}}\) are defined as:

$$\begin{aligned} {\text {disc}}\big ({\mathbb {S}}, {\mathbb {T}}\big ) :=&\sup _{g \in {\mathcal {G}}} \big \vert \epsilon _{{\mathbb {S}}}\big (g \circ A_{\varvec{\phi }}^f(\hat{{\mathcal {S}}})\big ) - \epsilon _{{\mathbb {T}}}\big (g \circ A_{\varvec{\phi }}^f(\hat{{\mathcal {S}}}) \big ) \big \vert ; \\ \hat{{\text {disc}}}\big (\hat{{\mathbb {S}}}, \hat{{\mathbb {T}}}\big ) :=&\sup _{g \in {\mathcal {G}}} \big \vert {\hat{\epsilon }}_{\hat{{\mathbb {S}}}}\big (g \circ A_{\varvec{\phi }}^f(\hat{{\mathcal {S}}})\big ) - {\hat{\epsilon }}_{\hat{{\mathbb {T}}}}\big (g \circ A_{\varvec{\phi }}^f(\hat{{\mathcal {S}}})\big ) \big \vert . \end{aligned}$$
(2)

It is clear that \({\mathcal {Y}}\)-discrepancy defines a pseudo-distance between a pair of domains in that it satisfies symmetry and the triangle inequality but not satisfies identity of indiscernibility since \({\text {disc}}({\mathbb {S}}, {\mathbb {T}})=0 \nRightarrow {\mathbb {S}} = {\mathbb {T}}\). It can measure not only covariate shift between domains, but also conditional shift between domains (Zhang et al., 2012). Therefore, we choose \({\mathcal {Y}}\)-discrepancy as a measurement for domain discrepancy for the proposed algorithm.

4 Approach

The goal of our algorithm is to reduce the \({\mathcal {Y}}\)-discrepancy between the source domains and the target domain. We present a specific meta-learning algorithm.

4.1 Meta-learning via bilevel optimization

We focus on a bilevel meta-learning framework (Finn et al., 2017), which uses the meta-sample to learn a meta-parameter \(\varvec{\phi }^* \in \varvec{\Phi }\) for a learning algorithm \(A^f_{\varvec{\phi }^*}(\cdot )\). Such learning algorithm can use the source samples \(\hat{{\mathcal {S}}}\) for optimizing the feature embedding, represented as \(A^f_{\varvec{\phi }^*}(\cdot ): \hat{S} \mapsto f_{\varvec{\psi }^*}\), where \(f_{\varvec{\psi }^*}\) denotes the feature embedding parameterized by the optimized parameter \(\varvec{\psi }^*\). For notation convenience, we will sometimes treat \(f_{\varvec{\psi }^*}\) and \(\varvec{\psi }^*\) equivalently to represent the learned feature embedding.

In this work, the meta-learner optimizes the meta-parameter \(\varvec{\phi }\) to minimize \({\mathcal {Y}}\)-discrepancy between the meta-target domain and meta-source domains (will be defined in Eq. 4), such that the learned algorithm optimizes the parameter of feature embedding to minimize the \({\mathcal {Y}}\)-discrepancy across different meta-source domains (will be defined in Eq. 5). We formally define the bilevel optimization problem as follows.

Definition 2

(Bilevel Optimization) We denote the outer-loop and inner-loop objectives w.r.t. the feature embedding as \(\hat{{\mathcal {L}}}_{out}\) and \(\hat{{\mathcal {L}}}_{in}\), respectively. Let \(A^f_{\varvec{\phi }}(\cdot )\) be a learning algorithm parameterized by \(\varvec{\phi }\) for the inner-loop optimization. Given a meta-sample \(\{(\hat{{\mathcal {D}}}^{tr}_i, \hat{{\mathcal {D}}}^{te}_i)\}_{1 \le i \le M}\) defined in Eq. 1, the bilevel optimization problem is defined as:

$$\begin{aligned} {\text {Outer-loop:}}&\quad \varvec{\phi }^* \in \mathop {\arg \min }_{\varvec{\phi } \in \varvec{\Phi }} \sum _{i \in [M]} \hat{{\mathcal {L}}}_{out} \big ( A^f_{\varvec{\phi }}(\hat{{\mathcal {D}}}^{tr}_i), (\hat{{\mathcal {D}}}^{tr}_i, \hat{{\mathcal {D}}}^{te}_i) \big ); \\ {\text {Inner-loop:}}&\quad A^f_{\varvec{\phi }}(\hat{{\mathcal {D}}}^{tr}_i) \in \mathop {\arg \min }_{\varvec{\psi } \in {\mathcal {C}} (\varvec{\phi })} \hat{{\mathcal {L}}}_{in}\big (\varvec{\psi }; \hat{{\mathcal {D}}}^{tr}_i \big ), \end{aligned}$$
(3)

where \({\mathcal {C}} (\varvec{\phi })\) denotes the constrained parameter space of \(\varvec{\psi }\) by \(\varvec{\phi }\), which will be specified in the next section. Let \(\varvec{\psi }^*_i\) denote \(\varvec{\psi }^*_i:=A^f_{\varvec{\phi }}(\hat{{\mathcal {D}}}^{tr}_i)\), the empirical objectives in the outer-loop and inner-loop are defined as follows:

$$\begin{aligned} \hat{{\mathcal {L}}}_{out}\left (\varvec{\psi }^*_i; (\hat{{\mathcal {D}}}^{tr}_i, \hat{{\mathcal {D}}}^{te}_i)\right ) := \sum _{\hat{{\mathbb {S}}}^k_i \in \hat{{\mathcal {D}}}^{tr}_i} \hat{{\text {disc}}}_{{\mathcal {Y}}}\left (f_{\varvec{\psi }^*_i}\left (\hat{{\mathcal {D}}}^{te}_i, \hat{{\mathbb {S}}}^k_i\right ) \right ); \end{aligned}$$
(4)
$$\begin{aligned} \hat{{\mathcal {L}}}_{in}\left (\varvec{\psi }; \hat{{\mathcal {D}}}^{tr}_i\right ) := \sum \limits _{\hat{{\mathbb {S}}}^k_i, \hat{{\mathbb {S}}}^t_i \in \hat{{\mathcal {D}}}^{tr}_i} \hat{{\text {disc}}}_{{\mathcal {Y}}}\left (f_{\varvec{\psi }}\left (\hat{{\mathbb {S}}}^k_i, \hat{{\mathbb {S}}}^t_i\right )\right ). \end{aligned}$$
(5)

4.2 Gradient-based meta-learning algorithm

figure a
figure b

In practice, we specify the previous bilevel meta-learning algorithm as the first-order MAML (Finn et al., 2017). In particular, the meta-parameter \(\varvec{\phi }\) is defined as the parameter initialization for the inner-loop learning algorithm \(A^f_{\varvec{\phi }}(\cdot )\), which corresponds to one or multiple steps of gradient descent for optimizing the inner-loop objective in the constrained parameter space \({\mathcal {C}}(\varvec{\phi })\) by the initialization with \(\varvec{\phi }\). Given a batch of training samples \({\mathcal {B}} = \{({\mathcal {B}}^{tr}_i, {\mathcal {B}}^{te}_i)\}_{1 \le i \le M}\) which contains M pairs of meta-training and meta-test domains, the sample size of each meta-training domain and the meta-test domain in \({\mathcal {B}}\) are both b. The update in one iteration with m inner-loop steps is computed as,

$$\begin{aligned}&\varvec{\phi } = \varvec{\phi } - \gamma \eta \sum _{i=1}^M \nabla _{\varvec{\psi }} \hat{{\mathcal {L}}}_{out} \left (\varvec{\psi }; ({\mathcal {B}}^{tr}_i, {\mathcal {B}}^{te}_i) \right ) \big \vert _{\varvec{\psi }=A^f_{\varvec{\phi }}({\mathcal {B}}^{tr}_i)} \end{aligned}$$
(6)
$$\begin{aligned}&s.t.\ \underbrace{{\psi }^{(0)}_i = \varvec{\phi }; \varvec{\psi }^{(m)}_i = \varvec{\psi }^{(m-1)}_i - \alpha \eta \nabla _{\varvec{\psi }^{(m-1)}_i}\hat{{\mathcal {L}}}_{in} \left (\varvec{\psi }^{(m-1)}_i; {\mathcal {B}}^{tr}_i\right )}_{A^f_{\varvec{\phi }}\left ({\mathcal {B}}^{tr}_i\right )}, \end{aligned}$$
(7)

where \(\eta\) denotes the learning rate and \(\alpha , \gamma\) denote the adversarial factors of the inner-loop and outer-loop, respectively. In addition to the first-order MAML framework, there exit other gradient-based meta-learning frameworks used in the prior work (Li et al., 2018c; Balaji et al., 2018). We analyze the differences to these work and propose two variants of our approach in Appendix 1.

We use an adversarial training strategy (Goodfellow et al., 2014; Zhang et al., 2021) to optimize the inner-loop and outer-loop objectives (\(\hat{{\mathcal {L}}}_{in}\), \(\hat{{\mathcal {L}}}_{out}\) in Def. 2). Following the previous work (Zhang et al., 2021), the \({\mathcal {Y}}\)-discrepancy is estimated by the trained classifier using gradient ascent updates, while the minimizing of \({\mathcal {Y}}\)-discrepancy is performed via gradient descent w.r.t the parameters of feature embedding. The whole meta-learning procedure is shown in Algorithms 1 & 2 and described as follows.

4.3 Meta-training

As shown in Algorithm 1, lines 3–7 show an adversarial training process to optimize the inner-loop objective \(\hat{{\mathcal {L}}}_{in}\), which can be seen as a two-player minimax game between adversarial classifiers and the feature embedding. Lines 8–12 show a similar way to optimize the outer-loop objective \(\hat{{\mathcal {L}}}_{out}\) via adversarial training. In addition, lines 13–14 show the training process of the classification task w.r.t. the task classifier and feature embedding with the source samples.

4.4 Meta-test

As shown in Algorithm 2, the learned feature embedding is further trained on all the N source domains with the inner-loop objective in lines 4–8 and simultaneously, the classification task w.r.t. the task classifier and feature embedding is also trained with the source samples in lines 9–10.

4.5 Computational complexity

Following the convergence analysis on bilevel meta-learning by Ji et al. (2022), we assume that \(\nabla \hat{{\mathcal {L}}}_{in}(\cdot )\) and \(\nabla \hat{{\mathcal {L}}}_{out}(\cdot )\) are Lipschitz continuous, \(\nabla \hat{{\mathcal {L}}}_{out}(\cdot )\) has a bounded variance and the batch size is large enough. Then, to achieve \({\mathbb {E}}[\Vert \nabla \hat{{\mathcal {L}}}_{out}(\varvec{\phi })\Vert ] \le \varepsilon\), we need \({\mathcal {O}}(\varepsilon ^{-2})\) iterations. Therefore, by the computational cost of each iteration analyzed in Appendix 1, we need a total number \({\mathcal {O}}(mbN^3\varepsilon ^{-2})\) of gradient computations.

5 Theoretical analysis

We analyze the learned feature distribution from a geometric perspective. For convenience in presentation, we regard the feature embedding \(A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})\) as a mapping from a domain \({\mathbb {D}}\) on \({\mathcal {X}}\times {\mathcal {Y}}\) to a domain \({\mathbb {D}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}\) on Cartesian product of the feature space and the output space \({\mathbb {R}}^d \times {\mathcal {Y}}\). To show such definition is reasonable, we can regard the feature embedding as a random transformation \(\Phi (x'\vert x)\), where \(x \in {\mathcal {X}}\) and \(x' \in {\mathbb {R}}^d\). In particular, the deterministic representation function is a special case such that \(\Phi (x'\vert x)\) is the Dirac delta function \(\delta _{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})(x)}\). Therefore, we can define the domain on \({\mathbb {R}}^d \times {\mathcal {Y}}\) as \({\mathbb {D}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}(x', y) = \int {\Phi (x' \vert x){\mathbb {D}}(x, y)}dx\), for any \(y \in {\mathcal {Y}}\). We denote the set of all domains on \({\mathbb {R}}^d \times {\mathcal {Y}}\) induced by \(A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})\) as \({\mathfrak {P}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}\). The associated \({\mathcal {Y}}\)-discrepancy, equivalent to Def. 1, is defined as follows.

Definition 3

Let \(g \in {\mathcal {G}}\) be the classifier and \(A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})\) be the feature embedding, then, the \({\mathcal {Y}}\)-discrepancy between two domains \({\mathbb {S}}\) and \({\mathbb {T}}\) is defined as:

$$\begin{aligned} {\text {disc}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}\left ({\mathbb {S}}, {\mathbb {T}}\right ) := \sup _{g \in {\mathcal {G}}} \big \vert \epsilon _{{\mathbb {T}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}}(g) - \epsilon _{{\mathbb {S}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}}(g) \big \vert . \end{aligned}$$
(8)

Definition 4

(Intrinsic domain discrepancy) Given a feature embedding \(A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})\), We define the intrinsic domain discrepancy as the \({\mathcal {Y}}\)-discrepancy between the target domain \({\mathbb {T}}\) and the convex hull of source domains \({\text {conv}}({\mathcal {S}})\):

$$\begin{aligned} {\text {disc}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}\left ({\mathbb {T}}, {\text {conv}}({\mathcal {S}})\right ) = {\text {disc}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}\left (\overline{{\mathbb {T}}}^*, {\mathbb {T}}\right ), \end{aligned}$$
(9)

where \(\overline{{\mathbb {T}}}^*\) denotes the nearest point to the target domain in \({\text {conv}}({\mathcal {S}})\),

$$\begin{aligned} \overline{{\mathbb {T}}}^* := \mathop {\arg \min }_{\overline{{\mathbb {T}}} \in {\text {conv}}({\mathcal {S}})} {\text {disc}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})} (\overline{{\mathbb {T}}}, {\mathbb {T}}). \end{aligned}$$
(10)

Proposition 1

(Geometric understanding) Given a feature embedding \(A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})\), we consider a pseudo-metric space \(\left ({\mathcal {M}}({\mathfrak {P}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}), {\text {disc}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}(\cdot , \cdot )\right )\), defined as the space of all domains \({\mathfrak {P}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}\) equipped with a pseudo-metric \({\text {disc}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}(\cdot , \cdot )\). Let \(\widetilde{{\mathcal {S}}}\) denote the average of source domains \(\widetilde{{\mathcal {S}}} = \frac{1}{N}\sum _{i \in [N]}{\mathbb {S}}^i\) and \(\overline{{\mathbb {T}}}^*\) be defined as Def. 4, by triangle inequality w.r.t. the pseudo-metric, we first have:

$$\begin{aligned} {\text {disc}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}\left ({\mathbb {T}}, {\text {conv}}({\mathcal {S}})\right ) \le&{\text {disc}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})} \left ({\mathbb {T}}, \widetilde{{\mathcal {S}}}\right ) + {\text {disc}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}\left (\overline{{\mathbb {T}}}^*, \widetilde{{{\mathcal {S}}}}\right ). \end{aligned}$$
(11)

Then, we assume that there exits a meta-distribution over the set of all domains, represented as \({\mathscr {P}}\). We also assume that the classifier class \({\mathcal {G}}\) has a finite VC-dimension d. Given the training set of N source domains \(\hat{{\mathcal {S}}}\) and the associated meta-sample \(\{(\hat{{\mathcal {D}}}^{tr}_i, \hat{{\mathcal {D}}}^{te}_i)\}_{1 \le i \le M}\) defined in Eq. 1, we have for any \(\delta > 0\), with probability at least \(1 - 5\delta\),

$$\begin{aligned} &{\mathbb {E}}_{{\mathscr {P}}}\big [ {\text {disc}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}\big ({\mathbb {T}}, {\text {conv}}({\mathcal {S}})\big ) \big ] \\ \lessapprox&\underbrace{\frac{1}{M} \sum _{i=1}^M \frac{1}{\vert \hat{{\mathcal {D}}}^{tr}_i\vert } \sum _{\hat{{\mathbb {S}}}^j_i \in \hat{{\mathcal {D}}}^{tr}_i} \hat{{\text {disc}}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {D}}}^{tr}_i)}\big (\hat{{\mathcal {D}}}^{te}_i, \hat{{\mathbb {S}}}^j_i\big )}_{{\text {meta-training\ objective}}} + \underbrace{\frac{2}{N} \sum _{i <j}^{N} \hat{{\text {disc}}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})}\big (\hat{{\mathbb {S}}}^i, \hat{{\mathbb {S}}}^j\big )}_{{\text {meta-test\ objective}}} \\&+4\sqrt{\frac{8d{\text {log}}(2en/d)+8{\text {log}}(4/\delta )}{n}} + \sqrt{\frac{{\text {log}} \delta ^{-1}}{2N}}. \end{aligned}$$
(12)

Proof

In Appendix 1. \(\square\)

Fig. 2
figure 2

Geometric understanding. Independently performing the inner-loop optimization based on the initial parameter \(\varvec{\phi }\) can only reduce the discrepancy across sources, as shown by (I)\(\rightarrow\)(II). The bilevel meta-training approach optimizes the meta-parameter \(\varvec{\phi }^*\) such that performing the inner-loop optimization in meta-test can reduce not only the discrepancy across sources but also the discrepancy between the target and sources, resulting in an optimization on the intrinsic domain discrepancy (Def. 4), as shown by (III)

Remark 1

Proposition 1 shows that the expectation of intrinsic domain discrepancy can be approximately upper-bounded by (i) the empirical objective of meta-training and (ii) the empirical objective of meta-test. Thus, the meta-training procedure directly optimizes the first empirical term (i), where the optimized meta-parameter is denoted as \(\varvec{\phi }^*\). Then, the second empirical term (ii) can also be minimized, since \(A^f_{\varvec{\phi }^*}(\hat{{\mathcal {S}}})\) is defined as an algorithm for optimizing the discrepancy across source domains. Therefore, the proposed meta-learning approach can approximately minimize the upper bound of intrinsic domain discrepancy. An intuitive illustration of the meta-learning procedure is shown in Fig. 2.

To show the effectiveness of optimizing the intrinsic domain discrepancy for DG, we give a generalization bound as follows.

Proposition 2

(Upper bound) Albuquerque et al. (2020). Let \(h = g \circ A_{\varvec{\phi }}^f(\hat{{\mathcal {S}}})\) be the hypothesis. We assume that there exists a meta-distribution \({\mathscr {P}}\) over the set of domains. Then,

$$\begin{aligned} {\mathbb {E}}_{{\mathscr {P}}} \left[ \epsilon _{{\mathbb {T}}}(h) \right] \le {\mathbb {E}}_{{\mathscr {P}}} \left [ {\text {disc}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})} \left ({\mathbb {T}}, {\text {conv}}({\mathcal {S}})\right ) \right ] + \frac{1}{N} \sum _{i=1}^N \epsilon _{{\mathbb {S}}^i}(h) + \frac{2}{N} \sum _{i <j}^{N} {\text {disc}}_{A^f_{\varvec{\phi }}(\hat{{\mathcal {S}}})} \left ({\mathbb {S}}^i, {\mathbb {S}}^j\right ). \end{aligned}$$

Proof

The proof largely follows Albuquerque et al. (2020) with only slight modification of replacing the \({\mathcal {H}}\)-divergence with \({\mathcal {Y}}\)-discrepancy and taking expectation on the target domain. \(\square\)

Remark 2

Proposition 2 gives an upper bound for DG, which consists of (i) the intrinsic domain discrepancy, (ii) the weighted average of source-domain errors and (iii) the discrepancy across domains. Compared with the invariance learning approach, which can be seen as only performing the inner-loop or the outer-loop optimization of our approach, the proposed bilevel meta-learning algorithm can further minimize the intrinsic domain discrepancy while also optimizing the discrepancy across source domains by meta-test.

6 Experiments

6.1 Experimental settings

6.1.1 Datasets and evaluation metrics

Following (Gulrajani & Lopez-Paz, 2020), we evaluate the proposed algorithm on five real-world datasets, including PACS (Li et al., 2017) (9,991 images, 7 classes and 4 domains), VLCS (Fang et al., 2013) (10,729 images, 5 classes and 4 domains), OfficeHome (Venkateswara et al., 2017) (15,588 images, 65 classes, 4 domains), TerraIncognita (Beery et al., 2018), (24,788 images, 10 classes and 4 domains) and DomainNet (Peng et al., 2019) (586,575 images, 345 classes, 6 domains).

We report the out-of-domain accuracy for each dataset and their average, i.e., we use the training set of each source domain to train a model and use the validation sets aggregated by source domains for model selection. Each reported result is the average of three independent repetitions with different hyperparameters, initialization and dataset splits.

Table 1 Hyperparameter, the default value and distribution for random search

Optimization protocol For a fair comparison, we follow training and evaluation protocol by Gulrajani and Lopez-Paz (2020) for our method and other baselines. In particular, we use an ImageNet pretrained \({\text {ResNet-50}}\) (Gulrajani & Lopez-Paz, 2020) as the feature embedding and \({\text {Adam}}\) as the optimizer in all experiments. For hyperparameter search, each hyperparameter is assigned with a default value as well as a range near the default value, all hyperparameters are tuned jointly via random search (Gulrajani & Lopez-Paz, 2020) according to their search distributions with a maximum number of 20 trials. The settings of hyperparameter search for our method and other baselines are the same, except for some hyperparameters specific to ours, which are detailed listed in Table 1.

6.2 Results

Table 2 Accuracy (\(\%\)) on five DG datasets using pretrained \({\text {ResNet-50}}\) backbone. \(^\dag\) denotes results of the baseline are reproduced under the same training and evaluation protocol (by Gulrajani and Lopez-Paz (2020)) as ours. Results of the other three baselines are from the original literature Dou et al. (2019); Chattopadhyay et al. (2020); Xiao et al. (2021)
Table 3 Ablation study on inner-loop and outer-loop objectives
Table 4 Ablation study on bilevel meta-learning

Table 2 shows the main results and Tables 3 & 4 show the ablation study.

6.2.1 Methods

We make comparisons with several related methods in Table 2. The compared approaches include ERM (Vapnik, 1999), domain-invariance learning (Chattopadhyay et al., 2020; Ganin et al., 2016; Sun & Saenko, 2016; Li et al., 2018a, b; Nam et al., 2019; Arjovsky et al., 2019; Xiao et al., 2021) and meta-learning (Li et al., 2018c; Balaji et al., 2018; Dou et al., 2019). Compared with these baselines, our algorithm achieves the best results on all the five datasets, which shows the effectiveness of the proposed bilevel optimization algorithm for DG.

6.2.2 Ablation study on inner-loop and outer-loop objectives

As shown in Table 3, we compare a range of variations of choosing the inner-loop or outer-loop objectives between task objective and \({\mathcal {Y}}\)-discrepancy. The first line is similar to the invariance learning approach (Zhang et al., 2021), which optimizes the \({\mathcal {Y}}\)-discrepancy across different source domains. Compared with this baseline, our approach (bottom line) achieves better results on both datasets, which shows that the proposed bilevel optimization algorithm can improve invariant representation learning for DG. In addition, compared with other meta-learning approaches, the proposed algorithm achieves the best results, which shows the potential of optimizing domain discrepancy to reduce domains shift for DG.

6.2.3 Ablation study on bilevel meta-learning

As shown in Table 4, we compare with two prior meta-learning algorithms (Li et al., 2018c; Balaji et al., 2018). We further make connection to these methods by unifying the empirical inner-loop and outer-loop objectives as our approach, and present two baselines Ours-MLDG and Ours-MetaReg to compare the frameworks of bilevel meta-learning. Results show that our approach is more effective than other variants of meta-learning framework. Besides, Ours-MLDG and Ours-MetaReg outperform the original MLDG (Li et al., 2018c) and MetaReg (Balaji et al., 2018), respectively. This shows the effectiveness of meta-learning the invariant representation for DG.

6.3 Analysis

Fig. 3
figure 3

The effectiveness of reducing \({\mathcal {Y}}\)-discrepancy by the bilevel optimization algorithm

6.3.1 Domain discrepancy

In Fig. 3b, we show the effectiveness of adversarial training strategy against the factor \(\alpha\) and \(\gamma\) for minimizing the \({\mathcal {Y}}\)-discrepancy across different source domains (top left), and the \({\mathcal {Y}}\)-discrepancy between the hold-out domain and source domains (bottom left), respectively. We can find that with the adversarial factors increasing from 0.01 to 2.00, both the \({\mathcal {Y}}\)-discrepancy across different source domains and the \({\mathcal {Y}}\)-discrepancy between the hold-out domain and source domains first decrease with only some small fluctuations and then come to a plateau or tend to slightly increase. This shows the sensitivity of adversarial factors for minimizing the \({\mathcal {Y}}\)-discrepancy in both inner-loop optimization and outer-loop optimization.

As shown in Fig. 3d, we compare \({\mathcal {Y}}\)-discrepancy (Zhang et al., 2012) with the ERM algorithm and an invariant representation learning algorithm (the same as the first line of Table 3) on five datasets. The top right picture shows that both our approach and invariance learning can better reduce the \({\mathcal {Y}}\)-discrepancy between source domains compared with the ERM algorithm. This is because these two approaches have a training objective to reduce \({\mathcal {Y}}\)-discrepancy across different source domains. In addition, the bottom right picture shows that the \({\mathcal {Y}}\)-discrepancy between the hold-out domain and source domains of our approach is lower than both the ERM algorithm and the invariance learning algorithm, which shows the effectiveness of meta-learning to achieve more robust domain invariance.

Fig. 4
figure 4

t-SNE visualization of feature representation on PACS when the target domain is photo. Each class is represented by a specific marker and each domain is represented in a specific colors where the target domain is in gray

6.3.2 Visualization

We visualize the learned feature representation in Fig. 1. We randomly select 250 test examples from each domain. As shown in Fig. 4, compared with ERM, both domain-invariant learning and our method can match the feature distributions of source domains; Compared with the domain-invariant learning, our method can also well match the feature distributions of the target and source domains, which benefits from the outer-loop objective in bilevel optimization to improve the robustness to domain shift.

7 Conclusion

We investigated a meta-learning approach for invariant representation learning to improve domain generalization. In particular, we learn a more robust domain invariance via a bilevel optimization algorithm, where the inner-loop aims to minimize the \({\mathcal {Y}}\)-discrepancy across source domains while the outer-loop aims to minimize the \({\mathcal {Y}}\)-discrepancy between the target and source domains. Theoretically, we show from a geometric perspective that the meta-learning approach minimizes the \({\mathcal {Y}}\)-discrepancy between the target domain and a convex hull of source domains. Empirically, our approach achieves the best results on five domain generalization datasets among a range of strong baselines.