1 Introduction

Deep neural networks (DNNs) have achieved outstanding performance on prediction tasks like visual object and speech recognition (Krizhevsky et al. 2012; Szegedy et al. 2015; He et al. 2015). Issues can arise when the learned representations rely on dependencies that vanish in test distributions (see for example Quionero-Candela et al. (2009),Torralba and Efros (2011), Csurka (2017) and references therein). Such domain shifts can be caused by changing conditions such as color, background or location changes. Predictive performance is then likely to degrade. For example, consider the analysis presented in Kuehlkamp et al. (2017) which is concerned with the problem of predicting a person’s gender based on images of their iris. The results indicate that this problem is more difficult than previous studies have suggested due to the remaining effect of cosmetics after segmenting the iris from the whole image.Footnote 1 Previous analyses obtained good predictive performance on certain datasets but when testing on a dataset only including images without cosmetics accuracy dropped. In other words, the high predictive performance previously reported relied to a significant extent on exploiting the confounding effect of mascara on the iris segmentation which is highly predictive for gender. Rather than the desired ability of discriminating based on the iris’ texture the systems would mostly learn to detect the presence of cosmetics.

More generally, existing biases in datasets used for training machine learning algorithms tend to be replicated in the estimated models (Bolukbasi et al. 2016). For an example involving Google’s photo app, see Crawford (2016) and Emspak (2016). In Sect. 5 we show many examples where unwanted biases in the training data are picked up by the trained model. As any bias in the training data is in general used to discriminate between classes, these biases will persist in future classifications, raising also considerations of fairness and discrimination (Barocas and Selbst 2016).

Addressing the issues outlined above, we propose Conditional variance Regularization (CoRe) to give differential weight to different latent features. Conceptually, we take a causal view of the data generating process and categorize the latent data generating factors into ‘conditionally invariant’ (core) and ‘orthogonal’ (style) features, as in Gong et al. (2016). The core and style features are unobserved and can in general be highly nonlinear transformations of the observed input data. It is desirable that a classifier only extracts the latent core features from the input data as they pertain to the target of interest in a stable and coherent fashion. Basing a prediction on the core features alone yields stable predictive accuracy even if the style features are altered. Under suitable assumptions, CoRe yields an estimator which is approximately invariant under changes in the conditional distribution of the style features (conditional on the class labels) and it is asymptotically robust with respect to domain shifts, arising through interventions on the style features. CoRe relies on the fact that for certain datasets we can observe grouped observations in the sense that we observe the same object under different conditions. For instance, such grouping information is available

  1. (i)

    in natural image data when several pictures of the same person are taken;

  2. (ii)

    in medical imaging when several images belonging to the same patient are made;

  3. (iii)

    in speech recognition when multiple recordings from the same speaker are available;

  4. (iv)

    in video data where nearby frames showing the same objects can be exploited to group observations;

  5. (v)

    in data augmentation where a transformed data point can be grouped together with the original one.

We will show examples for the first and last category. For the last category, we will show that pairing the augmented data with the original image they were generated from helps to improve accuracy and robustness with respect to the chosen transformation.

Rather than pooling over all examples, CoRe exploits knowledge about this grouping, i.e., that a number of instances relate to the same object. By penalizing between-object variation of the prediction less than variation of the prediction for the same object, we can steer the prediction to be based more on the latent core features and less on the latent style features. While the proposed methodology can be motivated from the desire the achieve representational invariance with respect to the style features, the causal framework we use throughout this work allows to precisely formulate the distribution shifts we aim to protect against.

The remainder of this manuscript is structured as follows: Sect. 1.1 starts with a few motivating examples, showing simple settings where the style features change in the test distribution such that standard empirical risk minimization approaches would fail. In Sect. 1.2 we review related work, introduce notation in Sect. 2 and in Sect. 3 we formally introduce conditional variance regularization CoRe. In Sect. 4, CoRe is shown to be asymptotically equivalent to minimizing the risk under a suitable class of strong interventions in a partially linear classification setting, provided one chooses sufficiently strong CoRe penalties. We also show that the population CoRe penalty induces domain shift robustness for general loss functions to first order in the intervention strength. The size of the conditional variance penalty can be shown to determine the size of the distribution class over which we can expect distributional robustness. In Sect. 5 we evaluate the performance of CoRe in a variety of experiments.

To summarize, our contributions are the following:

  1. (i)

    Causal framework and distributional robustness We build on the causal framework from Gong et al. (2016) to define distributional shifts for style variables. This allows us to formulate the objective of interest in terms of distributional robust inference. Specifically, the distribution class, on which the estimator should achieve a guaranteed performance bound, consists of those distributions that are generated by interventions on the latent style variables in a causal model. Our framework allows that the domain variable itself is latent.

  2. (ii)

    Conditional variance penalties We introduce conditional variance penalties and show two robustness properties in Theorems 1 and 2.

  3. (iii)

    Software We illustrate our ideas using synthetic and real-data experiments. A TensorFlow implementation of CoRe as well as code to reproduce some of the experimental results are available at https://github.com/christinaheinze/core.

1.1 Motivating examples

To motivate the methodology we propose, consider the examples shown in Figs. 1 and 2. Example 1 shows a setting where a nonlinear decision boundary is required. Here, the core feature corresponds to the distance from the origin while the style feature corresponds to the angle between the \(x_1\)-axis and the vector from the origin to \((x_1, x_2)\). Panel (a) shows a subsample of the training data where class 1 is associated with red points, dark blue points correspond to class 0. Panel (b) additionally shows a subsample of the test data where the style—i.e. the distribution of the angle—is intervened upon: class 1 is associated with orange squares, cyan squares correspond to class 0. Clearly, a circular decision boundary yields optimal performance on both training and test set but is unlikely to be found by a standard classification algorithm when only using the training set for the estimation. We will return to these examples in Sect. 3.4.

Fig. 1
figure 1

Motivating example 1: The distributions are shifted in the test data by style interventions where style is the polar angle. Standard estimators achieve error rates of \(0\%\) on the training data and test data drawn from the same distribution as the training data (panel a). On the shown test set where the distribution of the style conditional on Y has changed the error rates are \(> 50\%\) (panel b)

Fig. 2
figure 2

Motivating example 2: The goal is to predict whether a person is wearing glasses. The distributions are shifted in test data by style interventions where style is the image quality. A 5-layer CNN achieves 0% training error and 2% test error for images that are sampled from the same distribution as the training images a, but a 65% error rate on images where the confounding between image quality and glasses is changed b. See Sect. 5.3 for more details

Secondly, we introduce a strong dependence between the class label and the style feature “image quality” in the third example by manipulating the face images from the CelebA dataset (Liu et al. 2015): in the training set images of class “wearing glasses” are associated with a lower image quality than images of class “not wearing glasses”. Examples are shown in Fig. 2a. In the test set, this relation is reversed, i.e. images showing persons wearing glasses are of higher quality than images of persons without glasses, with examples in Fig. 2b. We will return to this example in Sect. 5.3 and show that training a convolutional neural network to distinguish between people wearing glasses or not works well on test data that are drawn from the same distribution (with error rates below 2%) but fails entirely on the shown test data, with error rates worse than 65%.

1.2 Related work

For general distributional robustness, the aim is to learn

$$\begin{aligned} \mathop {\text {argmin}}\limits _\theta \; \sup _{F\in {\mathcal {F}}} E_F(\ell (Y,f_\theta (X)))\end{aligned}$$
(1)

for a given set \({\mathcal {F}}\) of distributions, twice differentiable and convex loss \(\ell\), and prediction \(f_\theta (x)\). The set \({\mathcal {F}}\) is the set of distributions on which one would like the estimator to achieve a guaranteed performance bound.

Causal inference can be seen to be a specific instance of distributional robustness, where we take \({\mathcal {F}}\) to be the class of all distributions generated under do-interventions on the predictors X (Meinshausen 2018; Rothenhäusler et al. 2018). Causal models thus have the defining advantage that the predictions will be valid even under arbitrarily large interventions on all predictor variables (Haavelmo 1944; Aldrich 1989; Pearl 2009; Schölkopf et al. 2012; Peters et al. 2016; Zhang et al. 2013, 2015; Yu et al. 2017; Rojas-Carulla et al. 2018; Magliacane et al. 2018). There are two difficulties in transferring these results to the setting of domain shifts in image classification. The first hurdle is that the classification task is typically anti-causal since the image we use as a predictor is a descendant of the true class of the object we are interested in rather than the other way around. The second challenge is that the input data consists of pixel intensities and we do not want (or could) guard against arbitrary interventions on any or all variables but only would like to guard against a shift of the unobserved style features. It is hence not immediately obvious how standard causal inference can be used to guard against large domain shifts.

Another line of work uses a class of distributions of the form \({\mathcal {F}}= {\mathcal {F}}_\epsilon (F_0)\) with

$$\begin{aligned} {\mathcal {F}}_\epsilon (F_0):= \{ \text {distributions } F \text { such that } D(F,F_0) \le \epsilon \},\end{aligned}$$
(2)

with \(\epsilon >0\) a small constant and \(D(F,F_0)\) being, for example, a \(\phi\)-divergence (Namkoong and Duchi 2017; Ben-Tal et al. 2013; Bagnell 2005; Volpi et al. 2018) or a Wasserstein distance (Shafieezadeh-Abadeh et al. 2017; Sinha et al. 2018; Gao et al. 2017). The distribution \(F_0\) can be the true (but generally unknown) population distribution P from which the data were drawn or its empirical counterpart \(P_n\). The distributionally robust targets in Eq. (2) can often be expressed in penalized form (Gao et al. 2017; Sinha et al. 2018; Xu et al. 2009). A Wasserstein-ball is a suitable class of distributions for example in the context of adversarial examples (Sinha et al. 2018; Szegedy et al. 2014; Goodfellow et al. 2015).

In this work, we do not try to achieve robustness with respect to a set of distributions that are pre-defined by a Kullback-Leibler divergence or a Wasserstein metric as in Eq. (2). Instead, we try to achieve robustness against a set of distributions that are generated by interventions on latent style variables in a causal model (we will make this precise in Sect. 2). We will formulate the class of distributions over which we try to achieve robustness as in Eq. (1) but with the class of distributions in Eq. (2) now replaced with the class of distributions \({\mathcal {F}}_\xi\) defined as \(\{F: D_{\text {style}} (F,F_0) \le \xi \},\) where \(F_0\) is again the distribution the training data are drawn from. The difference to standard distributional robustness approaches listed below Eq. (2) is now that the metric \(D_{\text {style}}\) measures the shift of the orthogonal style features. We do not know a priori which features are prone to distributional shifts and which features have a stable (conditional) distribution. The metric is hence not known a priori and needs to be inferred in a suitable sense from the data.

Similar to this work in terms of their goals are the work of Gong et al. (2016) and Domain-Adversarial Neural Networks (DANN) proposed in Ganin et al. (2016), an approach motivated by the work of Ben-David et al. (2007). The main idea of Ganin et al. (2016) is to learn a representation that contains no discriminative information about the origin of the input (source or target domain). This is achieved by an adversarial training procedure: the loss on domain classification is maximized while the loss of the target prediction task is minimized simultaneously. The data generating process assumed in Gong et al. (2016) is similar to our model, introduced in Sect. 2.1, where we detail the similarities and differences between the models (cf. Fig. 3). Gong et al. (2016) identify the conditionally independent features by adjusting a transformation of the variables to minimize the squared MMD distance between distributions in different domains.Footnote 2 The fundamental difference between these very promising methods and our approach is that we use a different data basis. The domain identifier is explicitly observable in Gong et al. (2016) and Ganin et al. (2016), while it is latent in our approach. In contrast, we exploit the presence of an identifier variable \(\mathrm {ID}\) that relates to the identity of an object (for example identifying a person). In other words, we do not assume that we have data from different domains but just different realizations of the same object under different interventions. This also differentiates this work from latent domain adaptation papers from the computer vision literature (Hoffman et al. 2012; Gong et al. 2013). Further related work is discussed in Sect. 6.

2 Setting

We introduce the assumed underlying causal graph and some notation before discussing notions of domain shift robustness.

2.1 Causal graph

Let \(Y\in {\mathcal {Y}}\) be a target of interest. Typically \({\mathcal {Y}}=\mathbb {R}\) for regression or \({\mathcal {Y}}=\{1,\ldots ,K\}\) in classification with K classes. Let \(X \in \mathbb {R}^p\) be predictor variables, for example the p pixels of an image. The causal structural model for all variables is shown in panel (b) of Fig. 3. The domain variable D is latent, in contrast to Gong et al. (2016) whose model is shown in panel (a) of Fig. 3. We add the \(\mathrm {ID}\) variable to the graph. In Fig. 3, \(Y \rightarrow \mathrm {ID}\) but in some settings it might be more plausible to consider \(\mathrm {ID}\rightarrow Y\). For the proposed method both options are possible. Together with Y, the \(\mathrm {ID}\) variable is used to group observations. It is typically discrete and relates to the identity of the underlying object. The variable can be assumed to be latent in the setting of Gong et al. (2016).

Fig. 3
figure 3

Observed quantities are shown as shaded nodes; nodes of latent quantities are transparent. Left: data generating process for the considered model as in Gong et al. (2016), where the effect of the domain on the orthogonal features \(S\) is mediated via unobserved noise \(\varDelta\). The style interventions and all its descendants are shown as nodes with dashed borders to highlight variables that are affected by style interventions. Right: our setting. The domain itself is unobserved but we can now observe the (typically discrete) \(\mathrm {ID}\) variable we use for grouping. The arrow between \(\mathrm {ID}\) and Y can be reversed, depending on the sampling scheme

The rest of the graph is in analogy to Gong et al. (2016). The prediction is anti-causal, that is the predictor variables X that we use for \({\hat{Y}}\) are non-ancestral to Y. In other words, the class label is here seen to be causal for the image and not the other way around.Footnote 3 The causal effect from the class label Y on the image X is mediated via two types of latent variables: the so-called core or ‘conditionally invariant’ features \(C\) and the orthogonal or style features \(S\). The distinguishing factor between the two is that external interventions \(\varDelta\) are possible on the style features but not on the core features. If the interventions \(\varDelta\) have different distributions in different domains, then the conditional distributions \(C|Y=y,\mathrm {ID}=\mathrm {id}\) are invariant for all \((y,\mathrm {id})\) while \(S|Y=y,\mathrm {ID}=\mathrm {id}\) can change. The style variable can include point of view, image quality, resolution, rotations, color changes, body posture, movement etc. and will in general be context-dependent.Footnote 4 The style intervention variable \(\varDelta\) influences both the latent style \(S\), and hence also the image X. In potential outcome notation, we let \(S(\varDelta =\delta )\) be the style under intervention \(\varDelta =\delta\) and \(X(Y,\mathrm {ID},\varDelta =\delta )\) the image for class Y, identity \(\mathrm {ID}\) and style intervention \(\varDelta\). The latter is sometimes abbreviated as \(X(\varDelta =\delta )\) for notational simplicity. Finally, \(f_\theta (X(\varDelta =\delta ))\) is the prediction under the style intervention \(\varDelta =\delta\). For a formal justification of using a causal graph and potential outcome notation simultaneously see Richardson and Robins (2013).

To be specific, if not mentioned otherwise we will assume a causal graph as follows. For independent \(\varepsilon _Y, \varepsilon _\mathrm {ID},\varepsilon _{\text {style}}\) in \(\mathbb {R},\mathbb {R},\mathbb {R}^q\) respectively with positive density on their support and continuously differentiable functions \(k_y, k_\mathrm {id}\), and \(k_{\text {style}},k_{\text {core}},k_x\),

$$\begin{aligned}&Y \leftarrow k_y(D,\varepsilon _Y) \nonumber \\ \text {identifier }&\mathrm {ID}\leftarrow k_\mathrm {id}(Y,\varepsilon _\mathrm {ID}) \nonumber \\ \text {core or conditionally invariant features }&C\leftarrow k_{\text {core}}(Y,\mathrm {ID}) \nonumber \\ \text {style or orthogonal features }&S\leftarrow k_{\text {style}}(Y,\mathrm {ID},\varepsilon _{\text {style}}) + \varDelta \nonumber \\ \text {image }&X \leftarrow k_x(C, S) . \end{aligned}$$
(3)

The distribution of \(\varepsilon _{\text {style}}\) is assumed to be identical across domains, while \(\varDelta\) can change. In more generality, one could discard \(\varDelta\) and instead allow the distribution of \(\varepsilon _{\text {style}}\) to change. Here the assumption is slightly more restrictive because of additivity of \(\varDelta\) in the structural equation for the style variable. The core features are here assumed to be a deterministic function of Y and \(\mathrm {ID}\) to allow for theoretical analysis. In more generality (and as indicated in the graph), these would also be non-deterministic relations. The theoretical results will also require positive density for the style features in an \(\epsilon\)-ball around the origin, as made precise in assumption (A1) later.

The prediction \({\hat{y}}\) for y, given \(X=x\), is of the form \(f_\theta (x)\) for a suitable function \(f_\theta\) with parameters \(\theta \in \mathbb {R}^d\), where the parameters \(\theta\) correspond to the weights in a DNN, for example.

We would like to stress that the above model is fairly general and subsumes many simpler ones as special cases. To give a concrete example, consider the task of classifying a health condition Y from medical images X. Style features could be, for example, technical noise, orientation or resolution. The unobserved domain \(D\) could correspond to different hospitals or doctors. Due to the usage of different measuring devices in each of these locations, the conditional distribution \(S\vert Y\) will change substantially across different domains. In contrast, the core features \(C\), i.e. those image features that carry the actual signal, will remain invariant conditional on the underlying health condition Y.

2.2 Data

We assume we have \(n\) data points \((x_i,y_i,\mathrm {id}_i)\) for \(i=1,\ldots ,n\), where the observations \(\mathrm {id}_i\) with \(i=1,\ldots ,n\) of variable \(\mathrm {ID}\) can also contain unobserved values. Let \(m\le n\) be the number of unique realizations of \((Y,\mathrm {ID})\) and let \(G_1,\ldots ,G_m\) be a partition of \(\{1,\ldots ,n\}\) such that, for each \(j\in \{1,\ldots ,m\}\), the realizations \((y_i,\mathrm {id}_i)\) are identicalFootnote 5 for all \(i\in G_j\). While our prime application is classification, regression settings with continuous Y can be approximated in this framework by slicing the range of the response variable into distinct bins in analogy to the approach in sliced inverse regression (Li 1991). The cardinality of \(G_j\) is denoted by \(n_j:=|G_j| \ge 1\). Then \(n= \sum _i n_i\) is again the total number of samples and \(c= n-m\) is the total number of grouped observations in the following sense: if we count all samples in a group except the first one we have, if summing over all groups, a total of \(c= n-m= \sum _{i=1}^n (n_i - 1)\) observations left that are ‘grouped’ with the first example in their corresponding group.

Typically \(n_i=1\) for most samples and occasionally \(n_i \ge 2\) but one can also envisage scenarios with larger groups of the same identifier \((y,\mathrm {id})\).

2.3 Domain shift robustness

In this section, we clarify against which classes of distributions we hope to achieve robustness. Let \(\ell\) be a suitable loss that maps y and \({\hat{y}}=f_\theta (x)\) to \(\mathbb {R}^+\). The risk under distribution F and parameter \(\theta\) is given by

$$\begin{aligned} E_F\Big [ \ell ( Y, f_\theta (X)) \Big ] .\end{aligned}$$

Let \(F_0\) be the joint distribution of \((\mathrm {ID},Y,S)\) in the training distribution. A new domain and explicit interventions on the style features can now shift the distribution of \((\mathrm {ID},Y,\tilde{S})\) to F. We can measure the distance between distributions \(F_0\) and F in different ways. Below we will define the distance considered in this work and denote it by \(D_{\text {style}}(F,F_0)\). Once defined, we get a class of distributions

$$\begin{aligned} {\mathcal {F}}_\xi =\{F: D_{\text {style}}(F_0,F) \le \xi \} \end{aligned}$$
(4)

and the goal will be to optimize a worst-case loss over this distribution class in the sense of Eq. (1), where larger values of \(\xi\) afford protection against larger distributional changes. The relevant loss for distribution class \({\mathcal {F}}_\xi\) is then

$$\begin{aligned} L_\xi (\theta ) = \sup _{F\in {\mathcal {F}}_\xi } E_{F}\Big [ \ell \big (Y, f_\theta \big (X\big )\big )\Big ] . \end{aligned}$$
(5)

In the limit of arbitrarily strong interventions on the style features \(S\), the loss is given by

$$\begin{aligned} L_\infty (\theta ) = \lim _{\xi \rightarrow \infty } \sup _{F\in {\mathcal {F}}_\xi } E_{F}\Big [ \ell \big (Y, f_\theta \big (X\big )\big )\Big ] . \end{aligned}$$
(6)

Minimizing the loss \(L_\infty (\theta )\) with respect to \(\theta\) guarantees an accuracy in prediction which will work well across arbitrarily large shifts in the conditional distribution of the style features.

A natural choice to define \(D_\text {style}\) is to use a Wasserstein-type distance (see e.g. Villani 2003). We will first define a distance \(D_{y,\mathrm {id}}\) for the conditional distributions

$$\begin{aligned} S|Y=y,\mathrm {ID}=\mathrm {id}\quad \text { and } \quad \tilde{S}| Y=y,\mathrm {ID}=\mathrm {id},\end{aligned}$$

and then set \(D(F_0,F) = E(D_{Y,\mathrm {ID}})\), where the expectation is with respect to random \(\mathrm {ID}\) and labels Y. The distance \(D_{y,\mathrm {id}}\) between the two conditional distributions of \(S\) will be defined as a Wasserstein \(W_2^2(F_0,F)\)-distance for a suitable cost function \(c(x,\tilde{x})\). Specifically, let \(\Pi _{y,\mathrm {id}}\) be the couplings between the conditional distributions of \(S\) and \(\tilde{S}\), meaning measures supported on \(\mathbb {R}^q \times \mathbb {R}^q\) such that the marginal distribution over the first q components is equal to the distribution of \(S\) and the marginal distribution over the remaining q components equal to the distribution of \(\tilde{S}\). Then the distance between the conditional distributions is defined as

$$\begin{aligned} D_{y,\mathrm {id}} = \min _{M\in \Pi _{y,\mathrm {id}}} \; E\big [ c(x,\tilde{x})\big ],\end{aligned}$$

where \(c :\mathbb {R}^q\times \mathbb {R}^q \mapsto \mathbb {R}^+\) is a nonnegative, lower semi-continuous cost function. Here, we focus on a Mahalanobis distance as cost

$$\begin{aligned} c^2(x,\tilde{x} ) = (x-\tilde{x})^t \Sigma _{y,\mathrm {id}}^{-1} (x-\tilde{x}) .\end{aligned}$$

The cost of a shift is hence measured against the variability under the distribution \(F_0\), \(\Sigma _{y,\mathrm {id}} =\text {Var}(S|Y,\mathrm {ID}).\)Footnote 6

Clearly, since the core and style features are unobserved, we cannot directly optimize the loss (5) but need to infer the metric \(D_{\text {style}}\) from the input data X. In the next section, we will show how this can be achieved. Intuitively, the model in Fig. 3 implies that the variance conditional on \(Y,\mathrm {ID}\) stems from the difference in the style features. Hence, we would like to minimize the conditional variance of the prediction or loss when we condition on \(Y, \mathrm {ID}\). This enforces the desired invariance with respect to the style features.

3 Conditional variance regularization

3.1 Pooled estimator

Let \((x_i,y_i)\) for \(i=1,\ldots ,n\) be the observations that constitute the training data and \(\hat{y_i}=f_\theta (x_i)\) the prediction for \(y_i\). The standard approach is to simply pool over all available observations, ignoring any grouping information that might be available. The pooled estimator thus treats all examples identically by summing over the empirical loss as

$$\begin{aligned} {\hat{\theta }}^{pool} = \mathop { \text{ argmin }}\limits _\theta \; {\hat{E}}\Big [\ell (Y, f_\theta (X))\Big ] + \gamma \cdot \text {pen}(\theta ),\end{aligned}$$
(7)

where the first part is simply the empirical loss over the training data,

$$\begin{aligned} {\hat{E}}\Big [\ell (Y, f_\theta (X))\Big ] = \frac{1}{n} \sum _{i=1}^n\ell \big ( y_i , f_\theta (x_{i})\big ) .\end{aligned}$$

In the second part, \(\text {pen}(\theta )\) is a complexity penalty, for example a squared \(\ell _2\)-norm of the weights \(\theta\) in a convolutional neural network as a ridge penalty.

3.2 CoRe estimator

The CoRe estimator is defined in Lagrangian form for penalty \(\lambda \ge 0\) as

$$\begin{aligned} {\hat{\theta }}^{core}(\lambda ) = \mathop { \text{ argmin }}\limits _\theta \; {\hat{E}}\Big [\ell (Y, f_\theta (X))\Big ] \; + \lambda \cdot {\hat{C}}_{\theta }.\end{aligned}$$
(8)

The penalty \({\hat{C}}_{\theta }\) is a conditional variance penalty of the form

$$\begin{aligned}&\text {conditional-variance-of-prediction:}\qquad {\hat{C}}_{f,\nu ,\theta }:={\hat{E}}\big [ \widehat{\text {Var}}( f_\theta (X)|Y,\mathrm {ID})^\nu \big ] \end{aligned}$$
(9)
$$\begin{aligned}&\text {conditional-variance-of-loss:}\qquad {\hat{C}}_{\ell ,\nu ,\theta }:={\hat{E}}\big [ \widehat{\text {Var}}( \ell (Y, f_\theta (X))|Y,\mathrm {ID}) ^\nu \big ] , \end{aligned}$$
(10)

where typically \(\nu \in \{1/2,1\}\). For \(\nu =1/2\), we also refer to the respective penalties as “conditional-standard-deviation” penalties. In practice in the context of classification and DNNs, we apply the penalty (9) to the predicted logits. The conditional-variance-of-loss penalty (10) takes a similar form to Namkoong and Duchi (2017). The crucial difference of our approach to Namkoong and Duchi (2017) is that we penalize with the expected conditional variance or standard deviation. The fact that we take a conditional variance is here important as we try to achieve distributional robustness with respect to interventions on the style variables. Conditioning on \(\mathrm {ID}\) allows to guard specifically against these interventions. An unconditional variance penalty, in contrast, can achieve robustness against a pre-defined class of distributions such as a ball of distributions defined in a Kullback-Leibler or Wasserstein metric. The population CoRe estimator is defined as in Eq. (8) where empirical estimates are replaced by their respective population quantities.

Before showing numerical examples, we discuss the estimation of the expected conditional variance in Sect. 3.3 and return to the simple examples of Sect. 1.1 in Sect. 3.4. Domain shift robustness in a classification setting for a partially linear version of the structural equation model (3) is shown in Sect. 4.1. Furthermore, we discuss the population limit of \({\hat{\theta }}^{core}(\lambda )\) in Sect. 4.2, where we show that the regularization parameter \(\lambda \ge 0\) is proportional to the size of the future style interventions that we want to guard against for future test data.

3.3 Estimating the expected conditional variance

Recall that \(G_j\subseteq \{1,\ldots ,n\}\) contains samples with identical realizations of \((Y,\mathrm {ID})\) for \(j\in \{1,\ldots ,m\}\). For each \(j\in \{1,\ldots ,m\}\), define \({\hat{\mu }}_{\theta ,j}\) as the arithmetic mean across all \(f_\theta (x_{i}), i\in G_j\). The canonical estimator of the conditional variance \({\hat{C}}_{f,1,\theta }\) is then

$$\begin{aligned} {\hat{C}}_{f,1,\theta }&:= \frac{1}{m} \sum _{j =1}^m\frac{1}{|G_j|} \sum _{i\in G_j} (f_\theta (x_{i})-{\hat{\mu }}_{\theta ,j})^2,\quad \text{ where }\;\; {\hat{\mu }}_{\theta ,j} = \frac{1}{|G_j|} \sum _{i\in G_j} f_\theta (x_{i}) \end{aligned}$$

and analogously for the conditional-variance-of-loss, defined in Eq. (10)Footnote 7. If there are no groups of samples that share the same identifier \((y,\mathrm {id})\), we define \({\hat{C}}_{f,1,\theta }\) to vanish. The CoRe estimator is then identical to pooled estimation in this special case.

3.4 Motivating examples (continued)

We revisit the first example from Sect. 1.1. Figure 4 shows subsamples of the training and test set with the estimated decision boundaries for different values of the penalty parameter \(\lambda\) when using a 2-layer fully connected neural network. Here, \(n=20{,}000\) and \(c=500\). Additionally, grouped examples that share the same \((y,\mathrm {id})\) are visualized: two grouped observations are connected by a line or curve, respectively. Ten such groups are shown. Panel (a) shows the decision boundaries for \(\lambda =0\), equivalent to the pooled estimator, and for CoRe with \(\lambda \in \{0, 0.05, 0.1, 1\}\). The pooled estimator misclassifies a large number of test points as can be seen in panel (b), suffering from a test error of \(\approx 58\%\). In contrast, the decision boundary of the CoRe estimator with \(\lambda =1\) aligns with the direction along which the grouped observations vary, classifying the test set with almost perfect accuracy (test error is \(\approx 0\%\)).

Fig. 4
figure 4

The decision boundary as function of the penalty parameters \(\lambda\) for Example 1 from Fig. 1. There are ten pairs of samples visualized that share the same identifier \((y,\mathrm {id})\) and these are connected by a curve in the figures. The decision boundary associated with a solid line corresponds to \(\lambda =0\), the standard pooled estimator that ignores the groupings. The broken lines are decision boundaries for increasingly strong penalties, taking into account the groupings in the data. Here, we only show a subsample of the data to avoid overplotting

4 Domain shift robustness for the CoRe estimator

We show two properties of the CoRe estimator. First, consistency is shown under the risk definition (6) for an infinitely large conditional variance penalty and the logistic loss in a partially linear structural equation model. Second, the population CoRe estimator is shown to achieve distributional robustness against shift interventions in a first order expansion.

4.1 Asymptotic domain shift robustness under strong interventions

We analyze the loss under strong domain shifts, as given in Eq. (6), for the pooled and the CoRe estimator in a one-layer network for binary classification (logistic regression) in an asymptotic setting of large sample size and strong interventions.

Assume the structural equation for the image \(X\in \mathbb {R}^p\) is linear in the style features \(S\in \mathbb {R}^q\) (with generally \(p\gg q\)) and we use logistic regression to predict the class label \(Y\in \{-1,1\}\). Let the interventions \(\varDelta \in \mathbb {R}^q\) act additively on the style features \(S\) (this is only for notational convenience) and let the style features \(S\) act in a linear way on the image X via a matrix \({W}\in \mathbb {R}^{p\times q}\) (this is an important assumption without which results are more involved). The core or ‘conditionally invariant’ features are \(C\in \mathbb {R}^r\), where in general \(r\le p\) but this is not important for the following. For independent \(\varepsilon _Y, \varepsilon _\mathrm {ID},\varepsilon _{\text {style}}\) in \(\mathbb {R},\mathbb {R},\mathbb {R}^q\) respectively with positive density on their support and continuously differentiable functions \(k_y, k_\mathrm {id},k_{\text {style}},k_{\text {core}},k_x\),

$$\begin{aligned} \text {class }&Y \leftarrow k_y(D,\varepsilon _Y) \nonumber \\ \text {identifier }&\mathrm {ID}\leftarrow k_\mathrm {id}(Y,\varepsilon _\mathrm {ID}) \nonumber \\ \text {core or conditionally invariant features }&C\leftarrow k_{\text {core}}(Y,\mathrm {ID}) \nonumber \\ \text {style or orthogonal features }&S\leftarrow k_{\text {style}}(Y,\mathrm {ID},\varepsilon _{\text {style}}) + \varDelta \nonumber \\ \text {image }&X \leftarrow k_x(C) + {W}S. \end{aligned}$$
(11)

We assume a logistic regression as a prediction of Y from the image data X:

$$\begin{aligned} f_\theta (x) := \frac{ \exp ( x^t \theta )}{1+\exp (x^t \theta )} .\end{aligned}$$

Given training data with \(n\) samples, we estimate \(\theta\) with \({\hat{\theta }}\) and use here a logistic loss \(\ell _{\theta }(y_i,x_i) = \log (1+ \exp ( - y_i ( x_i^t \theta )))\).

The formulation of Theorem 1 relies on the following assumptions.

Assumption 1

We require the following conditions:

  1. (A1)

    Assume the conditional distribution \(S|Y=y,\mathrm {ID}=\mathrm {id}\) under the training distribution \(F_0\) has positive density (with respect to the Lebesgue measure) in an \(\epsilon\)-ball in \(\ell _2\)-norm around the origin for some \(\epsilon >0\) for all \(y\in {\mathcal {Y}}\) and \(\mathrm {id}\in {\mathcal {I}}\).

  2. (A2)

    Assume the matrix \({W}\) has full rank q.

  3. (A3)

    Let \(M\le n\) be the number of unique realizations among n iid samples of \((Y,\mathrm {ID})\) and let \(p_n:=P(M\le n-q)\). Assume that \(p_n\rightarrow 1\) for \(n\rightarrow \infty\).

Assumption (A1) is a key assumption about the style variations we observe in the training set. It requires that we observe some variance in those directions that we expect to be subject to domain shifts in the future. If, on the other hand, the conditional variance in a particular direction is vanishing, we also expect it to vanish in the future. A violation of this assumption would imply that the guarantee of the CoRe regularization no longer holds. Assumption (A3) guarantees that the number \(c=n-m\) of grouped examples is at least as large as the dimension of the style variables. If we have too few or no grouped examples (small c), we cannot estimate the conditional variance accurately. Under these assumptions we can prove domain shift robustness.

Theorem 1

(Asymptotic domain shift robustness under strong interventions) Under model (11) and Assumption 1, with probability 1, the pooled estimator (7) has infinite loss (6) under arbitrarily large shifts in the distribution of the style features,

$$\begin{aligned} L_\infty ({\hat{\theta }}^{pool}) \;=\; \infty .\end{aligned}$$

The CoRe estimator (8) \({\hat{\theta }}^{core}\) with \(\lambda \rightarrow \infty\) is domain shift robust under strong interventions in the sense that for \(n\rightarrow \infty\),

$$\begin{aligned} L_\infty ({\hat{\theta }}^{core}) \; \rightarrow _p \; \inf _\theta L_\infty (\theta ) .\end{aligned}$$

A proof is given in “Appendix A”. The respective ridge penalties in both estimators (7) and (8) are assumed to be zero for the proof, but the proof can easily be generalized to include ridge penalties that vanish sufficiently fast for large sample sizes. The Lagrangian regularizer \(\lambda\) is assumed to be infinite for the CoRe estimator to achieve domain shift robustness under these strong interventions. The next section considers the population CoRe estimator in a setting with weak interventions and finite values of the penalty parameter.

4.2 Population domain shift robustness under weak interventions

The previous theorem states that the CoRe estimator can achieve domain shift robustness under strong interventions for an infinitely strong penalty in an asymptotic setting. An open question is how the loss (5),

$$\begin{aligned} L_\xi (\theta ) = \sup _{F\in {\mathcal {F}}_\xi } E_{F}\Big [ \ell \big (Y,f_\theta (X)\big )\Big ] \end{aligned}$$

behaves under interventions of small to medium size and correspondingly smaller values of the penalty. Here, we aim to minimize this loss for a given value of \(\xi\) and show that domain shift robustness can be achieved to first order with the population CoRe estimator using the conditional-standard-deviation-of-loss penalty, i.e., Eq. (10) with \(\nu = 1/2\), by choosing an appropriate value of the penalty \(\lambda\). Below we will show this appropriate choice of the penalty weight is \(\lambda =\sqrt{\xi }\).

Assumption 2

We require the following conditions:

  1. (B1)

    Define the loss under a deterministic shift \(\delta\) as

    $$\begin{aligned} h_\theta (\delta ) := E_{F_\theta }[\ell (Y,f_\theta (X))],\end{aligned}$$

    where the expectation is with respect to random \((\mathrm {ID},Y,\tilde{S})\sim F_\theta\), with \(F_\theta\) defined by the deterministic shift intervention \(\tilde{S}=S+\delta\) and \((\mathrm {ID},Y,\tilde{S})\sim F_0\). Assume that for all \(\theta \in \Theta\), \(h_\theta (\delta )\) is twice continuously differentiable with bounded second derivative for a deterministic shift \(\delta \in \mathbb {R}^q\).

  2. (B2)

    The spectral norm of the conditional variance \(\Sigma _{y,\mathrm {id}}\) of \(S|Y,\mathrm {ID}\) under \(F_0\) is assumed to be smaller or equal to some \(\zeta \in \mathbb {R}\) for all \(y\in {\mathcal {Y}}\) and \(\mathrm {id}\in {\mathcal {I}}\).

The first assumption (B1) ensures that the loss is well behaved under interventions on the style variables. The second assumption (B2) allows to take the limit of small conditional variances in the style variables.

If setting \(\lambda =\sqrt{\xi }\) and using the conditional-standard-deviation-of-loss penalty, the CoRe estimator optimizes according to

$$\begin{aligned} {\hat{\theta }}^{core}(\sqrt{\xi }) = \mathop { \text{ argmin }}\limits _\theta \; {\hat{E}}_{F_0}\big [\ell (Y, f_\theta (X))\big ] \; + \sqrt{\xi } \cdot {\hat{C}}_{\ell ,1/2,\theta }.\end{aligned}$$

The next theorem shows that this is to first order equivalent to minimizing the worst-case loss over the distribution class \({\mathcal {F}}_\xi\). The following result holds for the population CoRe estimator, see below for a discussion about consistency.

Theorem 2

The supremum of the loss over the class of distribution \(F_\xi\) is to first-order given by the expected loss under distribution \(F_0\) with an additional conditional-standard-deviation-of-loss penalty \({C}_{\ell ,1/2,\theta }\)

$$\begin{aligned} \sup _{F\in {\mathcal {F}}_\xi } E_{F}\big [ \ell \big (Y,f_\theta (X)\big )\big ] =E_{F_0}\big [ \ell \big ( Y ,f_\theta (X)\big )\big ] + \sqrt{\xi }\cdot {C}_{\ell ,1/2,\theta }+O(\max \{\xi ,\zeta \}) .\end{aligned}$$
(12)

A proof is given in “Appendix B”. The objective of the population CoRe estimator matches thus to first order the loss under domain shifts if we set the penalty weight \(\lambda =\sqrt{\xi }\). Larger anticipated domain shifts thus require naturally a larger penalty \(\lambda\) in the CoRe estimation. The result is possible as we have chosen the Mahalanobis distance to measure shifts in the style variable and define \({\mathcal {F}}_\xi\), ensuring that the strength of shifts on style variables are measured against the natural variance on the training distribution \(F_0\).

In practice, the choice of \(\lambda\) involves a somewhat subjective choice about the strength of the distributional robustness guarantee. A stronger distributional robustness property is traded off against a loss in predictive accuracy if the distribution is not changing in the future. One option for choosing \(\lambda\) is to choose the largest penalty weight before the validation loss increases considerably. This approach would provide the best distributional robustness guarantee that keeps the loss of predictive accuracy in the training distribution within a pre-specified bound.Footnote 8

As a caveat, the result takes the limit of small conditional variance of \(S\) in the training distribution and small additional interventions. Under larger interventions higher-order terms could start to dominate, depending on the geometry of the loss function and \(f_\theta\). A further caveat is that the result looks at the population CoRe estimator. For finite sample sizes, we would optimize a noisy version on the rhs of (12). To show domain shift robustness in an asymptotic sense, we would need additional uniform convergence (in \(\theta\)) of both the empirical loss and the conditional variance in that for \(n\rightarrow \infty\),

$$\begin{aligned} \sup _\theta | {\hat{E}}_{F_0}\big [\ell (Y, f_\theta (X))\big ] - E_{F_0}\big [ \ell \big ( Y ,f_\theta (X)\big )\big ] |&\rightarrow _p 0, \quad \text {and} \\ \sup _\theta | {\hat{C}}_{\ell ,1/2,\theta } -C_{\ell ,1/2,\theta } |&\rightarrow _p 0. \end{aligned}$$

While this is in general a reasonable assumption to make, the validity of the assumption will depend on the specific function class and on the chosen estimator of the conditional variance.

5 Experiments

We perform an array of different experiments, showing the applicability and advantage of the conditional variance penalty for two broad settings:

  1. 1.

    Settings where we do not know what the style variables correspond to but still want to protect against a change in their distribution in the future. In the examples we show cases where the style variable ranges from fashion (Sect. 5.2), image quality (Sect. 5.3), movement (Sect. 5.4) and brightness (“Appendix D.1”), which are all not known explicitly to the method. We also include genuinely unknown style variables in Sect. 5.1 (in the sense that they are unknown not only to the methods but also to us as we did not explicitly create the style interventions).

  2. 2.

    Settings where we do know what type of style interventions we would like to protect against. This is usually dealt with by data augmentation (adding images which are, say, rotated or shifted compared to the training data if we want to protect against rotations or translations in the test data; see for example Schölkopf et al. (1996)). The conditional variance penalty is here exploiting that some augmented samples were generated from the same original sample and we use as \(\mathrm {ID}\) variable the index of the original image. We show that this approach generalizes better than simply pooling the augmented data, in the sense that we need fewer augmented samples to achieve the same test error. This setting is shown in Sect. 5.5.

We compare against the pooled estimator which has the same architecture as the network to which we add the CoRe penalty. For both the pooled and the CoRe estimator we apply an \(\ell _2\) penalty as regularization. We would like to stress that the related work discussed in Sects. 1.2 and 6 cannot be directly compared to the CoRe estimator as these approaches cannot exploit the \(\mathrm {ID}\) information but rely on having data from different domains available at training time instead. Since this is a different problem setting, we can only compare against the pooled estimator which is a standard approach to classification. As a downside, our approach requires availability of an \(\mathrm {ID}\) variable, which might not always be availableFootnote 9 To further understand the behavior of the CoRe penalty, we perform a number of analyses and ablation studies to show

  1. (i)

    How sensitive the performance of CoRe is to the value of the penalty weight \(\lambda\) (Sects. 5.1.1, 5.2);

  2. (ii)

    How the CoRe penalty differs from a standard \(\ell _2\) penalty (Sect. 5.1.1);

  3. (iii)

    How the value of the CoRe penalty can be used as a qualitative measure for the presence of sample bias (Sects. 5.1.1, 5.2);

  4. (iv)

    How sensitive the performance of both the CoRe and the pooled estimator is to label shift in the grouped observations (Sect. 5.2.1);

  5. (v)

    How the relative performance of both estimators is affected when using pre-trained InceptionV3 features (Sect. 5.2.2);

  6. (vi)

    How sensitive the performance is to different grouping strategies (Sects. 5.3, “Appendix D.1”, D.3.1, D.4);

  7. (vii)

    how sensitive the performance is as a function of the strength of the domain shift and the number of grouped observations (“Appendices D.1, D.4, D.5, D.6”).

Details of the network architectures can be found in Appendix “Appendix C”. All reported error rates are averaged over five runs of the respective method. A TensorFlow (Abadi et al. 2015) implementation of CoRe can be found at https://github.com/christinaheinze/core.

5.1 Eyeglasses detection with small sample size

In this example, we explore a setting where training and test data are drawn from the same distribution, so we might not expect a distributional shift between the two. However, we consider a small training sample size which gives rise to statistical fluctuations between training and test data. We assess to which extent the conditional variance penalty can help to improve test accuracies in this setting.

Specifically, we use a subsample of the CelebA dataset (Liu et al. 2015) and try to classify images according to whether or not the person in the image wears glasses. For construction of the \(\mathrm {ID}\) variable, we exploit the fact that several photos of the same person are available and set \(\mathrm {ID}\) to be the identifier of the person in the dataset. Figure 5 shows examples from both the training and the test dataset. The conditional variance penalty is estimated across groups of observations that share a common \((Y,\mathrm {ID})\). Here, this corresponds to pictures of the same person where all pictures show the person either with glasses (if \(Y=1\)) or all pictures show the person without glasses (\(Y=0\)). Statistical fluctuations between training and test set could for instance arise if by chance the background of eyeglass wearers is darker in the training sample than in test samples, the eyeglass wearers happen to be outdoors more often or might be more often female than male etc.

Fig. 5
figure 5

Eyeglass detection for CelebA dataset with small sample size. The goal is to predict whether a person wears glasses or not. Random samples from training and test data are shown. Groups of observations in the training data that have common \((Y,\mathrm {ID})\) here correspond to pictures of the same person with either glasses on or off. These are labelled by red boxes in the training data and the conditional variance penalty is calculated across these groups of pictures

Below, we present the following analyses. First, we look at five different datasets and analyze the effect of adding the CoRe penalty (using conditional-variance-of-prediction) to the cross-entropy loss. Second, we focus on one dataset and compare the four different variants of the CoRe penalty in Eqs. (9) and (10) with \(\nu \in \{ 1/2, 1 \}.\)

5.1.1 CoRe penalty using the conditional variance of the predicted logits

We consider five different training sets which are created as follows. For each person in the standard CelebA training data we count the number of available images and select the 50 identities for which most images are available individually. We partition these 50 identities into 5 disjoint subsets of size 10 and consider the resulting 5 datasets, containing the images of 10 unique identities each. The resulting 5 datasets have sizes \(\{289, 296, 292, 287, 287\}\). For the validation and the test set, we consider the usual CelebA validation and test split but balance these with respect to the target variable “Eyeglasses”. The balanced validation set consists of 2766 observations; the balanced test set contains 2578 images. The identities in the validation and test sets are disjoint from the identities in the training sets.

Given a training dataset, the standard approach would be to pool all examples. The only additional information we exploit is that some observations can be grouped. If using a 5-layer convolutional neural network with a standard ridge penalty (details can be found in Table 5) and pooling all data, the test error on unseen images ranges from 18.08 to 25.97%. Exploiting the group structure with the CoRe penalty (in addition to a ridge penalty) results in test errors ranging from 14.79 to 21.49%, see Table 1. The relative improvements when using the CoRe penalty range from 9 to 28.6%.

Table 1 Eyeglass detection, trained on small subsets (DS1–DS5) of the CelebA dataset with disjoint identities

The test error is not very sensitive to the weight of the CoRe penalty as shown in Fig. 6a: for a large range of penalty weights, adding the CoRe penalty decreases the test error compared to the pooled estimator (identical to a CoRe penalty weight of 0). This holds true for various ridge penalty weights.

While the test error rates shown in Fig. 6 suggest already that the CoRe penalty differentiates itself clearly from a standard ridge penalty, we examine next the differential effect of the CoRe penalty on the between- and within-group variances. Concretely, the variance of the predictions can be decomposed as

$$\begin{aligned} \text {Var}(f_\theta (X)) = {E}\big [ {\text {Var}}(f_\theta (X)|Y,\mathrm {ID}) \big ] + {\text {Var}}\big [ {E}(f_\theta (X)|Y,\mathrm {ID}) \big ] ,\end{aligned}$$

where the first term on the rhs is the within-group variance that CoRe penalizes, while a ridge penalty would penalize both the within- and also the between-group variance (the second term on the rhs above). In Fig. 6b we show the ratio between the CoRe penalty and the between-group variance where groups are defined by conditioning on \((Y,\mathrm {ID})\). Specifically, the ratio is computed as

$$\begin{aligned} {\hat{E}}\big [ \widehat{\text {Var}}(f_\theta (X)|Y,\mathrm {ID}) \big ] / \widehat{\text {Var}}\big [ {\hat{E}}(f_\theta (X)|Y,\mathrm {ID}) \big ]. \end{aligned}$$
(13)

The results shown in Fig. 6b are computed on dataset 1 (DS 1). While increasing ridge penalty weights do lead to a smaller value of the CoRe penalty, the between-group variance is also reduced such that the ratio between the two terms does not decrease with larger weights of the ridge penalty.Footnote 10 With increasing weight of the CoRe penalty, the variance ratio decreases, showing that the CoRe penalty indeed penalizes the within-group variance more than the between-group variance.

Table 1 also reports the value of the CoRe penalty after training when evaluated for the pooled and the CoRe estimator on the training and the test set. As a qualitative measure to assess the presence of sample bias in the data (provided the model assumptions hold), we can compare the value the CoRe penalty takes after training when evaluated for the pooled estimator and the CoRe estimator. The difference yields a measure for the extent the respective estimators are functions of \(\varDelta\). If the respective hold-out values are both small, this would indicate that the style features are not very predictive for the target variable. If, on the other hand, the CoRe penalty evaluated for the pooled estimator takes a much larger value than for the CoRe estimator (as in this case), this would indicate the presence of sample bias.

Fig. 6
figure 6

Eyeglass detection, trained on a small subset (DS1) of the CelebA dataset with disjoint identities. a Average test error as a function of both the CoRe penalty on x-axis and various levels of the ridge penalty. The results can be seen to be fairly insensitive to the ridge penalty. b The variance ratio (13) on test data as a function of both the CoRe and ridge penalty weights. The CoRe penalty can be seen to penalize the within-group variance selectively, whereas a strong ridge penalty decreases both the within- and between-group variance

5.1.2 Other CoRe penalty types

We now compare all CoRe penalty types, i.e., penalizing with (i) the conditional variance of the predicted logits \({\hat{C}}_{f,1,\theta }\), (ii) the conditional standard deviation of the predicted logits \({\hat{C}}_{f,1/2,\theta }\), (iii) the conditional variance of the loss \({\hat{C}}_{l,1,\theta }\) and (iv) the conditional standard deviation of the loss \({\hat{C}}_{l,1/2,\theta }\). For this comparison, we use the training dataset 1 (DS 1) from above. Table 2 contains the test error (training error was \(0\%\) for all methods) as well as the value the respective CoRe penalty took after training on the training set and the test set. The four CoRe penalty variants’ performance differences are not statistically significant. Hence, we mostly focus on the conditional variance of the predicted logits \({\hat{C}}_{f,1,\theta }\) in the other experiments.

Table 2 Eyeglass detection, trained on a small subset (DS1) of the CelebA dataset with disjoint identities

5.1.3 Discussion

While the distributional shift in this example arises due to statistical fluctuations which will diminish as the sample size grows, the following examples are more concerned with biases that will persist even if the number of training and test samples is very large. A second difference to the subsequent examples is the grouping structure—in this example, we consider only a few identities, namely \(m=10\), with a relatively large number \(n_i\) of associated observations (about thirty observations per individual). In the following examples, \(m\) is much larger while \(n_i\) is typically smaller than five.

5.2 Gender classification with unknown confounding

In the following set of experiments, we work again with the CelebA dataset and the 5-layer convolutional neural network architecture described in Table 5. This time we consider the problem of classifying whether the person shown in the image is male or female. We create a confounding in training and test set I by including mostly images of men wearing glasses and women not wearing glasses. In test set 2 the association between gender and glasses is flipped: women always wear glasses while men never wear glasses. Examples from the training and test sets 1 and 2 are shown in Fig. 7. The training set, test set 1 and 2 are subsampled such that they are balanced with respect to Y, resulting in 16,982, 4224 and 1120 observations, respectively.

Fig. 7
figure 7

Classification for \(Y \in \{\text {woman},\text {man}\}\). There is an unknown confounding here as men are very likely to wear glasses in training and test set 1 data, while it is women that are likely to wear glasses in test set 2. Estimators that pool all observations are making use of this confounding and hence fail for test set 2. The conditional variance penalty for the CoRe estimator is computed over groups of images of the same person (and consequently same class label), such as the images in the red box on the left. The number of grouped examples \(c\) is 500. We vary the proportion of males in the grouped examples between 50 and 100% (cf. Sect. 5.2.1)

To compute the conditional variance penalty, we use again images of the same person. The \(\mathrm {ID}\) variable is, in other words, the identity of the person and gender Y is constant across all examples with the same \(\mathrm {ID}\). Conditioning on \((Y,\mathrm {ID})\) is hence identical to conditioning on \(\mathrm {ID}\) alone. Another difference to the other experiments is that we consider a binary style feature here.

5.2.1 Label shift in grouped observations

We compare six different datasets that vary with respect to the distribution of Y in the grouped observations. In all training datasets, the total number of observations is 16982 and the total number of grouped observations is 500. In the first dataset, 50% of the grouped observations correspond to males and 50% correspond to females. In the remaining 5 datasets, we increase the number of grouped observations with \(Y=\text {``man''}\), denoted by \(\kappa\), to 75%, 90%, 95%, 99% and 100%, respectively. Table 3 shows the performance obtained for these datasets when using the pooled estimator compared to the CoRe estimator with \({\hat{C}}_{f,1,\theta }\). The results show that both the pooled estimator as well as the CoRe estimator perform better if the distribution of Y in the grouped observations is more balanced. The CoRe estimator improves the error rate of the pooled estimator by \(\approx 28-39\%\) on a relative scale. Figure 8 shows the performance for \(\kappa =50\%\) as a function of the CoRe penalty weight. Significant improvements can be obtained across a large range of values for the CoRe penalty and the ridge penalty. Test errors become more sensitive to the chosen value of the CoRe penalty for very large values of the ridge penalty weight as the overall amount of regularization is already large.

Fig. 8
figure 8

Classification for \(Y \in \{\text {woman},\text {man}\}\) with \(\kappa =0.5\). Panels a and b show the test error on test data sets 1 and 2 respectively as a function of the CoRe and ridge penalty. Panels c and d show the variance ratio (13) (comparing within- and between- group variances) for females and males separately

Table 3 Classification for \(Y \in \{\text {woman},\text {man}\}\)

5.2.2 Using pre-trained Inception V3 features

To verify that the above conclusions do not change when using more powerful features, we here compare \(\ell _2\)-regularized logistic regression using pre-trained Inception V3 featuresFootnote 11 with and without the CoRe penalty. Table 4 shows the results for \(\kappa =0.5\). While the results show that both the pooled estimator as well as the CoRe estimator perform better using pre-trained Inception features, the relative improvement with the CoRe penalty is still 28% on test set 2.

Table 4 Classification for \(Y \in \{\text {woman},\text {man}\}\) with \(\kappa =0.5\) Here, we compared \(\ell _2\)-regularized logistic regression based on Inception V3 features with and without the CoRe penalty

5.2.3 Ablation experiments

In Sect. D.3.1, we report results for the following two additional baselines: (i) we group all examples sharing the same class label and penalize with the conditional variance of the predicted logits, computed over these two groups; (ii) we penalize the overall variance of the predicted logits, i.e., a form of unconditional variance regularization.

5.3 Eyeglasses detection with known and unknown image quality intervention

We now revisit the second example from Sect. 1.1. We again use the CelebA dataset and consider the problem of classifying whether the person in the image is wearing eyeglasses. Here, we modify the images in the following way: in the training set and in test set 1, we sample the image qualityFootnote 12 for all samples \(\{i:y_i =1\}\) (all samples that show glasses) from a Gaussian distribution with mean \(\mu =30\) and standard deviation \(\sigma =10\). Samples with \(y_i=0\) (no glasses) are unmodified. In other words, if the image shows a person wearing glasses, the image quality tends to be lower. In test set 2, the quality is reduced in the same way for \(y_i =0\) samples (no glasses), while images with \(y_i =1\) are not changed. Figure 9 shows examples from the training set and test sets 1 and 2. For the CoRe penalty, we calculate the conditional variance across images that share the same \(\mathrm {ID}\) if \(Y=1\), that is across images that show the same person wearing glasses on all images. Observations with \(Y=0\) (not wearing glasses) are not grouped. Two examples are shown in the red box of Fig. 9. Here, we have \(c=5000\) grouped observations among a total sample size of \(n=20{,}000\).

Fig. 9
figure 9

Eyeglass detection for CelebA dataset with image quality interventions (which are unknown to any procedure used). The JPEG compression level is lowered for \(Y=1\) (glasses) samples on training data and test set 1 and lowered for \(Y=0\) (no glasses) samples for test set 2. To the human eye, these interventions are barely visible but the CNN that uses pooled data without CoRe penalty has exploited the correlation between image quality and outcome Y to achieve a (arguably spurious) low test error of 2% on test set 1. However, if the correlation between image quality and Y breaks down, as in test set 2, the CNN that uses pooled data without a CoRe penalty has a 65% misclassification rate. The training data on the left show paired observations in two red boxes: these observations share the same label Y and show the same person \(\mathrm {ID}\). They are used to compute the conditional variance penalty for the CoRe estimator that does not suffer from the same degradation in performance for test set 2

Figure 9 shows misclassification rates for CoRe and the pooled estimator on test sets 1 and 2. The pooled estimator (only penalized with an \(\ell _2\) penalty) achieves low error rates of 2% on test set 1, but suffers from a 65% misclassification error on test set 2, as now the relation between Y and the implicit \(S\) variable (image quality) has been flipped. The CoRe estimator has a larger error of 13% on test set 1 as image quality as a feature is penalized by CoRe implicitly and the signal is less strong if image quality has been removed as a dimension. However, in test set 2 the performance of the CoRe estimator is 28% and improves substantially on the 65% error of the pooled estimator. The reason is again the same: the CoRe penalty ensures that image quality is not used as a feature to the same extent as for the pooled estimator. This increases the test error slightly if the samples are generated from the same distribution as training data (as here for test set 1) but substantially improves the test error if the distribution of image quality, conditional on the class label, is changed on test data (as here for test set 2).

Eyeglasses detection with known image quality intervention To compare to the above results, we repeat the experiment by changing the grouped observations as follows. Above, we grouped images that had the same person \(\mathrm {ID}\) when \(Y=1\). We refer to this scheme of grouping observations with the same \((Y, \mathrm {ID})\) as ‘Grouping setting 2’. Here, we use an explicit augmentation scheme and augment \(c=5000\) images with \(Y=1\) in the following way: each image is paired with a copy of itself and the image quality is adjusted as described above. In other words, the only difference between the two images is that image quality differs slightly, depending on the value that was drawn from the Gaussian distribution with mean \(\mu =30\) and standard deviation \(\sigma =10\), determining the strength of the image quality intervention. Both the original and the copy get the same value of identifier variable \(\mathrm {ID}\). We call this grouping scheme ‘Grouping setting 1’. Compare the left panels of Figs. 9 and 10 for examples.

Fig. 10
figure 10

Eyeglass detection for CelebA dataset with image quality interventions. The only difference to Fig. 9 is in the training data where the paired images now use the same underlying image in two different JPEG compressions. The compression level is drawn from the same distribution. The CoRe penalty performs better than for the experiment in Fig. 9 since we could explicitly control that only \(S\equiv \textit{image quality}\) varies between grouped examples. On the other hand, the performance of the pooled estimator is not changed in a noticeable way if we add augmented images as the (spurious) correlation between image quality and outcome Y still persists in the presence of the extra augmented images. Thus, the pooled estimator continues to be susceptible to image quality interventions

While we used explicit changes in image quality in both above and here, we referred to grouping setting 2 as ‘unknown image quality interventions’ as the training sample as in the left panel of Fig. 9 does not immediately reveal that image quality is the important style variable. In contrast, the augmented data samples (grouping setting 1) we use here differ only in their image quality for a constant \((Y,\mathrm {ID})\).

Figure 10 shows examples and results. The pooled estimator performs more or less identical to the previous dataset. The explicit augmentation did not help as the association between image quality and whether eyeglasses are worn is not changed in the pooled data after including the augmented data samples. The misclassification error of the CoRe estimator is substantially better than the error rate of the pooled estimator. The error rate on test set 2 of 13% is also improving on the rate of 28% of the CoRe estimator in grouping setting 2. We see that using grouping setting 1 works best since we could explicitly control that only \(S\equiv \textit{image quality}\) varies between grouped examples. In grouping setting 2, different images of the same person can vary in many factors, making it more challenging to isolate image quality as the factor to be invariant against.

A similar example where \(S\equiv \textit{brightness}\) is summarized in “Appendix D.1”.

5.4 Stickmen image-based age classification with unknown movement interventions

In this example we consider synthetically generated stickmen images; see Fig. 11 for some examples. The target of interest is \(Y \in \{\text {adult}, \text {child}\}\). The core feature \(C\) is here the height of each person. The class Y is causal for height and height cannot be easily intervened on or change in different domains. Height is thus a robust predictor for differentiating between children and adults. As style feature we have here the movement of a person (distribution of angles between body, arms and legs). For the training data we created a dependence between age and the style feature ‘movement’, which can be thought to arise through a hidden common cause \(D\), namely the place of observation. For instance, the images of children might mostly show children playing while the images of adults typically show them in more “static” postures. The left panel of Fig. 11 shows examples from the training set where large movements are associated with children and small movements are associated with adults. Test set 1 follows the same distribution, as shown in the middle panel. A standard CNN will exploit this relationship between movement and the label Y of interest, whereas this is discouraged by the conditional variance penalty of CoRe. The latter is pairing images of the same person in slightly different movements as shown by the red boxes in the leftmost panel of Fig. 11. If the learned model exploits this dependence between movement and age for predicting Y, it will fail when presented images of, say, dancing adults. The right panel of Fig. 11 shows such examples (test set 2). The standard CNN suffers in this case from a 41% misclassification rate, as opposed to the 3% on test set 1 data. For as few as \(c=50\) paired observations, the network with an added CoRe penalty, in contrast, achieves also 4% on test set 1 data and succeeds in achieving an 9% performance on test set 2, whereas the pooled estimator fails on this dataset with a test error of 41%.

Fig. 11
figure 11

Classification into \(\{\text {adult},\text {child}\}\) based on stickmen images, where children tend to be smaller and adults taller. In training and test set 1 data, children tend to have stronger movement whereas adults tend to stand still. In test set 2 data, adults show stronger movement. The two red boxes in the panel with the training data show two out of the \(c=50\) pairs of examples over which the conditional variance is calculated. The CoRe penalty leads to a network that generalizes better for test set 2 data, where the spurious correlation between age and movement is reversed, if compared to the training data

These results suggest that the learned representation of the pooled estimator uses movement as a predictor for age while CoRe does not use this feature due to the conditional variance regularization. Importantly, including more grouped examples would not improve the performance of the pooled estimator as these would be subject to the same bias and hence also predominantly have examples of heavily moving children and “static” adults (also see Fig. 23 which shows results for \(c\in \{ 20, 500, 2000\}\)).

5.5 MNIST: more sample efficient data augmentation

The goal of using CoRe in this example is to make data augmentation more efficient in terms of the required samples. In data augmentation, one creates additional samples by modifying the original inputs, e.g. by rotating, translating, or flipping the images (Schölkopf et al. 1996). In other words, additional samples are generated by interventions on style features. Using this augmented data set for training results in invariance of the estimator with respect to the transformations (style features) of interest. For CoRe we can use the grouping information that the original and the augmented samples belong to the same object. This enforces the invariance with respect to the style features more strongly compared to normal data augmentation which just pools all samples. We assess this for the style feature ‘rotation’ on MNIST (LeCun et al. 1998) and only include \(c=200\) augmented training examples for \(m=10{,}000\) original samples, resulting in a total sample size of \(n=10200\). The degree of the rotations is sampled uniformly at random from [35, 70]. Figure 12 shows examples from the training set. By using CoRe the average test error on rotated examples is reduced from 22% to 10%. Very few augmented sample are thus sufficient to lead to stronger rotational invariance. The standard approach of creating augmented data and pooling all images requires, in contrast, many more samples to achieve the same effect. Additional results for \(m\in \{1000, 10{,}000\}\) and \(c\) ranging from 100 to 5000 can be found in Fig. 22 in Appendix Sect. D.5.

Fig. 12
figure 12

Data augmentation for MNIST images. The left shows training data with a few rotated images. Evaluating on only rotated images from the test set, a standard network achieves only 22% accuracy. We can add the CoRe penalty by computing the conditional variance over images that were generated from the same original image. The test error is then lowered to 10% on the test data of rotated images

5.6 Elmer the Elephant

In this example, we want to assess whether invariance with respect to the style feature ‘color’ can be achieved. In the children’s book ‘Elmer the elephant’Footnote 13 one instance of a colored elephant suffices to recognize it as being an elephant, making the color ‘gray’ no longer an integral part of the object ‘elephant’. Motivated by this process of concept formation, we would like to assess whether CoRe can exclude ‘color’ from its learned representation by penalizing conditional variance appropriately.

We work with the ‘Animals with attributes 2’ (AwA2) dataset (Xian et al. 2017) and consider classifying images of horses and elephants. We include additional examples by adding grayscale images for \(c=250\) images of elephants. These additional examples do not distinguish themselves strongly from the original training data as the elephant images are already close to grayscale images. The total training sample size is 1850.

Figure 13 shows examples and misclassification rates from the training set and test sets for CoRe and the pooled estimator on different test sets. Examples from these and more test sets can be found in Fig. 24. Test set 1 contains original, colored images only. In test set 2 images of horses are in grayscale and the colorspace of elephant images is modified, effectively changing the color gray to red-brown. We observe that the pooled estimator does not perform well on test set 2 as its learned representation seems to exploit the fact that ‘gray’ is predictive for ‘elephant’ in the training set. This association is no longer valid for test set 2. In contrast, the predictive performance of CoRe is hardly affected by the changing color distributions. More details can be found in “Appendix D.7”.

Fig. 13
figure 13

Elmer-the-Elephant dataset. The left panel shows training data with a few additional grayscale elephants. The pooled estimator learns that color is predictive for the animal class and achieves test error of 24% on test set 1 where this association is still true but suffers a misclassification error of 53% on test set 2 where this association breaks down. By adding the CoRe penalty, the test error is consistently around 30%, irrespective of the color distribution of horses and elephants

It is noteworthy that a colored elephant can be recognized as an elephant by adding a few examples of a grayscale elephant to the very lightly colored pictures of natural elephants. If we just pool over these examples, there is still a strong bias that elephants are gray. The CoRe estimator, in contrast, demands invariance of the prediction for instances of the same elephant and we can learn color invariance with a few added grayscale images.

6 Further related work

Encoding certain invariances in estimators is a well-studied area in computer vision and machine learning with an extensive body of literature. While a large part of this work assumes the desired invariance to be known, fewer approaches aim to learn the required invariances from data and the focus often lies on geometric transformations of the input data or explicitly creating augmented observations (Sohn and Lee 2012; Khasanova and Frossard 2017; Hashimoto et al. 2017; Devries and Taylor 2017). The main difference between this line of work and CoRe is that we do not require to know the style feature explicitly, the set of possible style features is not restricted to a particular class of transformations and we do not aim to create augmented observations in a generative framework.

Recently, various approaches have been proposed that leverage causal motivations for deep learning or use deep learning for causal inference, related to e.g. the problems of cause-effect inference and generative adversarial networks (Chalupka et al. 2014; Lopez-Paz et al. 2017; Lopez-Paz and Oquab 2017; Goudet et al. 2017; Bahadori et al. 2017; Besserve et al. 2018; Kocaoglu et al. 2018).

Kilbertus et al. (2017) exploit causal reasoning to characterize fairness considerations in machine learning. Distinguishing between the protected attribute and its proxies, they derive causal non-discrimination criteria. The resulting algorithms avoiding proxy discrimination require classifiers to be constant as a function of the proxy variables in the causal graph, thereby bearing some structural similarity to our style features.

Distinguishing between core and style features can be seen as some form of disentangling factors of variation. Estimating disentangled factors of variation has gathered a lot of interested in the context of generative modeling. As in CoRe, Bouchacourt et al. (2018) exploit grouped observations. In a variational autoencoder framework, they aim to separate style and content—they assume that samples within a group share a common but unknown value for one of the factors of variation while the style can differ. Denton and Birodkar (2017) propose an autoencoder framework to disentangle style and content in videos using an adversarial loss term where the grouping structure induced by clip identity is exploited. Here we try to solve a classification task directly without estimating the latent factors explicitly as in a generative framework.

In the computer vision literature, various works have used identity information to achieve pose invariance in the context of face recognition (Bartlett and Sejnowski 1996; Tran et al. 2017). More generally, the idea of exploiting various observations of the same underlying object is related to multi-view learning (Xu et al. 2013). In the context of adversarial examples, Kannan et al. (2018) recently proposed the defense “Adversarial logit pairing” which is methodologically equivalent to the CoRe penalty \({C}_{f,1,\theta }\) when using the squared error loss. Several empirical studies have shown mixed results regarding the performance on \(\ell _\infty\) perturbations (Engstrom et al. 2018; Mosbach et al. 2018), so far this setting has not been analyzed theoretically and hence it is an open question whether a CoRe-type penalty constitutes an effective defense against adversarial examples.

7 Conclusion

Distinguishing the latent features in an image into CoRe and style features, we have proposed conditional variance regularization (CoRe) to achieve robustness with respect to interventions on the style or “orthogonal” features. The main idea of the CoRe estimator is to exploit the fact that we often have instances of the same object in the training data. By demanding invariance of the classifier amongst a group of instances that relate to the same object, we can achieve invariance of the classification performance with respect to interventions on style features such as image quality, fashion type, color, or body posture. The training also works despite sampling biases in the data.

There are two main application areas:

  1. 1.

    If the style features are known explicitly, we can achieve the same classification performance as standard data augmentation approaches with substantially fewer augmented samples, as shown for example in Sect. 5.5.

  2. 2.

    Perhaps more interesting are settings in which it is unknown what the style features are, with examples in Sects. 5.1, 5.2, 5.3, 5.4 and “Appendix D.1”. CoRe regularization forces predictions to be based on features that do not vary strongly between instances of the same object. We could show in the examples and in Theorems 1 and 2 that this regularization achieves distributional robustness with respect to changes in the distribution of the (unknown) style variables.

An interesting line of work would be to use larger models such as Inception or large ResNet architectures (Szegedy et al. 2015; He et al. 2016). These models have been trained to be invariant to an array of explicitly defined style features. In Sect. 5.2.2 we include results which show that using Inception V3 features does not guard against interventions on more implicit style features. We would thus like to assess what benefits CoRe can bring for training Inception-style models end-to-end, both in terms of sample efficiency and in terms of generalization performance. While we showed some examples where the necessary grouping information is available, an interesting possible future direction would be to use video data since objects display temporal constancy and the temporal information can hence be used for grouping and conditional variance regularization. Beyond that our results show that it can be worthwhile to collect \(\mathrm {ID}\) information when new datasets are created. As CoRe only requires a subset of the observations to have \(\mathrm {ID}\) annotations, in many cases this information might be cheap to collect while it can improve performance substantially when future test data is subject to domain shifts.