1 Introduction

VAEs (Kingma and Welling 2014; Rezende et al. 2014) are powerful latent variable models that routinely use neural networks to parameterise conditional distributions of observations given a latent representation. This renders such models’ maximum likelihood estimation (MLE) intractable, so one commonly resorts to extensions of expectation-maximisation (EM) approaches that maximise a lower bound on the data log-likelihood. These objectives introduce a variational or encoding distribution of the latent variables that approximates the true posterior distribution of the latent variable given the observation. However, VAEs have shortcomings; for example, they can struggle to generate high-quality images. These shortcomings have been attributed to failures to match corresponding distributions in the latent space. First, the VAE prior can significantly differ from the aggregated approximate posterior (Hoffman and Johnson 2016; Rosca et al. 2018). To alleviate this prior hole phenomenon, previous work has considered more flexible priors, such as mixtures (Tomczak and Welling 2017), normalising flows (Kingma et al. 2016), hierarchical priors (Sønderby et al. 2016; Klushyn et al. 2019), energy-based models (Du and Mordatch 2019; Aneja et al. 2021) or diffusion models (Vahdat et al. 2021; Sinha et al. 2021). Second, the encoding distribution can be significantly different from the true posterior distribution. It has been an ongoing challenge to reduce this approximation error by constructing new flexible variational families such as parametric constructions (Barber and Bishop 1998; Tran et al. 2015; Han et al. 2016; Guo et al. 2016; Abadi et al. 2016; Louizos and Welling 2016; Locatello et al. 2018; Louizos and Welling 2017), with normalising flows (Rezende and Mohamed 2015; Kingma et al. 2016; Papamakarios et al. 2019) being a popular example. Other works resort to auxiliary variables (Ranganath et al. 2016) with implicit (Tran et al. 2017; Mescheder et al. 2017) or semi-implicit (Yin and Zhou 2018; Molchanov et al. 2019; Titsias and Ruiz 2019; Yu et al. 2023) models that require appropriate adjustments to the variational objectives.

This work utilises adaptive MCMC kernels to construct an implicit variational distribution, that, by the reversibility of the associated Markov kernel, decreases an upper bound of the Kullback–Leibler (KL) between an initial encoding distribution and the true posterior. In summary, this paper (i) develops gradient-based adaptive MCMC methods that give rise to flexible implicit variational densities for training VAEs; (ii) shows that non-diagonal preconditioning schemes are beneficial for learning hierarchical structures within VAEs; and (iii) illustrates the improved generative performance for different data sets and MCMC schemes. Our code is available at https://github.com/kreouzisv/smvaes.

2 Background

We are interested in learning deep generative latent variable models using VAEs. Let \(\textsf{X}\subset \mathbb {R}^{d_x}\), \(\textsf{Z}\subset \mathbb {R}^{d_z}\) and assume some prior density \(p_{\theta }(z)\) for \(z \in \textsf{Z}\), with all densities assumed with respect to the Lebesgue measure. The prior density can be fixed or made dependent on some parameters \(\theta \in \Theta \). Consider a conditional density \(p_{\theta }(x|z)\), also called decoder, with \(z \in \textsf{Z}\), \(x \in \textsf{X}\) and parameters also denoted \(\theta \). We can interpret this decoder as a generative network that tries to explain a data point x using a latent variable z. This latent structure yields the following generative distribution of the data

$$\begin{aligned} p_{\theta }(x)= \int _{\textsf{X}} p_{\theta }(x|z)p_{\theta }(z)\textrm{d}z.\end{aligned}$$

Assume a ground truth measure \(p_d\) on \(\textsf{X}\), which can be seen as the empirical distribution of some observed data set. We want to maximise the log-likelihood with respect to \(p_d\), i.e.  \( \max _{\theta \in \Theta } \int _{\mathcal {X}} \log p_{\theta }(x) p_d(\textrm{d}x)\). Variational inference approaches for maximising this log-likelihood proceed by introducing so-called encoder distributions \(q_{\phi }(z|x)\) with parameter \(\phi \in \Phi \). These encoder distributions can be used to construct a tractable surrogate objective which minorises the log-likelihood and becomes tight if the encoder distribution coincides with the posterior distribution. In particular, letting \(\{q_{\phi }(z|x) :\phi \in \Phi \}\) be a parameterised family of encoders, one can define the so-called evidence lower bound (ELBO),

$$\begin{aligned}\mathcal {L}(\theta ,\phi , x)=\mathbb {E}_{q_{\phi }(z|x)} \left[ \log p_{\theta }(x|z) \right] - \textrm{KL}(q_{\phi }(z|x)|p_{\theta }(z)) \end{aligned}$$

averaged over \(x \sim \mu \). Here, \(\textrm{KL}(q(z)|p(z))=\int _{\textsf{Z}} q(z)\left( \log q(z)-\log p(z) \right) \textrm{d}z \ge 0\) denotes the Kullback–Leibler divergence between two densities q and p. Recalling the posterior density \(p_{\theta }(z|x)\propto p_{\theta }(z) p_{\theta }(x|z)\), one can see directly that the ELBO constitutes a surrogate objective that minorises the log-likelihood,

$$\begin{aligned}\mathcal {L}(\theta ,\phi ,x)= \log p_{\theta }(x) - \textrm{KL}(q_{\phi }(z|x)|p_{\theta }(z|x)).\end{aligned}$$

3 Related work

Many approaches have been proposed for combining MCMC with variational inference. Salimans et al. (2015) and Wolf et al. (2016) construct a variational bound on an extended state space that includes multiple samples of the Markov chain. This was extended in Caterini et al. (2018) using tempering and illustrated connections with SMC samplers. Instead of considering variational objectives on augmented state spaces, our approach follows more closely the work of Hoffman (2017), Levy et al. (2018), Hoffman et al. (2019). In particular, we follow their approach to estimate the gradients of the decoder parameters and the initial variational distribution. However, our approach considers an unexplored gradient-based adaptation of the Markov chain that also allows us to learn, for instance, non-diagonal pre-conditioning matrices. Titsias (2017) suggested a model reparameterisation using a transport mapping, while Ruiz and Titsias (2019) suggested using a variational contrastive divergence instead of a KL divergence used herein. Thin et al. (2020) presented a variational objective on an extended space of the accept/reject variables that allows for entropy estimates of the distribution of the final state of the Markov chain. Nijkamp et al. (2020) have used short-run MCMC approximations based on unadjusted Langevin samplers to train multi-layer latent variable models without learning an auto-encoding model, with extensions to learn energy-based priors in Pang et al. (2020). We follow the approach in Pang et al. (2020) for learning the generative parameters, with the difference being that we utilise adaptive and Metropolis-adjusted algorithms, instead of unadjusted Langevin samplers. Ruiz et al. (2021) used couplings for Markov chains to construct unbiased estimates of the marginal log-likelihood for VAEs. The introduced Markov chains are samples from an extension of the iterated sampling importance resampling algorithm (Andrieu et al. 2010) and target an augmented posterior distribution for the IWAE bound (Burda et al. 2015). Non-adaptive MCMC transitions have also been utilised in Wu et al. (2020) to build stochastic normalising flows that approximate the posterior distribution in VAEs, but are trained by minimising a KL divergence between the forward and backward path probabilities, see also Hagemann et al. (2022). More recently, Taniguchi et al. (2022) considered an amortised energy function over the encoder parameters and used a MALA algorithm to sample from its invariant distribution. In particular, they considered a MALA algorithm that operates on the parameter space of the encoder parameters instead of utilising MCMC algorithms in the latent space as in our work. Peis et al. (2022) learn an initial encoding distribution based on a sliced Kernel Stein Discrepancy and then apply a non-adapted HMC algorithm.

4 Training VAEs with MCMC speed measures

In this work, we will construct an approximation to the posterior distribution \(p_\theta (z|x)\) by first sampling from an initial tractable distribution \(q^0_{\phi _0}(z|x)\) and then recursively update z by applying a sequence of K Markov kernels. More precisely, for \(x \in \mathcal {X}, \theta \in \Theta , \phi \in \Phi \), let \(M^k_{\theta ,\phi _k}(\cdot |x)\) denote a parameterised Markov kernel which is reversible with respect to the posterior \(p_\theta (z|x)\). We then define the following variational family

$$\begin{aligned} \mathcal {Q}_x =\{&q^K_{\theta ,\phi }(\cdot |x)=q^0_{\phi _0}(\cdot |x) M^1_{\theta ,\phi _1}(\cdot |x) \cdots M^K_{\theta ,\phi _K}(\cdot |x) ~,\\ {}&\phi _k \in \Phi _k, \phi =(\phi _0, \ldots , \phi _K), \theta \in \Theta \}, \end{aligned}$$

where \((qM)(z'|x)=\int _{\textsf{Z}} q(z|x)M(z, z'|x) \textrm{d}z\) for a conditional density \(q(\cdot |x)\) and Markov kernel \(M(\cdot |x)\) that depends on x. Although \(q_{\theta ,\phi }^K\) can be evaluated explicitly for the choice of Markov kernels considered here (Thin et al. 2020), we do not require this. Instead, we rely on the fact (Ambrosio et al. 2005), Lemma 9.5.4, that due to the reversibility of the Markov kernels with respect to \(p_{\theta }(z|x)\), it holds that

$$\begin{aligned} \textrm{KL}\left( q^K_{\theta ,\phi }(z |x)| p_{\theta }(z|x) \right) \le \textrm{KL}\left( q^0_{\phi _0}(z|x)| p_{\theta }(z|x) \right) . \end{aligned}$$
(1)

The non-asymptotic convergence of the Markov chain depends on the posterior distribution as well as on the specific MCMC algorithm used, see for example Dwivedi et al. (2019), Mangoubi and Vishnoi (2019), Chewi et al. (2021), Wu et al. (2022), Altschuler and Chewi (2023), Chen and Gatmiry (2023) for the MALA case and (Chen et al. 2019b; Lee et al. 2020, 2021) for HMC often under convexity or smoothness with isoperimetry assumptions.

4.1 Learning the warm start distribution

We consider first a standard ELBO

$$\begin{aligned}&{\mathcal {L}_0}(\theta ,\phi _0, x) \\ {}&= \mathbb {E}_{q^0_{\phi _0}(z|x)} \left[ \log p_{\theta }(x|z) \right] - \textrm{KL}(q^0_{\phi _0}(z|x)|p(z)). \nonumber \end{aligned}$$
(2)

Relation (1) motivates to learn \(\phi _0\) by maximising \({\mathcal {L}_0}(\theta ,\phi _0, x)\). Indeed, due to

$$\begin{aligned} \textrm{KL}\left( q^K_{\theta ,\phi }(z |x)| p_{\theta }(z|x) \right)&\le \textrm{KL}\left( q^0_{\phi _0,}(z |x)| p_{\theta }(z|x) \right) \\&= \log p_{\theta }(x) -{\mathcal {L}_0}(\theta ,\phi _0, x),\end{aligned}$$

maximising \({\mathcal {L}_0}(\theta ,\phi _0, x)\) decreases an upper bound of the KL divergence between the variational density \(q^K_{\theta ,\phi }(\cdot |x)\) and the posterior density for fixed \(\theta \) and \(\phi _1, \ldots \phi _K\). While decreasing an upper bound of the KL divergence may not necessarily decrease the actual KL divergence, we found this choice to work well in practice. It also allows to utilised pre-trained encoding distributions with parameters \(\phi _0\) from a standard VAE as a parameter initialisation. On a high level, upper and lower bounds on the mixing time of MALA or HMC for log-concave targets hinge on a well-chosen warm initial distribution, as well as a small condition number of the target distribution, adjusted for the preconditioning matrix of the sampler, see for example Wu et al. (2022); Altschuler and Chewi (2023). For m-strongly convex and L-smooth targets \(p_\theta (z|x)\), one can obtain, see Dwivedi et al. (2019), a \(\beta \)-warm distribution \(q^0_{\phi _0}(\cdot |x)=\mathcal {N}(\mu ^\star ,L^{-1} {\text {I}})\), i.e. it holds that

$$\begin{aligned} \sup _A \frac{\int _A q^0_{\phi _0}(z|x) \textrm{d}z}{\int p_{\theta }(z|x) \textrm{d}z} \le \beta \end{aligned}$$

over all measurable sets A for \(\beta =\kappa ^{d_z/2}\), with condition number \(\kappa =L/m\), where \(\mu ^\star \) is the mode of \(p_\theta (z|x)\). By optimising the bound \({\mathcal {L}_0}(\theta ,\phi _0, x)\), we can expect to find parameters \(\phi _0\) so that the mean of the variational distribution is close to a mode of the true posterior. Because acceptance probabilities in regions that are unlikely under the target \(p_\theta (z|x)\) lead to small acceptance probabilities, warm start distributions lead to faster convergence as they avoid such bottlenecks in the state space. One can obtain warm starts by controlling the forward chi-squared divergence

$$\begin{aligned}\chi ^2(q^0_{\theta ,\phi ^0}(z |x)| p_{\theta }(z|x)) =\int _{\textsf{Z}} \left( \frac{p_\theta (z|x)}{ q^0_{\phi _0}(z|x)} \right) ^2 p_{\theta }(z|x))\textrm{d}z,\end{aligned}$$

or more generally a Reny divergence of order strictly larger than one, see Altschuler and Chewi (2023). Such objectives are more challenging to optimise, with variational approaches typically requiring multiple Monte Carlo samples (Hernandez-Lobato et al. 2016; Finke and Thiery 2019; Geffner and Domke 2021; Li et al. 2023). It may be of interest to explore in future work such different variational objectives for learning \(\phi _0\).

4.2 Markov kernels

We also need to specify the Markov kernels. We use reparameterisable Metropolis-Hastings kernels with the potential function \(U_{\theta }(z|x)=-\log p_{\theta }(x|z)-\log p_{\theta }(z)\) corresponding to the target \(\pi _{\theta }(z)=p_{\theta }(z|x) \propto \exp (-U_{\theta }(z|x))\). More precisely, for \(A \in \mathcal {B}(\textsf{Z})\),

$$\begin{aligned} M^k_{\theta ,\phi _1}(z,A|x) =&\int _{\textsf{Z}} \nu (\textrm{d}v) \Big [ \left( 1- \alpha (z,z') \right) \delta _{z} (A) \\ {}&+\alpha (z, z'|x) \delta _{z'} (A) \Big ]_{z'=\mathcal {T}_{\theta ,\phi _1}(v|z,x)} \\ =&\int _{\textsf{Z}} \Big [ \left( 1- \alpha (z,z') \right) \delta _{z} (A) \\ {}&+\alpha (z, z'|x) \delta _{z'} (A) \Big ] r_{\theta ,\phi _1}(z,z'|x) \textrm{d}z' \end{aligned}$$

where \(\alpha (z,z'|x)\) is an acceptance rate for moving from state z to \(z'\), \(\mathcal {T}_{\theta ,\phi _1}(\cdot |z,x)\) is a proposal mapping and \(\nu \) is a parameter-free density over \(\textsf{Z}\). Expressing the proposal density \(r_{\theta ,\phi _1}(z,z'|x)\) through a proposal mapping \(\mathcal {T}_{\theta ,\phi _1}(\cdot |z,x)\) having as input a parameter-free variable v with density \(\nu \), allows us to apply the reparameterisation trick (Kingma and Welling 2014; Rezende et al. 2014; Titsias and Lázaro-Gredilla 2014). Although the different Markov kernels could have different parameters \(\phi _k\) for \(k \in \{1, \ldots , K\}\), we assume for simplicity that they all share the parameters \(\phi _1\), thereby helping the method to scale more easily to large values of K.

4.3 Speed measure adaptation

For a random walk Markov chain with isotropic proposal density \(r(z,\cdot |x)=\mathcal {N}(z, \sigma ^2 {\text {I}})\) at position z, the speed measure (Roberts et al. 1997) is defined as \(\sigma ^2 \times \alpha (z|x)\), where \(a(z|x)=\int \alpha (z,z'|x) r(z,z'|x) \textrm{d}z'\) is the average acceptance rate. To encourage fast mixing for the Markov chain across all dimensions jointly, Titsias and Dellaportas (2019) suggested a generalisation of this speed measure that amounts to choosing the parameters h and C from the proposal so that the proposal has both high acceptance rates, but also a high entropy

$$\begin{aligned} \mathcal {H}_{\theta ,\phi _1}&=-\int _{\textsf{Z}} r_{\theta ,\phi _1}(z,z'|x) \log r_{\theta ,\phi _1}(z,z'|x) \textrm{d}z'. \end{aligned}$$

More precisely, we consider the generalised speed measure

$$\begin{aligned} s_{\theta ,\phi _1}(z|x)= e^{\beta \mathcal {H}_{\theta ,\phi _1}} \times a(z|x) \end{aligned}$$

for some hyper-parameter \(\beta >0\). While maximising \(s_{\theta ,\phi _1}(z|x)\), or equivalently,

$$\begin{aligned}\log s_{\theta ,\phi _1}(z|x)=&\log \left[ \int _{\textsf{Z}} \alpha (z,\mathcal {T}_{\theta ,\phi _1}(v|z,x)) \nu (\textrm{d}v) \right] \\ {}&+ \beta \mathcal {H}_{\theta ,\phi _1}, \end{aligned}$$

is intractable, we follow Titsias and Dellaportas (2019) and maximise a lower bound thereof due to Jensen’s inequality,

$$\begin{aligned}&\log s_{\theta ,\phi _1}(z|x) \ge \mathcal {F}(\phi _1, z,x) \\&= \left[ \int _{\textsf{Z}} \log \alpha (z,\mathcal {T}_{\theta ,\phi _1}(v|z,x)) \nu (\textrm{d}v) + \beta \mathcal {H}_{\theta ,\phi _1} \right] , \end{aligned}$$

averaged over \((x, z) \sim \mu (x) q^{0}_{\phi _0}(z|x)\) where \(\beta >0\) is some hyper-parameter that can be updated online to achieve a desirable average acceptance rate \(\alpha ^\star \).

4.4 MALA

Consider first a Metropolis Adjusted Langevin Algorithm (MALA). We assume that \(\phi _1\) parameterises a non-singular matrix C, possibly dependent on x, which can be, for instance, a diagonal matrix or a Cholesky factor. In this case, we can write the proposed state \(z'\) as

$$\begin{aligned} z'=\mathcal {T}_{\theta ,\phi _1}(v|z,x) = z-\frac{h^2}{2} CC^\top \nabla U_{\theta }(z|x)+h Cv \end{aligned}$$
(3)

for some step size \(h>0\) that is part of the parameter \(\phi _1\) and where \(v\sim \nu =\mathcal {N}(0,{\text {I}})\). The log-acceptance rate is \(\log a(z, z')=\min \{0, -\Delta (v,z,z')\}\) based on the energy error

$$\begin{aligned} \Delta (v, z,z')&= U_{\theta }(z'|x)-U_{\theta }(z|x) -\frac{1}{2}{\Vert v \Vert }^2\\&\quad + \frac{1}{2} {\Big \Vert v- \frac{h}{2} C \left\{ \nabla U_{\theta }(z|x)+ \nabla U_{\theta } (z'|x) \right\} \Big \Vert }^2, \end{aligned}$$

evaluated at \(z'=\mathcal {T}_{\theta ,\phi _1}(v|z,x)\). The proposal density of the Markov kernel

$$\begin{aligned}r_{\theta ,\phi _1}(z,z'|x) = \mathcal {N}\left( z-\frac{h^2}{2} CC^\top \nabla U_{\theta }(z|x), h^2CC^\top \right) \end{aligned}$$

can be viewed as the pushforward density of \(\mathcal {N}(0,{\text {I}})\) with respect to the transformation \(\mathcal {T}_{\theta ,\phi _1}(v|z,x)\). Its entropy is

$$\begin{aligned} \mathcal {H}_{\theta ,\phi _1}&=-\int _{\textsf{Z}} r_{\theta ,\phi _1}(z,z'|x) \log r_{\theta ,\phi _1}(z,z'|x) \textrm{d}z'\\&= \text {const} +\log |\det (hC)|, \end{aligned}$$

which is constant for \(z \in \textsf{Z}\) in the standard MALA case, although it can depend on x for MALA with state-dependent proposals.

4.5 HMC

Consider next a Hamiltonian Monte Carlo Algorithm (HMC) based on a leapfrog or velocity Verlet integrator with L steps (Hairer et al. 2003; Bou-Rabee and Sanz-Serna 2018). We assume that \(\phi _1\) parameterises a Cholesky factor matrix C of the inverse mass matrix \(M^{-1}=CC^\top \). The proposed state \(z'=q_L\) is commonly computed recursively for \(\ell \in \{1, \ldots , L\}\) via

$$\begin{aligned} p_{\ell + \frac{1}{2}}&= p_{\ell } - \frac{1}{2} \nabla U_{\theta }(q_{\ell }|x) \\ q_{\ell +1}&= q_{\ell } + h M^{-1} p_{\ell + \frac{1}{2}} \\ p_{\ell +1}&= p_{\ell +\frac{1}{2}} - \frac{1}{2} \nabla U_{\theta }(q_{\ell +1}|x), \end{aligned}$$

where \(p_{\ell }\) is a sequence of momentum variables initialised at \(p_0=C^{-\top }v\) for \(v \sim \mathcal {N}(0,{\text {I}})\). It is possible (Livingstone et al. 2019; Durmus et al. 2017) to write the proposed state \(z'=\mathcal {T}_{\theta , \phi _1}(v|z,x)\) in the representation

$$\begin{aligned} z'=z-\frac{Lh^2}{2} CC^\top \nabla U_{\theta }(z|x)+LhCv - h^2CC^\top \Xi _{L}(v)\end{aligned}$$

where

$$\begin{aligned} \Xi _L(v)=\sum _{\ell =1}^{L-1}(L-\ell ) \nabla U_{\theta }(q_{\ell }) \end{aligned}$$
(4)

is a weighted average of the potential energy gradients along the leapfrog trajectory. Consequently, the proposal density can be written as

$$\begin{aligned} \log r_{\theta ,\phi _1}(z,\mathcal {T}_L(v)) =&\log \nu (v) - d \log L - \log |\det C| \\ {}&- \log \left| \det \textsf{D}\Xi _L(v) \right| , \end{aligned}$$

where \(\textsf{D}\Xi _L(v)\) is the Jacobian of the non-linear function \(\Xi _L\) in (4). However, the computational complexity of evaluating the log-determinant of the Jacobian of \(\Xi _L\) scales poorly for high dimensional latent variables. We, therefore, consider the approximation suggested in Hirt et al. (2021) based on a local Gaussian assumption that the Hessian of the potential function \(U_{\theta }\) along the leapfrog trajectory can be approximated by its value at the mid-point \(q_{\left\lfloor L/2 \right\rfloor }\) of the trajectory. Under this assumption, the log-determinant of the Jacobian can be written as

$$\begin{aligned}&\log \left| \det \textsf{D}\Xi _L(v) \right| \\&\quad \approx \log \left| \det \left( {\text {I}}- \frac{L^2-1}{6} C^\top \nabla ^2 U_{\theta }(q_{\left\lfloor L/2 \right\rfloor } |x) C \right) \right| , \end{aligned}$$

which can be estimated by resorting to Russian roulette estimators (Behrmann et al. 2019; Chen et al. 2019a). The above approximation becomes exact for Gaussian targets with covariance matrix \(\Sigma \), since \(\nabla ^2 U_{\theta }(q)=\Sigma ^{-1}\) for any point q in the state space.

4.6 Learning the generative model

Maximizing the log-likelihood function directly using

$$\begin{aligned} \nabla _{\theta } \log p_{\theta }(x)=\int _{\textsf{Z}} p_{\theta }(z|x) \nabla _{\theta } \log p_{\theta }(x,z) \textrm{d}z\end{aligned}$$

is usually intractable as it requires samples from \(p_{\theta }(z|x)\). On the other hand, optimizing the generative parameters by optimizing the classic variational bound \({\mathcal {L}_0}(\theta ,\phi _0, x)\) based on the initial variational distribution does not allow us to leverage samples from the MCMC chain. Conversely, using a variational bound based on the implicit variational distribution \(q^K_{\theta ,\phi }(z|x)\) requires more refined approaches to compute its entropy (Thin et al. 2020). Instead, we use samples from an MCMC chain in conjunction with a perturbation of the MLE, as used previously, see, for instance, Han et al. (2017), Hoffman (2017), Nijkamp et al. (2020). More precisely, at iteration t, let \(\theta ^{(t)}\) and \(\phi ^{(t)}\) be the current estimate of the generative and variational parameters. Since maximising the log-likelihood is equivalent to minimising the KL divergence loss \(D(\theta )=\textrm{KL}(p_d(x)| p_{\theta }(x))\) over the generative parameters \(\theta \), we consider the following perturbed loss function

$$\begin{aligned} S(\theta )&= D(\theta )+ \textrm{KL}(q^K_{\theta ^{(t)},\phi ^{(t)}}(z|x)|p_{\theta }(z|x))) \\&=\textrm{KL}(p_d(x) q^K_{\theta ^{(t)},\phi ^{(t)}}(z|x) | p_\theta (z,x))), \end{aligned}$$

see also Pang et al. (2020), Han et al. (2020). Note first that \(S(\theta )\) becomes a tractable objective as it involves joint distributions over the latent variables and the data, in contrast to the log-likelihood objective involving marginal distributions. Second, \(S(\theta )\) majorises \(D(\theta )\), that is \(S(\theta )\ge D(\theta )\). An EM-type algorithm would update \(\theta ^{(t)}\) to \(\theta ^{(t+1)}\) by minimising \(S(\theta )\) for fixed variational parameters \(\phi ^{(t)}\) so that \(S(\theta ^{(t+1) })\le S(\theta ^{(t)})\). We consider instead an alternating approach that follows the gradient of \(\nabla S(\theta ^{(t)})\) given by the average of

$$\begin{aligned} \int _{\textsf{Z}} q^K_{\theta ^{(t)},\phi ^{(t)}}(z|x) \left[ \nabla _{\theta } \log p_{\theta }(z) + \nabla _{\theta }\log p_{\theta }(x|z) \right] \textrm{d}z \end{aligned}$$

over \(x \sim p_d\), while also updating the variational and MCMC parameters \(\phi ^{(t)}\) in a single iteration.

4.7 Algorithm

Pseudo-code for the suggested algorithm is given in Algorithm 1 at a given iteration t, for illustration based on a mini-batch of size one. We have found that pre-training the decoder and encoder parameters \(\theta \), respectively \(\phi _0\), by optimizing the standard ELBO (2) before applying Algorithm 1, can decrease the overall training time. While we only consider MALA or HMC proposals in our experiments, other proposals with a tractable entropy, for instance those suggested in Li et al. (2020), can be used analogously.

Algorithm 1
figure a

Single training step for updating the generative model, initial encoding distribution and MCMC kernel.

5 Extension to hierarchical VAEs

We consider top-down hierarchical VAE (hVAE) architectures. Such models can leverage multiple layers L of latent variables \((z^1, \ldots , z^L)\), \(z^\ell \in \mathbb {R}^{n_{\ell }}\), where \(z^L\) is the latent variable at the top and \(z^1\) the latent variable at the bottom. Often \(n_{\ell +1} \le n_\ell \) to account for multiple resolutions. The generation of the latent variables follows the same order in both the prior

$$\begin{aligned} (z^1, \ldots , z^L) \sim p_{\theta }(z^1) p_{\theta }(z^2|z^1) \cdots p_{\theta }(z^L| z^{\le L-1}) , \end{aligned}$$
(5)

for \(z^{\le \ell }=(z^1, \dots z^\ell )\), and in the approximate posterior,

$$\begin{aligned} (z^1, \ldots , z^L)|x \sim q^0_{\phi _0,\theta }(z^1|x) \cdots q^0_{\phi _0,\theta }(z^L|x, z^{\le L-1}) \end{aligned}$$
(6)

cf. Sønderby et al. (2016), Kingma et al. (2016), Nijkamp et al. (2020), Maaløe et al. (2019), Vahdat and Kautz (2020), Child (2021). More concretely, to build the auto-regressive densities, we consider a sequence of variables \(d^{\ell } \in \mathbb {R}^{n'_{\ell }}\) that are deterministic given \(z^{\ell }\) and defined recursively as

$$\begin{aligned} d^{\ell }=h_{\ell ,\theta }(z^{\ell -1}, d^{\ell -1}) \end{aligned}$$
(7)

for some neural network function \(h_{\ell ,\theta }\), where the \(d^{\ell -1}\)-argument is a possible skip connection in a residual architecture for \(\ell >1\) and some constant \(d^1\). This implies that the dependence on all previous latent variables \(z^{\le \ell }\) is implemented via the first-order Markov model of the residual discrete states \(d^1, \dots , d^\ell \). Suppose further that we instantiate (5) in the form

$$\begin{aligned} z^{\ell } = \mu _{\ell ,\theta }(d^{\ell }) + \sigma _{\ell ,\theta }(d^{\ell }) \odot \epsilon ^{\ell } \end{aligned}$$
(8)

for some functions \(\mu _{\ell ,\theta }\) and \(\sigma _{\ell ,\theta }\), with \(\epsilon ^{\ell }\) denoting iid Gaussian random variables. This construction leads to the auto-regressive structure in the prior (5). To describe the variational approximation in (6), we consider a bottom-up network that defines deterministic variables \(d^{'\ell }\in \mathbb {R}^{n'_{\ell }}\) recursively by setting \(d^{'L+1}=x\) and \(d^{'\ell }=h'_{\ell , \phi _0}(d^{'\ell +1})\) for \(1 \le \ell \le L\) for functions \(h'_{\ell ,\phi _0}\). We assume a residual parameterisation (Vahdat and Kautz 2020; Vahdat et al. 2021) for \(q^0_{\phi _0}(z^{\ell }|x, z^{\le \ell -1})\) in the form

$$\begin{aligned} z^{\ell }&=\mu _{\ell ,\theta }(d^{\ell }) + \sigma _{\ell ,\theta }(d^{\ell }) \mu '_{\ell ,\phi }(d^{\ell },d^{'\ell }) \nonumber \\ {}&\quad + (\sigma _{\ell ,\theta }(d^{\ell }) \sigma '_{\ell ,\phi _0}(d^{\ell },d^{'\ell }) ) \odot \epsilon ^{\ell } \end{aligned}$$
(9)

for some functions \(\mu '_{\ell ,\phi _0}\) and \(\sigma '_{\ell , \phi _0}\). This implies that

$$\begin{aligned}&\textrm{KL}(q^0_{\phi _0}(z^\ell |x,z^{\le \ell -1})|p_{\theta }(z^\ell |z^{\le \ell -1})) \\&= \frac{1}{2} \Bigg [\sum _{i=1}^{n^\ell } \sigma _{\ell , \phi _0}'(d^{\ell }, d^{'\ell })_i ^2 - n^\ell + \mu '_{\ell ,\phi }(d^{\ell },d^{'\ell })_i^2 \nonumber \\ {}&+ \log \sigma _{\ell , \phi _0}'(d^\ell , d^{'\ell })_i^2 \Bigg ]. \nonumber \end{aligned}$$
(10)

The observations x are assumed to depend explicitly only on \(z^L\) and \(d^L\) through some function \(g_{\theta }\) in the sense that \(x|z^1, \ldots , z^L \sim p_{\theta }(x| g_{\theta }(z^L))\). The generative model of the latent variables \(z^1, \ldots z^L\) in (5) is written in a centred parameterisation (Papaspiliopoulos et al. 2007) that makes them dependent a priori. Our experiments will illustrate that these dependencies can make it challenging to sample from the posterior distribution for MCMC schemes that are not adaptive.

We want to clarify that we interpret a hVAE as a special case of the VAE with a hierarchical structure of the latent variables \(z=(z^1, \ldots , z^L) \in \mathbb {R}^{n}\), \(z^\ell \in \mathbb {R}^{n_\ell }\), \(n=\sum _{i=1}^L n_\ell .\) An alternative viewpoint would be to consider the VAE in Sect. 4 that utilises MCMC steps as a hierarchical VAE wherein each step of the Markov chain corresponds to a new layer of an hVAE, with all latent variables \(z_0, \ldots , z_K \in \mathbb {R}^n\) living in the same latent space. More precisely, from such an alternative perspective, the latent variable \(z_0\) sampled from the prior \(p_\theta \) or initial encoder \(q^0_{\phi _0}\) can be seen as the first latent variable of an hVAE at the bottom layer, while the transition densities in the generative auto-regressive distributions in (5) are modelled as Metropolis-Hastings kernels. In our viewpoint, we minimise the KL of the joint latent variables \((z^1, \ldots , z^L)\) as in (10) for learning the initial variational parameters \(\phi _0\) that parameterises the encodings of \((z^1, \ldots , z^L)\) jointly. However, performing variational inference in the alternative viewpoint would require different approaches. Our approach also differs from those in score-based diffusion models that utilise score functions to transition between hierarchical latent variables, see Appendix A for details.

6 Numerical experiments

6.1 Evaluating model performance with marginal log likelihood

We start by considering different VAE models and inference strategies on four standard image data sets (MNIST, Fashion-MNIST, Omniglot and SVHN) and evaluate their performance in terms of their test log-likelihood estimates.

Table 1 Importance sampling estimate of the log-likelihood (with highest values in bold for each setup of either a standard Gaussian or a learnable prior) on the test set based on \(S=10000\) and \(\tau = 1.5\)
Table 2 Estimates of KID for each model considered across different datasets

6.2 Marginal log-likelihood estimation

We start to evaluate the performance of different variations of VAEs using the marginal log-likelihood of the model on a held-out test set for a variety of benchmark datasets. In doing so, we resort to importance sampling to estimate the marginal log-likelihood using S importance samples via

$$\begin{aligned}\log \hat{p}_{\text {IS}} (x) = \log \frac{1}{S}\sum _{s=1}^{S}\frac{p_\theta (x|z_{s})p_\theta (z_{s})}{r(z_{s}|x)} ~, z_s \sim r(\cdot |x),\end{aligned}$$

where r is an importance sampling density. Following Ruiz and Titsias (2019), in the case of a standard VAE, we choose \(r(z|x)=\mathcal {N}(\mu _{\phi _0}^z(x), \tau \Sigma _{\phi _0}^z(x))\) for some scaling constant \(\tau \ge 1\), assuming that \(q^0_{\phi _0}(z_0|x)=\mathcal {N}(\mu _{\phi _0}^z(x), \Sigma _{\phi _0}^z(x))\) with diagonal covariance matrix \(\Sigma _{\phi _0}^z(x)\). For the case with MCMC sampling using K steps, we choose \(r(z_s|x)= \mathcal {N}(z_K(x), \tau \Sigma _{\phi _0}^z(x))\), where \(z_K(x)\) is an estimate of the posterior mean from the MCMC chain.

6.3 VAE models

Using the metric described above, we evaluate our model and compare it against other popular adjustments of VAEs for various data sets. In terms of comparing models, we focused on comparing our model, denoted VAE-gradMALA and VAE-gradHMC, against i) a Vanilla VAE, ii) VAEs utilising MCMC samplers that are adapted using a dual-averaging scheme (Hoffman and Gelman 2014; Nesterov 2009) that we refer to as VAE-gradMALA and VAE-gradHMC. We also compare against iii) VAEs using more expressive priors such as a Mixture of Gaussians (MoG), denoted VAE-MoG cf. Jiang et al. (2017), Dilokthanakul et al. (2016), or a Variational Mixture of Posteriors Prior (VAMP), see Tomczak and Welling (2017), denoted VAE-VAMP. For the MNIST example, we consider a Bernoulli likelihood with a latent space of dimension 10. We pre-trained the model for 90 epochs with a standard VAE, and subsequently trained the model for 10 epochs with MCMC. We used a learning rate of 0.001 for both algorithms. For the remaining datasets, we pre-trained for 290 epochs with a standard VAE, followed by training for 10 epochs with MCMC. We used a learning rate of 0.005, while the dimension of the latent space is 10, 20, and 64 for Fashion-MNIST, Omniglot and SVHN, respectively. For the SVHN dataset, we considered a 256-logistic likelihood with a variance fixed at \(\sigma ^2 = 0.1\), see Salimans et al. (2017) for details. In terms of the neural network architecture used for the encoder and the decoder, more information can be found in the codebase. All models use the same decoders and (initial) encoding distributions. The inference times of the models trained either with dual-averaging or with the gradient-based generalised speed measure objective are comparable. We use \(K=10\) MCMC steps.

6.4 Experimental results

Table 1 summarises the estimated log-likelihoods for the different data sets. The results therein show the means of three independent runs, with their standard deviations in brackets. For the case of SVHN, the estimate is transformed to be represented in bits per dimension. We observe that among the considered methods that utilize MCMC samplers within VAEs, our approach performs better across the datasets we explored. We note that for the considered decoder and encoder architectures, the use of more flexible generative models by using more flexible priors such as a VAMP prior, can yield higher log-likelihoods. However, the choice of more flexible priors is completely complementary to the inference approach suggested in this work. Indeed, we illustrate in Sects. 6.7 and 6.8 that our MCMC adaptation strategy performs well for more flexible hierarchical priors.

6.5 Evaluating generative performance with kernel inception distance (KID)

6.5.1 Generative metrics

The generative performance of our proposed model is additionally quantitatively assessed by computing the Kernel Inception Distance (KID) relative to a subset of the ground truth data. We chose the KID score instead of the more traditional Fréchet inception distance (FID), due to the inherent bias of the FID estimator (Bińkowski et al. 2018). To compute the KID score, for each image from a held-out test set, we sample a latent variable from the prior density and then pass it through the trained decoder of the corresponding model to generate a synthetic image. Images are resized to (150,150,3) using the bi-cubic method, followed by a forward pass through an inception-V3 model using the Imagenet weights. This yields a set of Inception features for the synthetic and held-out test set. The computation of the KID score for these features is based on a polynomial kernel, similarly to Bińkowski et al. (2018). For all datasets, we utilised a learning rate of 0.001 for both the VAE and MCMC algorithms. We trained the VAE for 100 epochs and performed sampling with the MCMC algorithms for 50 epochs if applicable, yielding a total training of 150 epochs across all cases. The likelihood functions used were Bernoulli for the MNIST and Fashion-MNIST datasets, while the logistic-256 likelihood (Salimans et al. 2017) was used for the SVHN and Cifar-10 datasets, with a fixed variance of \(\sigma ^2 = 0.1\) and \(\sigma ^2 = 0.05\), respectively. The dimension of the latent variable was fixed to 10 for the MNIST datasets, while it was set to 64 and 256 for the SVHN and CIFAR-10 datasets. More details regarding the neural network architecture used for training the VAE can be found in the codebase.

6.5.2 VAE models and quantitative evaluation

Similarly to Sect. 6.1, we perform a series of experiments that compare our adaptation scheme to other popular VAE modifications across different data sets. In Table 2, we summarise the results of our experiments reporting mean KID scores from three different seeds with the standard deviation in brackets. We notice a similar pattern to that from Sect. 6.1, where our proposed method outperforms other MCMC-related methods. At the same time, we observe that models with more expressive priors such as the VAMP prior, can perform equally or slightly better, particularly in the case of a low-dimensional latent state space, such as for MNIST and Fashion-MNIST. However, in the case of higher dimensional latent space, such as used for CIFAR-10 with \(d_z=256\), we observe that our method shows considerable improvement compared to the other methods.

6.5.3 Qualitative results

In addition to computing the KID score, we qualitatively inspect the reconstructed images and the images sampled from the model. In Fig. 1, we can see reconstruction for the best three performing models. Figure 2 contains unconditionally generated samples for the same models. We observe that, indeed, KID scores qualitatively correlate with more expressive generations and reconstructions. In particular, we observe a slight decrease in blurriness and an increase in the resolution of smaller details such as the car-light of the red car in Fig. 1. Moreover, the unconditionally generated images in Fig. 2 exhibit more expressive colour patterns.

Fig. 1
figure 1

Model reconstruction images for the top three performing models tested on CIFAR-10 in terms of the KID-score evaluated on model samples. The first two rows illustrate the ground truth, the next two show reconstructions from the Vanilla VAE model, the next two illustrate reconstructions from the dsHMC model, and the last two rows illustrate reconstructions from the gradHMC coupled VAE

Fig. 2
figure 2

Model sampled images for the top three performing models tested on CIFAR-10 in terms of the KID-score evaluated on model samples. The first three rows illustrate samples from the Vanilla VAE model, the next three rows illustrate samples from the dsHMC model, and the last three rows illustrate samples from the gradHMC coupled VAE

6.6 Evaluating model performance in small sample size data

6.6.1 Data augmentation task

In addition to testing our proposed approach against the above benchmark datasets, we also test our approach in a real-world dataset comprised of complex images that, however, are characterised by a relatively small sample size. We chose the Alzheimer’s Disease Neuroimaging Initiative (ADNI)Footnote 1 brain MRI dataset, which is comprised of 4000 Brain MRI Scans of individuals suffering from Dementia, and individuals from a healthy control group, in a ratio of 1:3, respectively. The small sample size as well as the imbalance in the dataset pose a problem for classification tasks that are often addressed by different data augmentations. We illustrate here that the proposed generative model can be used to generate additional samples that are representative of each sub-population in the dataset, namely healthy controls and diseased individuals. We first trained VAEs for each separate class on the dataset,Footnote 2 using a VAE learning rate of 0.001 and an MCMC learning rate of 0.01, whenever applicable.

6.6.2 Generative performance

The VAEs were trained for 2000 epochs with 100 epochs of MCMC coupled training, whenever it was applicable. The KID score presented in Table 3 was based on the whole dataset (that is, including both training and test sets), because the KID score can under-perform for the small size of the test set in the minority class. The neural networks utilised in the encoder and the decoder were similar to those of Sect. 6.5, consisting of two dense layers of 200 units each for the decoder and the encoder. Moreover, the latent dimension for all experiments was fixed at 20, while the likelihood utilised was a logistic-256 (Salimans et al. 2017) with a fixed variance of \(\sigma ^2 = 0.05\). After training, a series of 200, 500, 1000, and 2000 images were generated for the minority class, which were then augmented with the generated images. Classification performance for classifier models trained on this augmented dataset was then compared against classifier models trained on the non-augmented dataset. More details regarding the architectures used for the VAE and classifier models can be found in the codebase. We observed that one obtains the best performance in terms of the classification metrics for the dataset augmented with 200 images and we thus report these values in Table 3. We find that a VAE with a gradient-based adaptation of the HMC sampler has better generative performance, particularly for the dementia group. The minority class, i.e.  the dementia group, was augmented by the addition of synthetic data from the generative models. Qualitative results showing the generated samples are given in Fig. 3 for the standard VAE model, in addition to those VAE models that are combined with MCMC. We notice that our proposed method captures more brain characteristics for both the demented and normal patients, due to the presence of various brain structures throughout the generated samples, while also capturing class-specific characteristics, such as a greater degree of brain matter loss in the dementia class.

Table 3 Estimates of the KID score for each respective class in the ADNI brain MRI dataset and classification metrics from the data augmentation task across different models
Fig. 3
figure 3

Model samples from VAE variations trained on either demented (first four columns) or normal patients (last four columns). The first two rows are samples from the Vanilla VAE model, the next two rows from the VAE using dual-average adaptation and the last two rows from our proposed method using a VAE with entropy-based adaptation

6.6.3 Classification results

We performed a classification between the two groups with results summarised in Table 3. It illustrates first that augmenting data with a trained VAE improves the classification in general, and second, that augmentations with our proposed method lead to a small, yet significant increase in Balanced Accuracy, True Positive Rate (TPR) and True Negative Rate (TNR). These results are consistent with the improved quality of the generated samples using our approach and we thus believe that our method can be leveraged for effective data augmentations.

Table 4 Condition Number \(\kappa (\Sigma _{z|x}^{-1})\) of the posterior distribution for a smaller and larger linear hierarchical VAE model, each consisting of latent variables in two layers, having the dimensions (10,20) and (50,100), respectively

6.7 Linear hierarchical VAEs

We consider linear Gaussian models with a Gaussian prior \(p_{\theta }(z)=\mathcal {N}(\mu _z,\Sigma _z)\) and a linear decoder mapping so that \(p_{\theta }(x|z)=\mathcal {N}(Wz+b, \Sigma _{x|z})\) for \(\mu _z, b \in \mathbb {R}^{d_z}\), \(W\in \mathbb {R}^{d_x \times d_z}\) and covariance matrices \(\Sigma _z\) and \(\Sigma _{x|z}\) of appropriate dimension. The resulting generative model corresponds to a probabilistic PCA model (Tipping and Bishop 1999), see also Dai et al. (2018), Lucas et al. (2019) for further connections to VAEs. This section aims to illustrate that adaptation with a non-diagonal pre-conditioning matrix becomes beneficial to account for the dependence structure of the latent variables prevalent in such hierarchical models.

6.7.1 Hierarchical generative model

We can sample from the Gaussian prior \(z\sim \mathcal {N}(\mu _z, \Sigma _z)\) in a hierarchical representation using two layers:

$$\begin{aligned} z^1 \sim \mathcal {N}(0, {\text {I}}), \quad z^2|z^1 \sim \mathcal {N}(A_2z^1+c_2^{\mu }, \Lambda _{z^2|z^1})) , \end{aligned}$$
(11)

where \(z=(z^1,z^2)\) and \(\Lambda _{z^2|z^1}=\text {diag}(\sigma _{z^2|z^1}^2)\). To recover (11) from the general auto-regressive prior factorisation (5), assume that \(d^1=0 \in \mathbb {R}^{n_1'}\), \(n_1'=n_1'\). For \(d=(d^{\mu },d^{\sigma })\), suppose that \(\mu _{1,\theta }(d)=d^{\mu }\) is the projection on the first \(n_1\) components, while \(\sigma _{1,\theta }(d)=\exp (0.5 d^{\sigma })\) describes the standard deviation based on the last \(n_1\) components. Further, consider the linear top-down mapping

$$\begin{aligned}h_{2, \theta }:(z^1, d^1) \mapsto d^2=\begin{bmatrix} A_2 &{} B_2 \\ 0 &{} 0 \end{bmatrix} \begin{bmatrix} z^1 \\ d^1 \end{bmatrix} + \begin{bmatrix} c_2^\mu \\ c_2^\sigma \end{bmatrix}, \end{aligned}$$

for the deterministic variables, where \(c_2^\sigma =2 \log \sigma _{z^2|z^1}\). We assume the same parameterisation for the prior densities of \(z^2\) given \(d^2\) as in the first layer: \(\mu _{2,\theta }(d)=\mu _{1,\theta }(d)=d^\mu \), and \(\sigma _{2,\theta }(d)=\sigma _{1, \theta }(d)=\exp (0.5 d^{\sigma })\). We assume further that the decoder function depends explicitly only on the latent variables \(z^2\) and \(d^2\) at the bottom in the form of

$$\begin{aligned} p_{\theta }(x|z)&=\mathcal {N}(W_2^z z^2 +W_2^d d^2 +b, \Sigma _{x|z})\\ {}&= \mathcal {N}(W z+b+W_2^d c_2^{\mu }, \Sigma _{x|z}), \end{aligned}$$

for \(W=\begin{bmatrix} W_2^d A_2&W^z_2 \end{bmatrix}\). Observe that the covariance matrix of the prior density is

$$\begin{aligned} \Sigma _z = \begin{bmatrix} {\text {I}}&{} (A_2)^\top \\ A_2 &{} A_2 A_2 ^\top + {\text {I}}\end{bmatrix}.\end{aligned}$$

The marginal distribution of the data is \(x \sim \mathcal {N}(\mu _x, \Sigma _x)\), where \(\mu _x=W_2^z c_2^{\mu } + b\) and

$$\begin{aligned}\Sigma _x=W \Sigma _z W^\top + \Sigma _{x|z}.\end{aligned}$$

The covariance matrix of the posterior density becomes

$$\begin{aligned} \Sigma _{z|x} = \Sigma _z - (W\Sigma _z)^\top \Sigma _x^{-1} W\Sigma _z . \end{aligned}$$
(12)

Depending on the conditioning number of \(\Sigma _{z|x}\), this matrix can be poorly-conditioned, which can hinder the performance of non-adaptive MCMC methods. Particularly for models that infer a high dependence between \(z^1\) and \(s^2\), the prior covariance \(\Sigma _z\) can be ill-conditioned, which can lead to ill-conditioned posteriors. By contrast, with suitable preconditioning, we can expect MALA, HMC, and other MCMC methods to become more performant at sampling from the posterior distribution.

6.7.2 Encoding model

Assume a linear encoder model based on a linear bottom-up model so that \(d_3'=x\), and for \(1 \le \ell \le 2\), suppose that \(d^{'\ell }=W_{\ell }'d^{'\ell +1}+b_{\ell }'\) are bottom-up deterministic variables. We construct an encoding distribution by setting

$$\begin{aligned} \mu '_{\ell ,\theta } :(d^{\ell },d^{' \ell }) \mapsto B'_{\ell } \begin{bmatrix} d^{\ell } \\ d_{\ell }' \end{bmatrix} + c_{\ell }' \end{aligned}$$

and \( \sigma '_{\ell ,\theta }:(d^{\ell },d^{' \ell }) \mapsto \exp (b_{\ell }')\) in the residual pasteurisation (9).

6.7.3 Experimental results

We first test if the adaptation scheme can adapt to the posterior covariance \(\Sigma _{x|z}\) given in (12) of a linear hVAE model, i.e.  if the condition number of \(C\Sigma _{x|z}C^\top \) becomes small. As choices of C, we consider (i) a diagonal preconditioning matrix (denoted D) and (ii) a lower-triangular preconditioning matrix (denoted LT). Note that the dual-averaging adaptation scheme used here and in Hoffman and Gelman (2014) adapts only a single step-size parameter, thereby leaving the condition number unchanged. We tested two simulated data sets with corresponding latent dimensions \((n_1,n_2)\) of (10,20) and (50,100). More specifically, we simulated datasets with 1000 samples for each configuration, using the linear observation model with a standard deviation of 0.5. We used a hierarchical VAE with two layers and a learning rate of 0.001. For the dataset from the model with a latent dimension of (10,20), we pre-trained the VAE for 1000 epochs without MCMC, followed by training for 1000 epochs with MCMC. The number of MCMC steps was fixed at \(K=2\). For the dataset from the model generated from a higher dimensional latent space of dimension (50,100), we increased the number of training epochs from 1000 to 5000, while also increasing the number of MCMC steps from \(K=2\) to \(K=10\). For different choices for the size of the latent variables, Table 4 shows that both gradient-based adaptation schemes lead to a very small transformed condition number \(\kappa (C^\top \Sigma _{z|x}^{-1} C)\) when a full preconditioning matrix is learnt, with smallest values in bold for each configuration of the latent dimensions. Notice also that for all models, the posterior becomes increasingly ill-conditioned for higher dimensional latent variables, as confirmed by the large values of \(\kappa (\Sigma _{z|x}^{-1})\) in Table 4.

Table 5 Difference between true and estimated data log-likelihood \(\log p_{\theta }(x)\) for hierarchical VAEs with two layers and where the dimension of the latent variables \((z^1,z^2)\) are set to (10,20) and (50,100), respectively

In addition to the condition number, we also investigate how the adaptation scheme affects the learned model in terms of the marginal log-likelihood, which is analytically tractable. The results summarised in Table 5 show that the gradient-based adaptation schemes indeed achieve a higher log-likelihood.

6.8 Non-linear hierarchical VAEs

Finally, we investigate the effect of combining MCMC with hVAE in the general non-linear case for hierarchical models. More precisely, we follow the general model setup in Sect. 5, which differs from the linear examples above by the inclusion of a ReLU activation in the considered neural networks. We consider a hVAE with two layers of size 5 and 10. The learning rate of the hVAE and MCMC algorithms was set to 0.001. We use 200 epochs for training overall. For models that included MCMC sampling, we used the first 190 epochs for pre-training without MCMC. Additionally, the prior of the model was trained only during the hVAE portion of the algorithm. The resulting KID scores for MNIST and Fashion-MNIST can be found in Table 6. In this scenario, our proposed method outperforms other sampling schemes when combined with a hVAE model.

Table 6 Estimates of KID for each model considered across different datasets with lowest KID scores for each dataset in bold

7 Conclusion

We have investigated the performance effect of training VAEs and hierarchical VAEs with MCMC speed measures and subsequently compared our proposed method with other widely used adaptive MCMC adaptations and VAE model variations. Adopting recent advances in the adaptive MCMC literature that are based on the notion of a generalised speed measure seem to provide, in the problems and datasets we tested, a more efficient learning algorithm for VAEs. Future research directions may focus on using our proposed method in models with deeper architectures in the encoder and the decoder, using our method in challenging inpainting problems and exploring its power at alleviating adversarial attacks as seen in Kuzina et al. (2022).