Abstract
In this work we explore a new framework for approximate Bayesian inference in large datasets based on stochastic control. We advocate stochastic control as a finite time and low variance alternative to popular steadystate methods such as stochastic gradient Langevin dynamics. Furthermore, we discuss and adapt the existing theoretical guarantees of this framework and establish connections to already existing VI routines in SDEbased models.
Similar content being viewed by others
Avoid common mistakes on your manuscript.
1 Introduction
Steering a stochastic flow from one distribution to another across the space of probability measures is a wellstudied problem initially proposed in Schrödinger (1932). There has been recent interest in the machine learning community in these methods for generative modelling, sampling, dataset imputation and optimal transport (Wang et al. 2021; De Bortoli et al. 2021; Huang et al. 2021; Bernton et al. 2019; Vargas et al. 2021; Chen et al. 2022; Cuturi 2013; Maoutsa and Opper 2021; Reich 2019).
We consider a particular instance of the Schrödinger bridge problem (SBP), known as the Schrödinger–Föllmer process (SFP). In machine learning, this process has been proposed for sampling and generative modelling (Huang et al. 2021; Tzen and Raginsky 2019b) and in molecular dynamics for rare event simulation and importance sampling (Hartmann and Schütte 2012; Hartmann et al. 2017); here we apply it to Bayesian inference. We show that a controlbased formulation of the SFP has deeprooted connections to variational inference and is particularly well suited to Bayesian inference in high dimensions. This capability arises from the SFP’s characterisation as an optimisation problem and its parametrisation through neural networks (Tzen and Raginsky 2019b). Finally, due to the variational characterisation that these methods possess, many lowvariance estimators (Richter et al. 2020; Nüsken and Richter 2021; Roeder et al. 2017; Xu et al. 2021) are applicable to the SFP formulation we consider.
We reformulate the Bayesian inference problem by constructing a stochastic process \({{\varvec{\Theta }}}_t\) which at a fixed time \(t=1\) will generate samples from a prespecified posterior \(p({{\varvec{\theta }}}\vert {{\varvec{X}}})\), i.e. \({\textrm{Law}}{{\varvec{\Theta }}}_1 =p({{\varvec{\theta }}}\vert {{\varvec{X}}}) \), with dataset \({{\varvec{X}}}=\{{\textbf{x}}_i\}_{i=1}^N\), and where the model is given by:
Here the prior \(p({{\varvec{\theta }}})\) and the likelihood \(p({\textbf{x}}_i \vert {{\varvec{\theta }}})\) are userspecified. Our target is \(\pi _1({{\varvec{\theta }}}) = \frac{p({{\varvec{X}}}\vert {{\varvec{\theta }}})p({{\varvec{\theta }}})}{{\mathcal {Z}}}\), where \({\mathcal {Z}}= \int \prod _i p({{\varvec{x}}}_i \vert {{\varvec{\theta }}})p({{\varvec{\theta }}}) d{{\varvec{\theta }}}\). This formulation is reminiscent of the setup proposed in the previous works (Grenander and Miller 1994; Roberts and Tweedie 1996; Girolami and Calderhead 2011; Welling and Teh 2011) and covers many Bayesian machinelearning models, but our formulation has an important difference. SGLD relies on a diffusion that reaches the posterior as its equilibrium state when time approaches infinity. In contrast, our dynamics are controlled and the posterior is reached in finite time (bounded time). The benefit of this property is elegantly illustrated in Sect. 3.2 of Huang et al. (2021) where they rigorously demonstrate that even under an Euler approximation the proposed approach reaches a Gaussian target at time \(t\!=\!1\) whilst SGLD does not.
Contributions The main contributions of this work can be detailed as follows:

In this work we scale and apply the theoretical framework proposed in (Dai Pra 1991; Tzen and Raginsky 2019a) to sample from posteriors in large scale Bayesian machine learning tasks such as Bayesian Deep learning. We study the robustness of the predictions under this framework as well as evaluate their uncertainty quantification.

More precisely we propose an amortised parametrisation that allows scaling models with local and global variables to large datasets.

We explore and provide further theoretical backing (Sect. 2.2) to the “sticking the landing” estimator provided by Xu et al. (2021).

Overall we empirically demonstrate that the stochastic control framework offers a promising direction in Bayesian machine learning, striking the balance between theoretical/asymptotic guarantees found in MCMC methods (Hastings 1970; Duane et al. 1987; Neal 2011; Brooks et al. 2011) and more practical approaches such as variational inference (Blei et al. 2003).
1.1 Notation
Throughout the paper we consider path measures (denoted as \({\mathbb {Q}}\) or \({\mathbb {S}}\)) on the space of continuous functions \(\Omega = C([0, 1], {\mathbb {R}}^d)\). Random processes associated with such path measures \({\mathbb {Q}}\) are denoted as \({{\varvec{\Theta }}}\) and their timemarginal distributions as \({\mathbb {Q}}_t = ({{\varvec{\Theta }}}_t)_\#{\mathbb {Q}}\) (which are just pushforward measures). Given two marginal distributions \(\pi _0\) and \(\pi _1\) we write \({\mathcal {D}}(\pi _0, \pi _1) = \{{\mathbb {Q}}: {\mathbb {Q}}_0 = \pi _0, {\mathbb {Q}}_1 = \pi _1 \}\) for the set of all path measures with given marginal distributions at the initial and final times. We denote by \({\mathbb {Q}}^{{\textbf{u}}, \pi }\) the path measure of the following Stochastic Differential Equation (SDE):
(we drop the dependence on \(\gamma \) since it is fixed) and we write \(\mathbb {W}^\gamma = {\mathbb {Q}}^{0, \delta _0}\) for the Wiener measure. We will write \(\frac{d{\mathbb {Q}}}{d{\mathbb {S}}}\) for the RadonNikodym derivative (RND) of \({\mathbb {Q}}\) w.r.t. \({\mathbb {S}}\).
1.2 Schrödinger–Föllmer processes
Definition 1
(SchrödingerBridge Process) Given a reference process \({\mathbb {S}}\) and two measures \(\pi _0\) and \(\pi _1\) the Schrödinger bridge distribution is given by
where \({\mathbb {S}}\) acts as a “prior”.
It is known (Léonard 2013) that if \({\mathbb {S}}= {\mathbb {Q}}^{{\textbf{u}}, \pi }\), \({\mathbb {Q}}^*\) is induced by an SDE with a modified drift:
i.e. \({\mathbb {Q}}^* = {\mathbb {Q}}^{{\textbf{u}}^*, \pi _0}\). Solution of this SDE is called the SchrödingerBridge Process (SBP).
Definition 2
(Schrödinger–Föllmer Process) The SFP is an SBP where \(\pi _0=\delta _0\) and the reference process \({\mathbb {S}}= \mathbb {W}^\gamma \) is the Wiener measure.
The SFP differs from the general SBP in that, rather than constraining the initial distribution to \(\delta _0\), the SBP considers any initial distribution \(\pi _0\). The SBP also involves general Itô SDEs associated with \({\mathbb {Q}}^{{{\varvec{u}}}, \pi }\) as the dynamical prior, compared to the SFP which restricts attention to Wiener processes as priors.
The advantage of considering this more limited version of the SBP is that it admits a closedform characterisation of the solution to the Schrödinger system (Léonard 2013; Wang et al. 2021; Pavon et al. 2018) which allows for an unconstrained formulation of the problem. For accessible introductions to the SBP we suggest (Pavon et al. 2018; Vargas et al. 2021). Now we will consider instances of the SBP and the SFP where \(\pi _1=p({{\varvec{\theta }}}\vert {{\varvec{X}}})\).
1.2.1 Analytic solutions and the heat semigroup
Prior work (Pavon 1989; Dai Pra 1991; Tzen and Raginsky 2019b; Huang et al. 2021) has explored the properties of SFPs via a closed form formulation of the Föllmer drift expressed in terms of expectations over Gaussian random variables known as the heat semigroup. The seminal works (Pavon 1989; Dai Pra 1991; Tzen and Raginsky 2019b) highlight how this formulation of the Föllmer drift characterises an exact sampling scheme for a target distribution and how it could potentially be used in practice. The recent work by Huang et al. (2021) builds on Tzen and Raginsky (2019b) and explores estimating the optimal drift in practice via the heat semigroup formulation using a Monte Carlo approximation. Our work aims to take the next step and scale the estimation of the Föllmer drift to high dimensional cases (Graves 2011; Hoffman et al. 2013). In order to do this we must move away from the heat semigroup and instead consider the dual formulation of the Föllmer drift in terms of a stochastic control problem (Tzen and Raginsky 2019b).
In the setting when \(\pi _0=\delta _0\) we can express the optimal SBP drift as follows
Definition 3
The Euclidean heat semigroup \(Q_t^\gamma , \; t\ge 0\), acts on bounded measurable functions \(f: {\mathbb {R}}^d \rightarrow {\mathbb {R}}\) as \(Q^\gamma _t f({{\varvec{x}}}) = \int _{{\mathbb {R}}^d} f\left( {{\varvec{x}}}+\sqrt{t}{{\varvec{z}}}\right) {\mathcal {N}}({{\varvec{z}}}\vert {{\varvec{0}}}, \gamma {\mathbb {I}}) d{{\varvec{z}}}=\mathbb {E}_{{{\varvec{z}}}\sim {\mathcal {N}}({{\varvec{0}}}, \gamma {\mathbb {I}})}\left[ f\left( {{\varvec{x}}}+\sqrt{t}{{\varvec{z}}}\right) \right] .\)
In the SFP case where \({\mathbb {S}}=\mathbb {W}^\gamma \), the optimal drift from Eq. 5 can be written in terms of the heat semigroup, \({\textbf{u}}^*(t, {{\varvec{x}}}) = \nabla _{{\varvec{x}}}\ln Q^\gamma _{1t} \left[ \frac{d \pi _1}{d{\mathcal {N}}({{\varvec{0}}}, \gamma {\mathbb {I}})}({{\varvec{x}}})\right] \). Note that an SDE with the heat semigroup induced drift
satisfies \( {\textrm{Law}}{{\varvec{\Theta }}}_1 = \pi _1\), that is, at \(t=1\) these processes are distributed according to our target distribution of interest \(\pi _1\).
1.2.2 Schrödinger–Föllmer samplers
Huang et al. (2021) carried out preliminary work on empirically exploring the success of using the heat semigroup formulation of SFPs in combination with the EulerMayurama (EM) discretisation to sample from target distributions in a method they call Schrödinger–Föllmer samplers (SFS). More precisely the SFS approach proposes estimating the Föllmer drift via:
where \({{\varvec{z}}}_s \sim {\mathcal {N}}({{\varvec{0}}}, \gamma {\mathbb {I}})\) and \(f=\frac{d\pi _1}{d{\mathcal {N}}({{\varvec{0}}}, \gamma {\mathbb {I}})}\). Whilst this estimator enjoys sound theoretical properties (Huang et al. 2021) it falls short in practice for the following reasons:

The term f involves the product of PDFs evaluated at samples rather than a log product and is thus often very unstable numerically. In Appendix 1 we provide a more stable implementation of Eq. 7 exploiting the logsumexp trick and properties of the Lebesgue integral.

In it’s current form the estimator does not admit low variance estimators (e.g. Variatonal Inference), being a Monte Carlo estimator it is prone to high variance.

Both empirically and theoretically we found the computational running time of the above approach to be considerably slower than the other methods we compare to. At test time SFS has a computational complexity of \({\mathcal {O}}(T S\#_f(d) )\) where \(T=\Delta t^{1}\), S is the number of Monte Carlo samples and \(\#_f(d)\) is the cost of evaluating the RND f which at best is linear in d. Meanwhile our proposed approach enjoys a cost of \({\mathcal {O}}(T \#_{{{\varvec{u}}}_\phi }(d) )\) where \(\#_{{{\varvec{u}}}_\phi }(d)\) is the forward pass through a neural network approximating the Föllmer drift.
In practice we found this implementation to be too numerically unstable and unable to produce reasonable results even in low dimensional examples in order to carry out a fair comparison we reformulated Eq. 7 stably, the stable formulation and its derivation can be found in Appendix 1.
In this work build on Huang et al. (2021) by considering a formulation of the Schrödinger–Föllmer process that is suitable for the high dimensional settings arising in Bayesian ML. Our work will focus on a dual formulation of the optimal drift that is closer to variational inference and thus admits the scalable and flexible parametrisations used in ML.
2 Stochastic control formulation
In this section, we introduce a particular formulation of the Schrödinger–Föllmer process in the context of the Bayesian inference problem in Eq. 1. In its most general setting of sampling from a target distribution, this formulation was known to Dai Pra (1991). Tzen and Raginsky (2019b) study the theoretical properties of this approach in the context of generative models (Kingma et al. 2021; Goodfellow et al. 2014), finally Opper (2019) applies this formulation to time series modelling. In contrast our focus is on the estimation of a Bayesian posterior for a broader class of models than Tzen and Raginsky explore.
Corollary 1
Define
Then the minimiser (with \({\mathcal {U}}\) being the set of admissible controls^{Footnote 1})
satisfies \({\mathbb {Q}}_1^{\gamma , {\textbf{u}}^*, \delta _0} = \frac{p({{\varvec{X}}}\vert {{\varvec{\theta }}})p({{\varvec{\theta }}})}{{\mathcal {Z}}}d{{\varvec{\theta }}}\).
Moreover, \({\textbf{u}}^{*}\) solves the SFP with \(\pi _1=p({{\varvec{\theta }}}\vert {{\varvec{X}}})\).
The objective in Eq. 8 can be estimated using an SDE discretisation, such as the EM method. Since the drift \({\textbf{u}}^{*}\) is Markov, it can be parametrised by a flexible function estimator such as a neural network, as in Tzen and Raginsky (2019b). In addition, unbiased estimators for the gradient of objective in (8) can be formed by subsampling the data. In this work we will refer to the above formulation of the SFP as the Neural Schrödinger–Föllmer sampler (NSFS) when we parametrise the drift with a neural network and implement unbiased minibatched estimators for this objective (Appendix 1). This formulation of SFPs has been previously studied in the context of generative modelling / marginal likelihood estimation (Tzen and Raginsky 2019b), while we focus on Bayesian inference.
We note that recent concurrent work (Zhang et al. 2022)^{Footnote 2} proposes an algorithm akin to ours based on Dai Pra (1991); Tzen and Raginsky (2019b), however their focus is on estimating the normalising constant of unnormalised densities, while ours is on Bayesian ML tasks such as Bayesian regression, classification and LVMs, thus our work leads to different insights and algorithmic motivations.
2.1 Theoretical guarantees for neural SFS
While the focus in Tzen and Raginsky (2019b) is in providing guarantees for generative models of the form \( {{\varvec{x}}}\sim q_\phi ({{\varvec{x}}}\vert {{\varvec{Z}}}_1)\;,d{{\varvec{Z}}}_t = {{\varvec{u}}}_{\phi }({{\varvec{Z}}}_t,t)dt + \sqrt{\gamma } d{\textbf{B}}_t, \;{{\varvec{Z}}}_0 =\! 0,\) their results extend to our setting as they explore approximating the Föllmer drift for a generic target \(\pi _1\).
Theorem 4 in Tzen and Raginsky (restated as Theorem 2 in Appendix 1) motivates using neural networks to parametrise the drift in Eq. 8 as it provides a guarantee regarding the expressivity of a network parametrised drift via an upper bound on the target distribution error in terms of the size of the network.
We will now proceed to highlight how this error is affected by the EM discretisation:
Corollary 2
Given the network \({{\varvec{v}}}\) from Theorem 2 it follows that the EulerMaruyama discretisation of (2) with \({\textbf{u}}={{\varvec{v}}}\) induces an approximate target \(\hat{\pi }^{{{\varvec{v}}}}_1\) that satisfies
This result provides a bound of the error in terms of the depth \(\Delta t^{1}\) of the stochastic flow (Chen et al. 2022; Zhang et al. 2021) and the size of the network that we parametrise the drift with. Under the view that NN parametrised SDEs can be interpreted as ResNets (Li et al. 2020) we find that this result illustrates that increasing the ResNets’ depth will lead to more accurate results.
2.2 Sticking the landing and low variance estimators
As with VI (Richter et al. 2020; Roeder et al. 2017), the gradient of the objective in this study admits several low variance estimators (Nüsken and Richter 2021; Xu et al. 2021). In this section we formally recap what it means for an estimator to “stick the landing” and we prove that the estimator proposed in Xu et al. satisfies said property.
The full objective being minimised in our approach is (where expectations are taken over \({{\varvec{\Theta }}}\sim {\mathbb {Q}}^{{\textbf{u}}, \delta _0}\)):
noticing that in previous formulations we have omitted the Itô integral as it has zero expectation (but the integral appears naturally through Girsanov’s theorem). We call the estimator calculated by taking gradients of the above objective the relativeentropy estimator. The estimator proposed in Xu et al. (2021) (Sticking the landing estimator) is given by:
where \(\perp \) means that the gradient is stopped/detached as in Xu et al. (2021); Roeder et al. (2017).
We study perturbations of \({\mathcal {F}}\) around \({\textbf{u}}^*\) by considering \({\textbf{u}}^* + \varepsilon {{\varvec{\phi }}}\), with \({{\varvec{\phi }}}\) arbitrary, and \(\varepsilon \) small. More precisely, we set out to compute (where dependence on \({{\varvec{\theta }}}\) is dropped):
through which we define the definition of “sticking the landing”:
Definition 4
We say that an estimator “sticks the landing” when
almost surely, for all smooth and bounded perturbations \({{\varvec{\phi }}}\).
Notice that by construction, \({{\varvec{u}}}^*\) is a global minimiser of J, and hence all directional derivatives vanish,
Definition 4 additionally demands that this quantity is zero almost surely, and not just on average. Consequently, “sticking the landing”estimators will have zerovariance at \({\textbf{u}}^*\).
Remark 1
The relativeentropy stochastic control estimator does not stick the landing.
Proof
See Nüsken and Richter (2021), Theorem 5.3.1, clause 3, Eq. 133 clearly indicates \(\frac{\textrm{d}}{\textrm{d}\varepsilon } \mathcal {F}({\textbf{u}}^* + \varepsilon {{\varvec{\phi }}}) \Big \vert _{\varepsilon = 0}\ne 0\). \(\square \)
We can now go ahead and prove that the estimator proposed by Xu et al. (2021) does indeed stick the landing.
Theorem 1
The STL estimator proposed in (Xu et al. 2021) satisfies
almost surely, for all smooth and bounded perturbations \({{\varvec{\phi }}}\).
The proof for the above result can be found in Appendix 1, and combines results from Nüsken and Richter (2021).
2.3 Structured SVI in models with local and global variables
Algorithm 1 produces unbiased estimates of the gradient^{Footnote 3} as demonstrated in Appendix 1 only under the assumption that the parameters are global, that is when there is not a local parameter for each data point. In the setting where we have local and global variables we can no longer do minibatch updates as in Algorithm 1 since the energy term in the objective does not decouple as a sum over the datapoints (Hoffman et al. 2013; Hoffman and Blei 2015). In this section we discuss said limitation and propose a reasonable heuristic to overcome it.
We consider the general setting where our model has global and local variables \({\varvec{\Phi }}, \{{{\varvec{\theta }}}_i\}\) satisfying (Hoffman et al. 2013). This case is particularly challenging as the local variables scale with the size of the dataset and so will the state space. This is a fundamental setting as many hierachical latent variable models in machine learning admit such dependancy structure, such as Topic models (Pritchard and M., S., and P., D. 2000; Blei et al. 2003); Bayesian factor analysis (Amari et al. 1996; Bishop 1999; Klami et al. 2013; Daxberger et al. 2019); Variational GP Regression (Hensman et al. 2013); and others.
Remark 2
The heat semigroup does not preserve conditional independence structure in the drift, i.e. the optimal drift does not decouple and thus depends on the full statespace (Appendix 1).
Remark 2 tells us that the drift is not structured in a way that admits scalable sampling approaches such as stochastic variational inference (SVI) (Hoffman et al. 2013). Additionally this also highlights that the method by Huang et al. (2021) does not scale to models like this as the dimension of the state space will be linear in the size of the dataset.
In a similar fashion to Hoffman and Blei (2015), who focussed on structured SVI, we suggest parametrising the drift via \([{\textbf{u}}_{t}]_{{{\varvec{\theta }}}_i} \!\!=\!u^{{{\varvec{\theta }}}_i}(t, {{\varvec{\theta }}}_i,\Phi ,{{\varvec{x}}}_i) \); this way the dimension of the drift depends only on the respective local variables and the global variable \(\Phi \). While the Föllmer drift does not admit this particular decoupling we can show that this drift is flexible enough to represent fairly general distributions, thus it is expected to have the capacity to reach the target distribution. Via this parametrisation we can sample in the same fashion as SVI and maintain unbiased gradient estimates.
Remark 3
An SDE parametrised with a decoupled drift \([{\textbf{u}}_{t}]_{{{\varvec{\theta }}}_i} =u^{{{\varvec{\theta }}}_i}(t, {{\varvec{\theta }}}_i,\Phi ,{{\varvec{x}}}_i)\) can reach transition densities which do not factor (See Appendix 1 for proof).
It is important to highlight that whilst the parametrisation in Remark 3 may be flexible, it may not satisfy the previous theory developed for the Föllmer drift and SBPs, thus an interesting direction would be in recasting the SBP such that the optimal drift is decoupled. However, we found in practice that the decoupled and amortised drift worked very well, outperforming SGLD and the nondecoupled NSFS.
3 Connections between SBPs and variational inference in latent diffusion models
In this section, we highlight the connection between the objective in Eq. 8 to variational inference in models with an SDE as the latent object, as studied in Tzen and Raginsky (2019a). We first start by making the connection in a simpler case – when the prior of our Bayesian model is given by a Gaussian distribution with variance \(\gamma \), that is \(p({{\varvec{\theta }}})={\mathcal {N}}({{\varvec{\theta }}}\vert {\varvec{0}}, \gamma {\mathbb {I}_d})\).
Observation 1
When \(p({{\varvec{\theta }}})={\mathcal {N}}({{\varvec{\theta }}}\vert {\varvec{0}}, \gamma {\mathbb {I}_d})\), it follows that the NSFP objective in Eq. 8 corresponds to the negative ELBO of the model:
While the above observation highlights a specific connection between NSFP and traditional VBI (Variational Bayesian Inference), it is limited to Bayesian models that are specified with Gaussian priors. In Lemma 1 of Appendix 1 we extend this result to more general priors and reference process via exploiting the general recursive nature of Bayesian updates (Khan and Rue 2021). In short, we can view the objective in Eq. 8 as an instance of variational Bayesian inference with an SDE prior. Note that this provides a succinct connection between variational inference and maximum entropy in path space (Léonard 2012). In more detail, this observation establishes an explicit connection between the ELBO of an SDEbased generative model where the SDE is latent and the SBP/stochasticcontrol objectives we explore in this work.
Note that Lemma 1 induces a new two stage algorithm in which we first estimate a prior reference process as in Eq. B10 and then we optimise the ELBO for the model in Eq. B11. This raises the question as to what effect the dynamical prior can have within SBPbased frameworks. In practice we do not explore this formulation as the Föllmer drift of the prior may not be available in closed form and thus may require resorting to additional approximations.
4 Experimental results
We ran experiments on Bayesian NN regression, classification, logistic regression and ICA (Amari et al. 1996), reporting accuracies, log joints (Welling and Teh 2011; Izmailov et al. 2021) and expected calibration error (ECE) (Guo et al. 2017). For details on exact experimental setups please see Appendix 1. Across experiments we compare to SGLD as it has been shown to be a competitive baseline in Bayesian deep learning (Izmailov et al. 2021). Notice that we do not compare to more standard MCMC methodologies (Duane et al. 1987; Neal 2011; Doucet et al. 2001) as they do not scale well to very high dimensional tasks such as Bayesian DL (Izmailov et al. 2021) which are central to our experiments. However, Huang et al. (2021) contrasts the performance of the heat semigroup SFS sampler with more traditional MCMC samplers in 2D toy examples, finding SFS to be competitive.^{Footnote 4}
4.1 Bayesian linear regression and comparison with MCSFS
In this section we explore a bayesian linear regression model with a prior on the regression weights. As this model has a Gaussian closed form for the posterior predictive distribution we report the error of the MCSFS and NSFS posterior predictive mean and variance with respect to the true posterior predictive moments as is seen in Fig. 2. The datasets where generated by sampling the inputs randomly from a spherical Gaussian distribution and transforming them via:
we then estimated the posterior of the model:
Where we use \({{\varvec{x}}}\oplus 1\) to denote adding an extra dimension with a 1 to the vector \({{\varvec{x}}}\). We carried out this experiment increasing the dimension of \({{\varvec{x}}}\) from \(2^5\) to \(2^{11}\). We observe that the NSFS based approaches have overall a notably smaller posterior predictive error to the MCSFS approach. Finally we note the STL method is more concentrated in its predictions than the naive NSFS approach, whilst having similar errors.
4.2 Bayesian logistic regression/independent component analysis—a9a/MEG datasets
Following Welling and Teh (2011) we explore a logistic regression model on the a9a dataset. Results can be found in Table 1 which show that NSFS achieves a test accuracy, ECE and log likelihood comparable to SGLD. We then explore the performance of our approach on the Bayesian variant of ICA studied in Welling and Teh (2011) on the MEGDataset (Vigario 1997). We can observe (Table 3) that here NSFS also achieves results comparable to SGLD.
4.3 Bayesian deep learning
In these tasks we use models of the form
where \(f_{{\varvec{\theta }}}\) is a neural network. In these settings we are interested in using the posterior predictive distribution \(p({{\varvec{y}}}^* \vert {{\varvec{x}}}^*, {{\varvec{X}}}) \!=\!\int p({{\varvec{y}}}^*\vert f_{{\varvec{\theta }}}({{\varvec{x}}}^*)) dP({{\varvec{\theta }}}\vert {{\varvec{X}}}) \) to make robust predictions. Across the image experiments we use the LeNet5 (LeCun et al. 1998) architecture. Future works should explore recent architectures for images such as VGG16 (Simonyan and Zisserman 2014) and ResNet32 (He et al. 2016).
Nonlinear regression—step function We fit a 2hiddenlayer neural network with a total of 14876 parameters on a toy step function dataset. We can see in Fig. 3 how both the SGD and SGLD fits interpolate the noise, whilst NSFS has straight lines, thus both achieving a better test error and having wellcalibrated error bars. We believe it is a great milestone to see how an overparameterised neural network is able to achieve such well calibrated predictions.
Digits classification—LeNet5 We train the standard LeNet5 (LeCun et al. 1998) architecture (with 44426 parameters) on the MNIST dataset (LeCun and Cortes 2010). At test time we evaluate the methods on the MNIST test set augmented by random rotations of up to 30\(^\circ \) (Ferianc et al. 2021). Table 4 shows how NSFS has the highest accuracy whilst obtaining the lowest calibration error among the considered methods, highlighting that our approach has the most wellcalibrated and accurate predictions when considering a slightly perturbed test set. We highlight that LeNet5 falls into an interesting regime as the number of parameters is considerably less than the size of the training set, and thus we can argue it is not in the overparameterised regime. This regime (Belkin et al. 2019) has been shown to be challenging in achieving good generalisation errors, thus we believe the predictive and calibrated accuracy achieved by NSFS is a strong milestone.
Additionally we provide results on the regular MNIST test set. We can observe that NSFS maintains a high test accuracy and at the same time preserves a low ECE score. We believe the reason SGD and SGLD obtain slightly better ECE performances is that the MNIST test set has very little variation to the MNIST training set, and thus all results seem well calibrated. We can see this observation confirmed by how the distribution of ECE scores changes dramatically on the Rotated MNIST set, a similar argument to that developed in Ferianc et al. (2021). We note that across both experiments SGLD achieves a slightly better log likelihood which comes at the cost of lower predictive performance and less calibrated predictions.
Image classification—CIFAR10 We fit a variation of the LeNet5 (Appendix 1) architecture with 62006 parameters on the CIFAR10 dataset (Krizhevsky et al. 2009). We note that the predictive test accuracies and loglikelihoods of NSFS\(_{\textrm{stl}}\), SGLD and SGD are comparable. However, we can see that NSFS\(_{\textrm{stl}}\) has an ECE an order of magnitude smaller. We notice that the STL estimator made a significant difference on CIFAR10, making the training faster and more stable.
4.4 Hyperspectral image unmixing
To assess our method’s performance visually, we use it to sample from Hyperspectral Unmixing Models (BioucasDias et al. 2012). Hyperspectral images are high spectral resolution but low spatial resolution images typically taken of vast areas via satellites. High spectral resolution provides much more information about the materials present in each pixel. However, due to the low spatial resolution, each pixel of an image can correspond to a \(50m^2\) area, containing several materials. Such pixels will therefore have mixed and uninformative spectra. The task of Hyperspectral Unmixing is to determine the presence of given materials in each pixel.
We use the Indian Pines image,^{Footnote 5} denoted as \({\varvec{Y}}\), which has a spatial resolution of \(P = 145 \times 145 = 21025\) pixels and a spectral resolution of \(B = 200\) bands, i.e. \({\varvec{Y}}= [{\varvec{y}}_1, \dots , {\varvec{y}}_P] \in [0, 1]^{B \times P}\). \(R = 3\) materials have been chosen automatically using the Pixel Purity Index and the collection of their spectra will be denoted as \({\varvec{M}}= [{\varvec{m}}_1, {\varvec{m}}_2, {\varvec{m}}_3] \in [0, 1]^{B \times 3}\). The task of Hyperspectral Unmixing is to determine for each pixel p a vector \({\varvec{a}}_p \in \Delta _{R}\) in the probability simplex, where \([{{\varvec{A}}}]_{p,i}=a_{p, i}\) represents the fraction of the ith material in pixel p. To determine the presence of each material, we use the Normal Compositional Model (Eches et al. 2010) as it is a challenging model to sample from. Specifically, it has parameters \(({{\varvec{\Phi }}}, {{\varvec{\Theta }}}) = (\sigma ^2, {\varvec{A}})\) and is defined by:
First note that this model follows the structured model setting discussed in Sect. 2.2—it has one global parameter \(\sigma ^2\) and a local parameter \({\varvec{a}}_p\) for each pixel. Finally, while all the parameters are constrained to lie on the probability simplices, this sampling problem can be cast into an unconstrained sampling problem via Lagrange transformations as in Hsieh et al. (2018). The Normal Compositional Model Eches et al. (2010) is primarily of interest to us because the unusual noise scaling in the likelihood can produce several modes in each pixel, making it especially easy for sampling algorithms to get stuck in modes.
We compared three approaches for this problem: 1) SGLD 2) NSFS 3) NSFS with decoupled drift, where the decoupled drift is defined as:
Unmixing results are shown in Fig. 4. We stress that to run SGLD successfully we had to tune the approach heavily — we used separate step sizes (which acts as a preconditioning) and step size schedules for parameters \(\sigma ^2\) and \({\varvec{A}}\), only with one combination of which we managed to get decent unmixing results. Without the amortised drift, NSFS struggled with multiple modes in certain patches of the image, however, decoupling the drift resulted in almost perfect unmixing. With a slight deviation from the optimal step size schedule, SGLD fails to explore modes and produces speckly images. In contrast, the only tunable parameter for NSFS was \(\gamma \), which was giving similar results for all tried values. Further sensitivity results for SGLD/NSFS are provided in Appendix 1.
4.5 Analysis of NSFS training dynamics
In addition to the experiments above, we investigate our method’s performance in a synthetic multimodal scenario. Here, NSFS is used to fit a Gaussian Mixture posterior distribution that has modes aligned on the xaxis, as shown in figure 5. In one case, there are 4 modes – 2 inner modes (those closer to 0) and 2 outer modes (those further away from 0). We notice that in the presence of the 2 inner modes NSFS is unable to discover the outer modes. In contrast, when considering a posterior with only the 2 outer modes, the distribution is fit correctly. This phenomenon could be explained by previously indicated connections between stochastic control and agentbased learning via the Hamilton–Jacobi–Bellman equation Powell et al. (2019) and the explorationexploitation tradeoff. More concretely, the optimisation objective (8) implies the following training dynamics – random samples are generated from a diffusion (a Brownian motion to begin with) which is then refined to produce more samples in areas where previous samples had high posterior density. This implies that after some modes are discovered, the diffusion will be adjusted to fit them, i.e. the algorithm immediately starts exploiting the detected modes. Other modes will only be discovered if some random sample accidentally hits them, which is very unlikely if the modes are far away. This indicates that the algorithm could be improved by incorporating exploration techniques found in agentbased learning literature.
Given the behaviour of NSFS on this multimodal example, it is then natural to ask if it happens in Bayesian Deep Learning applications. To examine this, we look at the marginal distributions of a pair of weights of a Bayesian Neural Network for MNIST classification given by the samples of NSFS and SGLD given in Fig. 6. Note that compared to SGLD, NSFS samples from a dramatically wider distribution, while maintaining a comparable predictive log likelihood score, and therefore does not suffer from the lack of exploration.
5 Discussion and future directions
Overall we achieve predictive performance competitive to SGLD across a variety of tasks whilst obtaining better calibrated predictions as measured by the ECE metric. We hypothesise that the gain in performance is due to the flexible and low variance VI parametrisation of the proposed approach. We would like to highlight that these results were achieved with minimal tuning and simple NN architectures. We find that the decoupled and amortised drift we propose achieves very strong results making our approach tractable to Bayesian models with local and global structure. Additionally we notice that the architecture used in the drift network can influence results, thus future work in this area should develop the drift architectures further.
A key advantage of our approach is that at training time the objective effectively minimises an ELBO styled objective parameterised via a ResNet. This allows us to monitor training using the traditional techniques from deep learning, without the challenges arising from mixing times and correlation of samples found in traditional MCMC methods; once NSFS is trained, generating samples at test time is a fast forward pass through a ResNet that does not require retraining. Finally, as we demonstrated, our approach allows the learned sampler to be amortised (Zhang et al. 2018) which not only allows the drift to be more tractably parameterised but also creates the prospects of meta learning the posterior (Edwards and Storkey 2016; Yoon et al. 2018; Gordon et al. 2018; Gordon 2018). We believe that this work motivates how stochastic control paves a new exciting and promising direction in Bayesian ML/DL.
Notes
This work was made public on arxiv within a month of our arxiv preprint release.
Supporting code at https://anonymous.4open.science/r/ControlledFollmerDrift23F6/README.md.
References
Amari, S.i., Cichocki, A., Yang, H. H., et al.: A new learning algorithm for blind signal separation. In: Advances in Neural Information Processing Systems, pp. 757–763. Morgan Kaufmann Publishers (1996)
BartholomewBiggs, M., Brown, S., Christianson, B., Dixon, L.: Automatic differentiation of algorithms. J. Comput. Appl. Math. 124(1–2), 171–190 (2000)
Belkin, M., Hsu, D., Ma, S., Mandal, S.: Reconciling modern machinelearning practice and the classical biasvariance tradeoff. Proc. Natl. Acad. Sci. 116(32), 15849–15854 (2019)
Bernton, E., Heng, J., Doucet, A., Jacob, P.E.: Schrödinger bridge samplers (2019). arXiv preprint
BioucasDias, J.M., Plaza, A., Dobigeon, N., Parente, M., Du, Q., Gader, P., Chanussot, J.: Hyperspectral unmixing overview: geometrical, statistical, and sparse regressionbased approaches. IEEE J. Select. Topics Appl. Earth Obs. Remote Sens. 5(2), 354–379 (2012)
Bishop, C.M.: Bayesian PCA. Adv. Neural Inf. Process. Syst. pp. 382–388 (1999)
Blei, D.M., Ng, A.Y., Jordan, M.I.: Latent Dirichlet allocation. J. Mach. Learn. Res. 3, 993–1022 (2003)
Boué, M., Dupuis, P.: A variational representation for certain functionals of Brownian motion. Ann. Probab. 26(4), 1641–1659 (1998)
Brooks, S., Gelman, A., Jones, G., Meng, X.L.: Handbook of Markov chain Monte Carlo. CRC Press, Boca Raton (2011)
Chen, T., Liu, G.H., Theodorou, E.: Likelihood training of schrödinger bridge using forwardbackward SDEs theory. In: International Conference on Learning Representations (2022)
Chen, T., Liu, G.H., Theodorou, E.A.: Likelihood training of Schrödinger bridge using forwardbackward SDEs theory (2021). arXiv preprint arXiv:2110.11291
Cuturi, M.: Sinkhorn distances: lightspeed computation of optimal transport. Adv. Neural Inf. Process. Syst. (2013)
Dai Pra, P.: A stochastic control approach to reciprocal diffusion processes. Appl. Math. Optim. 23(1), 313–329 (1991)
Daxberger, E., HernándezLobato, J.M.: Bayesian variational autoencoders for unsupervised outofdistribution detection (2019). arXiv preprint arXiv:1912.05651
De Bortoli, V., Thornton, J., Heng, J., and Doucet, A.: Diffusion Schrödinger bridge with applications to scorebased generative modeling (2021). arXiv preprint arXiv:2106.01357
Diethe, T.: 13 Benchmark datasets derived from the UCI, DELVE and STATLOG repositories (2015). https://github.com/tdiethe/gunnar_raetsch_benchmark_datasets/
Doucet, A., De Freitas, N., Gordon, N.J., et al.: Sequential Monte Carlo methods in practice, vol. 1. Springer, Cham (2001)
Duane, S., Kennedy, A.D., Pendleton, B.J., Roweth, D.: Hybrid Monte Carlo. Phys. Lett. B 195(2), 216–222 (1987)
Eches, O., Dobigeon, N., Mailhes, C., Tourneret, J.Y.: Bayesian estimation of linear mixtures using the normal compositional model. Application to hyperspectral imagery. IEEE Trans. Image Process. 19(6), 1403–1413 (2010)
Edwards, H. and Storkey, A.: Towards a neural statistician (2016). arXiv preprint arXiv:1606.02185
Ferianc, M., Maji, P., Mattina, M., and Rodrigues, M.: On the effects of quantisation on model uncertainty in Bayesian neural networks(2021). arXiv preprint arXiv:2102.11062
Giles, M.: An extended collection of matrix derivative results for forward and reverse mode automatic differentiation (2008)
Girolami, M., Calderhead, B.: Riemann manifold Langevin and Hamiltonian Monte Carlo methods. J. Royal Stat. Soc. Series B (Stat. Methodol.) 73(2), 123–214 (2011)
Goodfellow, I., PougetAbadie, J., Mirza, M., Xu, B., WardeFarley, D., Ozair, S., Courville, A., and Bengio, Y.: Generative adversarial nets. In: Adv. Neural Inf. Process. Syst. pp. 2672–2680 (2014)
Gordon, J.: Advances in Probabilistic MetaLearning and the Neural Process Family. PhD thesis, University of Cambridge (2018)
Gordon, J., Bronskill, J., Bauer, M., Nowozin, S., Turner, R.E.: Metalearning probabilistic inference for prediction (2018). arXiv preprint arXiv:1805.09921
Graves, A.: Practical variational inference for neural networks. Adv. Neural Inf. Process. Syst. 24 (2011)
Grenander, U., Miller, M.I.: Representations of knowledge in complex systems. J. Roy. Stat. Soc. Ser. B (Methodol.) 56(4), 549–581 (1994)
Guo, C., Pleiss, G., Sun, Y., Weinberger, K.Q.: On calibration of modern neural networks. In: International Conference on Machine Learning, pp. 1321–1330. PMLR (2017)
Gyöngy, I., Krylov, N.: Existence of strong solutions for Itô’s stochastic equations via approximations. Probab. Theory Relat. Fields 105(2), 143–158 (1996)
Hartmann, C., Richter, L., Schütte, C., Zhang, W.: Variational characterization of free energy: theory and algorithms. Entropy 19(11), 626 (2017)
Hartmann, C., Schütte, C.: Efficient rare event simulation by optimal nonequilibrium forcing. J. Stat. Mech Theory Exp. 2012(11), P11004 (2012)
Hastings, W.K.: Monte Carlo sampling methods using Markov chains and their applications (1970)
He, K., Zhang, X., Ren, S., and Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016)
Hensman, J., Fusi, N., Lawrence, N.D.: Gaussian processes for big data (2013). arXiv preprint arXiv:1309.6835
Hoffman, M.D., Blei, D.M.: Structured stochastic variational inference. In: Artificial Intelligence and Statistics, pp. 361–369 (2015)
Hoffman, M.D., Blei, D.M., Wang, C., Paisley, J.: Stochastic variational inference. J. Mach. Learn. Res. 14(5) (2013)
Hsieh, Y.P., Kavis, A., Rolland, P., Cevher, V.: Mirrored langevin dynamics. In: Bengio, S., Wallach, H., Larochelle, H., Grauman, K., CesaBianchi, N., Garnett, R. (eds.) Advances in Neural Information Processing Systems, vol. 31. Curran Associates Inc (2018)
Huang, J., Jiao, Y., Kang, L., Liao, X., Liu, J., Liu, Y.: SchrödingerFöllmer sampler: sampling without ergodicity. arXiv preprint arXiv:2106.10880 (2021)
Izmailov, P., Vikram, S., Hoffman, M.D., Wilson, A.G.: What are Bayesian neural network posteriors really like? (2021). arXiv preprint arXiv:2104.14421
Kappen, H.J.: Linear theory for control of nonlinear stochastic systems. Phys. Rev. Lett. 95(20), 200201 (2005)
Khan, M.E. Rue, H.: The Bayesian learning rule (2021). arXiv preprint arXiv:2107.04562
Kingma, D.P., Salimans, T., Poole, B., Ho, J.: Variational diffusion models (2021). arXiv preprint arXiv:2107.00630
Kingma, D.P., Welling, M.: Autoencoding variational Bayes (2013). arXiv preprint arXiv:1312.6114
Klami, A., Virtanen, S., Kaski, S.: Bayesian canonical correlation analysis. J. Mach. Learn. Res. 14(4) (2013)
Krizhevsky, A., Hinton, G., et al.: Learning multiple layers of features from tiny images (2009)
LeCun, Y., Bottou, L., Bengio, Y., Haffner, P.: Gradientbased learning applied to document recognition. Proc. IEEE 86(11), 2278–2324 (1998)
LeCun, Y., Cortes, C.: MNIST handwritten digit database (2010)
Léonard, C.: From the Schrödinger problem to the Monge–Kantorovich problem. J. Funct. Anal. 262(4), 1879–1920 (2012)
Léonard, C.: A survey of the Schrödinger problem and some of its connections with optimal transport (2013). arXiv preprint arXiv:1308.0215
Li, X., Wong, T.K.L., Chen, R. T.Q., Duvenaud, D.K.: Scalable gradients and variational inference for stochastic differential equations. In: Symposium on Advances in Approximate Bayesian Inference, pp. 1–28. PMLR (2020)
Maoutsa, D., Opper, M.: Deterministic particle flows for constraining SDEs (2021). arXiv preprint arXiv:2110.13020
Neal, R.M., et al.: MCMC using Hamiltonian dynamics. Handbook of Markov chain Monte Carlo 2(11), 2 (2011)
Nüsken, N., Richter, L.: Solving highdimensional Hamilton–Jacobi–Bellman PDEs using neural networks: perspectives from the theory of controlled diffusions and measures on path space. Partial Differ. Equ. Appl. 2(4), 1–48 (2021)
Opper, M.: Variational inference for stochastic differential equations. Ann. Phys. 531(3), 1800233 (2019)
Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., et al.: Pytorch: an imperative style, highperformance deep learning library. Adv. Neural Inf. Process. Syst. 32 (2019)
Pavon, M.: Stochastic control and nonequilibrium thermodynamical systems. Appl. Math. Optim. 19(1), 187–202 (1989)
Pavon, M., Tabak, E.G., Trigila, G.: The datadriven Schrödinger bridge. arXiv preprint (2018)
Powell, W.B.: From reinforcement learning to optimal control: A unified framework for sequential decisions. (2019). CoRR arXiv: abs/1912.03513
Pritchard, J., Stephen, M., Donnelly, P.: Inference of population structure using multilocus genotype data. Genetics 155(2), 945–959 (2000)
Reich, S.: Data assimilation: the Schrödinger perspective. Acta Numer. 28, 635–711 (2019)
Richter, L., Boustati, A., Nüsken, N., Ruiz, F. J., and Akyildiz, Ö. D.: Vargrad: a lowvariance gradient estimator for variational inference (2020). arXiv preprint arXiv:2010.10436
Roberts, G.O., Tweedie, R.L.: Exponential convergence of Langevin distributions and their discrete approximations. Bernoulli, pp. 341–363 (1996)
Roeder, G., Wu, Y., Duvenaud, D.: Sticking the landing: simple, lowervariance gradient estimators for variational inference (2017). arXiv preprint arXiv:1703.09194
Schrödinger, E.: Sur la théorie relativiste de l’électron et l’interprétation de la mécanique quantique. Annales de l’institut Henri Poincaré 2, 269–310 (1932)
Simonyan, K., Zisserman, A.: Very deep convolutional networks for largescale image recognition (2014). arXiv preprint arXiv:1409.1556
Thijssen, S., Kappen, H.: Path integral control and statedependent feedback. Phys. Rev. E 91(3), 032104 (2015)
Tzen, B., Raginsky, M.: Neural stochastic differential equations: Deep latent Gaussian models in the diffusion limit (2019a). arXiv preprint arXiv:1905.09883
Tzen, B., Raginsky, M.: Theoretical guarantees for sampling and inference in generative models with latent diffusions. In: Conference on Learning Theory, pp. 3084–3114. PMLR (2019b)
Vargas, F., Thodoroff, P., Lamacraft, A., Lawrence, N.: Solving Schrödinger bridges via maximum likelihood. Entropy 23(9), 1134 (2021)
Vigario, R.: Meg data for studies using independent component analysis. (1997) http://www.cis.hut.fi/projects/ica/eegmeg/MEG_data.html
Wang, G., Jiao, Y., Xu, Q., Wang, Y., Yang, C.: Deep generative learning via Schrödinger bridge. (2021). arXiv preprint arXiv:2106.10410
Welling, M., Teh, Y.W.: Bayesian learning via stochastic gradient Langevin dynamics. In Proceedings of the 28th International Conference on Machine Learning (ICML11), pp. 681–688. (2011) Citeseer
Xu, W., Chen, R. T.Q., Li, X., Duvenaud, D.: Infinitely deep Bayesian neural networks with stochastic differential equations. (2021) arXiv preprint arXiv:2102.06559
Yoon, J., Kim, T., Dia, O., Kim, S., Bengio, Y., Ahn, S.: Bayesian modelagnostic metalearning. In: Proceedings of the 32nd International Conference on Neural Information Processing Systems, pp. 7343–7353 (2018)
Zhang, C., Bütepage, J., Kjellström, H., Mandt, S.: Advances in variational inference. IEEE Trans. Pattern Anal. Mach. Intell. 41(8), 2008–2026 (2018)
Zhang, Q., Chen, Y.: Diffusion normalizing flow. arXiv preprint arXiv:2110.07579 (2021)
Zhang, Q., Chen, Y.: Path integral sampler: a stochastic control approach for sampling. In: International Conference on Learning Representations (2022)
Acknowledgements
Francisco Vargas is Funded by Huawei Technologies Co. This research has been partially funded by Deutsche Forschungsgemeinschaft (DFG) through the grant CRC 1114 ‘Scaling Cascades in Complex Systems’ (project A02, project number 235221301). Andrius Ovsianas is funded by EPSRC iCASE Award EP/T517677/1. Mark Girolami is supported by a Royal Academy of Engineering Research Chair, and EPSRC grants EP/T000414/1, EP/R018413/2, EP/P020720/2, EP/R034710/1, EP/R004889/1.
Author information
Authors and Affiliations
Corresponding author
Additional information
Publisher's Note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Appendices
Appendix A Main results
1.1 A.1 Posterior drift
Corollary
1 The minimiser
satisfies \({\textrm{Law}}{{\varvec{\Theta }}}_1^{{\textbf{u}}^{*}} = \frac{p({{\varvec{X}}}\vert {{\varvec{\theta }}})p({{\varvec{\theta }}})}{{\mathcal {Z}}}\).
Proof
This follows directly after substituting the RadonNikodym derivative between the Gaussian distribution and the posterior into Theorem 1 in Tzen and Raginsky (2019b) or Theorem 3.1 in Dai Pra (1991). \(\square \)
1.2 A.2 EMDiscretisation result
First we would like to introduce the following auxiliary theorem from Tzen and Raginsky (2019b):
Theorem 2
(Tzen and Raginsky 2019b) Given the standard regularity assumptions presented for \(f=\frac{d\pi _1}{d{\mathcal {N}}({\varvec{0}}, \gamma {\mathbb {I}})}\) in Tzen and Raginsky (2019b), let \(L= \max \{\textrm{Lip}(f), \textrm{Lip}(\nabla f)\}\) and assume that there exists a constant \(c \in (0,1]\) such that \(f \ge c\). Then for any \(\epsilon \in \left( 0, 16 \frac{L^2}{c^2}\right) \) there exists a neural net \({{\varvec{v}}}: {\mathbb {R}}^d \times [0, 1] \rightarrow {\mathbb {R}}^d\) with size polynomial in \(1/\epsilon , d, L, c, 1/c, \gamma \), such that the activation function of each neuron follows the regularity assumptions in Tzen and Raginsky (2019b) (e.g. \(\textrm{ReLU}, \textrm{Sigmoid}, \textrm{Softplus}\)) and
where \(\pi ^{{{\varvec{v}}}}_1={\textrm{Law}}({{\varvec{\Theta }}}_1^{{\varvec{v}}})\) is the terminal distribution of the diffusion process
We can now proceed to prove the direct corollary of the above theorem when using the EM scheme for simulation.
Corollary
2 Given the network \({{\varvec{v}}}\) from Theorem 2 it follows that the EulerMayurama discretisation \(\hat{X}_t^{{\varvec{v}}}\) of \(X_t^{{\varvec{v}}}\) has a KLdivergence to the target distribution \(\pi _1\) of:
Proof
Consider the pathwise KLdivergence between the exact Schrödinger–Föllmer process and its EMdiscretised neural approximation:
Defining , it is clear that \(d({{\varvec{x}}}, {{\varvec{y}}})\) satisfies the triangle inequality as it is the \({\mathcal {L}}^2({{\mathbb {Q}}^{\gamma , {\textbf{u}}^*, \delta _0}})\) metric between drifts, thus applying the triangle inequality at the drift level we have that (for simplicitly letting \(\gamma =1\)):
From Tzen and Raginsky (2019b) we can bound the first term resulting in:
Now remembering that the EM drift is given by \(\hat{{{\varvec{v}}}}_{\sqrt{1t}}({{\varvec{\Theta }}}_t) = {{\varvec{v}}}(\hat{{{\varvec{\Theta }}}}_t, \sqrt{1\Delta t \lceil t/\Delta t\rceil })\), we can use that \({{\varvec{v}}}\) is L’Lipschitz in both arguments, thus:
which, using the strong convergence of the EM approximation (Gyöngy and Krylov 1996), implies:
thus:
Squaring both sides and applying the data processing inequality completes the proof. \(\square \)
Appendix B Connections to VI
We first start by making the connection in a simpler case – when the prior of our Bayesian model is given by a Gaussian distribution with variance \(\gamma \), that is \(p({{\varvec{\theta }}})={\mathcal {N}}({{\varvec{\theta }}}\vert {\varvec{0}}, \gamma {\mathbb {I}_d})\).
Observation
1 When \(p({{\varvec{\theta }}})={\mathcal {N}}({{\varvec{\theta }}}\vert {\varvec{0}}, \gamma {\mathbb {I}_d})\), it follows that the NSFP objective in Eq. 8 corresponds to the negative ELBO of the model:
Proof
Substituting \(p({{\varvec{\theta }}})\) into Eq. 8 yields
Then, from (Boué and Dupuis 1998; Tzen and Raginsky 2019a; Tzen and Raginsky 2019b) we know that the term \(\mathbb {E}\left[ \int _0^1\left\ {{\varvec{u}}}_t\right\ ^2 dt  \ln p({{\varvec{X}}}\vert {{\varvec{\Theta }}}_1)\right] \) is the negative ELBO of the model specified in Eq. B7. \(\square \)
While the above observation highlights a specific connection between NSFP and traditional VBI (Variational Bayesian Inference), it is limited to Bayesian models that are specified with Gaussian priors. To extend the result, we take inspiration from the recursive nature of Bayesian updates in the following result.
Lemma 1
The SBP \(\;\inf _{{\mathbb {Q}}\in {\mathcal {D}}\left( \delta _0,\; p({{\varvec{\theta }}}\vert {{\varvec{X}}})\right) } D_{\textrm{KL}}\left( {\mathbb {Q}}\big \vert \big \vert {\mathbb {S}}\right) \) with reference process \({\mathbb {S}}\) described by
corresponds to maximising the ELBO of the model:
Proof
For brevity let \({\textbf{u}}^0(t, {{\varvec{\theta }}})=\nabla \ln Q^\gamma _{1t}\left[ \frac{p({{\varvec{\theta }}})}{{\mathcal {N}}({{\varvec{\theta }}}\vert {\varvec{0}}, \gamma {\mathbb {I}_d})}\right] \). First notice that the timeone marginals of \({\mathbb {S}}\) are given by the Bayesian prior:
Now from Léonard (2012); Pavon et al. (2018) we know that the Schrödinger system is given by:
where Eq. B12 can be given a rigorous meaning in weak form (that is, by integrating against suitable test functions). Notice \( \phi _0=\delta _0\) and thus it follows that
By Pavon (1989); Dai Pra (1991); Pavon et al. (2018) the optimal drift is given by:
where the expectation is taken with respect to the reference process \({\mathbb {S}}\). Now if we let \(v({{\varvec{\theta }}},t) =\ln \mathbb {E}[p({{\varvec{X}}}\vert {{\varvec{\Theta }}}_1) \vert {{\varvec{\Theta }}}_t = {{\varvec{\theta }}}]\) be our value function then via the linearisation of the Hamilton–Bellman–Jacobi Equation through Fleming’s logarithmic transform (Kappen 2005; Thijssen and Kappen 2015; Tzen and Raginsky 2019b) it follows that said value function satisfies:
and thus \({\textbf{u}}^*(t, {{\varvec{\theta }}}) = \gamma \nabla \ln \mathbb {E}[p({{\varvec{X}}}\vert {{\varvec{\Theta }}}_1) \vert {{\varvec{\Theta }}}_t = {{\varvec{\theta }}}]\) is a minimiser to:
\(\square \)
Appendix C Stochastic variational inference
For a Bayesian model having the structure specified by (1) the objective in (8) can be written as follows:
where the last term can be written as:
That is, it is possible to obtain an unbiased estimate of the objective (and its gradients) by subsampling the data with random batches of size B and using the scaling \(\frac{N}{B}\). A version of the algorithm with EulerMaruyama discretization of the SDE is given in Algorithm 1.
Appendix D Decoupled drift results
First let us consider the setting where the local variables are fully independent, that is, .
Remark 4
The heat semigroup preserves fully factored (meanfield) distributions thus the Föllmer drift is decoupled.
In this setting we can parametrise the dimensions of the drift which correspond to local variables in a decoupled manner, \([{\textbf{u}}_{t}]_{{{\varvec{\theta }}}_i} =u^{{{\varvec{\theta }}}_i}(t, {{\varvec{\theta }}}_i, {{\varvec{x}}}_i)\). This amortised parametrisation (Kingma and Welling 2013) allows us to carry out gradient estimates using a minibatch (Hoffman et al. 2013) rather than hold the whole state space in memory.
Remark
2 The heat semigroup does not preserve conditional independence structure in the drift. That is, the optimal drift does not decouple and as a result depends on the full state space.
Proof
Consider the following distribution:
We want to estimate:
where \(X,Y,Z \sim {\mathcal {N}}(0, \sqrt{1t})\). From
we can easily see that the above no longer has conditional independence structure and thus when taking its logarithmic derivative the drift does not decouple. \(\square \)
Remark
3 An SDE parametrised with a decoupled drift \([{\textbf{u}}_{t}]_{{{\varvec{\theta }}}_i} =u(t, {{\varvec{\theta }}}_i,\Phi ,{{\varvec{x}}}_i)\) can reach transition densities which do not factor.
Proof
Consider the linear timehomogeneous SDE:
where:
then this SDE admits a closed form solution:
which is a GaussMarkov process with 0 mean and covariance matrix:
We can carry out the matrix exponential through the eigendecomposition of \({{\varvec{A}}}\), for simplicity let us consider the 3dimensional case:
From this we see that:
Integrating wrt to s yields:
The covariance matrix is dense at all times and thus the density \({\textrm{Law}}({{\varvec{\Theta }}}_t) = {\mathcal {N}}({\varvec{\mu }}(t), {\varvec{\Sigma }}(t))\) does not factor (is a fully joint distribution). This example motivates that even with the decoupled drift we can reach coupled distributions. \(\square \)
Appendix E Low variance estimators and sticking the landing
Theorem
1 The STL estimator proposed in (Xu et al. 2021) satisfies
almost surely, for all smooth and bounded perturbations \({{\varvec{\phi }}}\).
Proof
Let us decompose \({\mathcal {F}}\) in the following way:
where (denoting the terminal cost with g):
Denoting \({{\varvec{\Theta }}}^{\textbf{u}}\sim {\mathbb {Q}}^{{\textbf{u}}, \delta _0}\), from Nüsken and Richter (2021), Theorem 5.3.1, Eq. 133 it follows that:
almost surely, where \({{\varvec{A}}}_t\) is defined as
and satisfies:
Similarly via the chain rule it follows that:
almost surely, combining these results we can see that \( \frac{\textrm{d}}{\textrm{d}\varepsilon } \mathcal {F}({\textbf{u}}^* + \varepsilon {{\varvec{\phi }}}) \Big \vert _{\varepsilon = 0} =0\) almost surely as required. \(\square \)
Appendix F Stabilising MCSFS implementation
We found the estimators proposed in Huang et al. (2021) (Eqs. 2.20 or 2.21, and Algorithm 2 in Huang et al. (2021)) to be very numerically unstable. Even in two dimensions the montecarlo estimator of the drift evaluated to nans and infs on more than 50% of the generated samples. This is due to the RND f of Eq. 7 often evaluating to either 0 due to underflow or a very small number resulting in Eq. 7 becoming very large and unstable.
In order to alleviate this we propose the a novel modified logsmexp reformulation of Eq. 7:
Lemma 2
(Stable MCSFS) The MCSFS estimator
Where \(\hat{P}\) is the empirical measure:
Can be reexpresssed as:
where:
and \(\ln {\mathcal {Z}}_s = \ln \sqrt{1t} + \ln f({{\varvec{x}}}+ \sqrt{1t} {{\varvec{z}}}_s) \)
Proof
Firstly notice that the logsumexp formula cannot be applied to the numerator as the terms \({{\varvec{z}}}_s f({{\varvec{x}}}+ \sqrt{1t} {{\varvec{z}}}_s)\) in the numerator can take on negative values and thus we cannot take the log.
In order to take log the note that \(\mathbb {E}_{\hat{P}}[f]\) is a Lebesgue–Stieltjes integral and thus by construction we can decompose it into positive and negative parts:
wlog consider the first term:
and similarly for the second, at this point we can trivially apply the log sum exp formula to each of the exponents separately as their integrands are positive. \(\square \)
For efficient implementation we first separate the samples into positive and negative and then proceed to compute each of the \(g^+\) and \(g^\) terms separately which avoids evaluating any \({{\,\textrm{ln}\,}}0\) terms. We found this formula to have no numerical instabilities in our experiments ranging up to high dimensional cases \(d=2^{12}\) without issue.
Appendix G Sensitivity of hyperparameters to hypespectral unmixing results
While we were able to find step size schedules for SGLD that would work well for the Hyperspectral image data, it is important to note that it was due to heavy tuning and a stroke of luck. As shown in 5 there are four parameters to adjust for the step size scheduling of SGLD and the resulting performance is very sensitive to all of them. To illustrate this, we fixed the parameters associated to \(\sigma ^2\) as given in 5, and varied the others. The resulting samples are provided in figure 7.
In contrast, NSFS has only one tunable parameter, which impacts the results much less, as shown in figures 8 and 9.
Appendix H Experimental details and further results
1.1 H.1 Method hyperparameters
In Table 5 we show the experimental configuration of the trialled algorithms across all datasets. For the selected values of \(\gamma \) we ran a small grid search \(\gamma \in \{0.5^2, 0.2^2, 0.1^2, 0.05^2, 0.01^2\}\) and selected the \(\gamma \) with best training set results.
1.2 H.2 Step function dataset
Here we describe in detail how the step function dataset was generated:
Where:

\(\sigma _y=0.1\)

\(N_{\textrm{train}} = 100\), \(N_{\textrm{test}} = 100\)

\(x_{\textrm{train}} \in (3.5, 3.5)\)

\(x_{\textrm{test}} \in (10, 10)\)
1.3 H.3 Föllmer drift architecture
Across all experiments (with the exception of the MNIST dataset) we used the same architecture to parametrise the Föllmer drift:
Note the weights and biases of the final layer are initialised to 0 in order to start the process at a Brownian motion matching the SBP prior.
For the MNIST dataset we used the score network proposed in Chen et al. (2021). We aimed in using this same architecture for the CIFAR10 experiments however we were unable to train it stably.
For Hyperspectral Unmixing dataset we used this architecture for NSFS with full drift, but had to devise a different architecture for decoupled drifts, as shown below.
1.4 H.4 BNN architectures
For the step function dataset we used the following architecture:
For LeNet5 the architecture used was:
The same layer structure as in LeNet5 was used for the CIFAR10 dataset,and with a difference in the number of channels and size of filters. Exact details can be found in the code repository.
1.5 H.5 Likelihood and prior hyperparameters
In Table 6 we describe the hyperparameters of each Bayesian model as well as their priors and likelihood.
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if changes were made. The images or other third party material in this article are included in the article’s Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article’s Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by/4.0/.
About this article
Cite this article
Vargas, F., Ovsianas, A., Fernandes, D. et al. Bayesian learning via neural Schrödinger–Föllmer flows. Stat Comput 33, 3 (2023). https://doi.org/10.1007/s11222022101725
Received:
Accepted:
Published:
DOI: https://doi.org/10.1007/s11222022101725