1 Introduction

Most MCMC algorithms contain user-controlled hyperparameters which need to be carefully selected to ensure that the MCMC algorithm explores the posterior distribution efficiently. Optimal tuning rates for many popular MCMC algorithms such the random-walk (Gelman et al. 1997) or Metropolis-adjusted Langevin algorithms (Roberts and Rosenthal 1998) rely on setting the tuning parameters according to the Metropolis-Hastings acceptance rate. Using metrics such as the acceptance rate, hyperparameters can be optimized on-the-fly within the MCMC algorithm using adaptive MCMC (Andrieu and Thoms 2008; Vihola 2012). However, in the context of stochastic gradient MCMC (SGMCMC), there is no acceptance rate to tune against and the trade-off between bias and variance for a fixed computation budget means that tuning approaches designed for target invariant MCMC algorithms are not applicable.

1.1 Related work

Previous adaptive SGMCMC algorithms have focused on embedding ideas from the optimization literature within the SGMCMC framework, e.g. gradient preconditioning (Li et al. 2016), RMSprop (Chen et al. 2016) and Adam (Kim et al 2020). However, all of these algorithms still rely on hyperparameters such as learning rates and subsample sizes which need to be optimized. To the best of our knowledge, no principled approach has been developed to optimize the SGMCMC hyperparameters. In practice, users often use a trial-and-error approach and run multiple short chains with different hyperparameter configurations and select the hyperparameter setting which minimizes a metric of choice, such as the kernel Stein discrepancy (Nemeth and Fearnhead 2020) or cross-validation (Izmailov et al. 2021). However, this laborious approach is inefficient and not guaranteed to produce the best hyperparameter configuration.

1.2 Contribution

In this paper we propose a principled adaptive SGMCMC scheme that allows users to tune the hyperparameters, e.g. step-sizes h (also known as the learning rate) and data subsample size n. Our approach provides an automated trade-off between bias and variance in the posterior approximation for a given computational time budget. Our adaptive scheme uses a multi-armed bandit algorithm to select SGMCMC hyperparameters which minimize the Stein discrepancy between the approximate and true posterior distributions. The approach only requires a user-defined computational budget as well as unbiased estimates of the gradients of the log-posterior, which are already available to us via the stochastic gradient MCMC algorithm. A second contribution in this paper is a rigorous assessment of existing tuning methods for SGMCMC, which to our knowledge is not present in the literature.

2 Background

2.1 Stochastic Gradient Langevin Algorithm

We are interested in sampling from a target density \(\pi (\varvec{\theta })\), where for some parameters of interest \(\varvec{\theta }\in {\mathbb {R}}^d\) the unnormalized density is of the form \(\pi (\varvec{\theta }) \propto \exp \{ -U(\varvec{\theta })\}\). We assume that the potential function \(U(\varvec{\theta })\) is continuous and differentiable almost everywhere. If we have independent data, \(y_1,\ldots ,y_N\) then \(\pi (\varvec{\theta }) \propto p(\varvec{\theta }) \prod _{i=1}^N f(y_i \mid \varvec{\theta })\) is the posterior density, where \(p(\varvec{\theta })\) is the prior density and \(f(y_i\mid \varvec{\theta })\) is the likelihood for the ith observation. In this setting, we can define \(U(\varvec{\theta })=\sum _{i=1}^N U_i(\varvec{\theta })\), where \(U_i(\varvec{\theta })= - \log f(y_i \mid \varvec{\theta })-(1/N)\log p(\varvec{\theta })\).

We can sample from \(\pi (\varvec{\theta })\) by simulating a stochastic process that has \(\pi \) as its stationary distribution. Under mild regularity conditions, the Langevin diffusion (Roberts and Tweedie 1996; Pillai et al. 2012) has \(\pi \) as its stationary distribution, however, in practice it is not possible to simulate the Langevin diffusion exactly in continuous time and instead we sample from a discretized version. That is, for a small time-interval \(h>0\), the Langevin diffusion has approximate dynamics given by

$$\begin{aligned} \varvec{\theta }_{k+1} = \varvec{\theta }_{k} - \frac{h}{2} \nabla U(\varvec{\theta }(t)) + \sqrt{h} \varvec{\xi }_k, \quad k=0,\ldots ,K \end{aligned}$$

where \(\varvec{\xi }_k\) is a vector of d independent standard Gaussian random variables. In the large data setting, we replace \(\nabla U(\varvec{\theta })\) with an unbiased estimate \(\nabla {\tilde{U}}(\varvec{\theta }) = \frac{N}{n} \sum _{i \in {\mathcal {S}}_n} \nabla U_i(\varvec{\theta })\), calculated using a subsample of the data of size \(n<<N\), where \({\mathcal {S}}_n\) is a random sample, without replacement, from \(\{1,\ldots ,N\}\). This algorithm is known as the stochastic gradient Langevin dynamics (SGLD, Welling and Teh 2011).

In this paper we present our adaptive stochastic gradient MCMC scheme in the context of the SGLD algorithm for simplicity of exposition. However, our proposed approach is readily generalizable to all other stochastic gradient MCMC methods, e.g. stochastic gradient Hamiltonian Monte Carlo (Chen et al. 2014). Details of the general class of stochastic gradient MCMC methods presented under the complete recipe framework are given in Ma et al. (2015). See Sect. C of the Supplementary Material for a summary of the SGMCMC algorithms used in this paper.

2.2 Stein discrepancy

We define \({\tilde{\pi }}\) as the empirical distribution generated by the stochastic gradient MCMC algorithm (1). We can define a measure of how well this distribution approximates our target distribution of interest, \(\pi \), by defining a discrepancy metric between the two distributions. Following Gorham and Mackey (2015) we consider the Stein discrepancy

$$\begin{aligned} D({\tilde{\pi }},\pi ) = \sup _{\phi \in {\mathcal {F}}} \left|\mathbb {E}_{{{\tilde{\pi }}}} \left[ {\underbrace{-\nabla _{\varvec{\theta }}U(\varvec{\theta })^{\top } \phi (\varvec{\theta }) + \nabla _{\varvec{\theta }}^{\top } \phi (\varvec{\theta })}_{\text {Stein operator: }{\mathcal {A}}_{\pi }\phi (\varvec{\theta })}}\right] \right|\end{aligned}$$

where \(\phi : {\mathbb {R}}^d \rightarrow {\mathbb {R}}^d\) is any smooth function in the Stein set \({\mathcal {F}}\) which satisfies Stein’s identity \(\mathbb {E}_{{\pi }} \left[ {{\mathcal {A}}_{\pi }\phi (\varvec{\theta })}\right] =0\) for all \(\phi \in {\mathcal {F}}\).

2.2.1 Kernel Stein discrepancy

To obtain an analytic form of the Stein discrepancy, Liu et al. (2016) and Chwialkowski et al. (2016) introduced the kernelized Stein discrepancy (KSD) where \({\mathcal {F}}\) is the unit ball of a d-dimensional reproducing kernel Hilbert space. The KSD has the closed form solution

$$\begin{aligned} \textrm{KSD}({\tilde{\pi }},\pi ) := \sqrt{\mathbb {E}_{{{\tilde{\pi }}(\varvec{\theta }){\tilde{\pi }}(\varvec{\theta }')}} \left[ {k_{\pi }(\varvec{\theta },\varvec{\theta }')}\right] } \end{aligned}$$


$$\begin{aligned} k_{\pi }(\varvec{\theta },\varvec{\theta }^\prime ) =&\nabla _{\varvec{\theta }} U(\varvec{\theta })^{\top } \nabla _{\varvec{\theta }'} U(\varvec{\theta }') k(\varvec{\theta },\varvec{\theta }') \\ -&\nabla _{\varvec{\theta }} U(\varvec{\theta })^{\top } \nabla _{\varvec{\theta }'} k(\varvec{\theta },\varvec{\theta }') \\ -&\nabla _{\varvec{\theta }'} U(\varvec{\theta }')^{\top } \nabla _{\varvec{\theta }} k(\varvec{\theta },\varvec{\theta }') + \nabla _{\varvec{\theta }}^{\top } \nabla _{\varvec{\theta }'} k(\varvec{\theta },\varvec{\theta }^\prime ). \end{aligned}$$

The kernel k must be positive definite, which is a condition satisfied by most popular kernels, including the Gaussian and Matérn kernels. Gorham and Mackey (2017) recommend using the inverse multi-quadric kernel, \(k(\varvec{\theta },\varvec{\theta }^\prime ) = (c^2 + ||\varvec{\theta }-\varvec{\theta }^\prime ||_2^2)^\beta \), which they prove detects non-convergence when \(c>0\) and \(\beta \in (-1,0)\).

2.2.2 Finite set Stein discrepancy

KSD is a natural discrepancy measure for stochastic gradient MCMC algorithms as \(\pi (\varvec{\theta })\) is only required up to a normalization constant and the gradients of the log-posterior density are readily available. The drawback of KSD is that the computational cost is quadratic in the number of samples. Linear versions of the KSD (Liu et al. 2016) are an order of magnitude faster, but the computational advantage is outweighed by a significant decrease in the accuracy of the Stein estimator.

Jitkrittum et al. (2017) propose a linear-time Stein discrepancy, the Finite Set Stein Discrepancy (FSSD), which utilizes the Stein witness function \(g(\varvec{\theta }^\prime ):=\mathbb {E}_{{\varvec{\theta }\sim {\tilde{\pi }}}} \left[ {-\nabla _{\varvec{\theta }}U(\varvec{\theta })^{\top } k(\varvec{\theta },\varvec{\theta }^\prime ) + \nabla _{\varvec{\theta }}^{\top } k(\varvec{\theta },\varvec{\theta }^\prime )}\right] \). The function g can be thought of as witnessing the differences between \({\tilde{\pi }}\) and \(\pi \), where a discrepancy in the region around \(\varvec{\theta }\) is indicated by large \(|g(\varvec{\theta }) |\). The Stein discrepancy is essentially then measured via the flatness of g, where the measure of flatness can be computed in linear time. The key to FSSD is to use real analytic kernels k, e.g, Gaussian kernel, which results in \(g_1,\ldots ,g_d\) also having a real analytic form. If \(g_i \ne 0\) then this implies almost surely that \(g_i({\textbf{v}}_1),\ldots ,g_i({\textbf{v}}_J)\) are not zero for a finite set of test locations \(V=\{{\textbf{v}}_1,\ldots ,{\textbf{v}}_J\}\). Under the same assumptions as KSD, FSSD is defined as,

$$\begin{aligned} \textrm{FSSD}({\tilde{\pi }},\pi ):=\sqrt{\frac{1}{dJ}\sum _{i=1}^d\sum _{j=1}^{J}g_i^2({\textbf{v}}_j)}. \end{aligned}$$

Theorem 1 of Jitkrittum et al. (2017) guarantees that \(\textrm{FSSD}^2=0\) if and only if \({\tilde{\pi }}=\pi \) for any choice of test locations \(\{{\textbf{v}}\}_{j=1}^J\). However, some test locations will result in an improved test power for finite samples and so, following Jitkrittum et al. (2017), we optimize the test locations by first sampling them from a Gaussian fit to the posterior samples and then use gradient ascent so that they maximise the FSSD.

3 Hyperparameter learning

In this section we introduce an automated and generally-applicable approach to learning the user-controlled parameters of a stochastic gradient MCMC algorithm, which throughout we will refer to as hyperparameters. For example, in the case of SGLD, this would be the stepsize parameter h and batch size n, or in the case of stochastic gradient Hamiltonian Monte Carlo, this would also include the number of leap frog steps. Our adaptive scheme relies on multi-armed bandits (Slivkins 2019) to identify the optimal setting for the hyperparameters such that, for a given time budget, the selected parameters minimize the Stein discrepancy, and therefore maximize the accuracy of the posterior approximation. Our proposed approach, the Multi-Armed MCMC Bandit Algorithm (MAMBA), works by sequentially identifying and pruning, i.e. removing, poor hyperparameter configurations in a principled, automatic and online setting to speed-up hyperparameter learning. The MAMBA algorithm can be used within any stochastic gradient MCMC algorithm and only requires the user to specify the training budget and the number of hyperparameter sets.

3.1 Multi-armed bandits with successive halving

Multi-armed bandits are a class of algorithms for sequential decision-making that iteratively select actions from a set of possible decisions. These algorithms can be split into two categories: 1) best arm identification in which the goal is to identify the action with the highest average reward, and 2) exploration vs. exploitation, where the goal is to maximize the cumulative reward over time (Bubeck and Cesa-Bianchi 2012). In the best-arm identification setting, an action, aka arm, is selected and produces a reward, where the reward is drawn from a fixed probability distribution corresponding to the chosen arm. At the end of the exploration phase, a single arm is chosen which maximizes the expected reward. This differs from the typical multi-armed bandit setting where the strategy for selecting arms is based on minimizing cumulative regret (Lattimore and Szepesvári 2020).

The successive halving algorithm (Karnin et al. 2013; Jamieson and Talwalkar 2016) is a multi-armed bandit algorithm based on best arm identification. Successive halving learns the best hyperparameter settings, i.e. the best arm, using a principled early-stopping criterion to identify the best arm within a set level of confidence, or for a fixed computational budget. In this paper, we consider the fixed computational budget setting, where the algorithm proceeds as follows: 1) uniformly allocate a computational budget to a set of arms, 2) evaluate the performance of all arms against a chosen metric, 3) promote the best \(1/\eta \) of arms to the next stage, where typically \(\eta =2 \ \text {or} \ 3\), and prune the remaining arms from the set. The process is repeated until only one arm remains. As the total computational budget is fixed, pruning the least promising arms allows the algorithm to allocate exponentially more computational resource to the most promising hyperparameter sets.

3.2 Tuning stochastic gradients with a multi-armed MCMC bandit algorithm (MAMBA)

We describe our proposed algorithm, MAMBA, to tune the hyperparameters of a generic stochastic gradient MCMC algorithm. For ease of exposition, we present MAMBA in the context of the SGLD algorithm (1), where a user tunes the step size h and batch size n. Details on other SGMCMC algorithms can be found in Appendix C. We present MAMBA as the following three stage process:

3.2.1 Initialize

In our multi-armed bandit setting we assume M possible stochastic gradient MCMC hyperparameter configurations, which we refer to as arms. Each arm s in the initial set \(S_0 = \{1,\ldots ,M\}\) represents a hyperparameter tuple \(\phi _s=(h_s,n_s)\). The hyperparameters in the initial set are chosen from a uniform grid.

3.2.2 Evaluate and prune

At each iteration of MAMBA, \(i = 0, 1, \ldots \), each arm s is selected from the set \(S_i\) and the \(s^\text {th}\) SGLD algorithm is run for \(r_i\) seconds using the hyperparameter configuration \(\phi _s\). Each arm is associated with a reward \(\nu _s\) that measures the quality of the posterior approximation. We use the negative Stein discrepancy as the reward function that we aim to maximize. Specifically, we calculate the Stein discrepancy from the SGMCMC output using KSD (3) or FSSD (4), i.e. \(\nu _s = -\textrm{KSD}({\tilde{\pi }}_{s},\pi )\) or \(\nu _s = -\textrm{FSSD}({\tilde{\pi }}_{s},\pi )\). Without loss of generality, we can order the set of arms \(S_i\) by their rewards, i.e. \(\nu _1 \ge \nu _2 \ge \ldots \ge \nu _{M}\), where \(\nu _1\) is the arm with the optimal reward at each iteration of MAMBA. The top \(100/\eta \%\) of arms in \(S_i\) with the highest rewards are retained to produce the set \(S_{i+1}\). The remaining arms are pruned from the set and not evaluated again at future iterations.

3.2.3 Reallocate time

Computation time allocated to the pruned samplers is reallocated to the remaining samplers, \(r_{i+1} = \eta r_i\). As a result, by iteration i, each of the remaining SGLD samplers has run for a time budget of \(R = r_0 + \eta r_0 + \eta ^2r_0 +... + \eta ^{i-1}r_0\) seconds, where \(r_0\) is the time budget for the first MAMBA iteration. This process is repeated for a total of \( \lfloor \log _{\eta }M\rfloor \) MAMBA iterations. We use a \(\log _\eta \) base as we are dividing the number of arms by \(\eta \) at every iteration. Furthermore, we use a floor function for the cases where the initial number of arms M is not a power of \(\eta \). The MAMBA algorithm is summarized in Algorithm 1.

figure a

3.2.4 Algorithmic guarantees

It is possible that MAMBA will eliminate the optimal hyperparameter set during one of the arm-pruning phases. Through examination of the \(1-1/2\eta \) quantile, we can derive a bound on the probability that MAMBA will incorrectly prune the best hyperparameter configuration (see Theorem 1). Using this result, we are also able to bound the maximum computational budget required for MAMBA to identify the optimal hyperparameters.

Definition 1

Let \(s \in \{2,\ldots ,M\}\) be an arm with reward \(\nu _s\), then we define the suboptimality gap between \(\nu _s\) and the optimal reward \(\nu _1\) as \(\alpha _s:= \nu _1 - \nu _s\), and we define \(H_2:= \max _{s \ne 1} s/\alpha _s^2\) as the complexity measure, see Audibert et al. (2010) for details.

Theorem 1

i) MAMBA correctly identifies the best hyperparameter configuration for a stochastic gradient MCMC algorithm with probability at least

$$\begin{aligned}1 - (2\eta -1) \log _\eta M \cdot \exp {\left( -\frac{\eta T}{4\sigma ^2_{\textrm{KSD}} H_2 (\log _\eta M+1)}\right) }, \end{aligned}$$

where \(\sigma ^2_{\textrm{KSD}} = \max _{s \in S}\textrm{Var}_{{\tilde{\pi }}_{s}}(\mathbb {E}_{{{\tilde{\pi }}_{s}}} \left[ {k_\pi (\varvec{\theta },\varvec{\theta }')}\right] )\).

ii) For a probability of at least \(1-\delta \) that MAMBA will successively identify the optimal hyperparameter set, MAMBA requires a computational budget of

$$\begin{aligned} T = O\left( \sigma ^2_{\textrm{KSD}} \log _\eta M\log \left( \frac{(2\eta -1)\log _\eta M}{\delta }\right) \right) . \end{aligned}$$

A proof of Theorem 1 is given in Appendix A and builds on the existing work of Karnin et al. (2013) for fixed-time best-arm identification bandits. Theorem 1 highlights the contribution of KSD variance in identifying the optimal arm. In particular, the total computation budget depends on the arm with the largest KSD variance.

3.3 Practical guidance for using MAMBA

3.3.1 Choice of budget

There is flexibility in the choice of budget in MAMBA. We advocate for the use of a compute time budget for fast but biased sampling algorithms like SGMCMC because it allows users to view these algorithms as a trade-off between statistical accuracy and runtime. The goal is then to identify the hyperparameters that produce the best Monte Carlo estimates under a given time constraint. A compute time budget allows users to optimise the batch size in a principled way and ties the hyperparameter optimisation to the available hardware and software, such as whether or not the model was implemented using vectorisation.

Alternative choices for the budget could be based on the total number of iterations or the total number of gradient evaluations. The former would be helpful for storage constraints but has no natural mechanism for tuning the batch size. The total number of gradient evaluations would allow for batch size tuning but is less closely linked to the available hardware and software, and would not take into account implementation decisions such as vectorising the gradient estimator.

3.3.2 Estimating KSD/FSSD

Calculating KSD/FSSD using (3) or (4) requires the gradients of the log-posterior and the SGMCMC samples. Typically, one would calculate the KSD/FSSD using fullbatch gradients (i.e. using the entire dataset) on the full chain of samples. However, as we only use SGMCMC when the dataset is large, this would be a computationally expensive approach. Two natural solutions are to i) use stochastic gradients (Gorham et al. 2020), calculated through subsampling, or ii) use a thinned chain of samples. We investigate both options in terms of KSD/FSSD accuracy in Appendix 3.4.2. We find that using the stochastic KSD/FSSD produces results similar to the fullbatch KSD/FSSD. However, calculating the KSD/FSSD for a large number of high dimensional samples is computationally expensive, so for our experimental study in Sect. 4 we use fullbatch gradients with thinned samples. This leads to lower variance KSD/FSSD estimates at a reasonable computational cost. Note that fullbatch gradients are only used for MAMBA iterations and not SGMCMC iterations. We find that this does not significantly increase the overall computational cost as for each iteration of MAMBA there are thousands of SGMCMC iterations.

3.3.3 Alternative metrics

Stein-based discrepancies are a natural metric to assess the quality of the posterior approximation as they only require the SGMCMC samples and log-posterior gradients. However, alternative metrics to tune SGMCMC can readily be applied within the MAMBA framework. For example, there is currently significant interest in understanding uncertainty in neural networks via metrics such as expected calibration error (ECE), maximum calibration error (MCE, Guo et al. 2017), and out-of-distribution (OOD) tests (Lakshminarayanan et al. 2017). These metrics have the advantage that they are more scalable to very high dimensional problems, compared to the KSD (Gong et al. 2020). As a result, although KSD is a sensible choice when aiming for posterior accuracy, alternative metrics may be more appropriate for some problems, for example, in the case of very high-dimensional deep neural networks.

3.4 Tuning methods

3.4.1 Grid search and heuristic method

We test the efficacy of MAMBA on a simpler grid search approach. For the grid search method we run the sampler using the training data, and calculate the RMSE/log-loss/accuracy on the test dataset. To have a fair comparison to MAMBA (see Sect. 3.4.2), we always start the sampler from the maximum a posteriori estimate (the MAP, found using optimization). As a result we need to add noise around this MAP or else the grid search tuning method will recommend the smallest step size available which results in the sampler not moving away from the starting point. This happens because the MAP has the smallest RMSE/ log-loss (or highest accuracy). To fix this we add Gaussian noise to the MAP, and report the scale of the noise for each model in Sect. B.

The heuristic method fixes the step size to be inversely proportional to the dataset size, i.e. \(h=\frac{1}{N}\) (Brosse et al. 2018). For both the grid search and heuristic approaches, we use a \(10\%\) batch size throughout.

3.4.2 MAMBA

We investigate the tradeoffs involved in estimating the KSD from samples in MAMBA. We can estimate this using the stochastic gradients estimated in the SGMCMC algorithm. However we can also calculate the fullbatch gradients and use these to estimate the KSD. Although the latter option is too computationally expensive in the big data setting, we can also thin the samples to estimate the KSD which may result in the fullbatch gradients being computationally tractable.

In Fig. 1 we estimate the KSD of samples using the logistic regression model over a grid of step sizes. We run SGLD for the three models for 1 s and with a batch size of \(1\%\). We estimate the KSD in 4 ways: fullbatch using all the samples, fullbatch using thinned samples (thin by a factor of 5), stochastic gradients using all samples, and stochastic gradients using thinned samples. In Fig. 2 we do the same but varying the batch size (and keeping the step size fixed to \(h = 10^{-4.5}\). We can see that the KSD estimated using stochastic gradients and unthinned samples follows the fullbatch KSD well. However as calculating the KSD for many high dimensional samples is computationally expensive, we opt for using thinned fullbatch gradients in all our experiments.

Fig. 1
figure 1

Grid search for different step sizes using both fullbatch and stochastic-KSD for logistic regression, PMF, and NN (from left to right). The sampler used is SGLD

Fig. 2
figure 2

Grid search for different batch sizes using both fullbatch and stochastic-KSD for logistic regression, PMF, and NN (from left to right). The sampler used is SGLD

4 Experimental study

In this section we illustrate MAMBA (Algorithm 1) on three different models and compare it to alternative tuning methods. We use three core tuning methods for all three models: i) MAMBA-KSD, ii) grid search with log-loss as metric, and iii) the heuristic approach. For the logistic regression model only, we also try two alternative tuning methods: iv) MAMBA-FSSD and v) grid search-KSD. We show in Table 4 an overview of the tuning methods used in these experiments.

The initial arms in MAMBA are set as an equally spaced grid over batch sizes and step sizes (and number of leapfrog steps for SGHMC). The heuristic method fixes the step size to be inversely proportional to the dataset size, i.e. \(h=\frac{1}{N}\) (Brosse et al. 2018). For both the grid search and heuristic approaches, we use a \(10\%\) batch size throughout.

Note that only the tuning methods that use KSD/FSSD are able to estimate both step size and batch size. This is because the log-loss metric used for grid search is not particularly sensitive to the choice of batch size, and over a range of batch sizes the log-loss produces similar values. In contrast, KSD and FSSD measure the quality of the posterior samples and their approximation accuracy to the posterior, which is strongly affected by the batch size as well as the available computational budget.

Full details of the experiments can be found in Appendix B. Experiments were conducted using the Python package SGMCMCJax (Coullon and Nemeth 2022) and code to replicate the experiments can be found at https://github.com/jeremiecoullon/SGMCMC_bandit_tuning. All experiments were carried out on a laptop CPU (MacBook Pro 1.4 GHz Quad-Core Intel Core i5). For each example, the figures show results over a short number of tuning iterations and tables give results for longer runs.

Fig. 3
figure 3

KSD curves for the six samplers applied to a logistic regression model

4.1 Logistic regression

We consider logistic regression on a simulated dataset with 10 dimensions and 1 million data points (details of the model and prior are in Appendix B.1). We sample from the posterior using six samplers: SGLD, SGLD with control variates (SGLD-CV, Baker et al. 2019), stochastic gradient Hamiltonian Monte Carlo (SGHMC, Chen et al. 2014), SGHMC-CV, stochastic gradient Nosé Hoover Thermostats (SGNHT), and SGNHT-CV (Ding et al. 2014a).

We recall that we tune each samplers’ hyperparameters using i) MAMBA-KSD, ii) grid search with log-loss, and iii) the heuristic approach. In this section, we also run MAMBA using FSSD as the metric, as well as grid search with KSD as the metric, to assess the practicality of these approaches.

For MAMBA, we set \(R=1sec\) (i.e.: the run time of the longest sampler). We point out that this time budget is small compared to what would be used by most practitioners. However, this example illustrates the MAMBA methodology and compares it against a full MCMC algorithm which provides us with “ground-truth” posterior samples. To calculate the KSD/FSSD efficiently, we thin the samples and use fullbatch gradients.

Table 1 Logistic regression. For each tuning method and each SGMCMC sampler we report the relative standard deviation error and the KSD. We abbreviate MAMBA-KSD and MAMBA-FSSD to M-KSD and M-FSSD respectively. In bold are the best results for a given sampler and metric
Fig. 4
figure 4

KSD curves for the six samplers applied to the probabilistic matrix factorization model

Applying the grid search approach with KSD as the metric: we thin the samples and use fullbatch gradients as we have done with KSD and FSSD. To allow a close comparison to MAMBA-KSD, we choose a time budget rather than number of iterations and we tune the batch size as well as the step size. The objective of this experiment is to see how grid search would work with a metric that can capture whether the samples are from the correct distribution. We choose 1sec as the time budget, which is the same amount of time that the final sampler will have run for in MAMBA-KSD. So, in the best arm in MAMBA-KSD, as well as for all the combinations of gridsearch-KSD, the sampler will have run for 1sec before the final KSD is computed. We discuss these results below and present them in Table 6 in the appendix.

In Fig. 3, we plot the KSD calculated from the posterior samples for each of the tuning methods. We calculated the KSD curves for ten independent runs and plotted the mean curve along with a confidence interval (two standard deviations). The optimal hyperparameters given by each method can be found in Table 5 of Appendix B.1. Our results from Fig. 3 show that optimizing the hyperparameters with MAMBA, using either KSD or FSSD, produces samples that have the lowest KSD out of all but one of the six samplers. For the SGNHT sampler, the heuristic approach gives the lowest KSD, however, as shown in Table 5 in Appendix B.1, MAMBA-FSSD finds an optimal step size of \(h=N^{-1}\), which coincides with step size given by the heuristic approach. Therefore, the difference in KSD from these two methods is a result of the batch size, which when taking into account computation time, MAMBA-FSSD finds \(1\%\) to be optimal, whereas the heuristic method does not learn the batch size and this is fixed at \(10\%\). Ignoring computation time, a larger batch size is expected to produce a better posterior approximation. However, it is interesting to note that for the five out of six samplers where MAMBA performs the best in terms of KSD, MAMBA chooses an optimal batch size of \(1\%\).

For this simulated data example with only 1 million samples we can compare the posterior accuracy of the SGMCMC algorithms against the ground-truth using NumPyro’s (Bingham et al. 2018; Phan et al. 2019) implementation of NUTS (Hoffman and Gelman 2014) on the full dataset for 20K iterations (after a burn-in of 1K iterations). We then calculate the relative error in the posterior standard deviation for each sampler: \(\xi ({\hat{\sigma }}):= \Vert {\hat{\sigma }}-\sigma _{\text {NUTS}}\Vert _2 / \Vert \sigma _{\text {NUTS}}\Vert _2\). The results are given in Table 1 and further results including predictive accuracy on a test dataset and the number of samples obtained within the time budget are given in Table 6 of Appendix B.1. We tested each sampler by running each sampler for 10 s.

We find that the MAMBA-optimized samplers perform among the best in terms of KSD. As a result, the Monte Carlo estimates of the posterior standard deviations generally perform well. As described above we also run grid search with KSD as a metric and tune the step size as well as the batch size. We find that although grid search with KSD gives results that are comparable to MAMBA-KSD, the running time for this tuning method is slower than MAMBA-KSD. Indeed grid search with KSD ranged from 1.2 to 2.2 times slower than MAMBA-KSD. As a result, we will not use this method for the models in the next sections, as this computational cost would only increase.

Furthermore, when tuning SGHMC and SGHMC-CV, we tune three hyperparameters using MAMBA (step size, batch size, and number of leapfrog steps), and two hyperparameters using grid-search (step size and number of leapfrog steps). We find that although MAMBA is tuning more hyperparameters, the method finds optimal hyperparameters with approximately 3x speedup. As grid search scales poorly with dimension, we expect this gap to widen when tuning more hyperparameters.

Table 2 Probabilistic matrix factorization. For each tuning method and each sampler we report KSD and the relative error of the standard deviation estimates. In bold are the best results for a given sampler and metric
Fig. 5
figure 5

ECE curves for the six samplers applied to the Bayesian neural network model

4.2 Probabilistic matrix factorization

We consider the probabilistic matrix factorization model (Salakhutdinov and Mnih 2008) on the MovieLens datasetFootnote 1 (Harper and Konstan 2015), which contains 100K ratings for 1682 movies from 943 users (see Appendix B.2.1 for model details). We optimize the hyperparameters for six samplers: SGLD, SGLD-CV, SGHMC, SGHMC-CV, SGNHT, and SGNHT-CV.

To tune these samplers we use a similar setup as for logistic regression and use (i) MAMBA-KSD, (ii) grid search with log-loss, and (iii) the heuristic approach. Details are given in Appendix B.2.

From Fig. 4 we can see that the samplers tuned using MAMBA tend to outperform the ones tuned with the other two methods. We also test the quality of the posterior samples against NumPyro’s (Phan et al. 2019; Bingham et al. 2018) implementation of NUTS (Hoffman and Gelman 2014), which produces 10K samples with 1K samples as burn-in. This state of the art sampler obtains high quality samples but is significantly more computationally expensive, taking around six hours on a laptop CPU. We estimate the posterior standard deviations using these samples and treat them as the ground-truth. We run each SGMCMC sampler for 20 s, and estimate the standard deviation after removing the burn-in. We estimate the posterior standard deviation for each sampler and show the relative errors and KSD in Table 2 (further results are given in Table 8 in Appendix B.2). We find that MAMBA consistently identifies hyperparameters that give the lowest KSD, but that for some samplers the heuristic approach gives a lower error on the estimated standard deviation. This could be due to the random realisation of the SGMCMC chain; however, while accuracy in standard deviation is fast to compute, as a metric it is not as useful as KSD, which measures the quality of the full distribution and not just the accuracy of the second moment.

Moreover, as in the case of logistic regression in the previous section, we find that tuning SGHMC and SGHMC-CV using MAMBA-KSD is faster than grid-search using log-loss (approximately 2x faster). This confirms our expectation that when tuning many hyperparameters MAMBA scales better than grid search.

4.3 Bayesian neural network

In this section we consider a feedforward Bayesian neural network with two hidden layers on the MNIST dataset (LeCun and Cortes 2010) (see Appendix B.3.1 for details). Here we tune six samplers: SGLD, SGLD-CV, SGHMC, SGHMC-CV, SGNHT, and SGNHT-CV.

For this example, as with the previous two examples, we tune these samplers using i) MAMBA-KSD, ii) grid search with log-loss, and iii) the heuristic approach. However, we validate the accuracy of the various tuning approaches against expected calibration error (ECE) and maximum calibration error (MCE) plotted in Fig. 5. We find that the samplers tuned using MAMBA tend to outperform the other approaches in terms of ECE. We assess the performance of the MAMBA-optimized samplers over a longer time budget and run the samplers for 300 s starting from the maximum aposteriori value. We then remove the visible burn-in and calculate the ECE and MCE to compare the quality of the posterior samples. We report the results in Table 3, where ECE and MCE are reported as percentages (lower is better).

Table 3 Bayesian neural network. For each tuning method and each sampler we report the ECE and MCE (as percentages). In bold are the best results for a given sampler and metric

Overall, the results in Table 3 show that MAMBA-optimized samplers tend to perform best in terms of KSD and when not the best they produce results which are very close to the best performing method. For all samplers, MAMBA finds an optimal batch size of \(1\%\), which is ten times smaller than the batch size of the other methods and therefore results in a faster and highly accurate algorithm. For SGNHT, both MAMBA and grid search found a step size that was slightly too large (\(\log _{10}(h)=-4.5\) and \(\log _{10}(h)=-4\) respectively) which caused the sampler to lose stability for longer chains. In contrast, the sampler tuned using the heuristic method is the only one that remained stable. As a result we re-ran these two tuning methods for a grid with smaller step sizes: \(\{-5., -5.5, -6., -6.5, -7., -7.5\}\). This smaller grid allowed the two tuning algorithms to find a stable step size (\(\log _{10}(h)=-5\) for both methods), and so this slight decrease in step size was enough to make the sampler stable. We note that there exists samplers with more stable numerical methods such as the BADODAB sampler which solves the same diffusion as SGNHT but with a more stable splitting method (Leimkuhkler and Xiaocheng 2016). Such samplers might be easier to tune with MAMBA or grid search.

Finally, tuning SGHMC and SGHMC-CV using MAMBA is faster than using grid-search (as is the case with the models in the previous two sections): in this case the speedup is 4x–6x faster. This confirms our expectation that when tuning many hyperparameters MAMBA scales better than grid search.

5 Discussion and future work

5.1 Final remarks

In this paper we have proposed a multi-armed bandit approach to estimate the hyperparameters for any stochastic gradient MCMC algorithm. Our approach optimizes the hyperparameters to produce posterior samples which accurately approximate the posterior distribution within a fixed time budget. We use Stein-based discrepancies as natural metrics to assess the quality of the posterior approximation.

The generality of the MAMBA algorithm means that alternative metrics, such as predictive accuracy, can easily be employed within MAMBA as an alternative to a Stein-based metric. We have also compared MAMBA with a grid search approach using the KSD and have found that although the results are comparable, MAMBA finds these optimal hyperparameters much faster than grid search.

When tuning SGHMC and SGHMC-CV (for all three models), we tune three hyperparameters using MAMBA (step size, batch size, and number of leapfrog steps), and two hyperparameters using gridsearch (step size and number of leapfrog steps). We find that although MAMBA is tuning more hyperparameters, the method finds optimal hyperparameters with a speedup ranging from 2x to 6x compared to gridsearch. This illustrates how, as we increase the number of hyperparameters to tune, the speed gains between the methods widens.

Whilst not explored in this paper, it is possible to apply MAMBA beyond the stochastic gradient MCMC setting and directly apply MAMBA to standard MCMC algorithms, such as Hamiltonian Monte Carlo, to estimate the MCMC hyperparameters. A variety of metrics including KSD and absolute difference between the average and optimal acceptance rate could be used in this context. However, existing algorithms like adaptive MCMC (Andrieu and Thoms 2008; Vihola 2012) may be more efficient for standard MCMC because the computational budget that makes MAMBA useful for tuning batch sizes in SGMCMC is less necessary when there is no inherent bias-variance trade-off.

Finally, in this paper we performed a systematic study of different SGMCMC tuning methods for various models and samplers, which to our knowledge is the first rigorous comparison of these methods. While these alternative approaches can work well they are only able to tune the step size parameter, and unlike MAMBA, they do not tune the batch size or other SGMCMC hyperparameters, such as the number of leap frog steps.

5.2 Future work

A limitation of this method is that computing the KSD can be expensive when there are many posterior samples. One solution we explored in this paper is to use the FSSD as a linear-time metric. In the case of KSD, we significantly lowered the cost of this by thinning the Markov chain, but the KSD remains an expensive metric to compute. The KSD also suffers from the curse of dimensionality (Gong et al. 2020), though our results show that the KSD gave good results even for our two high-dimensional problems. As a result, further work in this area should explore alternative discrepancy metrics which are both scalable in sample size and dimension. For example, scalable alternatives to KSD, such as sliced KSD (Gong et al. 2020), could be appropriate for very high-dimensional problems.

6 Supplementary information

The supplementary material for this article is available online. It contains a proof of Theorem 1, further details of experiment settings and details of the various SGMCMC algorithms considered.