1 Introduction

Steering a stochastic flow from one distribution to another across the space of probability measures is a well-studied problem initially proposed in Schrödinger (1932). There has been recent interest in the machine learning community in these methods for generative modelling, sampling, dataset imputation and optimal transport (Wang et al. 2021; De Bortoli et al. 2021; Huang et al. 2021; Bernton et al. 2019; Vargas et al. 2021; Chen et al. 2022; Cuturi 2013; Maoutsa and Opper 2021; Reich 2019).

We consider a particular instance of the Schrödinger bridge problem (SBP), known as the Schrödinger–Föllmer process (SFP). In machine learning, this process has been proposed for sampling and generative modelling (Huang et al. 2021; Tzen and Raginsky 2019b) and in molecular dynamics for rare event simulation and importance sampling (Hartmann and Schütte 2012; Hartmann et al. 2017); here we apply it to Bayesian inference. We show that a control-based formulation of the SFP has deep-rooted connections to variational inference and is particularly well suited to Bayesian inference in high dimensions. This capability arises from the SFP’s characterisation as an optimisation problem and its parametrisation through neural networks (Tzen and Raginsky 2019b). Finally, due to the variational characterisation that these methods possess, many low-variance estimators (Richter et al. 2020; Nüsken and Richter 2021; Roeder et al. 2017; Xu et al. 2021) are applicable to the SFP formulation we consider.

We reformulate the Bayesian inference problem by constructing a stochastic process \({{\varvec{\Theta }}}_t\) which at a fixed time \(t=1\) will generate samples from a pre-specified posterior \(p({{\varvec{\theta }}}\vert {{\varvec{X}}})\), i.e. \({\textrm{Law}}{{\varvec{\Theta }}}_1 =p({{\varvec{\theta }}}\vert {{\varvec{X}}}) \), with dataset \({{\varvec{X}}}=\{{\textbf{x}}_i\}_{i=1}^N\), and where the model is given by:

$$\begin{aligned} {{\varvec{\theta }}}&\sim p({{\varvec{\theta }}}),\nonumber \\ {\textbf{x}}_i \vert {{\varvec{\theta }}}&\sim p({{\varvec{x}}}_i \vert {{\varvec{\theta }}}). \quad \textrm{iid} \end{aligned}$$
(1)

Here the prior \(p({{\varvec{\theta }}})\) and the likelihood \(p({\textbf{x}}_i \vert {{\varvec{\theta }}})\) are user-specified. Our target is \(\pi _1({{\varvec{\theta }}}) = \frac{p({{\varvec{X}}}\vert {{\varvec{\theta }}})p({{\varvec{\theta }}})}{{\mathcal {Z}}}\), where \({\mathcal {Z}}= \int \prod _i p({{\varvec{x}}}_i \vert {{\varvec{\theta }}})p({{\varvec{\theta }}}) d{{\varvec{\theta }}}\). This formulation is reminiscent of the setup proposed in the previous works (Grenander and Miller 1994; Roberts and Tweedie 1996; Girolami and Calderhead 2011; Welling and Teh 2011) and covers many Bayesian machine-learning models, but our formulation has an important difference. SGLD relies on a diffusion that reaches the posterior as its equilibrium state when time approaches infinity. In contrast, our dynamics are controlled and the posterior is reached in finite time (bounded time). The benefit of this property is elegantly illustrated in Sect. 3.2 of Huang et al. (2021) where they rigorously demonstrate that even under an Euler approximation the proposed approach reaches a Gaussian target at time \(t\!=\!1\) whilst SGLD does not.

Contributions The main contributions of this work can be detailed as follows:

  • In this work we scale and apply the theoretical framework proposed in (Dai Pra 1991; Tzen and Raginsky 2019a) to sample from posteriors in large scale Bayesian machine learning tasks such as Bayesian Deep learning. We study the robustness of the predictions under this framework as well as evaluate their uncertainty quantification.

  • More precisely we propose an amortised parametrisation that allows scaling models with local and global variables to large datasets.

  • We explore and provide further theoretical backing (Sect. 2.2) to the “sticking the landing” estimator provided by Xu et al. (2021).

  • Overall we empirically demonstrate that the stochastic control framework offers a promising direction in Bayesian machine learning, striking the balance between theoretical/asymptotic guarantees found in MCMC methods (Hastings 1970; Duane et al. 1987; Neal 2011; Brooks et al. 2011) and more practical approaches such as variational inference (Blei et al. 2003).

Fig. 1
figure 1

Predictive posterior contour plots on the banana dataset (Diethe 2015). Test accuracies: \(0.8928 \pm 0.0056, 0.8913\pm 0.0105, 0.8800\pm 0.0063\) and test ECEs: \( 0.0229\pm 0.0062, 0.0253 \pm 0.0042, 0.0267 \pm 0.0083\) for N-SFS, SGLD, and SGD respectively. We observe that N-SFS obtains the highest test accuracy whilst preserving the lowest ECE

1.1 Notation

Throughout the paper we consider path measures (denoted as \({\mathbb {Q}}\) or \({\mathbb {S}}\)) on the space of continuous functions \(\Omega = C([0, 1], {\mathbb {R}}^d)\). Random processes associated with such path measures \({\mathbb {Q}}\) are denoted as \({{\varvec{\Theta }}}\) and their time-marginal distributions as \({\mathbb {Q}}_t = ({{\varvec{\Theta }}}_t)_\#{\mathbb {Q}}\) (which are just pushforward measures). Given two marginal distributions \(\pi _0\) and \(\pi _1\) we write \({\mathcal {D}}(\pi _0, \pi _1) = \{{\mathbb {Q}}: {\mathbb {Q}}_0 = \pi _0, {\mathbb {Q}}_1 = \pi _1 \}\) for the set of all path measures with given marginal distributions at the initial and final times. We denote by \({\mathbb {Q}}^{{\textbf{u}}, \pi }\) the path measure of the following Stochastic Differential Equation (SDE):

$$\begin{aligned} d {{\varvec{\Theta }}}_t = {\textbf{u}}(t, {{\varvec{\Theta }}}_t)d t + \sqrt{\gamma } d {\textbf{B}}_t, \;\;\; {{\varvec{\Theta }}}_0 \sim \pi \end{aligned}$$
(2)

(we drop the dependence on \(\gamma \) since it is fixed) and we write \(\mathbb {W}^\gamma = {\mathbb {Q}}^{0, \delta _0}\) for the Wiener measure. We will write \(\frac{d{\mathbb {Q}}}{d{\mathbb {S}}}\) for the Radon-Nikodym derivative (RND) of \({\mathbb {Q}}\) w.r.t. \({\mathbb {S}}\).

1.2 Schrödinger–Föllmer processes

Definition 1

(Schrödinger-Bridge Process) Given a reference process \({\mathbb {S}}\) and two measures \(\pi _0\) and \(\pi _1\) the Schrödinger bridge distribution is given by

$$\begin{aligned} {\mathbb {Q}}^* = \mathop {\mathrm {arg\; inf}}\limits _{{\mathbb {Q}}\in {\mathcal {D}}(\pi _0, \pi _1)} D_{\textrm{KL}}\left( {\mathbb {Q}}\big \vert \big \vert {\mathbb {S}}\right) , \end{aligned}$$
(3)

where \({\mathbb {S}}\) acts as a “prior”.

It is known (Léonard 2013) that if \({\mathbb {S}}= {\mathbb {Q}}^{{\textbf{u}}, \pi }\), \({\mathbb {Q}}^*\) is induced by an SDE with a modified drift:

$$\begin{aligned} \textrm{d}{{\varvec{\Theta }}}_t = {\textbf{u}}^*(t, {{\varvec{\Theta }}}_t)\textrm{d}t + \sqrt{\gamma } \textrm{d}{\textbf{B}}_t, \;\;\; {{\varvec{\Theta }}}_0 \sim \pi _0, \end{aligned}$$
(4)

i.e. \({\mathbb {Q}}^* = {\mathbb {Q}}^{{\textbf{u}}^*, \pi _0}\). Solution of this SDE is called the Schrödinger-Bridge Process (SBP).

Definition 2

(Schrödinger–Föllmer Process) The SFP is an SBP where \(\pi _0=\delta _0\) and the reference process \({\mathbb {S}}= \mathbb {W}^\gamma \) is the Wiener measure.

The SFP differs from the general SBP in that, rather than constraining the initial distribution to \(\delta _0\), the SBP considers any initial distribution \(\pi _0\). The SBP also involves general Itô SDEs associated with \({\mathbb {Q}}^{{{\varvec{u}}}, \pi }\) as the dynamical prior, compared to the SFP which restricts attention to Wiener processes as priors.

The advantage of considering this more limited version of the SBP is that it admits a closed-form characterisation of the solution to the Schrödinger system (Léonard 2013; Wang et al. 2021; Pavon et al. 2018) which allows for an unconstrained formulation of the problem. For accessible introductions to the SBP we suggest (Pavon et al. 2018; Vargas et al. 2021). Now we will consider instances of the SBP and the SFP where \(\pi _1=p({{\varvec{\theta }}}\vert {{\varvec{X}}})\).

1.2.1 Analytic solutions and the heat semigroup

Prior work (Pavon 1989; Dai Pra 1991; Tzen and Raginsky 2019b; Huang et al. 2021) has explored the properties of SFPs via a closed form formulation of the Föllmer drift expressed in terms of expectations over Gaussian random variables known as the heat semigroup. The seminal works (Pavon 1989; Dai Pra 1991; Tzen and Raginsky 2019b) highlight how this formulation of the Föllmer drift characterises an exact sampling scheme for a target distribution and how it could potentially be used in practice. The recent work by Huang et al. (2021) builds on Tzen and Raginsky (2019b) and explores estimating the optimal drift in practice via the heat semigroup formulation using a Monte Carlo approximation. Our work aims to take the next step and scale the estimation of the Föllmer drift to high dimensional cases (Graves 2011; Hoffman et al. 2013). In order to do this we must move away from the heat semigroup and instead consider the dual formulation of the Föllmer drift in terms of a stochastic control problem (Tzen and Raginsky 2019b).

In the setting when \(\pi _0=\delta _0\) we can express the optimal SBP drift as follows

$$\begin{aligned} {\textbf{u}}^*(t, {{\varvec{x}}}) = \nabla _{{\varvec{x}}}\ln \mathbb {E}_{{{\varvec{\Theta }}}\sim {\mathbb {S}}}\left[ \frac{d\pi _1}{d {\mathbb {S}}_1}({{\varvec{\Theta }}}_1)\Big \vert {{\varvec{\Theta }}}_t = {{\varvec{x}}}\right] \end{aligned}$$
(5)

Definition 3

The Euclidean heat semigroup \(Q_t^\gamma , \; t\ge 0\), acts on bounded measurable functions \(f: {\mathbb {R}}^d \rightarrow {\mathbb {R}}\) as \(Q^\gamma _t f({{\varvec{x}}}) = \int _{{\mathbb {R}}^d} f\left( {{\varvec{x}}}+\sqrt{t}{{\varvec{z}}}\right) {\mathcal {N}}({{\varvec{z}}}\vert {{\varvec{0}}}, \gamma {\mathbb {I}}) d{{\varvec{z}}}=\mathbb {E}_{{{\varvec{z}}}\sim {\mathcal {N}}({{\varvec{0}}}, \gamma {\mathbb {I}})}\left[ f\left( {{\varvec{x}}}+\sqrt{t}{{\varvec{z}}}\right) \right] .\)

In the SFP case where \({\mathbb {S}}=\mathbb {W}^\gamma \), the optimal drift from Eq. 5 can be written in terms of the heat semigroup, \({\textbf{u}}^*(t, {{\varvec{x}}}) = \nabla _{{\varvec{x}}}\ln Q^\gamma _{1-t} \left[ \frac{d \pi _1}{d{\mathcal {N}}({{\varvec{0}}}, \gamma {\mathbb {I}})}({{\varvec{x}}})\right] \). Note that an SDE with the heat semigroup induced drift

$$\begin{aligned} d{{\varvec{\Theta }}}_t = \nabla _{{{\varvec{\Theta }}}_t}\ln Q^\gamma _{1-t} \left[ \frac{d\pi _1}{d{\mathcal {N}}({{\varvec{0}}}, \gamma {\mathbb {I}})}({{\varvec{\Theta }}}_t )\right] d t + \sqrt{\gamma } d{\textbf{B}}_t \end{aligned}$$
(6)

satisfies \( {\textrm{Law}}{{\varvec{\Theta }}}_1 = \pi _1\), that is, at \(t=1\) these processes are distributed according to our target distribution of interest \(\pi _1\).

1.2.2 Schrödinger–Föllmer samplers

Huang et al. (2021) carried out preliminary work on empirically exploring the success of using the heat semigroup formulation of SFPs in combination with the Euler-Mayurama (EM) discretisation to sample from target distributions in a method they call Schrödinger–Föllmer samplers (SFS). More precisely the SFS approach proposes estimating the Föllmer drift via:

$$\begin{aligned} \hat{{\textbf{u}}}^*(t, {{\varvec{x}}}) = \frac{\frac{1}{S}\sum _{s=1}^S {{\varvec{z}}}_s f({{\varvec{x}}}+ \sqrt{1-t} {{\varvec{z}}}_s)}{\frac{\sqrt{1-t}}{S}\sum _{s=1}^S f({{\varvec{x}}}+ \sqrt{1-t} {{\varvec{z}}}_s)}, \end{aligned}$$
(7)

where \({{\varvec{z}}}_s \sim {\mathcal {N}}({{\varvec{0}}}, \gamma {\mathbb {I}})\) and \(f=\frac{d\pi _1}{d{\mathcal {N}}({{\varvec{0}}}, \gamma {\mathbb {I}})}\). Whilst this estimator enjoys sound theoretical properties (Huang et al. 2021) it falls short in practice for the following reasons:

Fig. 2
figure 2

Comparison between MC-SFS and N-SFS under similar computational constraints. Target distribution is the Gaussian posterior induced by a Bayesian linear regression model, we plot the error of the first and second posterior predictive moments between the true posterior predictive and the listed approximations. We found increasing the number of steps in SGLD drove the errors closer to 0 however when increasing the dimensions this threshold also increased notably. This illustrates the advantages of having a target at a finite time rather than at equilibrium

  • The term f involves the product of PDFs evaluated at samples rather than a log product and is thus often very unstable numerically. In Appendix 1 we provide a more stable implementation of Eq. 7 exploiting the logsumexp trick and properties of the Lebesgue integral.

  • In it’s current form the estimator does not admit low variance estimators (e.g. Variatonal Inference), being a Monte Carlo estimator it is prone to high variance.

  • Both empirically and theoretically we found the computational running time of the above approach to be considerably slower than the other methods we compare to. At test time SFS has a computational complexity of \({\mathcal {O}}(T S\#_f(d) )\) where \(T=\Delta t^{-1}\), S is the number of Monte Carlo samples and \(\#_f(d)\) is the cost of evaluating the RND f which at best is linear in d. Meanwhile our proposed approach enjoys a cost of \({\mathcal {O}}(T \#_{{{\varvec{u}}}_\phi }(d) )\) where \(\#_{{{\varvec{u}}}_\phi }(d)\) is the forward pass through a neural network approximating the Föllmer drift.

In practice we found this implementation to be too numerically unstable and unable to produce reasonable results even in low dimensional examples in order to carry out a fair comparison we reformulated Eq. 7 stably, the stable formulation and its derivation can be found in Appendix 1.

In this work build on Huang et al. (2021) by considering a formulation of the Schrödinger–Föllmer process that is suitable for the high dimensional settings arising in Bayesian ML. Our work will focus on a dual formulation of the optimal drift that is closer to variational inference and thus admits the scalable and flexible parametrisations used in ML.

figure a

2 Stochastic control formulation

In this section, we introduce a particular formulation of the Schrödinger–Föllmer process in the context of the Bayesian inference problem in Eq. 1. In its most general setting of sampling from a target distribution, this formulation was known to Dai Pra (1991). Tzen and Raginsky (2019b) study the theoretical properties of this approach in the context of generative models (Kingma et al. 2021; Goodfellow et al. 2014), finally Opper (2019) applies this formulation to time series modelling. In contrast our focus is on the estimation of a Bayesian posterior for a broader class of models than Tzen and Raginsky explore.

Corollary 1

Define

$$\begin{aligned}{} & {} {\mathcal {F}}_{\textrm{DET}}({\textbf{u}}, {{\varvec{\theta }}}) = \frac{1}{2\gamma }\int _0^1\Vert {\textbf{u}}(t, {{\varvec{\theta }}}_t)\Vert ^2 dt - \ln \frac{ p({{\varvec{X}}}\vert {{\varvec{\theta }}}_1)p({{\varvec{\theta }}}_1)}{{\mathcal {N}}({{\varvec{\theta }}}_1\vert {\varvec{0}}, \gamma {\mathbb {I}_d})} \\{} & {} J({\textbf{u}}) = \mathbb {E}_{{{\varvec{\Theta }}}\sim {\mathbb {Q}}^{{\textbf{u}}, \delta _0}}\left[ {\mathcal {F}}_{\textrm{DET}}({\textbf{u}}, {{\varvec{\Theta }}})\right] \end{aligned}$$

Then the minimiser (with \({\mathcal {U}}\) being the set of admissible controlsFootnote 1)

$$\begin{aligned} {\textbf{u}}^{*}\!=\!\mathop {\mathrm {\arg \; min}}\limits _{{\textbf{u}}\in {\mathcal {U}}} J({\textbf{u}}) \end{aligned}$$
(8)

satisfies \({\mathbb {Q}}_1^{\gamma , {\textbf{u}}^*, \delta _0} = \frac{p({{\varvec{X}}}\vert {{\varvec{\theta }}})p({{\varvec{\theta }}})}{{\mathcal {Z}}}d{{\varvec{\theta }}}\).

Moreover, \({\textbf{u}}^{*}\) solves the SFP with \(\pi _1=p({{\varvec{\theta }}}\vert {{\varvec{X}}})\).

The objective in Eq. 8 can be estimated using an SDE discretisation, such as the EM method. Since the drift \({\textbf{u}}^{*}\) is Markov, it can be parametrised by a flexible function estimator such as a neural network, as in Tzen and Raginsky (2019b). In addition, unbiased estimators for the gradient of objective in (8) can be formed by subsampling the data. In this work we will refer to the above formulation of the SFP as the Neural Schrödinger–Föllmer sampler (N-SFS) when we parametrise the drift with a neural network and implement unbiased mini-batched estimators for this objective (Appendix 1). This formulation of SFPs has been previously studied in the context of generative modelling / marginal likelihood estimation (Tzen and Raginsky 2019b), while we focus on Bayesian inference.

We note that recent concurrent work (Zhang et al. 2022)Footnote 2 proposes an algorithm akin to ours based on Dai Pra (1991); Tzen and Raginsky (2019b), however their focus is on estimating the normalising constant of unnormalised densities, while ours is on Bayesian ML tasks such as Bayesian regression, classification and LVMs, thus our work leads to different insights and algorithmic motivations.

2.1 Theoretical guarantees for neural SFS

While the focus in Tzen and Raginsky (2019b) is in providing guarantees for generative models of the form \( {{\varvec{x}}}\sim q_\phi ({{\varvec{x}}}\vert {{\varvec{Z}}}_1)\;,d{{\varvec{Z}}}_t = {{\varvec{u}}}_{\phi }({{\varvec{Z}}}_t,t)dt + \sqrt{\gamma } d{\textbf{B}}_t, \;{{\varvec{Z}}}_0 =\! 0,\) their results extend to our setting as they explore approximating the Föllmer drift for a generic target \(\pi _1\).

Theorem 4 in Tzen and Raginsky (restated as Theorem 2 in Appendix 1) motivates using neural networks to parametrise the drift in Eq. 8 as it provides a guarantee regarding the expressivity of a network parametrised drift via an upper bound on the target distribution error in terms of the size of the network.

We will now proceed to highlight how this error is affected by the EM discretisation:

Corollary 2

Given the network \({{\varvec{v}}}\) from Theorem 2 it follows that the Euler-Maruyama discretisation of (2) with \({\textbf{u}}={{\varvec{v}}}\) induces an approximate target \(\hat{\pi }^{{{\varvec{v}}}}_1\) that satisfies

$$\begin{aligned} D_{\textrm{KL}}(\pi _1\vert \vert \hat{\pi }^{{{\varvec{v}}}}_1 ) \le \left( \epsilon ^{1/2} + \mathcal {O}(\sqrt{\Delta t}) \right) ^2. \end{aligned}$$
(9)

This result provides a bound of the error in terms of the depth \(\Delta t^{-1}\) of the stochastic flow (Chen et al. 2022; Zhang et al. 2021) and the size of the network that we parametrise the drift with. Under the view that NN parametrised SDEs can be interpreted as ResNets (Li et al. 2020) we find that this result illustrates that increasing the ResNets’ depth will lead to more accurate results.

2.2 Sticking the landing and low variance estimators

As with VI (Richter et al. 2020; Roeder et al. 2017), the gradient of the objective in this study admits several low variance estimators (Nüsken and Richter 2021; Xu et al. 2021). In this section we formally recap what it means for an estimator to “stick the landing” and we prove that the estimator proposed in Xu et al. satisfies said property.

The full objective being minimised in our approach is (where expectations are taken over \({{\varvec{\Theta }}}\sim {\mathbb {Q}}^{{\textbf{u}}, \delta _0}\)):

$$\begin{aligned} J({\textbf{u}}) =\mathbb {E}[&{\mathcal {F}}_{\textrm{DET}}({\textbf{u}}, {{\varvec{\Theta }}})]\! \nonumber \\ =\mathbb {E}[&{\mathcal {F}}({\textbf{u}}, {{\varvec{\Theta }}})]\! \nonumber \\ =\!\mathbb {E}\!\Bigg [&\!\frac{1}{2\gamma }\!\!\int _0^1\!\!\!\!\vert \vert {\textbf{u}}_t({{\varvec{\Theta }}}_t)\vert \vert ^2 \!dt + \frac{1}{\sqrt{\gamma } }\int _0^1\!\!\!\!{\textbf{u}}_t({{\varvec{\Theta }}}_t)^\top \!d{\textbf{B}}_t \nonumber \\&\!-\! \ln \!\Big (\!\frac{ p({{\varvec{X}}}\vert {{\varvec{\Theta }}}_1)p({{\varvec{\Theta }}}_1)}{{\mathcal {N}}({{\varvec{\Theta }}}_1\vert {\varvec{0}}, \gamma {\mathbb {I}_d})}\!\!\Big )\!\Bigg ], \end{aligned}$$
(10)

noticing that in previous formulations we have omitted the Itô integral as it has zero expectation (but the integral appears naturally through Girsanov’s theorem). We call the estimator calculated by taking gradients of the above objective the relative-entropy estimator. The estimator proposed in Xu et al. (2021) (Sticking the landing estimator) is given by:

$$\begin{aligned} J_{\textrm{STL}}({\textbf{u}}) =\mathbb {E}[&{\mathcal {F}}_{\textrm{STL}}({\textbf{u}}, {{\varvec{\Theta }}})]\! \nonumber \\ =\!\mathbb {E}\!\Bigg [&\!\frac{1}{2\gamma }\!\!\int _0^1\!\!\!\!\vert \vert {\textbf{u}}_t({{\varvec{\Theta }}}_t)\vert \vert ^2 \!dt \!+ \!\frac{1}{\sqrt{\gamma } }\!\!\int _0^1\!\!\!\!\!{\textbf{u}}^{\perp }_t({{\varvec{\Theta }}}_t)^\top \!\!d{\textbf{B}}_t\nonumber \\ {}&\!-\! \ln \!\Big (\!\frac{ p({{\varvec{X}}}\vert {{\varvec{\Theta }}}_1)p({{\varvec{\Theta }}}_1)}{{\mathcal {N}}({{\varvec{\Theta }}}_1\vert {\varvec{0}}, \gamma {\mathbb {I}_d})}\!\!\Big )\!\Bigg ], \end{aligned}$$
(11)

where \(\perp \) means that the gradient is stopped/detached as in Xu et al. (2021); Roeder et al. (2017).

We study perturbations of \({\mathcal {F}}\) around \({\textbf{u}}^*\) by considering \({\textbf{u}}^* + \varepsilon {{\varvec{\phi }}}\), with \({{\varvec{\phi }}}\) arbitrary, and \(\varepsilon \) small. More precisely, we set out to compute (where dependence on \({{\varvec{\theta }}}\) is dropped):

$$\begin{aligned} \frac{\textrm{d}}{\textrm{d}\varepsilon } \mathcal {F}({\textbf{u}}^* + \varepsilon {{\varvec{\phi }}})\Big \vert _{\varepsilon = 0}, \end{aligned}$$
(12)

through which we define the definition of “sticking the landing”:

Definition 4

We say that an estimator “sticks the landing” when

$$\begin{aligned} \frac{\textrm{d}}{\textrm{d}\varepsilon } \mathcal {F}({\textbf{u}}^* + \varepsilon {{\varvec{\phi }}}) \Big \vert _{\varepsilon = 0}=0, \end{aligned}$$
(13)

almost surely, for all smooth and bounded perturbations \({{\varvec{\phi }}}\).

Notice that by construction, \({{\varvec{u}}}^*\) is a global minimiser of J, and hence all directional derivatives vanish,

$$\begin{aligned} \frac{\textrm{d}}{\textrm{d}\varepsilon } J({\textbf{u}}^* + \varepsilon {{\varvec{\phi }}}) \Big \vert _{\varepsilon = 0} = \frac{\textrm{d}}{\textrm{d}\varepsilon } \mathbb {E}[\mathcal {F}({\textbf{u}}^* + \varepsilon {{\varvec{\phi }}}, {{\varvec{\Theta }}})] \Big \vert _{\varepsilon = 0} =0. \end{aligned}$$
(14)

Definition 4 additionally demands that this quantity is zero almost surely, and not just on average. Consequently, “sticking the landing”-estimators will have zero-variance at \({\textbf{u}}^*\).

Remark 1

The relative-entropy stochastic control estimator does not stick the landing.

Proof

See Nüsken and Richter (2021), Theorem 5.3.1, clause 3, Eq. 133 clearly indicates \(\frac{\textrm{d}}{\textrm{d}\varepsilon } \mathcal {F}({\textbf{u}}^* + \varepsilon {{\varvec{\phi }}}) \Big \vert _{\varepsilon = 0}\ne 0\). \(\square \)

We can now go ahead and prove that the estimator proposed by Xu et al. (2021) does indeed stick the landing.

Theorem 1

The STL estimator proposed in (Xu et al. 2021) satisfies

$$\begin{aligned} \frac{\textrm{d}}{\textrm{d}\varepsilon } \mathcal {F}({\textbf{u}}^* + \varepsilon {{\varvec{\phi }}}) \Big \vert _{\varepsilon = 0} =0, \end{aligned}$$
(15)

almost surely, for all smooth and bounded perturbations \({{\varvec{\phi }}}\).

The proof for the above result can be found in Appendix 1, and combines results from Nüsken and Richter (2021).

2.3 Structured SVI in models with local and global variables

Algorithm 1 produces unbiased estimates of the gradientFootnote 3 as demonstrated in Appendix 1 only under the assumption that the parameters are global, that is when there is not a local parameter for each data point. In the setting where we have local and global variables we can no longer do mini-batch updates as in Algorithm 1 since the energy term in the objective does not decouple as a sum over the datapoints (Hoffman et al. 2013; Hoffman and Blei 2015). In this section we discuss said limitation and propose a reasonable heuristic to overcome it.

We consider the general setting where our model has global and local variables \({\varvec{\Phi }}, \{{{\varvec{\theta }}}_i\}\) satisfying (Hoffman et al. 2013). This case is particularly challenging as the local variables scale with the size of the dataset and so will the state space. This is a fundamental setting as many hierachical latent variable models in machine learning admit such dependancy structure, such as Topic models (Pritchard and M., S., and P., D. 2000; Blei et al. 2003); Bayesian factor analysis (Amari et al. 1996; Bishop 1999; Klami et al. 2013; Daxberger et al. 2019); Variational GP Regression (Hensman et al. 2013); and others.

Remark 2

The heat semigroup does not preserve conditional independence structure in the drift, i.e. the optimal drift does not decouple and thus depends on the full state-space (Appendix 1).

Fig. 3
figure 3

Visual comparison on step function data. We can see how the N-SFS based fits have the best generalisation while SGD and SGLD interpolate the noise

Table 1 a9a Dataset

Remark 2 tells us that the drift is not structured in a way that admits scalable sampling approaches such as stochastic variational inference (SVI) (Hoffman et al. 2013). Additionally this also highlights that the method by Huang et al. (2021) does not scale to models like this as the dimension of the state space will be linear in the size of the dataset.

In a similar fashion to Hoffman and Blei (2015), who focussed on structured SVI, we suggest parametrising the drift via \([{\textbf{u}}_{t}]_{{{\varvec{\theta }}}_i} \!\!=\!u^{{{\varvec{\theta }}}_i}(t, {{\varvec{\theta }}}_i,\Phi ,{{\varvec{x}}}_i) \); this way the dimension of the drift depends only on the respective local variables and the global variable \(\Phi \). While the Föllmer drift does not admit this particular decoupling we can show that this drift is flexible enough to represent fairly general distributions, thus it is expected to have the capacity to reach the target distribution. Via this parametrisation we can sample in the same fashion as SVI and maintain unbiased gradient estimates.

Remark 3

An SDE parametrised with a decoupled drift \([{\textbf{u}}_{t}]_{{{\varvec{\theta }}}_i} =u^{{{\varvec{\theta }}}_i}(t, {{\varvec{\theta }}}_i,\Phi ,{{\varvec{x}}}_i)\) can reach transition densities which do not factor (See Appendix 1 for proof).

It is important to highlight that whilst the parametrisation in Remark 3 may be flexible, it may not satisfy the previous theory developed for the Föllmer drift and SBPs, thus an interesting direction would be in recasting the SBP such that the optimal drift is decoupled. However, we found in practice that the decoupled and amortised drift worked very well, outperforming SGLD and the non-decoupled N-SFS.

3 Connections between SBPs and variational inference in latent diffusion models

In this section, we highlight the connection between the objective in Eq. 8 to variational inference in models with an SDE as the latent object, as studied in Tzen and Raginsky (2019a). We first start by making the connection in a simpler case – when the prior of our Bayesian model is given by a Gaussian distribution with variance \(\gamma \), that is \(p({{\varvec{\theta }}})={\mathcal {N}}({{\varvec{\theta }}}\vert {\varvec{0}}, \gamma {\mathbb {I}_d})\).

Observation 1

When \(p({{\varvec{\theta }}})={\mathcal {N}}({{\varvec{\theta }}}\vert {\varvec{0}}, \gamma {\mathbb {I}_d})\), it follows that the N-SFP objective in Eq. 8 corresponds to the negative ELBO of the model:

$$\begin{aligned} d{{\varvec{\Theta }}}_t&= \sqrt{\gamma } d{\textbf{B}}_t, \;\;\; {{\varvec{\Theta }}}_0 \sim \delta _0, \nonumber \\ {\textbf{x}}_i&\sim p({\textbf{x}}_i \vert {{\varvec{\Theta }}}_1) . \end{aligned}$$
(16)

While the above observation highlights a specific connection between N-SFP and traditional VBI (Variational Bayesian Inference), it is limited to Bayesian models that are specified with Gaussian priors. In Lemma 1 of Appendix 1 we extend this result to more general priors and reference process via exploiting the general recursive nature of Bayesian updates (Khan and Rue 2021). In short, we can view the objective in Eq. 8 as an instance of variational Bayesian inference with an SDE prior. Note that this provides a succinct connection between variational inference and maximum entropy in path space (Léonard 2012). In more detail, this observation establishes an explicit connection between the ELBO of an SDE-based generative model where the SDE is latent and the SBP/stochastic-control objectives we explore in this work.

Note that Lemma 1 induces a new two stage algorithm in which we first estimate a prior reference process as in Eq. B10 and then we optimise the ELBO for the model in Eq. B11. This raises the question as to what effect the dynamical prior can have within SBP-based frameworks. In practice we do not explore this formulation as the Föllmer drift of the prior may not be available in closed form and thus may require resorting to additional approximations.

4 Experimental results

We ran experiments on Bayesian NN regression, classification, logistic regression and ICA (Amari et al. 1996), reporting accuracies, log joints (Welling and Teh 2011; Izmailov et al. 2021) and expected calibration error (ECE) (Guo et al. 2017). For details on exact experimental setups please see Appendix 1. Across experiments we compare to SGLD as it has been shown to be a competitive baseline in Bayesian deep learning (Izmailov et al. 2021). Notice that we do not compare to more standard MCMC methodologies (Duane et al. 1987; Neal 2011; Doucet et al. 2001) as they do not scale well to very high dimensional tasks such as Bayesian DL (Izmailov et al. 2021) which are central to our experiments. However, Huang et al. (2021) contrasts the performance of the heat semigroup SFS sampler with more traditional MCMC samplers in 2D toy examples, finding SFS to be competitive.Footnote 4

Table 2 Step function dataset
Table 3 MEG dataset

4.1 Bayesian linear regression and comparison with MC-SFS

In this section we explore a bayesian linear regression model with a prior on the regression weights. As this model has a Gaussian closed form for the posterior predictive distribution we report the error of the MC-SFS and N-SFS posterior predictive mean and variance with respect to the true posterior predictive moments as is seen in Fig. 2. The datasets where generated by sampling the inputs randomly from a spherical Gaussian distribution and transforming them via:

$$\begin{aligned} y_i = {{\varvec{1}}}^\top {{\varvec{x}}}_i + 1 \end{aligned}$$

we then estimated the posterior of the model:

$$\begin{aligned} {{\varvec{\theta }}}&\sim {\mathcal {N}}({\varvec{0}}, \sigma ^2_\theta \mathbb {I}),\\ y_i \vert {{\varvec{x}}}_i&,{{\varvec{\theta }}}\sim {\mathcal {N}}(y_i\vert {{\varvec{\theta }}}^\top ({{\varvec{x}}}_i \oplus 1), \sigma _y^2{\mathbb {I}}), \end{aligned}$$

Where we use \({{\varvec{x}}}\oplus 1\) to denote adding an extra dimension with a 1 to the vector \({{\varvec{x}}}\). We carried out this experiment increasing the dimension of \({{\varvec{x}}}\) from \(2^5\) to \(2^{11}\). We observe that the N-SFS based approaches have overall a notably smaller posterior predictive error to the MC-SFS approach. Finally we note the STL method is more concentrated in its predictions than the naive N-SFS approach, whilst having similar errors.

4.2 Bayesian logistic regression/independent component analysis—a9a/MEG datasets

Following Welling and Teh (2011) we explore a logistic regression model on the a9a dataset. Results can be found in Table 1 which show that N-SFS achieves a test accuracy, ECE and log likelihood comparable to SGLD. We then explore the performance of our approach on the Bayesian variant of ICA studied in Welling and Teh (2011) on the MEG-Dataset (Vigario 1997). We can observe (Table 3) that here N-SFS also achieves results comparable to SGLD.

4.3 Bayesian deep learning

In these tasks we use models of the form

$$\begin{aligned} {{\varvec{\theta }}}&\sim {\mathcal {N}}({\varvec{0}}, \sigma ^2_\theta \mathbb {I}),\\ {{\varvec{y}}}_i \vert {{\varvec{x}}}_i&,{{\varvec{\theta }}}\sim p({{\varvec{y}}}_i\vert f_{{\varvec{\theta }}}({{\varvec{x}}}_i)), \end{aligned}$$

where \(f_{{\varvec{\theta }}}\) is a neural network. In these settings we are interested in using the posterior predictive distribution \(p({{\varvec{y}}}^* \vert {{\varvec{x}}}^*, {{\varvec{X}}}) \!=\!\int p({{\varvec{y}}}^*\vert f_{{\varvec{\theta }}}({{\varvec{x}}}^*)) dP({{\varvec{\theta }}}\vert {{\varvec{X}}}) \) to make robust predictions. Across the image experiments we use the LeNet5 (LeCun et al. 1998) architecture. Future works should explore recent architectures for images such as VGG-16 (Simonyan and Zisserman 2014) and ResNet32 (He et al. 2016).

Non-linear regression—step function We fit a 2-hidden-layer neural network with a total of 14876 parameters on a toy step function dataset. We can see in Fig. 3 how both the SGD and SGLD fits interpolate the noise, whilst N-SFS has straight lines, thus both achieving a better test error and having well-calibrated error bars. We believe it is a great milestone to see how an overparameterised neural network is able to achieve such well calibrated predictions.

Table 4 Test set results on MNIST, Rotated MNIST and CIFAR10. The Log-likelihood column is the mean posterior predictive and is thus not estimated for SGD

Digits classification—LeNet5 We train the standard LeNet5 (LeCun et al. 1998) architecture (with 44426 parameters) on the MNIST dataset (LeCun and Cortes 2010). At test time we evaluate the methods on the MNIST test set augmented by random rotations of up to 30\(^\circ \) (Ferianc et al. 2021). Table 4 shows how N-SFS has the highest accuracy whilst obtaining the lowest calibration error among the considered methods, highlighting that our approach has the most well-calibrated and accurate predictions when considering a slightly perturbed test set. We highlight that LeNet5 falls into an interesting regime as the number of parameters is considerably less than the size of the training set, and thus we can argue it is not in the overparameterised regime. This regime (Belkin et al. 2019) has been shown to be challenging in achieving good generalisation errors, thus we believe the predictive and calibrated accuracy achieved by N-SFS is a strong milestone.

Additionally we provide results on the regular MNIST test set. We can observe that N-SFS maintains a high test accuracy and at the same time preserves a low ECE score. We believe the reason SGD and SGLD obtain slightly better ECE performances is that the MNIST test set has very little variation to the MNIST training set, and thus all results seem well calibrated. We can see this observation confirmed by how the distribution of ECE scores changes dramatically on the Rotated MNIST set, a similar argument to that developed in Ferianc et al. (2021). We note that across both experiments SGLD achieves a slightly better log likelihood which comes at the cost of lower predictive performance and less calibrated predictions.

Image classification—CIFAR10 We fit a variation of the LeNet5 (Appendix 1) architecture with 62006 parameters on the CIFAR10 dataset (Krizhevsky et al. 2009). We note that the predictive test accuracies and log-likelihoods of N-SFS\(_{\textrm{stl}}\), SGLD and SGD are comparable. However, we can see that N-SFS\(_{\textrm{stl}}\) has an ECE an order of magnitude smaller. We notice that the STL estimator made a significant difference on CIFAR10, making the training faster and more stable.

4.4 Hyperspectral image unmixing

To assess our method’s performance visually, we use it to sample from Hyperspectral Unmixing Models (Bioucas-Dias et al. 2012). Hyperspectral images are high spectral resolution but low spatial resolution images typically taken of vast areas via satellites. High spectral resolution provides much more information about the materials present in each pixel. However, due to the low spatial resolution, each pixel of an image can correspond to a \(50m^2\) area, containing several materials. Such pixels will therefore have mixed and uninformative spectra. The task of Hyperspectral Unmixing is to determine the presence of given materials in each pixel.

Fig. 4
figure 4

False-color composites with channels given by the unmixed matrices \({{\varvec{A}}}\) obtained via SGLD, N-SFS and N-SFS with a decoupled drift. Speckles illustrate mode collapse

Fig. 5
figure 5

N-SFS performance on a gaussian mixture posterior distribution with several modes. Outer modes are only detected when the posterior does not contain the interior modes indicating exploration failure of N-SFS

We use the Indian Pines image,Footnote 5 denoted as \({\varvec{Y}}\), which has a spatial resolution of \(P = 145 \times 145 = 21025\) pixels and a spectral resolution of \(B = 200\) bands, i.e. \({\varvec{Y}}= [{\varvec{y}}_1, \dots , {\varvec{y}}_P] \in [0, 1]^{B \times P}\). \(R = 3\) materials have been chosen automatically using the Pixel Purity Index and the collection of their spectra will be denoted as \({\varvec{M}}= [{\varvec{m}}_1, {\varvec{m}}_2, {\varvec{m}}_3] \in [0, 1]^{B \times 3}\). The task of Hyperspectral Unmixing is to determine for each pixel p a vector \({\varvec{a}}_p \in \Delta _{R}\) in the probability simplex, where \([{{\varvec{A}}}]_{p,i}=a_{p, i}\) represents the fraction of the i-th material in pixel p. To determine the presence of each material, we use the Normal Compositional Model  (Eches et al. 2010) as it is a challenging model to sample from. Specifically, it has parameters \(({{\varvec{\Phi }}}, {{\varvec{\Theta }}}) = (\sigma ^2, {\varvec{A}})\) and is defined by:

$$\begin{aligned} p\left( \sigma ^2\right) = {\varvec{1}}_{[0, 1]}\left( \sigma ^2\right) , \quad p\left( {\varvec{A}}\right) = \prod _{p=1}^P {\varvec{1}}_{\Delta _{R}} \left( {\varvec{a}}_p\right) ,\\ p\left( {\varvec{Y}}\vert {\varvec{A}}, \sigma ^2\right) = \prod _{p=1}^P {\mathcal {N}}\left( {\varvec{y}}_p; {\varvec{M}}{\varvec{a}}_p; \vert \vert {\varvec{a}}_p\vert \vert ^2 \sigma ^2 I\right) , \end{aligned}$$

First note that this model follows the structured model setting discussed in Sect. 2.2—it has one global parameter \(\sigma ^2\) and a local parameter \({\varvec{a}}_p\) for each pixel. Finally, while all the parameters are constrained to lie on the probability simplices, this sampling problem can be cast into an unconstrained sampling problem via Lagrange transformations as in Hsieh et al. (2018). The Normal Compositional Model Eches et al. (2010) is primarily of interest to us because the unusual noise scaling in the likelihood can produce several modes in each pixel, making it especially easy for sampling algorithms to get stuck in modes.

We compared three approaches for this problem: 1) SGLD 2) N-SFS 3) N-SFS with decoupled drift, where the decoupled drift is defined as:

Unmixing results are shown in Fig. 4. We stress that to run SGLD successfully we had to tune the approach heavily — we used separate step sizes (which acts as a preconditioning) and step size schedules for parameters \(\sigma ^2\) and \({\varvec{A}}\), only with one combination of which we managed to get decent unmixing results. Without the amortised drift, N-SFS struggled with multiple modes in certain patches of the image, however, decoupling the drift resulted in almost perfect unmixing. With a slight deviation from the optimal step size schedule, SGLD fails to explore modes and produces speckly images. In contrast, the only tunable parameter for N-SFS was \(\gamma \), which was giving similar results for all tried values. Further sensitivity results for SGLD/N-SFS are provided in Appendix 1.

4.5 Analysis of N-SFS training dynamics

Fig. 6
figure 6

Distribution of log posterior values of samples from N-SFS and SGLD (left) and marginal distribution of a pair of weights in a neural network obtained from samples of N-SFS and SGLD (right)

In addition to the experiments above, we investigate our method’s performance in a synthetic multi-modal scenario. Here, N-SFS is used to fit a Gaussian Mixture posterior distribution that has modes aligned on the x-axis, as shown in figure 5. In one case, there are 4 modes – 2 inner modes (those closer to 0) and 2 outer modes (those further away from 0). We notice that in the presence of the 2 inner modes N-SFS is unable to discover the outer modes. In contrast, when considering a posterior with only the 2 outer modes, the distribution is fit correctly. This phenomenon could be explained by previously indicated connections between stochastic control and agent-based learning via the Hamilton–Jacobi–Bellman equation Powell et al. (2019) and the exploration-exploitation tradeoff. More concretely, the optimisation objective (8) implies the following training dynamics – random samples are generated from a diffusion (a Brownian motion to begin with) which is then refined to produce more samples in areas where previous samples had high posterior density. This implies that after some modes are discovered, the diffusion will be adjusted to fit them, i.e. the algorithm immediately starts exploiting the detected modes. Other modes will only be discovered if some random sample accidentally hits them, which is very unlikely if the modes are far away. This indicates that the algorithm could be improved by incorporating exploration techniques found in agent-based learning literature.

Given the behaviour of N-SFS on this multi-modal example, it is then natural to ask if it happens in Bayesian Deep Learning applications. To examine this, we look at the marginal distributions of a pair of weights of a Bayesian Neural Network for MNIST classification given by the samples of N-SFS and SGLD given in Fig. 6. Note that compared to SGLD, N-SFS samples from a dramatically wider distribution, while maintaining a comparable predictive log likelihood score, and therefore does not suffer from the lack of exploration.

5 Discussion and future directions

Overall we achieve predictive performance competitive to SGLD across a variety of tasks whilst obtaining better calibrated predictions as measured by the ECE metric. We hypothesise that the gain in performance is due to the flexible and low variance VI parametrisation of the proposed approach. We would like to highlight that these results were achieved with minimal tuning and simple NN architectures. We find that the decoupled and amortised drift we propose achieves very strong results making our approach tractable to Bayesian models with local and global structure. Additionally we notice that the architecture used in the drift network can influence results, thus future work in this area should develop the drift architectures further.

A key advantage of our approach is that at training time the objective effectively minimises an ELBO styled objective parameterised via a ResNet. This allows us to monitor training using the traditional techniques from deep learning, without the challenges arising from mixing times and correlation of samples found in traditional MCMC methods; once N-SFS is trained, generating samples at test time is a fast forward pass through a ResNet that does not require re-training. Finally, as we demonstrated, our approach allows the learned sampler to be amortised (Zhang et al. 2018) which not only allows the drift to be more tractably parameterised but also creates the prospects of meta learning the posterior (Edwards and Storkey 2016; Yoon et al. 2018; Gordon et al. 2018; Gordon 2018). We believe that this work motivates how stochastic control paves a new exciting and promising direction in Bayesian ML/DL.