Abstract
When training a deep neural network for image classification, one can broadly distinguish between two types of latent features of images that will drive the classification. We can divide latent features into (i) ‘core’ or ‘conditionally invariant’ features \(C\) whose distribution \(C\vert Y\), conditional on the class Y, does not change substantially across domains and (ii) ‘style’ features \(S\) whose distribution \(S\vert Y\) can change substantially across domains. Examples for style features include position, rotation, image quality or brightness but also more complex ones like hair color, image quality or posture for images of persons. Our goal is to minimize a loss that is robust under changes in the distribution of these style features. In contrast to previous work, we assume that the domain itself is not observed and hence a latent variable. We do assume that we can sometimes observe a typically discrete identifier or “\(\mathrm {ID}\) variable”. In some applications we know, for example, that two images show the same person, and \(\mathrm {ID}\) then refers to the identity of the person. The proposed method requires only a small fraction of images to have \(\mathrm {ID}\) information. We group observations if they share the same class and identifier \((Y,\mathrm {ID})=(y,\mathrm {id})\) and penalize the conditional variance of the prediction or the loss if we condition on \((Y,\mathrm {ID})\). Using a causal framework, this conditional variance regularization (CoRe) is shown to protect asymptotically against shifts in the distribution of the style variables in a partially linear structural equation model. Empirically, we show that the CoRe penalty improves predictive accuracy substantially in settings where domain changes occur in terms of image quality, brightness and color while we also look at more complex changes such as changes in movement and posture.
Similar content being viewed by others
Avoid common mistakes on your manuscript.
1 Introduction
Deep neural networks (DNNs) have achieved outstanding performance on prediction tasks like visual object and speech recognition (Krizhevsky et al. 2012; Szegedy et al. 2015; He et al. 2015). Issues can arise when the learned representations rely on dependencies that vanish in test distributions (see for example Quionero-Candela et al. (2009),Torralba and Efros (2011), Csurka (2017) and references therein). Such domain shifts can be caused by changing conditions such as color, background or location changes. Predictive performance is then likely to degrade. For example, consider the analysis presented in Kuehlkamp et al. (2017) which is concerned with the problem of predicting a person’s gender based on images of their iris. The results indicate that this problem is more difficult than previous studies have suggested due to the remaining effect of cosmetics after segmenting the iris from the whole image.Footnote 1 Previous analyses obtained good predictive performance on certain datasets but when testing on a dataset only including images without cosmetics accuracy dropped. In other words, the high predictive performance previously reported relied to a significant extent on exploiting the confounding effect of mascara on the iris segmentation which is highly predictive for gender. Rather than the desired ability of discriminating based on the iris’ texture the systems would mostly learn to detect the presence of cosmetics.
More generally, existing biases in datasets used for training machine learning algorithms tend to be replicated in the estimated models (Bolukbasi et al. 2016). For an example involving Google’s photo app, see Crawford (2016) and Emspak (2016). In Sect. 5 we show many examples where unwanted biases in the training data are picked up by the trained model. As any bias in the training data is in general used to discriminate between classes, these biases will persist in future classifications, raising also considerations of fairness and discrimination (Barocas and Selbst 2016).
Addressing the issues outlined above, we propose Conditional variance Regularization (CoRe) to give differential weight to different latent features. Conceptually, we take a causal view of the data generating process and categorize the latent data generating factors into ‘conditionally invariant’ (core) and ‘orthogonal’ (style) features, as in Gong et al. (2016). The core and style features are unobserved and can in general be highly nonlinear transformations of the observed input data. It is desirable that a classifier only extracts the latent core features from the input data as they pertain to the target of interest in a stable and coherent fashion. Basing a prediction on the core features alone yields stable predictive accuracy even if the style features are altered. Under suitable assumptions, CoRe yields an estimator which is approximately invariant under changes in the conditional distribution of the style features (conditional on the class labels) and it is asymptotically robust with respect to domain shifts, arising through interventions on the style features. CoRe relies on the fact that for certain datasets we can observe grouped observations in the sense that we observe the same object under different conditions. For instance, such grouping information is available
-
(i)
in natural image data when several pictures of the same person are taken;
-
(ii)
in medical imaging when several images belonging to the same patient are made;
-
(iii)
in speech recognition when multiple recordings from the same speaker are available;
-
(iv)
in video data where nearby frames showing the same objects can be exploited to group observations;
-
(v)
in data augmentation where a transformed data point can be grouped together with the original one.
We will show examples for the first and last category. For the last category, we will show that pairing the augmented data with the original image they were generated from helps to improve accuracy and robustness with respect to the chosen transformation.
Rather than pooling over all examples, CoRe exploits knowledge about this grouping, i.e., that a number of instances relate to the same object. By penalizing between-object variation of the prediction less than variation of the prediction for the same object, we can steer the prediction to be based more on the latent core features and less on the latent style features. While the proposed methodology can be motivated from the desire the achieve representational invariance with respect to the style features, the causal framework we use throughout this work allows to precisely formulate the distribution shifts we aim to protect against.
The remainder of this manuscript is structured as follows: Sect. 1.1 starts with a few motivating examples, showing simple settings where the style features change in the test distribution such that standard empirical risk minimization approaches would fail. In Sect. 1.2 we review related work, introduce notation in Sect. 2 and in Sect. 3 we formally introduce conditional variance regularization CoRe. In Sect. 4, CoRe is shown to be asymptotically equivalent to minimizing the risk under a suitable class of strong interventions in a partially linear classification setting, provided one chooses sufficiently strong CoRe penalties. We also show that the population CoRe penalty induces domain shift robustness for general loss functions to first order in the intervention strength. The size of the conditional variance penalty can be shown to determine the size of the distribution class over which we can expect distributional robustness. In Sect. 5 we evaluate the performance of CoRe in a variety of experiments.
To summarize, our contributions are the following:
-
(i)
Causal framework and distributional robustness We build on the causal framework from Gong et al. (2016) to define distributional shifts for style variables. This allows us to formulate the objective of interest in terms of distributional robust inference. Specifically, the distribution class, on which the estimator should achieve a guaranteed performance bound, consists of those distributions that are generated by interventions on the latent style variables in a causal model. Our framework allows that the domain variable itself is latent.
-
(ii)
Conditional variance penalties We introduce conditional variance penalties and show two robustness properties in Theorems 1 and 2.
-
(iii)
Software We illustrate our ideas using synthetic and real-data experiments. A TensorFlow implementation of CoRe as well as code to reproduce some of the experimental results are available at https://github.com/christinaheinze/core.
1.1 Motivating examples
To motivate the methodology we propose, consider the examples shown in Figs. 1 and 2. Example 1 shows a setting where a nonlinear decision boundary is required. Here, the core feature corresponds to the distance from the origin while the style feature corresponds to the angle between the \(x_1\)-axis and the vector from the origin to \((x_1, x_2)\). Panel (a) shows a subsample of the training data where class 1 is associated with red points, dark blue points correspond to class 0. Panel (b) additionally shows a subsample of the test data where the style—i.e. the distribution of the angle—is intervened upon: class 1 is associated with orange squares, cyan squares correspond to class 0. Clearly, a circular decision boundary yields optimal performance on both training and test set but is unlikely to be found by a standard classification algorithm when only using the training set for the estimation. We will return to these examples in Sect. 3.4.
Secondly, we introduce a strong dependence between the class label and the style feature “image quality” in the third example by manipulating the face images from the CelebA dataset (Liu et al. 2015): in the training set images of class “wearing glasses” are associated with a lower image quality than images of class “not wearing glasses”. Examples are shown in Fig. 2a. In the test set, this relation is reversed, i.e. images showing persons wearing glasses are of higher quality than images of persons without glasses, with examples in Fig. 2b. We will return to this example in Sect. 5.3 and show that training a convolutional neural network to distinguish between people wearing glasses or not works well on test data that are drawn from the same distribution (with error rates below 2%) but fails entirely on the shown test data, with error rates worse than 65%.
1.2 Related work
For general distributional robustness, the aim is to learn
for a given set \({\mathcal {F}}\) of distributions, twice differentiable and convex loss \(\ell\), and prediction \(f_\theta (x)\). The set \({\mathcal {F}}\) is the set of distributions on which one would like the estimator to achieve a guaranteed performance bound.
Causal inference can be seen to be a specific instance of distributional robustness, where we take \({\mathcal {F}}\) to be the class of all distributions generated under do-interventions on the predictors X (Meinshausen 2018; Rothenhäusler et al. 2018). Causal models thus have the defining advantage that the predictions will be valid even under arbitrarily large interventions on all predictor variables (Haavelmo 1944; Aldrich 1989; Pearl 2009; Schölkopf et al. 2012; Peters et al. 2016; Zhang et al. 2013, 2015; Yu et al. 2017; Rojas-Carulla et al. 2018; Magliacane et al. 2018). There are two difficulties in transferring these results to the setting of domain shifts in image classification. The first hurdle is that the classification task is typically anti-causal since the image we use as a predictor is a descendant of the true class of the object we are interested in rather than the other way around. The second challenge is that the input data consists of pixel intensities and we do not want (or could) guard against arbitrary interventions on any or all variables but only would like to guard against a shift of the unobserved style features. It is hence not immediately obvious how standard causal inference can be used to guard against large domain shifts.
Another line of work uses a class of distributions of the form \({\mathcal {F}}= {\mathcal {F}}_\epsilon (F_0)\) with
with \(\epsilon >0\) a small constant and \(D(F,F_0)\) being, for example, a \(\phi\)-divergence (Namkoong and Duchi 2017; Ben-Tal et al. 2013; Bagnell 2005; Volpi et al. 2018) or a Wasserstein distance (Shafieezadeh-Abadeh et al. 2017; Sinha et al. 2018; Gao et al. 2017). The distribution \(F_0\) can be the true (but generally unknown) population distribution P from which the data were drawn or its empirical counterpart \(P_n\). The distributionally robust targets in Eq. (2) can often be expressed in penalized form (Gao et al. 2017; Sinha et al. 2018; Xu et al. 2009). A Wasserstein-ball is a suitable class of distributions for example in the context of adversarial examples (Sinha et al. 2018; Szegedy et al. 2014; Goodfellow et al. 2015).
In this work, we do not try to achieve robustness with respect to a set of distributions that are pre-defined by a Kullback-Leibler divergence or a Wasserstein metric as in Eq. (2). Instead, we try to achieve robustness against a set of distributions that are generated by interventions on latent style variables in a causal model (we will make this precise in Sect. 2). We will formulate the class of distributions over which we try to achieve robustness as in Eq. (1) but with the class of distributions in Eq. (2) now replaced with the class of distributions \({\mathcal {F}}_\xi\) defined as \(\{F: D_{\text {style}} (F,F_0) \le \xi \},\) where \(F_0\) is again the distribution the training data are drawn from. The difference to standard distributional robustness approaches listed below Eq. (2) is now that the metric \(D_{\text {style}}\) measures the shift of the orthogonal style features. We do not know a priori which features are prone to distributional shifts and which features have a stable (conditional) distribution. The metric is hence not known a priori and needs to be inferred in a suitable sense from the data.
Similar to this work in terms of their goals are the work of Gong et al. (2016) and Domain-Adversarial Neural Networks (DANN) proposed in Ganin et al. (2016), an approach motivated by the work of Ben-David et al. (2007). The main idea of Ganin et al. (2016) is to learn a representation that contains no discriminative information about the origin of the input (source or target domain). This is achieved by an adversarial training procedure: the loss on domain classification is maximized while the loss of the target prediction task is minimized simultaneously. The data generating process assumed in Gong et al. (2016) is similar to our model, introduced in Sect. 2.1, where we detail the similarities and differences between the models (cf. Fig. 3). Gong et al. (2016) identify the conditionally independent features by adjusting a transformation of the variables to minimize the squared MMD distance between distributions in different domains.Footnote 2 The fundamental difference between these very promising methods and our approach is that we use a different data basis. The domain identifier is explicitly observable in Gong et al. (2016) and Ganin et al. (2016), while it is latent in our approach. In contrast, we exploit the presence of an identifier variable \(\mathrm {ID}\) that relates to the identity of an object (for example identifying a person). In other words, we do not assume that we have data from different domains but just different realizations of the same object under different interventions. This also differentiates this work from latent domain adaptation papers from the computer vision literature (Hoffman et al. 2012; Gong et al. 2013). Further related work is discussed in Sect. 6.
2 Setting
We introduce the assumed underlying causal graph and some notation before discussing notions of domain shift robustness.
2.1 Causal graph
Let \(Y\in {\mathcal {Y}}\) be a target of interest. Typically \({\mathcal {Y}}=\mathbb {R}\) for regression or \({\mathcal {Y}}=\{1,\ldots ,K\}\) in classification with K classes. Let \(X \in \mathbb {R}^p\) be predictor variables, for example the p pixels of an image. The causal structural model for all variables is shown in panel (b) of Fig. 3. The domain variable D is latent, in contrast to Gong et al. (2016) whose model is shown in panel (a) of Fig. 3. We add the \(\mathrm {ID}\) variable to the graph. In Fig. 3, \(Y \rightarrow \mathrm {ID}\) but in some settings it might be more plausible to consider \(\mathrm {ID}\rightarrow Y\). For the proposed method both options are possible. Together with Y, the \(\mathrm {ID}\) variable is used to group observations. It is typically discrete and relates to the identity of the underlying object. The variable can be assumed to be latent in the setting of Gong et al. (2016).
The rest of the graph is in analogy to Gong et al. (2016). The prediction is anti-causal, that is the predictor variables X that we use for \({\hat{Y}}\) are non-ancestral to Y. In other words, the class label is here seen to be causal for the image and not the other way around.Footnote 3 The causal effect from the class label Y on the image X is mediated via two types of latent variables: the so-called core or ‘conditionally invariant’ features \(C\) and the orthogonal or style features \(S\). The distinguishing factor between the two is that external interventions \(\varDelta\) are possible on the style features but not on the core features. If the interventions \(\varDelta\) have different distributions in different domains, then the conditional distributions \(C|Y=y,\mathrm {ID}=\mathrm {id}\) are invariant for all \((y,\mathrm {id})\) while \(S|Y=y,\mathrm {ID}=\mathrm {id}\) can change. The style variable can include point of view, image quality, resolution, rotations, color changes, body posture, movement etc. and will in general be context-dependent.Footnote 4 The style intervention variable \(\varDelta\) influences both the latent style \(S\), and hence also the image X. In potential outcome notation, we let \(S(\varDelta =\delta )\) be the style under intervention \(\varDelta =\delta\) and \(X(Y,\mathrm {ID},\varDelta =\delta )\) the image for class Y, identity \(\mathrm {ID}\) and style intervention \(\varDelta\). The latter is sometimes abbreviated as \(X(\varDelta =\delta )\) for notational simplicity. Finally, \(f_\theta (X(\varDelta =\delta ))\) is the prediction under the style intervention \(\varDelta =\delta\). For a formal justification of using a causal graph and potential outcome notation simultaneously see Richardson and Robins (2013).
To be specific, if not mentioned otherwise we will assume a causal graph as follows. For independent \(\varepsilon _Y, \varepsilon _\mathrm {ID},\varepsilon _{\text {style}}\) in \(\mathbb {R},\mathbb {R},\mathbb {R}^q\) respectively with positive density on their support and continuously differentiable functions \(k_y, k_\mathrm {id}\), and \(k_{\text {style}},k_{\text {core}},k_x\),
The distribution of \(\varepsilon _{\text {style}}\) is assumed to be identical across domains, while \(\varDelta\) can change. In more generality, one could discard \(\varDelta\) and instead allow the distribution of \(\varepsilon _{\text {style}}\) to change. Here the assumption is slightly more restrictive because of additivity of \(\varDelta\) in the structural equation for the style variable. The core features are here assumed to be a deterministic function of Y and \(\mathrm {ID}\) to allow for theoretical analysis. In more generality (and as indicated in the graph), these would also be non-deterministic relations. The theoretical results will also require positive density for the style features in an \(\epsilon\)-ball around the origin, as made precise in assumption (A1) later.
The prediction \({\hat{y}}\) for y, given \(X=x\), is of the form \(f_\theta (x)\) for a suitable function \(f_\theta\) with parameters \(\theta \in \mathbb {R}^d\), where the parameters \(\theta\) correspond to the weights in a DNN, for example.
We would like to stress that the above model is fairly general and subsumes many simpler ones as special cases. To give a concrete example, consider the task of classifying a health condition Y from medical images X. Style features could be, for example, technical noise, orientation or resolution. The unobserved domain \(D\) could correspond to different hospitals or doctors. Due to the usage of different measuring devices in each of these locations, the conditional distribution \(S\vert Y\) will change substantially across different domains. In contrast, the core features \(C\), i.e. those image features that carry the actual signal, will remain invariant conditional on the underlying health condition Y.
2.2 Data
We assume we have \(n\) data points \((x_i,y_i,\mathrm {id}_i)\) for \(i=1,\ldots ,n\), where the observations \(\mathrm {id}_i\) with \(i=1,\ldots ,n\) of variable \(\mathrm {ID}\) can also contain unobserved values. Let \(m\le n\) be the number of unique realizations of \((Y,\mathrm {ID})\) and let \(G_1,\ldots ,G_m\) be a partition of \(\{1,\ldots ,n\}\) such that, for each \(j\in \{1,\ldots ,m\}\), the realizations \((y_i,\mathrm {id}_i)\) are identicalFootnote 5 for all \(i\in G_j\). While our prime application is classification, regression settings with continuous Y can be approximated in this framework by slicing the range of the response variable into distinct bins in analogy to the approach in sliced inverse regression (Li 1991). The cardinality of \(G_j\) is denoted by \(n_j:=|G_j| \ge 1\). Then \(n= \sum _i n_i\) is again the total number of samples and \(c= n-m\) is the total number of grouped observations in the following sense: if we count all samples in a group except the first one we have, if summing over all groups, a total of \(c= n-m= \sum _{i=1}^n (n_i - 1)\) observations left that are ‘grouped’ with the first example in their corresponding group.
Typically \(n_i=1\) for most samples and occasionally \(n_i \ge 2\) but one can also envisage scenarios with larger groups of the same identifier \((y,\mathrm {id})\).
2.3 Domain shift robustness
In this section, we clarify against which classes of distributions we hope to achieve robustness. Let \(\ell\) be a suitable loss that maps y and \({\hat{y}}=f_\theta (x)\) to \(\mathbb {R}^+\). The risk under distribution F and parameter \(\theta\) is given by
Let \(F_0\) be the joint distribution of \((\mathrm {ID},Y,S)\) in the training distribution. A new domain and explicit interventions on the style features can now shift the distribution of \((\mathrm {ID},Y,\tilde{S})\) to F. We can measure the distance between distributions \(F_0\) and F in different ways. Below we will define the distance considered in this work and denote it by \(D_{\text {style}}(F,F_0)\). Once defined, we get a class of distributions
and the goal will be to optimize a worst-case loss over this distribution class in the sense of Eq. (1), where larger values of \(\xi\) afford protection against larger distributional changes. The relevant loss for distribution class \({\mathcal {F}}_\xi\) is then
In the limit of arbitrarily strong interventions on the style features \(S\), the loss is given by
Minimizing the loss \(L_\infty (\theta )\) with respect to \(\theta\) guarantees an accuracy in prediction which will work well across arbitrarily large shifts in the conditional distribution of the style features.
A natural choice to define \(D_\text {style}\) is to use a Wasserstein-type distance (see e.g. Villani 2003). We will first define a distance \(D_{y,\mathrm {id}}\) for the conditional distributions
and then set \(D(F_0,F) = E(D_{Y,\mathrm {ID}})\), where the expectation is with respect to random \(\mathrm {ID}\) and labels Y. The distance \(D_{y,\mathrm {id}}\) between the two conditional distributions of \(S\) will be defined as a Wasserstein \(W_2^2(F_0,F)\)-distance for a suitable cost function \(c(x,\tilde{x})\). Specifically, let \(\Pi _{y,\mathrm {id}}\) be the couplings between the conditional distributions of \(S\) and \(\tilde{S}\), meaning measures supported on \(\mathbb {R}^q \times \mathbb {R}^q\) such that the marginal distribution over the first q components is equal to the distribution of \(S\) and the marginal distribution over the remaining q components equal to the distribution of \(\tilde{S}\). Then the distance between the conditional distributions is defined as
where \(c :\mathbb {R}^q\times \mathbb {R}^q \mapsto \mathbb {R}^+\) is a nonnegative, lower semi-continuous cost function. Here, we focus on a Mahalanobis distance as cost
The cost of a shift is hence measured against the variability under the distribution \(F_0\), \(\Sigma _{y,\mathrm {id}} =\text {Var}(S|Y,\mathrm {ID}).\)Footnote 6
Clearly, since the core and style features are unobserved, we cannot directly optimize the loss (5) but need to infer the metric \(D_{\text {style}}\) from the input data X. In the next section, we will show how this can be achieved. Intuitively, the model in Fig. 3 implies that the variance conditional on \(Y,\mathrm {ID}\) stems from the difference in the style features. Hence, we would like to minimize the conditional variance of the prediction or loss when we condition on \(Y, \mathrm {ID}\). This enforces the desired invariance with respect to the style features.
3 Conditional variance regularization
3.1 Pooled estimator
Let \((x_i,y_i)\) for \(i=1,\ldots ,n\) be the observations that constitute the training data and \(\hat{y_i}=f_\theta (x_i)\) the prediction for \(y_i\). The standard approach is to simply pool over all available observations, ignoring any grouping information that might be available. The pooled estimator thus treats all examples identically by summing over the empirical loss as
where the first part is simply the empirical loss over the training data,
In the second part, \(\text {pen}(\theta )\) is a complexity penalty, for example a squared \(\ell _2\)-norm of the weights \(\theta\) in a convolutional neural network as a ridge penalty.
3.2 CoRe estimator
The CoRe estimator is defined in Lagrangian form for penalty \(\lambda \ge 0\) as
The penalty \({\hat{C}}_{\theta }\) is a conditional variance penalty of the form
where typically \(\nu \in \{1/2,1\}\). For \(\nu =1/2\), we also refer to the respective penalties as “conditional-standard-deviation” penalties. In practice in the context of classification and DNNs, we apply the penalty (9) to the predicted logits. The conditional-variance-of-loss penalty (10) takes a similar form to Namkoong and Duchi (2017). The crucial difference of our approach to Namkoong and Duchi (2017) is that we penalize with the expected conditional variance or standard deviation. The fact that we take a conditional variance is here important as we try to achieve distributional robustness with respect to interventions on the style variables. Conditioning on \(\mathrm {ID}\) allows to guard specifically against these interventions. An unconditional variance penalty, in contrast, can achieve robustness against a pre-defined class of distributions such as a ball of distributions defined in a Kullback-Leibler or Wasserstein metric. The population CoRe estimator is defined as in Eq. (8) where empirical estimates are replaced by their respective population quantities.
Before showing numerical examples, we discuss the estimation of the expected conditional variance in Sect. 3.3 and return to the simple examples of Sect. 1.1 in Sect. 3.4. Domain shift robustness in a classification setting for a partially linear version of the structural equation model (3) is shown in Sect. 4.1. Furthermore, we discuss the population limit of \({\hat{\theta }}^{core}(\lambda )\) in Sect. 4.2, where we show that the regularization parameter \(\lambda \ge 0\) is proportional to the size of the future style interventions that we want to guard against for future test data.
3.3 Estimating the expected conditional variance
Recall that \(G_j\subseteq \{1,\ldots ,n\}\) contains samples with identical realizations of \((Y,\mathrm {ID})\) for \(j\in \{1,\ldots ,m\}\). For each \(j\in \{1,\ldots ,m\}\), define \({\hat{\mu }}_{\theta ,j}\) as the arithmetic mean across all \(f_\theta (x_{i}), i\in G_j\). The canonical estimator of the conditional variance \({\hat{C}}_{f,1,\theta }\) is then
and analogously for the conditional-variance-of-loss, defined in Eq. (10)Footnote 7. If there are no groups of samples that share the same identifier \((y,\mathrm {id})\), we define \({\hat{C}}_{f,1,\theta }\) to vanish. The CoRe estimator is then identical to pooled estimation in this special case.
3.4 Motivating examples (continued)
We revisit the first example from Sect. 1.1. Figure 4 shows subsamples of the training and test set with the estimated decision boundaries for different values of the penalty parameter \(\lambda\) when using a 2-layer fully connected neural network. Here, \(n=20{,}000\) and \(c=500\). Additionally, grouped examples that share the same \((y,\mathrm {id})\) are visualized: two grouped observations are connected by a line or curve, respectively. Ten such groups are shown. Panel (a) shows the decision boundaries for \(\lambda =0\), equivalent to the pooled estimator, and for CoRe with \(\lambda \in \{0, 0.05, 0.1, 1\}\). The pooled estimator misclassifies a large number of test points as can be seen in panel (b), suffering from a test error of \(\approx 58\%\). In contrast, the decision boundary of the CoRe estimator with \(\lambda =1\) aligns with the direction along which the grouped observations vary, classifying the test set with almost perfect accuracy (test error is \(\approx 0\%\)).
4 Domain shift robustness for the CoRe estimator
We show two properties of the CoRe estimator. First, consistency is shown under the risk definition (6) for an infinitely large conditional variance penalty and the logistic loss in a partially linear structural equation model. Second, the population CoRe estimator is shown to achieve distributional robustness against shift interventions in a first order expansion.
4.1 Asymptotic domain shift robustness under strong interventions
We analyze the loss under strong domain shifts, as given in Eq. (6), for the pooled and the CoRe estimator in a one-layer network for binary classification (logistic regression) in an asymptotic setting of large sample size and strong interventions.
Assume the structural equation for the image \(X\in \mathbb {R}^p\) is linear in the style features \(S\in \mathbb {R}^q\) (with generally \(p\gg q\)) and we use logistic regression to predict the class label \(Y\in \{-1,1\}\). Let the interventions \(\varDelta \in \mathbb {R}^q\) act additively on the style features \(S\) (this is only for notational convenience) and let the style features \(S\) act in a linear way on the image X via a matrix \({W}\in \mathbb {R}^{p\times q}\) (this is an important assumption without which results are more involved). The core or ‘conditionally invariant’ features are \(C\in \mathbb {R}^r\), where in general \(r\le p\) but this is not important for the following. For independent \(\varepsilon _Y, \varepsilon _\mathrm {ID},\varepsilon _{\text {style}}\) in \(\mathbb {R},\mathbb {R},\mathbb {R}^q\) respectively with positive density on their support and continuously differentiable functions \(k_y, k_\mathrm {id},k_{\text {style}},k_{\text {core}},k_x\),
We assume a logistic regression as a prediction of Y from the image data X:
Given training data with \(n\) samples, we estimate \(\theta\) with \({\hat{\theta }}\) and use here a logistic loss \(\ell _{\theta }(y_i,x_i) = \log (1+ \exp ( - y_i ( x_i^t \theta )))\).
The formulation of Theorem 1 relies on the following assumptions.
Assumption 1
We require the following conditions:
-
(A1)
Assume the conditional distribution \(S|Y=y,\mathrm {ID}=\mathrm {id}\) under the training distribution \(F_0\) has positive density (with respect to the Lebesgue measure) in an \(\epsilon\)-ball in \(\ell _2\)-norm around the origin for some \(\epsilon >0\) for all \(y\in {\mathcal {Y}}\) and \(\mathrm {id}\in {\mathcal {I}}\).
-
(A2)
Assume the matrix \({W}\) has full rank q.
-
(A3)
Let \(M\le n\) be the number of unique realizations among n iid samples of \((Y,\mathrm {ID})\) and let \(p_n:=P(M\le n-q)\). Assume that \(p_n\rightarrow 1\) for \(n\rightarrow \infty\).
Assumption (A1) is a key assumption about the style variations we observe in the training set. It requires that we observe some variance in those directions that we expect to be subject to domain shifts in the future. If, on the other hand, the conditional variance in a particular direction is vanishing, we also expect it to vanish in the future. A violation of this assumption would imply that the guarantee of the CoRe regularization no longer holds. Assumption (A3) guarantees that the number \(c=n-m\) of grouped examples is at least as large as the dimension of the style variables. If we have too few or no grouped examples (small c), we cannot estimate the conditional variance accurately. Under these assumptions we can prove domain shift robustness.
Theorem 1
(Asymptotic domain shift robustness under strong interventions) Under model (11) and Assumption 1, with probability 1, the pooled estimator (7) has infinite loss (6) under arbitrarily large shifts in the distribution of the style features,
The CoRe estimator (8) \({\hat{\theta }}^{core}\) with \(\lambda \rightarrow \infty\) is domain shift robust under strong interventions in the sense that for \(n\rightarrow \infty\),
A proof is given in “Appendix A”. The respective ridge penalties in both estimators (7) and (8) are assumed to be zero for the proof, but the proof can easily be generalized to include ridge penalties that vanish sufficiently fast for large sample sizes. The Lagrangian regularizer \(\lambda\) is assumed to be infinite for the CoRe estimator to achieve domain shift robustness under these strong interventions. The next section considers the population CoRe estimator in a setting with weak interventions and finite values of the penalty parameter.
4.2 Population domain shift robustness under weak interventions
The previous theorem states that the CoRe estimator can achieve domain shift robustness under strong interventions for an infinitely strong penalty in an asymptotic setting. An open question is how the loss (5),
behaves under interventions of small to medium size and correspondingly smaller values of the penalty. Here, we aim to minimize this loss for a given value of \(\xi\) and show that domain shift robustness can be achieved to first order with the population CoRe estimator using the conditional-standard-deviation-of-loss penalty, i.e., Eq. (10) with \(\nu = 1/2\), by choosing an appropriate value of the penalty \(\lambda\). Below we will show this appropriate choice of the penalty weight is \(\lambda =\sqrt{\xi }\).
Assumption 2
We require the following conditions:
-
(B1)
Define the loss under a deterministic shift \(\delta\) as
$$\begin{aligned} h_\theta (\delta ) := E_{F_\theta }[\ell (Y,f_\theta (X))],\end{aligned}$$where the expectation is with respect to random \((\mathrm {ID},Y,\tilde{S})\sim F_\theta\), with \(F_\theta\) defined by the deterministic shift intervention \(\tilde{S}=S+\delta\) and \((\mathrm {ID},Y,\tilde{S})\sim F_0\). Assume that for all \(\theta \in \Theta\), \(h_\theta (\delta )\) is twice continuously differentiable with bounded second derivative for a deterministic shift \(\delta \in \mathbb {R}^q\).
-
(B2)
The spectral norm of the conditional variance \(\Sigma _{y,\mathrm {id}}\) of \(S|Y,\mathrm {ID}\) under \(F_0\) is assumed to be smaller or equal to some \(\zeta \in \mathbb {R}\) for all \(y\in {\mathcal {Y}}\) and \(\mathrm {id}\in {\mathcal {I}}\).
The first assumption (B1) ensures that the loss is well behaved under interventions on the style variables. The second assumption (B2) allows to take the limit of small conditional variances in the style variables.
If setting \(\lambda =\sqrt{\xi }\) and using the conditional-standard-deviation-of-loss penalty, the CoRe estimator optimizes according to
The next theorem shows that this is to first order equivalent to minimizing the worst-case loss over the distribution class \({\mathcal {F}}_\xi\). The following result holds for the population CoRe estimator, see below for a discussion about consistency.
Theorem 2
The supremum of the loss over the class of distribution \(F_\xi\) is to first-order given by the expected loss under distribution \(F_0\) with an additional conditional-standard-deviation-of-loss penalty \({C}_{\ell ,1/2,\theta }\)
A proof is given in “Appendix B”. The objective of the population CoRe estimator matches thus to first order the loss under domain shifts if we set the penalty weight \(\lambda =\sqrt{\xi }\). Larger anticipated domain shifts thus require naturally a larger penalty \(\lambda\) in the CoRe estimation. The result is possible as we have chosen the Mahalanobis distance to measure shifts in the style variable and define \({\mathcal {F}}_\xi\), ensuring that the strength of shifts on style variables are measured against the natural variance on the training distribution \(F_0\).
In practice, the choice of \(\lambda\) involves a somewhat subjective choice about the strength of the distributional robustness guarantee. A stronger distributional robustness property is traded off against a loss in predictive accuracy if the distribution is not changing in the future. One option for choosing \(\lambda\) is to choose the largest penalty weight before the validation loss increases considerably. This approach would provide the best distributional robustness guarantee that keeps the loss of predictive accuracy in the training distribution within a pre-specified bound.Footnote 8
As a caveat, the result takes the limit of small conditional variance of \(S\) in the training distribution and small additional interventions. Under larger interventions higher-order terms could start to dominate, depending on the geometry of the loss function and \(f_\theta\). A further caveat is that the result looks at the population CoRe estimator. For finite sample sizes, we would optimize a noisy version on the rhs of (12). To show domain shift robustness in an asymptotic sense, we would need additional uniform convergence (in \(\theta\)) of both the empirical loss and the conditional variance in that for \(n\rightarrow \infty\),
While this is in general a reasonable assumption to make, the validity of the assumption will depend on the specific function class and on the chosen estimator of the conditional variance.
5 Experiments
We perform an array of different experiments, showing the applicability and advantage of the conditional variance penalty for two broad settings:
-
1.
Settings where we do not know what the style variables correspond to but still want to protect against a change in their distribution in the future. In the examples we show cases where the style variable ranges from fashion (Sect. 5.2), image quality (Sect. 5.3), movement (Sect. 5.4) and brightness (“Appendix D.1”), which are all not known explicitly to the method. We also include genuinely unknown style variables in Sect. 5.1 (in the sense that they are unknown not only to the methods but also to us as we did not explicitly create the style interventions).
-
2.
Settings where we do know what type of style interventions we would like to protect against. This is usually dealt with by data augmentation (adding images which are, say, rotated or shifted compared to the training data if we want to protect against rotations or translations in the test data; see for example Schölkopf et al. (1996)). The conditional variance penalty is here exploiting that some augmented samples were generated from the same original sample and we use as \(\mathrm {ID}\) variable the index of the original image. We show that this approach generalizes better than simply pooling the augmented data, in the sense that we need fewer augmented samples to achieve the same test error. This setting is shown in Sect. 5.5.
We compare against the pooled estimator which has the same architecture as the network to which we add the CoRe penalty. For both the pooled and the CoRe estimator we apply an \(\ell _2\) penalty as regularization. We would like to stress that the related work discussed in Sects. 1.2 and 6 cannot be directly compared to the CoRe estimator as these approaches cannot exploit the \(\mathrm {ID}\) information but rely on having data from different domains available at training time instead. Since this is a different problem setting, we can only compare against the pooled estimator which is a standard approach to classification. As a downside, our approach requires availability of an \(\mathrm {ID}\) variable, which might not always be availableFootnote 9 To further understand the behavior of the CoRe penalty, we perform a number of analyses and ablation studies to show
-
(i)
How sensitive the performance of CoRe is to the value of the penalty weight \(\lambda\) (Sects. 5.1.1, 5.2);
-
(ii)
How the CoRe penalty differs from a standard \(\ell _2\) penalty (Sect. 5.1.1);
-
(iii)
How the value of the CoRe penalty can be used as a qualitative measure for the presence of sample bias (Sects. 5.1.1, 5.2);
-
(iv)
How sensitive the performance of both the CoRe and the pooled estimator is to label shift in the grouped observations (Sect. 5.2.1);
-
(v)
How the relative performance of both estimators is affected when using pre-trained InceptionV3 features (Sect. 5.2.2);
-
(vi)
How sensitive the performance is to different grouping strategies (Sects. 5.3, “Appendix D.1”, D.3.1, D.4);
-
(vii)
how sensitive the performance is as a function of the strength of the domain shift and the number of grouped observations (“Appendices D.1, D.4, D.5, D.6”).
Details of the network architectures can be found in Appendix “Appendix C”. All reported error rates are averaged over five runs of the respective method. A TensorFlow (Abadi et al. 2015) implementation of CoRe can be found at https://github.com/christinaheinze/core.
5.1 Eyeglasses detection with small sample size
In this example, we explore a setting where training and test data are drawn from the same distribution, so we might not expect a distributional shift between the two. However, we consider a small training sample size which gives rise to statistical fluctuations between training and test data. We assess to which extent the conditional variance penalty can help to improve test accuracies in this setting.
Specifically, we use a subsample of the CelebA dataset (Liu et al. 2015) and try to classify images according to whether or not the person in the image wears glasses. For construction of the \(\mathrm {ID}\) variable, we exploit the fact that several photos of the same person are available and set \(\mathrm {ID}\) to be the identifier of the person in the dataset. Figure 5 shows examples from both the training and the test dataset. The conditional variance penalty is estimated across groups of observations that share a common \((Y,\mathrm {ID})\). Here, this corresponds to pictures of the same person where all pictures show the person either with glasses (if \(Y=1\)) or all pictures show the person without glasses (\(Y=0\)). Statistical fluctuations between training and test set could for instance arise if by chance the background of eyeglass wearers is darker in the training sample than in test samples, the eyeglass wearers happen to be outdoors more often or might be more often female than male etc.
Below, we present the following analyses. First, we look at five different datasets and analyze the effect of adding the CoRe penalty (using conditional-variance-of-prediction) to the cross-entropy loss. Second, we focus on one dataset and compare the four different variants of the CoRe penalty in Eqs. (9) and (10) with \(\nu \in \{ 1/2, 1 \}.\)
5.1.1 CoRe penalty using the conditional variance of the predicted logits
We consider five different training sets which are created as follows. For each person in the standard CelebA training data we count the number of available images and select the 50 identities for which most images are available individually. We partition these 50 identities into 5 disjoint subsets of size 10 and consider the resulting 5 datasets, containing the images of 10 unique identities each. The resulting 5 datasets have sizes \(\{289, 296, 292, 287, 287\}\). For the validation and the test set, we consider the usual CelebA validation and test split but balance these with respect to the target variable “Eyeglasses”. The balanced validation set consists of 2766 observations; the balanced test set contains 2578 images. The identities in the validation and test sets are disjoint from the identities in the training sets.
Given a training dataset, the standard approach would be to pool all examples. The only additional information we exploit is that some observations can be grouped. If using a 5-layer convolutional neural network with a standard ridge penalty (details can be found in Table 5) and pooling all data, the test error on unseen images ranges from 18.08 to 25.97%. Exploiting the group structure with the CoRe penalty (in addition to a ridge penalty) results in test errors ranging from 14.79 to 21.49%, see Table 1. The relative improvements when using the CoRe penalty range from 9 to 28.6%.
The test error is not very sensitive to the weight of the CoRe penalty as shown in Fig. 6a: for a large range of penalty weights, adding the CoRe penalty decreases the test error compared to the pooled estimator (identical to a CoRe penalty weight of 0). This holds true for various ridge penalty weights.
While the test error rates shown in Fig. 6 suggest already that the CoRe penalty differentiates itself clearly from a standard ridge penalty, we examine next the differential effect of the CoRe penalty on the between- and within-group variances. Concretely, the variance of the predictions can be decomposed as
where the first term on the rhs is the within-group variance that CoRe penalizes, while a ridge penalty would penalize both the within- and also the between-group variance (the second term on the rhs above). In Fig. 6b we show the ratio between the CoRe penalty and the between-group variance where groups are defined by conditioning on \((Y,\mathrm {ID})\). Specifically, the ratio is computed as
The results shown in Fig. 6b are computed on dataset 1 (DS 1). While increasing ridge penalty weights do lead to a smaller value of the CoRe penalty, the between-group variance is also reduced such that the ratio between the two terms does not decrease with larger weights of the ridge penalty.Footnote 10 With increasing weight of the CoRe penalty, the variance ratio decreases, showing that the CoRe penalty indeed penalizes the within-group variance more than the between-group variance.
Table 1 also reports the value of the CoRe penalty after training when evaluated for the pooled and the CoRe estimator on the training and the test set. As a qualitative measure to assess the presence of sample bias in the data (provided the model assumptions hold), we can compare the value the CoRe penalty takes after training when evaluated for the pooled estimator and the CoRe estimator. The difference yields a measure for the extent the respective estimators are functions of \(\varDelta\). If the respective hold-out values are both small, this would indicate that the style features are not very predictive for the target variable. If, on the other hand, the CoRe penalty evaluated for the pooled estimator takes a much larger value than for the CoRe estimator (as in this case), this would indicate the presence of sample bias.
5.1.2 Other CoRe penalty types
We now compare all CoRe penalty types, i.e., penalizing with (i) the conditional variance of the predicted logits \({\hat{C}}_{f,1,\theta }\), (ii) the conditional standard deviation of the predicted logits \({\hat{C}}_{f,1/2,\theta }\), (iii) the conditional variance of the loss \({\hat{C}}_{l,1,\theta }\) and (iv) the conditional standard deviation of the loss \({\hat{C}}_{l,1/2,\theta }\). For this comparison, we use the training dataset 1 (DS 1) from above. Table 2 contains the test error (training error was \(0\%\) for all methods) as well as the value the respective CoRe penalty took after training on the training set and the test set. The four CoRe penalty variants’ performance differences are not statistically significant. Hence, we mostly focus on the conditional variance of the predicted logits \({\hat{C}}_{f,1,\theta }\) in the other experiments.
5.1.3 Discussion
While the distributional shift in this example arises due to statistical fluctuations which will diminish as the sample size grows, the following examples are more concerned with biases that will persist even if the number of training and test samples is very large. A second difference to the subsequent examples is the grouping structure—in this example, we consider only a few identities, namely \(m=10\), with a relatively large number \(n_i\) of associated observations (about thirty observations per individual). In the following examples, \(m\) is much larger while \(n_i\) is typically smaller than five.
5.2 Gender classification with unknown confounding
In the following set of experiments, we work again with the CelebA dataset and the 5-layer convolutional neural network architecture described in Table 5. This time we consider the problem of classifying whether the person shown in the image is male or female. We create a confounding in training and test set I by including mostly images of men wearing glasses and women not wearing glasses. In test set 2 the association between gender and glasses is flipped: women always wear glasses while men never wear glasses. Examples from the training and test sets 1 and 2 are shown in Fig. 7. The training set, test set 1 and 2 are subsampled such that they are balanced with respect to Y, resulting in 16,982, 4224 and 1120 observations, respectively.
To compute the conditional variance penalty, we use again images of the same person. The \(\mathrm {ID}\) variable is, in other words, the identity of the person and gender Y is constant across all examples with the same \(\mathrm {ID}\). Conditioning on \((Y,\mathrm {ID})\) is hence identical to conditioning on \(\mathrm {ID}\) alone. Another difference to the other experiments is that we consider a binary style feature here.
5.2.1 Label shift in grouped observations
We compare six different datasets that vary with respect to the distribution of Y in the grouped observations. In all training datasets, the total number of observations is 16982 and the total number of grouped observations is 500. In the first dataset, 50% of the grouped observations correspond to males and 50% correspond to females. In the remaining 5 datasets, we increase the number of grouped observations with \(Y=\text {``man''}\), denoted by \(\kappa\), to 75%, 90%, 95%, 99% and 100%, respectively. Table 3 shows the performance obtained for these datasets when using the pooled estimator compared to the CoRe estimator with \({\hat{C}}_{f,1,\theta }\). The results show that both the pooled estimator as well as the CoRe estimator perform better if the distribution of Y in the grouped observations is more balanced. The CoRe estimator improves the error rate of the pooled estimator by \(\approx 28-39\%\) on a relative scale. Figure 8 shows the performance for \(\kappa =50\%\) as a function of the CoRe penalty weight. Significant improvements can be obtained across a large range of values for the CoRe penalty and the ridge penalty. Test errors become more sensitive to the chosen value of the CoRe penalty for very large values of the ridge penalty weight as the overall amount of regularization is already large.
5.2.2 Using pre-trained Inception V3 features
To verify that the above conclusions do not change when using more powerful features, we here compare \(\ell _2\)-regularized logistic regression using pre-trained Inception V3 featuresFootnote 11 with and without the CoRe penalty. Table 4 shows the results for \(\kappa =0.5\). While the results show that both the pooled estimator as well as the CoRe estimator perform better using pre-trained Inception features, the relative improvement with the CoRe penalty is still 28% on test set 2.
5.2.3 Ablation experiments
In Sect. D.3.1, we report results for the following two additional baselines: (i) we group all examples sharing the same class label and penalize with the conditional variance of the predicted logits, computed over these two groups; (ii) we penalize the overall variance of the predicted logits, i.e., a form of unconditional variance regularization.
5.3 Eyeglasses detection with known and unknown image quality intervention
We now revisit the second example from Sect. 1.1. We again use the CelebA dataset and consider the problem of classifying whether the person in the image is wearing eyeglasses. Here, we modify the images in the following way: in the training set and in test set 1, we sample the image qualityFootnote 12 for all samples \(\{i:y_i =1\}\) (all samples that show glasses) from a Gaussian distribution with mean \(\mu =30\) and standard deviation \(\sigma =10\). Samples with \(y_i=0\) (no glasses) are unmodified. In other words, if the image shows a person wearing glasses, the image quality tends to be lower. In test set 2, the quality is reduced in the same way for \(y_i =0\) samples (no glasses), while images with \(y_i =1\) are not changed. Figure 9 shows examples from the training set and test sets 1 and 2. For the CoRe penalty, we calculate the conditional variance across images that share the same \(\mathrm {ID}\) if \(Y=1\), that is across images that show the same person wearing glasses on all images. Observations with \(Y=0\) (not wearing glasses) are not grouped. Two examples are shown in the red box of Fig. 9. Here, we have \(c=5000\) grouped observations among a total sample size of \(n=20{,}000\).
Figure 9 shows misclassification rates for CoRe and the pooled estimator on test sets 1 and 2. The pooled estimator (only penalized with an \(\ell _2\) penalty) achieves low error rates of 2% on test set 1, but suffers from a 65% misclassification error on test set 2, as now the relation between Y and the implicit \(S\) variable (image quality) has been flipped. The CoRe estimator has a larger error of 13% on test set 1 as image quality as a feature is penalized by CoRe implicitly and the signal is less strong if image quality has been removed as a dimension. However, in test set 2 the performance of the CoRe estimator is 28% and improves substantially on the 65% error of the pooled estimator. The reason is again the same: the CoRe penalty ensures that image quality is not used as a feature to the same extent as for the pooled estimator. This increases the test error slightly if the samples are generated from the same distribution as training data (as here for test set 1) but substantially improves the test error if the distribution of image quality, conditional on the class label, is changed on test data (as here for test set 2).
Eyeglasses detection with known image quality intervention To compare to the above results, we repeat the experiment by changing the grouped observations as follows. Above, we grouped images that had the same person \(\mathrm {ID}\) when \(Y=1\). We refer to this scheme of grouping observations with the same \((Y, \mathrm {ID})\) as ‘Grouping setting 2’. Here, we use an explicit augmentation scheme and augment \(c=5000\) images with \(Y=1\) in the following way: each image is paired with a copy of itself and the image quality is adjusted as described above. In other words, the only difference between the two images is that image quality differs slightly, depending on the value that was drawn from the Gaussian distribution with mean \(\mu =30\) and standard deviation \(\sigma =10\), determining the strength of the image quality intervention. Both the original and the copy get the same value of identifier variable \(\mathrm {ID}\). We call this grouping scheme ‘Grouping setting 1’. Compare the left panels of Figs. 9 and 10 for examples.
While we used explicit changes in image quality in both above and here, we referred to grouping setting 2 as ‘unknown image quality interventions’ as the training sample as in the left panel of Fig. 9 does not immediately reveal that image quality is the important style variable. In contrast, the augmented data samples (grouping setting 1) we use here differ only in their image quality for a constant \((Y,\mathrm {ID})\).
Figure 10 shows examples and results. The pooled estimator performs more or less identical to the previous dataset. The explicit augmentation did not help as the association between image quality and whether eyeglasses are worn is not changed in the pooled data after including the augmented data samples. The misclassification error of the CoRe estimator is substantially better than the error rate of the pooled estimator. The error rate on test set 2 of 13% is also improving on the rate of 28% of the CoRe estimator in grouping setting 2. We see that using grouping setting 1 works best since we could explicitly control that only \(S\equiv \textit{image quality}\) varies between grouped examples. In grouping setting 2, different images of the same person can vary in many factors, making it more challenging to isolate image quality as the factor to be invariant against.
A similar example where \(S\equiv \textit{brightness}\) is summarized in “Appendix D.1”.
5.4 Stickmen image-based age classification with unknown movement interventions
In this example we consider synthetically generated stickmen images; see Fig. 11 for some examples. The target of interest is \(Y \in \{\text {adult}, \text {child}\}\). The core feature \(C\) is here the height of each person. The class Y is causal for height and height cannot be easily intervened on or change in different domains. Height is thus a robust predictor for differentiating between children and adults. As style feature we have here the movement of a person (distribution of angles between body, arms and legs). For the training data we created a dependence between age and the style feature ‘movement’, which can be thought to arise through a hidden common cause \(D\), namely the place of observation. For instance, the images of children might mostly show children playing while the images of adults typically show them in more “static” postures. The left panel of Fig. 11 shows examples from the training set where large movements are associated with children and small movements are associated with adults. Test set 1 follows the same distribution, as shown in the middle panel. A standard CNN will exploit this relationship between movement and the label Y of interest, whereas this is discouraged by the conditional variance penalty of CoRe. The latter is pairing images of the same person in slightly different movements as shown by the red boxes in the leftmost panel of Fig. 11. If the learned model exploits this dependence between movement and age for predicting Y, it will fail when presented images of, say, dancing adults. The right panel of Fig. 11 shows such examples (test set 2). The standard CNN suffers in this case from a 41% misclassification rate, as opposed to the 3% on test set 1 data. For as few as \(c=50\) paired observations, the network with an added CoRe penalty, in contrast, achieves also 4% on test set 1 data and succeeds in achieving an 9% performance on test set 2, whereas the pooled estimator fails on this dataset with a test error of 41%.
These results suggest that the learned representation of the pooled estimator uses movement as a predictor for age while CoRe does not use this feature due to the conditional variance regularization. Importantly, including more grouped examples would not improve the performance of the pooled estimator as these would be subject to the same bias and hence also predominantly have examples of heavily moving children and “static” adults (also see Fig. 23 which shows results for \(c\in \{ 20, 500, 2000\}\)).
5.5 MNIST: more sample efficient data augmentation
The goal of using CoRe in this example is to make data augmentation more efficient in terms of the required samples. In data augmentation, one creates additional samples by modifying the original inputs, e.g. by rotating, translating, or flipping the images (Schölkopf et al. 1996). In other words, additional samples are generated by interventions on style features. Using this augmented data set for training results in invariance of the estimator with respect to the transformations (style features) of interest. For CoRe we can use the grouping information that the original and the augmented samples belong to the same object. This enforces the invariance with respect to the style features more strongly compared to normal data augmentation which just pools all samples. We assess this for the style feature ‘rotation’ on MNIST (LeCun et al. 1998) and only include \(c=200\) augmented training examples for \(m=10{,}000\) original samples, resulting in a total sample size of \(n=10200\). The degree of the rotations is sampled uniformly at random from [35, 70]. Figure 12 shows examples from the training set. By using CoRe the average test error on rotated examples is reduced from 22% to 10%. Very few augmented sample are thus sufficient to lead to stronger rotational invariance. The standard approach of creating augmented data and pooling all images requires, in contrast, many more samples to achieve the same effect. Additional results for \(m\in \{1000, 10{,}000\}\) and \(c\) ranging from 100 to 5000 can be found in Fig. 22 in Appendix Sect. D.5.
5.6 Elmer the Elephant
In this example, we want to assess whether invariance with respect to the style feature ‘color’ can be achieved. In the children’s book ‘Elmer the elephant’Footnote 13 one instance of a colored elephant suffices to recognize it as being an elephant, making the color ‘gray’ no longer an integral part of the object ‘elephant’. Motivated by this process of concept formation, we would like to assess whether CoRe can exclude ‘color’ from its learned representation by penalizing conditional variance appropriately.
We work with the ‘Animals with attributes 2’ (AwA2) dataset (Xian et al. 2017) and consider classifying images of horses and elephants. We include additional examples by adding grayscale images for \(c=250\) images of elephants. These additional examples do not distinguish themselves strongly from the original training data as the elephant images are already close to grayscale images. The total training sample size is 1850.
Figure 13 shows examples and misclassification rates from the training set and test sets for CoRe and the pooled estimator on different test sets. Examples from these and more test sets can be found in Fig. 24. Test set 1 contains original, colored images only. In test set 2 images of horses are in grayscale and the colorspace of elephant images is modified, effectively changing the color gray to red-brown. We observe that the pooled estimator does not perform well on test set 2 as its learned representation seems to exploit the fact that ‘gray’ is predictive for ‘elephant’ in the training set. This association is no longer valid for test set 2. In contrast, the predictive performance of CoRe is hardly affected by the changing color distributions. More details can be found in “Appendix D.7”.
It is noteworthy that a colored elephant can be recognized as an elephant by adding a few examples of a grayscale elephant to the very lightly colored pictures of natural elephants. If we just pool over these examples, there is still a strong bias that elephants are gray. The CoRe estimator, in contrast, demands invariance of the prediction for instances of the same elephant and we can learn color invariance with a few added grayscale images.
6 Further related work
Encoding certain invariances in estimators is a well-studied area in computer vision and machine learning with an extensive body of literature. While a large part of this work assumes the desired invariance to be known, fewer approaches aim to learn the required invariances from data and the focus often lies on geometric transformations of the input data or explicitly creating augmented observations (Sohn and Lee 2012; Khasanova and Frossard 2017; Hashimoto et al. 2017; Devries and Taylor 2017). The main difference between this line of work and CoRe is that we do not require to know the style feature explicitly, the set of possible style features is not restricted to a particular class of transformations and we do not aim to create augmented observations in a generative framework.
Recently, various approaches have been proposed that leverage causal motivations for deep learning or use deep learning for causal inference, related to e.g. the problems of cause-effect inference and generative adversarial networks (Chalupka et al. 2014; Lopez-Paz et al. 2017; Lopez-Paz and Oquab 2017; Goudet et al. 2017; Bahadori et al. 2017; Besserve et al. 2018; Kocaoglu et al. 2018).
Kilbertus et al. (2017) exploit causal reasoning to characterize fairness considerations in machine learning. Distinguishing between the protected attribute and its proxies, they derive causal non-discrimination criteria. The resulting algorithms avoiding proxy discrimination require classifiers to be constant as a function of the proxy variables in the causal graph, thereby bearing some structural similarity to our style features.
Distinguishing between core and style features can be seen as some form of disentangling factors of variation. Estimating disentangled factors of variation has gathered a lot of interested in the context of generative modeling. As in CoRe, Bouchacourt et al. (2018) exploit grouped observations. In a variational autoencoder framework, they aim to separate style and content—they assume that samples within a group share a common but unknown value for one of the factors of variation while the style can differ. Denton and Birodkar (2017) propose an autoencoder framework to disentangle style and content in videos using an adversarial loss term where the grouping structure induced by clip identity is exploited. Here we try to solve a classification task directly without estimating the latent factors explicitly as in a generative framework.
In the computer vision literature, various works have used identity information to achieve pose invariance in the context of face recognition (Bartlett and Sejnowski 1996; Tran et al. 2017). More generally, the idea of exploiting various observations of the same underlying object is related to multi-view learning (Xu et al. 2013). In the context of adversarial examples, Kannan et al. (2018) recently proposed the defense “Adversarial logit pairing” which is methodologically equivalent to the CoRe penalty \({C}_{f,1,\theta }\) when using the squared error loss. Several empirical studies have shown mixed results regarding the performance on \(\ell _\infty\) perturbations (Engstrom et al. 2018; Mosbach et al. 2018), so far this setting has not been analyzed theoretically and hence it is an open question whether a CoRe-type penalty constitutes an effective defense against adversarial examples.
7 Conclusion
Distinguishing the latent features in an image into CoRe and style features, we have proposed conditional variance regularization (CoRe) to achieve robustness with respect to interventions on the style or “orthogonal” features. The main idea of the CoRe estimator is to exploit the fact that we often have instances of the same object in the training data. By demanding invariance of the classifier amongst a group of instances that relate to the same object, we can achieve invariance of the classification performance with respect to interventions on style features such as image quality, fashion type, color, or body posture. The training also works despite sampling biases in the data.
There are two main application areas:
-
1.
If the style features are known explicitly, we can achieve the same classification performance as standard data augmentation approaches with substantially fewer augmented samples, as shown for example in Sect. 5.5.
-
2.
Perhaps more interesting are settings in which it is unknown what the style features are, with examples in Sects. 5.1, 5.2, 5.3, 5.4 and “Appendix D.1”. CoRe regularization forces predictions to be based on features that do not vary strongly between instances of the same object. We could show in the examples and in Theorems 1 and 2 that this regularization achieves distributional robustness with respect to changes in the distribution of the (unknown) style variables.
An interesting line of work would be to use larger models such as Inception or large ResNet architectures (Szegedy et al. 2015; He et al. 2016). These models have been trained to be invariant to an array of explicitly defined style features. In Sect. 5.2.2 we include results which show that using Inception V3 features does not guard against interventions on more implicit style features. We would thus like to assess what benefits CoRe can bring for training Inception-style models end-to-end, both in terms of sample efficiency and in terms of generalization performance. While we showed some examples where the necessary grouping information is available, an interesting possible future direction would be to use video data since objects display temporal constancy and the temporal information can hence be used for grouping and conditional variance regularization. Beyond that our results show that it can be worthwhile to collect \(\mathrm {ID}\) information when new datasets are created. As CoRe only requires a subset of the observations to have \(\mathrm {ID}\) annotations, in many cases this information might be cheap to collect while it can improve performance substantially when future test data is subject to domain shifts.
Notes
Segmenting eyelashes from the iris is not entirely accurate which implies that the iris images can still contain parts of eyelashes, occluding the iris. As mascara causes the eyelashes to be thicker and darker, it is difficult to entirely remove the presence of cosmetics from the iris images.
The distinction between ‘conditionally independent’ features and ‘conditionally transferable’ (which is the former modulo location and scale transformations) is for our purposes not relevant as we do not make a linearity assumption in general.
If an existing image is classified by a human, then the image is certainly ancestral for the attached label. If the label Y refers, however, to the underlying true object (say if you generate images by asking people to take pictures of objects or if you record the status of cells after performing a gene knockout), then the more fitting model is the one where Y is ancestral for X. We here focus on modeling the underlying true object since ultimately the goal is to deploy the trained system in the real world.
The type of features we regard as style and which ones we regard as core features can conceivably change depending on the circumstances—for instance, is the color “gray” an integral part of the object “elephant” or can it be changed so that a colored elephant is still considered to be an elephant?
Observations where the \(\mathrm {ID}\) variable is unobserved are not grouped, that is each such observation is counted as a unique observation of \((Y,\mathrm {ID})\).
As an example, if the change in distribution for \(S\) is caused by random shift-interventions \(\varDelta\), then \(\tilde{S}\leftarrow S+\varDelta\), and the distance \(D_{\text {style}}\) induced in the distributions is
$$\begin{aligned} D_{\text {style}} (F_0,F) \le E\big [ E( \varDelta ^t \Sigma ^{-1}_{y,\mathrm {id}} \varDelta |Y=y,\mathrm {ID}=\mathrm {id}) \big ],\end{aligned}$$ensuring that the strength of the shifts is measured against the natural variability \(\Sigma _{y,\mathrm {id}}\) of the style features.
The right hand side can also be interpreted as the graph Laplacian (Belkin et al. 2006) of an appropriately weighted graph that fully connects all observations \(i \in G_j\) for each \(j\in \{1,\ldots ,m\}\).
If some labeled test data are avaiable and our goal is to perform optimally for samples that come from the same distribution as those test samples, then we can and should adjust the parameter to minimize estimated test error.
For example, the home office dataset in (Venkateswara et al. 2017) contains images of various objects (kettle, clock etc) from different domains (real world photographs, cliparts, art images and product images). If we aim to predict the category of an object, it seems difficult to identify an appropriate \(\mathrm {ID}\) variable in the dataset. However, if we aim instead to predict properties of an object (does it use electricity?), then we can use as \(\mathrm {ID}\) the object category (kettle, clock etc.)
In Fig. 18 in the Appendix, the numerator and the denominator are plotted separately as a function of the CoRe penalty weight.
Retrieved from https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1.
We use ImageMagick (https://www.imagemagick.org) to change the level of the JPEG compression through convert -quality q_ij input.jpg output.jpg where \(q_{i,j}\sim {\mathcal {N}}(30,100)\).
Recall that \((y_i,\mathrm {id}_i)=(y_{i'},\mathrm {id}_{i'})\) if \(i,i'\in G_j\) as the subsets \(G_j\), \(j=1,\ldots ,m\), collect all observations that have a unique realization of \((Y,\mathrm {ID})\)
Specifically, we use ImageMagick (https://www.imagemagick.org) and modify the brightness of each image by applying the command convert -modulate b_ij,100,100 input.jpg output.jpg to the image.
For more details, see http://www.imagemagick.org/Usage/color_mods/#color_mods.
References
Abadi, M., Agarwal, A., Barham, P., Brevdo, E., Chen, Z., Citro, C., Corrado, G. S., Davis, A., Dean, J., Devin, M., Ghemawat, S., Goodfellow, I., Harp, A., Irving, G., Isard, M., Jia, Y., Jozefowicz, R., Kaiser, L., Kudlur, M., Levenberg, J., Mané, D., Monga, R., Moore, S., Murray, D., Olah, C., Schuster, M., Shlens, J., Steiner, B., Sutskever, I., Talwar, K., Tucker, P., Vanhoucke, V., Vasudevan, V., Viégas, F., Vinyals, O., Warden, P., Wattenberg, M., Wicke, M., Yu, Y., & Zheng, X. (2015). TensorFlow: Large-scale machinelearning on heterogeneous systems.https://www.tensorflow.org/, software available fromtensorflow.org.
Aldrich, J. (1989). Autonomy. Oxford Economic Papers, 41, 15–34.
Bagnell, J. (2005). Robust supervised learning. In Proceedings of the national conference on artificial intelligence, Menlo Park, CA; Cambridge, MA; London; AAAI Press; MIT Press; (1999) Vol. 20, p. 714.
Bahadori, M. T., Chalupka, K., Choi, E., Chen, R., Stewart, W. F., & Sun, J. (2017). Causal regularization. arXiv:170202604.
Barocas, S., & Selbst, A.D. (2016). Big Data’s Disparate Impact. 104 California Law Review 671.
Bartlett, M.S., Sejnowski, T. J. (1996). Viewpoint invariant face recognition using independent component analysis and attractor networks. In Proceedings of the 9th International Conference on Neural Information Processing Systems (pp. 817–823), MIT Press, Cambridge, MA, USA, NIPS’96.
Belkin, M., Niyogi, P., & Sindhwani, V. (2006). Manifold regularization: A geometric framework for learning from labeled and unlabeled examples. Journal of Machine Learning Research, 7(Nov), 2399–2434.
Ben-David, S., Blitzer, J., Crammer, K., & Pereira, F. (2007). Analysis of representations for domain adaptation. In Advances in Neural Information Processing Systems, Vol. 19.
Ben-Tal, A., Den Hertog, D., De Waegenaere, A., Melenberg, B., & Rennen, G. (2013). Robust solutions of optimization problems affected by uncertain probabilities. Management Science, 59(2), 341–357.
Besserve, M., Shajarisales, N., Schölkopf, B., & Janzing, D. (2018). Group invariance principles for causal generative models. In Proceedings of the 21st International Conference on Artificial Intelligence and Statistics (AISTATS), PMLR, Proceedings of Machine Learning Research (pp. 557–565), Vol. 84.
Bolukbasi, T., Chang, K. W., Zou, J. Y., Saligrama, V., & Kalai, A. T. (2016). Man is to computer programmer as woman is to homemaker? Debiasing word embeddings. In Advances in Neural Information Processing Systems, Vol. 29.
Bouchacourt, D., Tomioka, R., & Nowozin, S. (2018). Multi-level variational autoencoder: Learning disentangled representations from grouped observations. In AAAI Conference on Artificial Intelligence.
Chalupka, K., Perona, P., & Eberhardt, F. (2014). Visual Causal Feature Learning. Uncertainty in Artificial Intelligence.
Crawford, K. (2016). Artificial intelligence’s white guy problem. The New York Times, June 25 2016 https://www.nytimes.com/2016/06/26/opinion/sunday/artificial-intelligences-white-guy-problem.html.
Csurka, G. (2017). A comprehensive survey on domain adaptation for visual applications. In Domain Adaptation in Computer Vision Applications (pp. 1–35).
Denton, E. L., & Birodkar, V. (2017). Unsupervised learning of disentangled representations from video. In Advances in Neural Information Processing Systems, Vol. 30.
Devries, T., & Taylor, G. W. (2017). Dataset augmentation in feature space. ICLR Workshop Track.
Emspak, J. (2016). How a machine learns prejudice. Scientific American, December 29 2016 https://www.scientificamerican.com/article/how-a-machine-learns-prejudice/.
Engstrom, L., Ilyas, A., & Athalye, A. (2018). Evaluating and understanding the robustness of adversarial logit pairing. arXiv:180710272.
Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F., et al. (2016). Domain-adversarial training of neural networks. Journal of Machine Learning Research, 17(1), 2096–2030.
Gao, R., Chen, X., & Kleywegt, A. (2017). arXiv:171206050.
Gong, B., Grauman, K., & Sha, F. (2013). Reshaping visual datasets for domain adaptation. In Advances in Neural Information Processing Systems (Vol. 26, pp. 1286–1294), Curran Associates, Inc.
Gong, M., Zhang, K., Liu, T., Tao, D., Glymour, C., & Schölkopf, B. (2016). Domain adaptation with conditional transferable components. In International Conference on Machine Learning.
Goodfellow, I., Shlens, J., & Szegedy C. (2015). Explaining and harnessing adversarial examples. In International conference on learning representations.
Goudet, O., Kalainathan, D., Caillou, P., Lopez-Paz, D., Guyon, I., Sebag, M., Tritas, A., & Tubaro, P. (2017). Learning functional causal models with generative neural networks. arXiv:170905321.
Haavelmo, T. (1944). The probability approach in econometrics. Econometrica 12:S1–S115 (supplement).
Hashimoto, T. B., Liang, P. S., & Duchi, J. C. (2017). Unsupervised transformation learning via convex relaxations. In Advances in neural information processing systems (Vol. 30, pp. 6875–6883), Curran Associates, Inc.
He, K., Zhang, X., Ren, S., & Sun, J. (2015). Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. In Proceedings of the 2015 IEEE international conference on computer vision (ICCV) (pp. 1026–1034).
He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In 2016 IEEE conference on computer vision and pattern recognition (CVPR) (pp. 770–778).
Hoffman, J., Kulis, B., Darrell, T., & Saenko, K. (2012). Discovering latent domains for multisource domain adaptation. In Computer Vision - ECCV 2012 (pp. 702–715). Berlin Heidelberg, Berlin, Heidelberg: Springer.
Kannan, H., Kurakin, A., & Goodfellow, I. J. (2018). Adversarial logit pairing. arXiv:180306373.
Khasanova, R., & Frossard, P. (2017). Graph-based isometry invariant representation learning. In Proceedings of the 34th international conference on machine learning (Vol. 70, pp. 1847–1856).
Kilbertus, N., Rojas Carulla, M., Parascandolo, G., Hardt, M., Janzing, D., & Schölkopf, B. (2017). Avoiding discrimination through causal reasoning. Advances in Neural Information Processing Systems, 30, 656–666.
Kingma, D. P., & Ba, J. (2015). Adam: A method for stochastic optimization. International Conference on Learning Representations (ICLR).
Kocaoglu, M., Snyder, C., Dimakis, A., & Vishwanath, S. (2018). CausalGAN: Learning causal implicit generative models with adversarial training. In International conference on learning representations.
Krizhevsky, A., Sutskever, I., & Hinton, G. E. (2012). Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems, Vol. 25.
Kuehlkamp, A., Becker, B., & Bowyer, K. (2017). Gender-from-iris or gender-from-mascara? In 2017 IEEE winter conference on applications of computer vision (WACV).
LeCun, Y., Bottou, L., Bengio, Y., & Haffner, P. (1998). Gradient-based learning applied to document recognition. Proceedings of the IEEE.
Li, K. C. (1991). Sliced inverse regression for dimension reduction. Journal of the American Statistical Association, 86(414), 316–327.
Liu, Z., Luo, P., Wang, X., & Tang, X. (2015). Deep learning face attributes in the wild. In Proceedings of International Conference on Computer Vision (ICCV).
Lopez-Paz, D., & Oquab, M. (2017). Revisiting classifier two-sample tests. In International conference on learning representations (ICLR).
Lopez-Paz, D., Nishihara, R., Chintala, S., Schölkopf, B., & Bottou, L. (2017). Discovering causal signals in images. In The IEEE conference on computer vision and pattern recognition (CVPR 2017).
Magliacane, S., van Ommen, T., Claassen, T., Bongers, S., Versteeg, P., & Mooij, J. (2018). Domain adaptation by using causal inference to predict invariant conditional distributions. In Advances in neural information processing systems.
Meinshausen, N. (2018). Causality from a distributional robustness point of view. In 2018 IEEE Data Science Workshop (DSW) (pp. 6–10).
Mosbach, M., Andriushchenko, M., Trost, T., Hein, M., & Klakow, D. (2018). Logit pairing methods can fool gradient-based attacks. arXiv:181012042.
Namkoong, H., & Duchi, J. (2017). Variance-based regularization with convex objectives. In Advances in Neural Information Processing Systems (pp. 2975–2984).
Pearl, J. (2009). Causality: Models, reasoning, and inference (2nd ed.). New York: Cambridge University Press.
Peters, J., Bühlmann, P., & Meinshausen, N. (2016). Causal inference using invariant prediction: identification and confidence intervals. Journal of the Royal Statistical Society, Series B, 78, 947–1012.
Quionero-Candela, J., Sugiyama, M., Schwaighofer, A., & Lawrence, N. D. (2009). Dataset shift in machine learning. Cambridge: The MIT Press.
Richardson, T., & Robins, J. M. (2013). Single world intervention graphs (SWIGs): A unification of the counterfactual and graphical approaches to causality. Center for the Statistics and the Social Sciences, University of Washington Series Working Paper 128, 30 April 2013.
Rojas-Carulla, M., Schölkopf, B., Turner, R., & Peters, J. (2018). Causal transfer in machine learning. To appear in Journal of Machine Learning Research.
Rothenhäusler, D., Bühlmann, P., Meinshausen, N., & Peters, J. (2018). Anchor regression: heterogeneous data meets causality. arXiv:180106229.
Schölkopf, B., Burges, C., & Vapnik, V. (1996). Incorporating invariances in support vector learning machines. Artificial Neural Networks – ICANN 96 (pp. 47–52). Berlin Heidelberg, Berlin, Heidelberg: Springer.
Schölkopf, B., Janzing, D., Peters, J., Sgouritsa, E., Zhang, K., & Mooij, J. (2012). On causal and anticausal learning. In Proceedings of the 29th international conference on machine learning (ICML) (pp. 1255–1262).
Shafieezadeh-Abadeh, S., Kuhn, D., & Esfahani, P. (2017). Regularization via mass transportation. arXiv:171010016.
Sinha, A., Namkoong, H., & Duchi, J. (2018). Certifiable distributional robustness with principled adversarial training. In International conference on learning representations.
Sohn, K., & Lee, H. (2012). Learning invariant representations with local transformations. In Proceedings of the 29th international coference on international conference on machine learning, Omnipress, USA, ICML’12, pp. 1339–1346.
Szegedy, C., Liu, W., Jia, Y., Sermanet, P., Reed, S., Anguelov, D., Erhan, D., Vanhoucke, V., & Rabinovich, A. (2015). Going deeper with convolutions. In Computer vision and pattern recognition (CVPR).
Szegedy, C., Zaremba, W., Sutskever, I., Bruna, J., Erhan, D., Goodfellow, I., & Fergus, R. (2014). Intriguing properties of neural networks. In International conference on learning representations.
Torralba, A., & Efros, A. A. (2011). Unbiased look at dataset bias. In Computer Vision and Pattern Recognition (CVPR).
Tran, L., Yin, X., & Liu, X. (2017). Disentangled representation learning gan for pose-invariant face recognition. In Proceeding of IEEE computer vision and pattern recognition, Honolulu, HI
Venkateswara, H., Eusebio, J., Chakraborty, S., & Panchanathan, S. (2017). Deep hashing network for unsupervised domain adaptation. In (IEEE) Conference on Computer Vision and Pattern Recognition (CVPR).
Villani, C. (2003). Topics in optimal transportation (Vol. 58). Providence: American Mathematical Society.
Volpi, R., Namkoong, H., Sener, O., Duchi, J., Murino, V., & Savarese, S. (2018). Generalizing to unseen domains via adversarial data augmentation. arXiv:180512018.
Xian, Y., Lampert, C. H., Schiele, B., & Akata, Z. (2017). Zero-shot learning—A comprehensive evaluation of the good, the bad and the ugly. arXiv:170700600.
Xu, H., Caramanis, C., & Mannor, S. (2009). Robust regression and lasso. In Advances in Neural Information Processing Systems (pp. 1801–1808).
Xu, C., Tao, D., & Xu, C. (2013). A survey on multi-view learning. arXiv:13045634.
Yu, X., Liu, T., Gong, M., Zhang, K., & Tao, D. (2017). Transfer learning with label noise. arXiv:170709724.
Zhang, K., Gong, M., & Schölkopf, B. (2015). Multi-source domain adaptation: A causal view. In Proceedings of the Twenty-Ninth AAAI Conference on Artificial Intelligence.
Zhang, K., Schölkopf, B., Muandet, K., & Wang, Z. (2013). Domain adaptation under target and conditional shift. In International Conference on Machine Learning.
Acknowledgements
We thank Brian McWilliams, Jonas Peters, and Martin Arjovsky for helpful comments and discussions and CSCS for provision of computational resources. A preliminary version of this work was presented at the NIPS 2017 Interpretable ML Symposium and we thank participants of the symposium for very helpful discussions.
Funding
Open access funding provided by Swiss Federal Institute of Technology Zurich.
Author information
Authors and Affiliations
Corresponding author
Additional information
Publisher's Note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Editor: Paolo Frasconi.
Appendices
Proof of Theorem 1
First part To show the first part, namely that with probability 1,
we need to show that \({W}^t {\hat{\theta }}^{pool} \ne {0}\) with probability 1. The reason this is sufficient is as follows: if \({W}^t \theta \ne {0}\), then \(L_\infty (\theta )=\infty\) as we can then find a \(v\in \mathbb {R}^q\) such that \(\gamma := \theta ^t {W}v \ne 0\). Assume without limitation of generality that v is normed such that \(E(E(v^t \Sigma _{y,\mathrm {id}}^{-1} v|Y=y,\mathrm {ID}=\mathrm {id}))=1\). Setting \(\varDelta _\xi = \xi v\) for \(\xi \in \mathbb {R}\), we have that \((\mathrm {ID},Y,S+ \varDelta _\xi )\) is in the class \(F_{|\xi |}\) if the distribution of \((\mathrm {ID},Y,S)\) is equal to \(F_0\). Furthermore, \(x(\varDelta _\xi )^t \theta = x(\varDelta =0)^t \theta + \xi \gamma\). Hence \(\log (1+\exp (-y\cdot x(\varDelta _\xi )^t \theta ))\rightarrow \infty\) for either \(\xi \rightarrow \infty\) or \(\xi \rightarrow -\infty\).
To show that \({W}^t {\hat{\theta }}^{pool} \ne {0}\) with probability 1, let \({\hat{\theta }}^*\) be the oracle estimator that is constrained to be orthogonal to the column space of \({W}\):
We show \({W}^t {\hat{\theta }}^{pool} \ne {0}\) by contradiction. Assume hence that \({W}^t {\hat{\theta }}^{pool} = {0}\). If this is indeed the case, then the constraint \({W}^t \theta = {0}\) in (14) becomes non-active and we have \({\hat{\theta }}^{pool} = {\hat{\theta }}^*\). This would imply that taking the directional derivative of the training loss with respect to any \(\delta \in \mathbb {R}^p\) in the column space of \({W}\) should vanish at the solution \({\hat{\theta }}^*\). In other words, define the gradient as \(g(\theta )=\nabla _\theta L_n(\theta ) \in \mathbb {R}^p\). The implication is then that for all \(\delta\) in the column-space of W,
and we will show the latter condition is violated almost surely.
As we work with the logistic loss and \({\mathcal {Y}}\in \{-1,1\}\), the loss is given by \(\ell (y_i, f_\theta (x_i)) = \log (1+\exp (- y_i x_i^t\theta )) .\) Define \(r_i(\theta ):=y_i /(1+\exp ( y_i x_i^t \theta ))\). For all \(i=1,\ldots ,n\) we have \(r_i\ne 0\). Then
The training images can be written according to the model as \(x_i = x^0_i + W s_i\), where \(X^0:= k_x(C, \varepsilon _X)\) are the images in absence of any style variation. Since the style features only have an effect on the column space of \({W}\) in X, the oracle estimator \({\hat{\theta }}^*\) is identical under the true training data and the (hypothetical) training data \(x^0_i\), \(i=1,\ldots ,n\) in absence of style variation. As \(X -X^0= {W}S\), Eq. (16) can also be written as
Since \(\delta\) is in the column-space of \({W}\), there exists \(u\in \mathbb {R}^q\) such that \(\delta ={W}u\) and we can write (17) as
From (A2) we have that the eigenvalues of \({W}^t {W}\) are all positive. Also \(r_i({\hat{\theta }}^*)\) is not a function of the interventions \(s_{i}\), \(i=1,\ldots ,n\) since, as above, the estimator \({\hat{\theta }}^*\) is identical whether trained on the original data \(x_{i}\) or on the intervention-free data \(x^0_{i}\), \(i=1,\ldots , n\). If we condition on everything except for the random interventions by conditioning on \((x^0_i,y_i)\) for \(i=1,\ldots ,n\), then the rhs of (18) can be written as
where \(a\in \mathbb {R}^q\) is fixed (conditionally) and \(B=\frac{1}{n}\sum _{i=1}^nr_i({\hat{\theta }}^*) (s_{i})^t W^t W\in \mathbb {R}^q\) is a random vector and \(B\ne -a\in \mathbb {R}^q\) with probability 1 by (A1) and (A2) Hence the left hand side of (18) is not identically 0 with probability 1 for any given \(\delta\) in the column-space of W. This shows that the implication (15) is incorrect with probability 1 and hence completes the proof of the first part by contradiction.
Invariant parameter space Before continuing with the second part of the proof, some definitions. Let I be the invariant parameter space
For all \(\theta \in I\), the loss (6) for any \(F\in {F}_\xi\) is identical to the loss under \(F_0\). That is for all \(\xi \ge 0\),
The optimal predictor in the invariant space I is
If \(f_\theta\) is only a function of the core features \(C\), then \(\theta \in I\). The challenge is that the core features are not directly observable and we have to infer the invariant space I from data.
Second part For the second part, we first show that with probability at least \(p_n\), as defined in (A3), \({\hat{\theta }}^{core}={\hat{\theta }}^*\) with \({\hat{\theta }}^*\) defined as in (14). The invariant space for this model is the linear subspace \(I=\{\theta : {W}^t\theta =0\}\) and by their respective definitions,
Since we use \(I_n=I_n(\tau )\) with \(\tau =0\),
This implies that for \(\theta \in I_n\), \(f_\theta (x_{i})=f_\theta (x_{i'})\) if \(i,i'\in G_j\) for some \(j\in \{1,\ldots ,m\}.\)Footnote 14 Since \(f_\theta (x)=f_\theta (x')\) implies \((x-x')^t \theta =0\), it follows that \((x_{i} - x_{i'})^t \theta =0\) if \(i,i'\in G_j\) for some \(j\in \{1,\ldots ,m\}\) and hence
Since \(S\) has a linear influence on X in (11), \(x_{i} - x_{i'}= {W}(\varDelta _{i}-\varDelta _{i'})\) if \(i,i'\) are in the same group \(G_j\) of observations for some \(j\in \{1,\ldots ,m\}\). Note that the number of grouped examples \(n-m\) is equal to or exceeds the rank q of \({W}\) with probability \(p_n\), using (A3), and \(p_n\rightarrow 1\) for \(n\rightarrow \infty\). By (A2), it follows then with probability at least \(p_n\) that \(I_n\subseteq \{\theta : {W}^t\theta =0\}=I\). As, by definition, \(I\subseteq I_n\) is always true, we have with probability \(p_n\) that \(I=I_n\). Hence, with probability \(p_n\) (and \(p_n\rightarrow 1\) for \(n\rightarrow \infty\)), \({\hat{\theta }}^{core}={\hat{\theta }}^*\). It thus remains to be shown that
Since \({\hat{\theta }}^*\) is in I, we have \(\ell (y, x(\varDelta )) = \ell (y, x^0)\), where \(x^0\) are the previously defined data in absence of any style variance. Hence
that is the estimator is unchanged if we use the (hypothetical) data \(x^0_i\), \(i=1,\ldots ,n\) as training data. The population optimal parameter vector defined in (19) as
is for all \(\xi \ge 0\) identical to
Hence (21) and (22) can be written as
By uniform convergence of \(L_n^{(0)}\) to the population loss \(L^{(0)}\), we have \(L^{(0)}({\hat{\theta }}^*) \rightarrow _p L^{(0)}(\theta ^*)\). By definition of I and \(\theta ^*\), we have \(L_\infty ^* = L_\infty (\theta ^*)=L^{(0)}(\theta ^*)\). As \({\hat{\theta }}^*\) is in I, we also have \(L_\infty ({\hat{\theta }}^*) = L^{(0)}({\hat{\theta }}^*)\). Since, from above, \(L^{(0)}({\hat{\theta }}^*) \rightarrow _p L^{(0)}(\theta ^*)\), this also implies \(L_\infty ({\hat{\theta }}^*) \rightarrow _p L_\infty (\theta ^*)=L_\infty ^*\). Using the previously established result that \({\hat{\theta }}^{core}= {\hat{\theta }}^*\) with probability at least \(p_n\) and \(p_n \rightarrow 1\) for \(n\rightarrow \infty\), this completes the proof.
Proof of Theorem 2
Let \(F_0\) be the training distribution of \((\mathrm {ID},Y,S)\) and F a distribution for \((\mathrm {ID},Y,\tilde{S})\) in \({\mathcal {F}}_\xi\). By definition of \({\mathcal {F}}_\xi\), we can write \(\tilde{S}=S+\varDelta\) for a suitable random variable \(\varDelta \in \mathbb {R}^q\) with
Vice versa: if we can write \(\tilde{S}=S+ \varDelta\) with \(\varDelta \in {\mathcal {U}}_\xi\), then the distribution is in \(F_\xi\). While X under \(F_0\) can be written as \(X(\varDelta =0)\), the distribution of X under F is of the form \(X(\varDelta )\) or, alternatively, \(X(\sqrt{\xi } U)\) with \(U\in {\mathcal {U}}_1\). Adopting from now on the latter constraint that \(U\in {\mathcal {U}}_1\), and using (B2),
where \(\nabla h_\theta\) is the gradient of \(h_\theta (\delta )\) with respect to \(\delta\), evaluated at \(\delta \equiv 0\). Hence
The proof is complete if we can show that
On the one hand,
This follows for a matrix \(\Sigma\) with Cholesky decomposition \(\Sigma =V^t V\),
On the other hand, the conditional-variance-of-loss can be expanded as
which completes the proof.
Network architectures
We implemented the considered models in TensorFlow (Abadi et al. 2015). The model architectures used are detailed in Table 5. CoReCoRe and the pooled estimator use the same network architecture and training procedure; merely the loss function differs by the CoRe regularization term. In all experiments we use the Adam optimizer (Kingma and Ba 2015). All experimental results are based on training the respective model five times (using the same data) to assess the variance due to the randomness in the training procedure. In each epoch of the training, the training data \(x_{i}, i = 1, \ldots , n\) are randomly shuffled, keeping the grouped observations \((x_{i})_{i\in I_j}\) for \(j\in \{1,\ldots ,m\}\) together to ensure that mini batches will contain grouped observations. In all experiments the mini batch size is set to 120. For small \(c\) this implies that not all mini batches contain grouped observations, making the optimization more challenging.
Additional experiments
1.1 Eyeglasses detection: known and unknown brightness interventions
As in Sect. 5.3 we work with the CelebA dataset and try to classify whether the person in the image is wearing eyeglasses. Here we analyze a confounded setting that could arise as follows. Say the hidden common cause \(D\) of Y and \(S\) is a binary variable and indicates whether the image was taken outdoors or indoors. If it was taken outdoors, then the person tends to wear (sun-)glasses more often and the image tends to be brighter. If the image was taken indoors, then the person tends not to wear (sun-)glasses and the image tends to be darker. In other words, the style variable \(S\) is here equivalent to brightness and the structure of the data generating process is equivalent to the one shown in Fig. 3. Figure 14 shows examples from the training set and test sets. As previously, we compute the conditional variance over images of the same person, sharing the same class label (and the CoRe estimator is hence not using the knowledge that brightness is important). Two alternatives for constructing grouped observations in this setting are discussed further below. We use \(c=2000\) and \(n=20{,}000\). For the brightness intervention, we sample the value for the magnitude of the brightness increase resp. decrease from an exponential distribution with mean \(\beta = 20\). In the training set and test set 1, we sample the brightness value as \(b_{i,j} = [100 + y_i e_{i,j}]_+\) where \(e_{i,j} \sim \text {Exp}(\beta ^{-1})\) and \(y_i \in \{-1, 1\}\), where \(y_i = 1\) indicates presence of glasses and \(y_i=-1\) indicates absence.Footnote 15 For test set 2, we use instead \(b_{i,j} = [100 - y_i e_{i,j}]_+\), so that the relation between brightness and glasses is flipped.
Figure 14 shows misclassification rates for CoRe and the pooled estimator on different test sets. Examples from all test sets can be found in Fig. 15. First, we notice that the pooled estimator performs better than CoRe on test set 1. This can be explained by the fact that it can exploit the predictive information contained in the brightness of an image while CoRe is restricted not to do so. Second, we observe that the pooled estimator does not perform well on test set 2 as its learned representation seems to use the image’s brightness as a predictor for the response which fails when the brightness distribution in the test set differs significantly from the training set. In contrast, the predictive performance of CoRe is hardly affected by the changing brightness distributions.
We now discuss two alternatives for constructing different test sets and we vary the number of grouped observations in \(c\in \{200,2000,5000\}\) as well as the strength of the brightness interventions in \(\beta \in \{ 5, 10, 20\}\), all with sample size \(n=20{,}000\). Generation of training and test sets 1 and 2 were already described above. Here, we consider additionally test set 3 where all images are left unchanged (no brightness interventions at all) and in test set 4 the brightness of all images is increased. Furthermore, we consider three different ways of grouping images. Above, we used images of the same person to create a grouped observation by sampling a different value for the brightness intervention. We refer to this as ‘Grouping setting 2’ here. An alternative is to use the same image of the same person in different brightnesses (drawn from the same distribution) as a group over which the conditional variance is calculated. We call this ‘Grouping setting 1’ and it can be useful if we know that we want to protect against brightness interventions in the future. For comparison, we also evaluate grouping with an image of a different person (but sharing the same class label) as a baseline (‘Grouping setting 3’). Examples from the training sets using grouping settings 1, 2 and 3 can be found in Fig. 15.
Results for all grouping settings, \(\beta \in \{ 5, 10, 20\}\) and \(c\in \{200, 5000\}\) can be found in Fig. 16. We see that using grouping setting 1 works best since we could explicitly control that only \(S\equiv \textit{brightness}\) varies between grouping examples. In grouping setting 2, different images of the same person can vary in many factors, making it more challenging to isolate brightness as the factor to be invariant against. Lastly, we see that if we group images of different persons (‘Grouping setting 3’), the difference between CoRe estimator and the pooled estimator becomes much smaller than in the previous settings. Figure 17 shows some examples of misclassified observations for Grouping setting 1.
1.2 Eyeglasses detection with small sample size
Figure 18 shows the numerator and the denominator of the variance ratio defined in Eq. (13) separately as a function of the CoRe penalty weight. In conjunction with Fig. 6b, we observe that a ridge penalty decreases both the within- and between-group variance while the CoRe penalty penalizes the within-group variance selectively.
1.3 Gender classification
1.3.1 Additional baselines: Unconditional variance regularization and grouping by class label
As additional baselines, we consider the following two schemes: (i) we group all examples sharing the same class label and penalize with the conditional variance of the predicted logits, computed over these two groups; (ii) we penalize the overall variance of the predicted logits, i.e., a form of unconditional variance regularization. Figure 19 shows the performance of these two approaches. In contrast to the CoRe penalty, regularizing with the variance of the predicted logits conditional on Y only does not yield performance improvements on test set 2, compared to the pooled estimator (corresponding to a penalty weight of 0). Interestingly, using baseline (i) without a ridge penalty does yield an improvement on test set I, compared to the pooled estimator with various strengths of the ridge penalty.
1.3.2 Additional results
Table 6 additionally reports the standard errors for the results discussed in Sect. 5.2.
1.4 Eyeglasses detection: image quality intervention
Here, we show further results for the experiments introduced in Sect. 5.3. Specifically, we consider interventions of different strengths by varying the mean of the quality intervention in \(\mu \in \{30, 40, 50\}\). Recall that we use ImageMagick to modify the image quality. In the training set and in test set 1, we sample the image quality value as \(q_{i,j} \sim \mathcal {N}(\mu , \sigma = 10)\) and apply the command convert -quality q_ij input.jpg output.jpg if \(y_i \equiv \textit{glasses}\). If \(y_i \equiv \textit{no glasses}\), the image is not modified. In test set 2, the above command is applied if \(y_i \equiv \textit{no glasses}\) while images with \(y_i \equiv \textit{glasses}\) are not changed. In test set 3 all images are left unchanged and in test set 4 the command is applied to all images, i.e. the quality of all images is reduced.
We run experiments for grouping settings 1–3 and for \(c=5000\), where the definition of the grouping settings 1–3 is identical to “Appendix D.1”. Figure 20 shows examples from the respective training and test sets and Fig. 21 shows the corresponding misclassification rates. Again, we observe that grouping setting 1 works best, followed by grouping setting 2. Interestingly, there is a large performance difference between \(\mu =40\) and \(\mu =50\) for the pooled estimator. Possibly, with \(\mu =50\) the image quality is not sufficiently predictive for the target.
1.5 MNIST: more sample efficient data augmentation
Here, we show further results for the experiment introduced in Sect. 5.5. We vary the number of augmented training examples \(c\) from 100 to 5000 for \(m=10{,}000\) and \(c\in \{100, 200, 500, 1000\}\) for \(m=1000\). The degree of the rotations is sampled uniformly at random from [35, 70]. Figure 22 shows the misclassification rates. Test set 1 contains rotated digits only, test set 2 is the usual MNIST test set. We see that the misclassification rates of CoRe are always lower on test set 1, showing that it makes data augmentation more efficient. For \(m=1000\), it even turns out to be beneficial for performance on test set 2.
1.6 Stickmen image-based age classification
Here, we show further results for the experiment introduced in Sect. 5.4. Recall that test set 1 follows the same distribution as the training set. In test sets 2 and 3 large movements are associated with both children and adults, while the movements are heavier in test set 3 than in test set 2. Figure D.10b shows results for different numbers of grouping examples. For \(c=20\) the misclassification rate of CoRe estimator has a large variance. For \(c\in \{ 50, 500, 2000 \}\), the CoRe estimator shows similar results. Its performance is thus not sensitive to the number of grouped examples, once there are sufficiently many grouped observations in the training set. The pooled estimator fails to achieve good predictive performance on test sets 2 and 3 as it seems to use “movement” as a predictor for “age” (Fig. 23).
1.7 Elmer the Elephant
The color interventions for the experiment introduced in Sect. 5.6 were created as follows. In the training set, if \(y_i \equiv \textit{elephant}\) we apply the following ImageMagick command for the grouped examples convert -modulate 100,0,100 input.jpg output.jpg. Test sets 1 and 2 were already discussed in Sect. 5.6: in test set 1, all images are left unchanged. In test set 2, the above command is applied if \(y_i \equiv \text {horse}\). If \(y_i \equiv \text {elephant}\), we sample \(c_{i,j} \sim \mathcal {N}(\mu = 20, \sigma = 1)\) and apply convert -modulate 100,100,100-c_ij input.jpg output.jpg to the image. Here, we consider again some more test sets than in Sect. 5.6. In test set 4, the latter command is applied to all images. It rotates the colors of the image, in a cyclic manner.Footnote 16 In test set 3, all images are changed to grayscale. Examples from all four test sets are shown in Fig. 24 and classification results are shown in Fig. 25.
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if changes were made. The images or other third party material in this article are included in the article's Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article's Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by/4.0/.
About this article
Cite this article
Heinze-Deml, C., Meinshausen, N. Conditional variance penalties and domain shift robustness. Mach Learn 110, 303–348 (2021). https://doi.org/10.1007/s10994-020-05924-1
Received:
Revised:
Accepted:
Published:
Issue Date:
DOI: https://doi.org/10.1007/s10994-020-05924-1