Skip to main content

On Large Batch Training and Sharp Minima: A Fokker–Planck Perspective

Abstract

We study the statistical properties of the dynamic trajectory of stochastic gradient descent (SGD). We approximate the mini-batch SGD and the momentum SGD as stochastic differential equations. We exploit the continuous formulation of SDE and the theory of Fokker–Planck equations to develop new results on the escaping phenomenon and the relationship with large batch and sharp minima. In particular, we find that the stochastic process solution tends to converge to flatter minima regardless of the batch size in the asymptotic regime. However, the convergence rate is rigorously proven to depend on the batch size. These results are validated empirically with various datasets and models.

This is a preview of subscription content, access via your institution.

Fig. 1
Fig. 2
Fig. 3
Fig. 4

References

  1. 1.

    Amodei D, Ananthanarayanan S, Anubhai R, Bai J, Battenberg E, Case C, Casper J, Catanzaro B, Cheng Q, Chen G, Chen J (2016) Deep speech 2: end-to-end speech recognition in English and Mandarin. In: International conference on machine learning (ICML), pp 173–182

  2. 2.

    An J, Lu J, Ying L (2019) Stochastic modified equations for the asynchronous stochastic gradient descent. Inf Inference. https://doi.org/10.1093/imaiai/iaz030

  3. 3.

    Berglund N (2013) Kramers’ law: validity, derivations and generalisations. Markov Process Relat Fields 19(3):459–490

    MathSciNet  MATH  Google Scholar 

  4. 4.

    Bottou L, Curtis FE, Nocedal J (2018) Optimization methods for large-scale machine learning. SIAM Rev 60(2):223–311

    MathSciNet  Article  Google Scholar 

  5. 5.

    Bovier A, Eckhoff M, Gayrard V, Klein M (2004) Metastability in reversible diffusion processes I: sharp asymptotics for capacities and exit times. J Eur Math Soc 6(4):399–424

    MathSciNet  Article  Google Scholar 

  6. 6.

    Bovier A, Gayrard V, Klein M (2004) Metastability in reversible diffusion processes II: precise asymptotics for small eigenvalues. J Eur Math Soc 7(1):69–99

    MathSciNet  MATH  Google Scholar 

  7. 7.

    Chaudhari P, Oberman A, Osher S, Soatto S, Carlier G (2017) Deep relaxation: partial differential equations for optimizing deep neural networks. In: International conference on learning representations (ICLR)

  8. 8.

    Chaudhari P, Soatto S (2018) Stochastic gradient descent performs variational inference, converges to limit cycles for deep networks. In: International conference on learning representations (ICLR)

  9. 9.

    Dinh L, Pascanu R, Bengio S, Bengio Y (2017) Sharp minima can generalize for deep nets. In: International conference on machine learning (ICML)

  10. 10.

    Evans LC (2010) Partial differential equations, vol 19. American Mathematical Society, Providence

    MATH  Google Scholar 

  11. 11.

    Goyal P, Dollar P, Girshick R, Noordhuis P, Wesolowski L, Kyrola A, Tulloch A, Jia Y, He K (2017) Accurate, large minibatch SGD: training ImageNet in 1 hour. arXiv:1706.02677

  12. 12.

    He K, Zhang X, Ren S, Sun J (2016) Deep residual learning for image recognition. In: Proceedings of the IEEE conference on computer vision and pattern recognition (CVPR), pp 770–778

  13. 13.

    Heskes TM, Kappen B (1993) On-line learning processes in artificial neural networks. Mathematical foundations of neural networks. Elsevier, Amsterdam, pp 199–233

    MATH  Google Scholar 

  14. 14.

    Hoffer E, Hubara I, Soudry D (2017) Train longer, generalize better: closing the generalization gap in large batch training of neural networks. In: Advances in neural information processing systems (NIPS), pp 1729–1739

  15. 15.

    Ioffe S, Szegedy C (2015) Batch normalization: accelerating deep network training by reducing internal covariate shift. In: International conference on machine learning (ICML)

  16. 16.

    Jastrzebski S, Kenton Z, Arpit D, Ballas N, Fischer A, Bengio Y, Storkey A (2017) Three factors influencing minima in SGD. arXiv:1711.04623

  17. 17.

    Keskar NS, Mudigere D, Nocedal J, Smelyanskiy M, Tang PTP (2017) On large-batch training for deep learning: Generalization gap and sharp minima. In: International conference on learning representations (ICLR)

  18. 18.

    Kolpas A, Moehlis J, Kevrekidis IG (2007) Coarse-grained analysis of stochasticity-induced switching between collective motion states. Proc Natl Acad Sci 104(14):5931–5935

    Article  Google Scholar 

  19. 19.

    Krizhevsky A, Sutskever I, Hinton GE (2012) Imagenet classification with deep convolutional neural networks. In: Advances in neural information processing systems (NIPS), pp 1091–1105

  20. 20.

    LeCun Y, Cortes C, Christopher JC (1998) The MNIST dataset of handwritten digit. http://yann.lecun.com/exdb/mnist

  21. 21.

    Li Q, Tai C, Weinan E (2017) Stochastic modified equations and adaptive stochastic gradient algorithms. In: International conference on machine learning (ICML)

  22. 22.

    Mandt S, Hoffman MD, Blei DM (2017) Stochastic gradient descent as approximate bayesian inference. J Mach Learn Res 18:1–35

    MathSciNet  MATH  Google Scholar 

  23. 23.

    Nesterov Y (2013) Introductory lectures on convex optimization: a basic course, vol 87. Springer, Berlin

    MATH  Google Scholar 

  24. 24.

    Pavliotis GA (2014) Stochastic processes and applications: diffusion processes, the Fokker–Planck and Langevin equations. Springer, Berlin

    Book  Google Scholar 

  25. 25.

    Pólya G (1945) Remarks on computing the probability integral in one and two dimensions. In: Proceedings of the 1st Berkeley symposium on mathematical statistics and probability

  26. 26.

    Qian N (1999) On the momentum term in gradient descent learning algorithm. Neural Netw 12(1):145–151

    MathSciNet  Article  Google Scholar 

  27. 27.

    Simonyan K, Zisserman A (2015) Very deep convolutional networks for large-scale image recognition. In: International conference on learning representations (ICLR)

  28. 28.

    Smith SL, Le QV (2018) A Bayesian perspective on generalization and stochastic gradient descent. In: International conference on learning representations (ICLR)

  29. 29.

    Sutskever I, Martens J, Dahl G, Hinton G (2013) On the importance of initialization and momentum in deep learning. In: International conference on machine learning (ICML), pp 1139–1147

  30. 30.

    Villani C (2009) Hypocoercivity. Memoirs of the American Mathematical Society 202 (950)

  31. 31.

    Wu L, Zhu Z (2017) Towards understanding generalization of deep learning: perspective of loss landscapes. In: International conference on machine learning (ICML) workshop on principled approaches to deep learning

Download references

Author information

Affiliations

Authors

Corresponding author

Correspondence to Xiaowu Dai.

Additional information

Publisher's Note

Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.

This article is part of the topical collection “Advances in Deep Learning” guest edited by David Banks, Ernest Fokoué, and Hailin Sang.

Appendices

Appendix A: Proofs for Section 2

Appendix A.1: Mean and Variance for Random Error Vector

By the mean value theorem with some \(\tau (h)\in (0,h)\),

$$\begin{aligned} \begin{aligned} \nabla L(\mathbf{w})&= \frac{\mathrm{{d}}}{\mathrm{{d}}\mathbf{w}}\mathbb {E}[L_n(\mathbf{w})] \\&= \lim _{h\rightarrow 0}\frac{1}{h}\left\{ \mathbb {E}[L_n(\mathbf{w}+h)] - \mathbb {E}[L_n(\mathbf{w})]\right\} \\&= \lim _{h\rightarrow 0}\mathbb {E}\left\{ \frac{L_n(\mathbf{w}+h) - L_n(\mathbf{w})}{h}\right\} = \lim _{h\rightarrow 0}\mathbb {E}\left\{ \nabla L_n(\mathbf{w}+\tau (h))\right\} . \end{aligned} \end{aligned}$$

By the continuity of \(\nabla L_n\) and the dominated convergence theorem,

$$\begin{aligned} \lim _{h\rightarrow 0}\mathbb {E}\left\{ \nabla L_n(\mathbf{w}+\tau (h))\right\} = \mathbb {E}\left\{ \lim _{h\rightarrow 0}\nabla L_n(\mathbf{w}+\tau (h))\right\} = \mathbb {E}\left\{ \nabla L_n(\mathbf{w})\right\} . \end{aligned}$$

Hence, \(\varvec{\epsilon }_k\) has mean 0. Since the independent and uniform sampling for the mini-batch \(B_k\), we have \(\text {Var}[\varvec{\epsilon }_k] = \varvec{\sigma }^2(\mathbf{w})\) as desired.

We remark that a different view of sampling distribution has been adopted in the literature, for example Jastrzebski et al. [16] and Li et al. [21], where the expectation and variance are taken with respect to the sampling distribution of drawing the mini-batch \(B_k\) from \(\{1,\ldots ,N\}\). On the contrary, we use the sampling distribution with respect to the joint distribution of the underlying population, since our interest is the risk function \(L(\cdot )\) instead of the sample average loss

$$\begin{aligned} \frac{1}{N}[L_1(\cdot )+\cdots +L_N(\cdot )], \end{aligned}$$

and we regard the training data only a subset of the underlying population.

Appendix A.2: Proof of Lemma 1

We first consider a special case that \(\beta (\mathbf{w}) \equiv \beta\) is a constant and derive the Fokker–Planck equation by following Kolpas et al. [18]. If \(\mathbf{W}(t)=W(t)\in \mathbb R\), W(t) is a Markov process and the Chapman–Kolmogorov equation gives the conditional probability density function for any \(t_1\le t_2\le t_3\),

$$\begin{aligned} p\left( W(t_3)|W(t_1)\right) = \int _{-\infty }^{+\infty } p\left( W(t_3)|W(t_2)=w\right) p\left( W(t_2)=w|W(t_1)\right) \mathrm{{d}}w. \end{aligned}$$

Denote the integral

$$\begin{aligned} I(h) = \int _{-\infty }^{+\infty }h(w)\partial _t p(w,t|W)\mathrm{{d}}w, \end{aligned}$$
(11)

where h(w) is a smooth function with compact support. Observe that

$$\begin{aligned} \int _{-\infty }^{+\infty }h(w)\partial _t p(w,t|W)\mathrm{{d}}w = \lim _{\Delta t\rightarrow 0 }\int _{-\infty }^{+\infty } h(w)\left( \frac{p(w,t+\Delta t|W) - p(w,t|W)}{\Delta t}\right) \mathrm{{d}}w. \end{aligned}$$

Letting Z be an intermediate point between w and W. Applying the Chapman–Kolmogorov identity on the right hand side yields

$$\begin{aligned} \lim _{\Delta t\rightarrow 0 }\frac{1}{\Delta t}\left( \int _{-\infty }^{+\infty }h(w)\int _{-\infty }^{+\infty }p(w,\Delta t|Z)p(Z,t|W)\mathrm{{d}}Z\mathrm{{d}}w-\int _{-\infty }^{+\infty }h(w)p(w,t|W)\mathrm{{d}}w\right) . \end{aligned}$$

By changing the order of integrations in the first term and letting w approach Z in the second term, we obtain that

$$\begin{aligned} \lim _{\Delta t\rightarrow 0}\frac{1}{\Delta t}\left( \int _{-\infty }^{+\infty }p(Z,t|W)\int _{-\infty }^{+\infty }p(w,\Delta t|Z)(h(w)-h(Z))\mathrm{{d}}w\mathrm{{d}}Z\right) . \end{aligned}$$

Expand h(w) as a Taylor series about Z, we can write the above integral as

$$\begin{aligned} \lim _{\Delta t\rightarrow 0}\frac{1}{\Delta t}\left( \int _{-\infty }^{+\infty }p(Z,t|W)\int _{-\infty }^{+\infty }p(w,\Delta t|Z)\sum _{n=1}^{\infty }h^{(n)}(Z)\frac{(w-Z)^n}{n!}\right) \mathrm{{d}}w\mathrm{{d}}Z. \end{aligned}$$

Now we define the function

$$\begin{aligned} D^{(n)}(Z)=\frac{1}{n!}\frac{1}{\Delta t}\int _{-\infty }^{+\infty }p(w,\Delta t|Z)(w-Z)^n\mathrm{{d}}w. \end{aligned}$$

We can write the integral I(h) defined in (11) as

$$\begin{aligned} \int _{-\infty }^{+\infty }h(w)\partial _t p(w,t|W)\mathrm{{d}}w = \int _{-\infty }^{+\infty }p(Z,t|W)\sum _{n=1}^{\infty }D^{(n)}(Z)h^{(n)}(Z)\mathrm{{d}}Z. \end{aligned}$$

Taking the integration by parts n times gives

$$\begin{aligned} \partial _t p(w,t) = \sum _{n=1}^{\infty }-\frac{\partial ^n}{\partial Z^n} \left[ D^{(n)}(Z)p(Z,t|W)\right] . \end{aligned}$$

Let \(D^{(1)}(w) = -L(w)\), \(D^{(2)}(w) = -\gamma (t)\beta /[2M(t)]\) and \(D^{(n)}(w) = 0\) for all \(n\ge 3\). Then, the above equation yields

$$\begin{aligned} \partial _t p(w,t) = \frac{\partial }{\partial w}\left[ \nabla L(w)p(w,t)\right] + \frac{\partial }{\partial w^2}\left[ \frac{\gamma (t)\beta }{2M(t)}p(w,t)\right] , \end{aligned}$$

which is the Fokker–Planck equation in one variable. For the multidimensional case that \(\mathbf{W}=(W_1,W_2,\ldots ,W_p)\in \mathbb R^p\), we similarly generalize the above procedure to get

$$\begin{aligned} \begin{aligned} \partial _t p(\mathbf{w},t)&= \sum _{i=1}^p\frac{\partial }{\partial w_i}\left[ \nabla L(\mathbf{w})p(\mathbf{w},t)\right] +\sum _{i=1}^p \frac{\partial ^2}{\partial w_i^2}\left[ \frac{\gamma (t)\beta }{2M(t)}p(\mathbf{w},t)\right] \\&= \nabla \cdot \left( \nabla L(\mathbf{w})p +\frac{\gamma (t)\beta }{2M(t)}\nabla p\right) . \end{aligned} \end{aligned}$$
(12)

Since \(\mathbf{W}(0)=\mathbf{w}_0\), \(p(\mathbf{w},0)=\delta (\mathbf{w}_0)\). This completes the derivation of the Fokker–Planck equation for constant \(\beta (\mathbf{w}) =\beta\).

For deriving (5) with general \(\beta (\mathbf{w})\), we can simply apply (12) together with the fact that

$$\begin{aligned} \nabla \left[ \frac{\gamma (t)\beta (\mathbf{w})}{2M(t)}p\right] = \nabla \left[ \frac{\gamma (t)\beta (\mathbf{w})}{2M(t)}\right] p + \frac{\gamma (t)\beta (\mathbf{w})}{2M(t)}\nabla p. \end{aligned}$$

This completes the proof.

Appendix B: Proofs for Section 3

Appendix B.1: Discussion on Main Assumptions (A.1)–(A.3)

We show that Assumptions (A.1)–(A.3) hold for the squared loss and the regularized mean cross-entropy loss. Denote by \(\{({\mathbf {x}}_n,y_n),1\le n\le N\}\) the set of training data. Without loss of generality, let \(\text {Var}[y_n|{\mathbf {x}}_n] = 1\). First, we consider the squared loss with the corresponding risk function

$$\begin{aligned} L(\mathbf{w}) = \big (\mathbf{w}-\mathbf{w}^0\big )^\top \mathbb {E}\big [{\mathbf {x}}_n{\mathbf {x}}_n^\top \big ]\big (\mathbf{w}-\mathbf{w}^0\big )+1, \end{aligned}$$

where \(\mathbf{w}^0\) is the true parameter vector. Since \(\text {Var}[\nabla L_n(\mathbf{w})]\equiv \varvec{\sigma }^2(\mathbf{w})\) is positive definite, we have

$$\begin{aligned} \begin{aligned} \lim _{\Vert \mathbf{w}\Vert \rightarrow +\infty }L(\mathbf{w})&\ge \lim _{\Vert \mathbf{w}\Vert \rightarrow +\infty } \lambda _{\min }\big \{\mathbb {E}\big [{\mathbf {x}}_n{\mathbf {x}}_n^\top \big ]\big \}\Vert \mathbf{w}-\mathbf{w}^0\Vert ^2 +1\\&\ge \lim _{\Vert \mathbf{w}\Vert \rightarrow +\infty } \lambda _{\min }\big \{\mathbb {E}\big [{\mathbf {x}}_n{\mathbf {x}}_n^\top \big ]\big \}\big [\Vert \mathbf{w}\Vert ^2/2-\Vert \mathbf{w}^0\Vert ^2/2\big ] +1 = +\infty , \end{aligned} \end{aligned}$$
(13)

where \(\lambda _{\min }\{\cdot \}\) denotes the minimal eigenvalue. Note that

$$\begin{aligned} \begin{aligned} \int \mathrm{{e}}^{-L(\mathbf{w})} d\mathbf{w}&= \int \exp \left( -(\mathbf{w}-\mathbf{w}^0)^\top \mathbb {E}[{\mathbf {x}}_n{\mathbf {x}}_n^\top ](\mathbf{w}-\mathbf{w}^0)-1\right) d\mathbf{w}\\&\le \int \exp \left( - \lambda _{\min }\{\mathbb {E}[{\mathbf {x}}_n{\mathbf {x}}_n^\top ]\}[\Vert \mathbf{w}\Vert ^2/2-\Vert \mathbf{w}^0\Vert ^2/2]-1\right) d\mathbf{w}< +\infty . \end{aligned} \end{aligned}$$

Hence, Assumption (A.1) holds. To prove (A.2), note that

$$\begin{aligned} \Vert \nabla L(\mathbf{w})\Vert ^2/2 = 2 \big (\mathbf{w}-\mathbf{w}^0\big )^\top \big \{\mathbb {E}\big [{\mathbf {x}}_n{\mathbf {x}}_n^\top \big ]\big \}^2\big (\mathbf{w}-\mathbf{w}^0\big ), \quad \text {Tr}\big (\nabla ^2 L(\mathbf{w})\big ) = \text {Tr}\big \{\mathbb {E}\big [{\mathbf {x}}_n{\mathbf {x}}_n^\top \big ]\big \}. \end{aligned}$$

Similar to (13), we can prove that

$$\begin{aligned} \lim _{\Vert \mathbf{w}\Vert \rightarrow +\infty } \left\{ \Vert \nabla L(\mathbf{w})\Vert ^2/2 - \text {Tr}(\nabla ^2 L(\mathbf{w}))\right\} = + \infty , \quad \lim _{\Vert \mathbf{w}\Vert \rightarrow +\infty } \left\{ \text {Tr}(\nabla ^2 L(\mathbf{w}))/\Vert \nabla L(\mathbf{w})\Vert ^2\right\} = 0. \end{aligned}$$

This finishes the proof for Assumption (A.2). Finally, (A.3) can be shown similarly by following the proof for (A.2) and we omit the details.

Next, we consider the mean cross-entropy loss with the \(l_2\)-penalty for the logistic regression. Without loss of generality, we consider the binary classification:

$$\begin{aligned} L(\mathbf{w}) = \mathbb {E}[-y_n\log \widehat{y}_n -(1-y_n)\log (1-\widehat{y}_n)]+\lambda \Vert \mathbf{w}\Vert ^2 \end{aligned}$$

with \(\widehat{y}_n = (1+\mathrm{{e}}^{-\mathbf{w}\cdot {\mathbf {x}}_n})^{-1}\). Note that

$$\begin{aligned} \lim _{\Vert \mathbf{w}\Vert \rightarrow +\infty }L(\mathbf{w}) \ge \lambda \Vert \mathbf{w}\Vert ^2 = +\infty , \quad \int \mathrm{{e}}^{-L(\mathbf{w})} \mathrm{{d}}\mathbf{w}\le \int \mathrm{{e}}^{-\lambda \Vert \mathbf{w}\Vert ^2}\mathrm{{d}}\mathbf{w}<+\infty \end{aligned}$$

which proves (A.1). For (A.2), since

$$\begin{aligned} \nabla L(\mathbf{w}) = \mathbb {E}[-{\mathbf {x}}_ny_n + {\mathbf {x}}_n/(1+\mathrm{{e}}^{-\mathbf{w}\cdot {\mathbf {x}}_n})]+2\lambda \mathbf{w}, \end{aligned}$$

and

$$\begin{aligned} \text {Tr}(\nabla ^2 L(\mathbf{w})) = \mathbb {E}\left[ \frac{\mathrm{{e}}^{-\mathbf{w}\cdot {\mathbf {x}}_n}}{(1+\mathrm{{e}}^{-\mathbf{w}\cdot {\mathbf {x}}_n})^2}\text {Tr}({\mathbf {x}}_n{\mathbf {x}}_n^\top )\right] +2\lambda d, \end{aligned}$$

we have

$$\begin{aligned} \Vert \nabla L(\mathbf{w})\Vert ^2/2 - \text {Tr}(\nabla ^2 L(\mathbf{w}))\rightarrow \infty , \quad \text {Tr}(\nabla ^2L(\mathbf{w}))/\Vert \nabla L(\mathbf{w})\Vert ^2\rightarrow 0, \text { as } \Vert \mathbf{w}\Vert \rightarrow \infty . \end{aligned}$$

Similarly, Assumption (A.3) can be verified as by following the proof for (A.2).

Appendix B.2: Proof of Lemma 3

By Assumption (A.1), the density function \(p_\infty (\mathbf{w})\equiv \kappa \mathrm{{e}}^{-2M(\infty )L(\mathbf{w})/[\gamma (\infty )\beta ]}\) is well-defined. Moreover, \(p_\infty (\mathbf{w})\) satisfies

$$\begin{aligned} \nabla \cdot \left[ \nabla \left( L(\mathbf{w})+\frac{\gamma (\infty )\beta }{2M(\infty )}\right) p_\infty (\mathbf{w}) + \frac{\gamma (\infty )\beta }{2M(\infty )}\nabla p_\infty (\mathbf{w})\right] =0. \end{aligned}$$

Hence, \(p_\infty (\mathbf{w})\) is a stationary solution to Fokker–Planck equation (5) by letting \(\partial _tp(\mathbf{w},t) =0\).

Appendix B.3: Proof of Theorem 1

Parallel to the notation \(p_\infty (\mathbf{w}) = \kappa \exp (- \frac{2M(\infty )L(\mathbf{w})}{\gamma (\infty )\beta })\) in Lemma 3, we define

$$\begin{aligned} \hat{p}(\mathbf{w},t)\equiv \kappa (t) \exp \left( -\eta (t)L(\mathbf{w})\right) , \end{aligned}$$

where

$$\begin{aligned} \eta (t) \equiv 2M(t)/[\gamma (t)\beta ], \end{aligned}$$
(14)

and \(\kappa (t)\) is a time-dependent normalization factor such that

$$\begin{aligned} \int \hat{p}(\mathbf{w},t)\mathrm{{d}}\mathbf{w}=1. \end{aligned}$$

We can rewrite (5) as

$$\begin{aligned} \partial _tp = \frac{1}{\eta }\nabla _\mathbf{w}\cdot \left( \hat{p}\nabla _\mathbf{w}\left( \frac{p}{\hat{p}}\right) \right) . \end{aligned}$$
(15)

Let

$$\begin{aligned} \delta (t,\mathbf{w})\equiv \frac{\kappa (t)}{\kappa }\exp \left( L(\mathbf{w})\left( {\eta (\infty )} - {\eta (t)}\right) \right) . \end{aligned}$$

Then

$$\begin{aligned} \hat{p}(t,\mathbf{w}) = p_\infty (\mathbf{w})\delta (t,\mathbf{w}). \end{aligned}$$

Denote by \(h(\mathbf{w},t)\) the scaled distance between \(p(\mathbf{w},t)\) and \(p_\infty (\mathbf{w})\):

$$\begin{aligned} h(\mathbf{w},t)\equiv \frac{p(\mathbf{w},t) - p_\infty (\mathbf{w})}{\sqrt{p_\infty (\mathbf{w})}}, \end{aligned}$$

which satisfies the following equation:

$$\begin{aligned} \begin{aligned} \partial _t h &=\frac{1}{\eta \sqrt{p_\infty }}\nabla _\mathbf{w}\cdot \left[ \hat{p}\,\nabla _\mathbf{w}\left( \frac{1}{\delta } + \frac{h}{\sqrt{p_\infty } \delta } \right) \right] \\ &=\frac{1}{\eta \sqrt{p_\infty }}\nabla _\mathbf{w}\cdot \left[ p_\infty \left( \nabla _\mathbf{w}L\hat{\delta }+ \nabla _\mathbf{w}L\hat{\delta }\left( \frac{h}{\sqrt{p_\infty }}\right) +\nabla _\mathbf{w}\left( \frac{h}{\sqrt{p_\infty }}\right) \right) \right] . \end{aligned} \end{aligned}$$
(16)

Here, \(\hat{\delta }\) is defined as \(\hat{\delta }(t) =\eta (t) - \eta (\infty )\), where \(\eta (\infty ) = \lim _{t\rightarrow \infty }\eta (t)\). We multiply h to the both sides of (16) and integrate them over \(\mathbf{w}\). Using the integration by parts, we can obtain

$$\begin{aligned} \begin{aligned} \frac{1}{2}\partial _t\left\Vert h \right\Vert ^2 &=\frac{\hat{\delta }}{\eta } {\underbrace{\int \frac{h}{\sqrt{p_\infty }}\nabla _{\mathbf{w}}\cdot \left( p_\infty \nabla _\mathbf{w}L\right) \mathrm{{d}}\mathbf{w}}_{I}} + \frac{\hat{\delta }}{\eta } {\underbrace{\int \frac{1}{2}\left\Vert \frac{h}{\sqrt{p_\infty }}\right\Vert ^2 {\nabla }_{\mathbf{w}}\cdot \left( {p_\infty }{\nabla }_{\mathbf{w}} L\right) \mathrm{{d}}\mathbf{w}}_{II}}\\&- \frac{1}{\eta } {\underbrace{\int p_\infty \left\Vert \nabla _\mathbf{w}\left( \frac{h}{\sqrt{p_\infty }}\right) \right\Vert ^2\mathrm{{d}}\mathbf{w}}_{III}}. \end{aligned} \end{aligned}$$
(17)

We study the parts I, II, III in the right-hand side of above equation separately.

For the part I, note that

$$\begin{aligned} \nabla _\mathbf{w}\cdot \left( p_\infty \nabla _\mathbf{w}L\right) = p_\infty \left( \nabla _\mathbf{w}\cdot \nabla _\mathbf{w}L - \eta (\infty )\left\Vert \nabla _\mathbf{w}L\right\Vert ^2 \right) . \end{aligned}$$

Hence, Assumption (A.3) yields that

$$\begin{aligned} \left|\nabla _\mathbf{w}\cdot \left( p_\infty \nabla _\mathbf{w}L\right) \right|\le p_\infty ^{2/3}\max \{1 , \eta (\infty )\}M(\infty ), \end{aligned}$$

which implies that an upper bound of part I in (17):

$$\begin{aligned} I \le \frac{\max \{1 , \eta (\infty )\}M(\infty )}{2}\left( \left\Vert h \right\Vert ^2 + \int p_\infty ^{1/3}\mathrm{{d}}\mathbf{w}\right) . \end{aligned}$$

For the part II, note that Assumption (A.3) gives

$$\begin{aligned} \lim _{\left\Vert \mathbf{w}\right\Vert \rightarrow \infty } \frac{\nabla _\mathbf{w}\cdot \nabla _\mathbf{w}L}{2\eta (\infty )\left\Vert \nabla _\mathbf{w}L\right\Vert ^2 } = 0, \end{aligned}$$

which together with Assumption (A.2) implies that

$$\begin{aligned} \lim _{\left\Vert \mathbf{w}\right\Vert \rightarrow \infty } \left\Vert \nabla _\mathbf{w}L\right\Vert ^2 \rightarrow +\infty . \end{aligned}$$

Thus, there exists a constant R, such that

$$\begin{aligned} \nabla _\mathbf{w}\cdot \nabla _\mathbf{w}L - 2\eta (\infty )\left\Vert \nabla _\mathbf{w}L\right\Vert ^2 \le \eta (\infty ), \quad \eta (\infty )\left\Vert \nabla _\mathbf{w}L\right\Vert ^2 \ge \eta (\infty ),\quad \text {for }\forall \left\Vert \mathbf{w}\right\Vert > R. \end{aligned}$$

Hence,

$$\begin{aligned} \nabla _\mathbf{w}\cdot \nabla _\mathbf{w}L - \eta (\infty )\left\Vert \nabla _\mathbf{w}L\right\Vert ^2 \le 0, \quad \text {for }\forall \left\Vert \mathbf{w}\right\Vert > R. \end{aligned}$$

By the continuity of \(L(\mathbf{w})\), there exists a constant \(C_2\) such that

$$\begin{aligned} \left|\nabla _\mathbf{w}\cdot \nabla _\mathbf{w}L - \eta (\infty )\left\Vert \nabla _\mathbf{w}L\right\Vert ^2\right|\le C_2, \quad \text {for }\forall \left\Vert \mathbf{w}\right\Vert <R. \end{aligned}$$

Therefore, we have the following upper bound for the part II in (17):

$$\begin{aligned} \left|II \right|\le \frac{C_2}{2} \left\Vert h \right\Vert ^2 . \end{aligned}$$

By combining the estimates for the parts I and II, we have

$$\begin{aligned} I+II \le C_1 \left\Vert h \right\Vert ^2 + C_1, \end{aligned}$$

where \(C_1 = \frac{1}{2}\max \{1 , \eta (\infty )\}\max \left\{ \int p_\infty ^{1/3}\mathrm{{d}}\mathbf{w}, 1+C_2/2\right\} M(\infty )\).

For the part III, note that Assumption (A.2) implies the following Poincaré inequality (see, e.g., [24]),

$$\begin{aligned} \int \left\Vert \nabla _\mathbf{w}\left( \frac{h}{\sqrt{p_\infty }}\right) \right\Vert ^2 p_\infty \,\mathrm{{d}}\mathbf{w}\ge C_P\int \left( \frac{h}{\sqrt{p_\infty }} - \int h\sqrt{p_\infty } \mathrm{{d}}\mathbf{w}\right) ^2 p_\infty \,\mathrm{{d}}\mathbf{w}. \end{aligned}$$
(18)

We need to show that

$$\begin{aligned} \int h\sqrt{p_\infty }\,\mathrm{{d}}\mathbf{w}= 0. \end{aligned}$$
(19)

The (19) can be proven using the conservation of mass. In particular, if we integrate (15) over \(\mathbf{w}\) and use the integration by parts,

$$\begin{aligned} \partial _t\left( \int p(\mathbf{w},t)\, \mathrm{{d}}\mathbf{w}\right) = 0, \end{aligned}$$

which implies \(\int h\sqrt{p_\infty } \,\mathrm{{d}}\mathbf{w}= \int p\, \mathrm{{d}}\mathbf{w}- \int p_\infty \,\mathrm{{d}}\mathbf{w}= 0\). Combining (18) with (19) gives a lower bound for the part III:

$$\begin{aligned} III \ge C_P\left\Vert h \right\Vert ^2. \end{aligned}$$

Combining () and () gives

$$\begin{aligned} \frac{1}{2}\partial _t\left\Vert h \right\Vert ^2 + \frac{C_P}{\eta } \left\Vert h \right\Vert ^2 \le \frac{C_1\hat{\delta }}{\eta } \left( \left\Vert h \right\Vert ^2 +1\right) \end{aligned}$$
(20)

Since \(\eta (t) \rightarrow \eta (\infty ) >0\) as \(t\rightarrow \infty\), there exists some T large enough and for \(\forall t>T\),

$$\begin{aligned} \hat{\delta }= \left|\eta (t) - \eta (\infty ) \right|\le \min \left\{ \frac{\eta (\infty )}{3}, \frac{C_P}{3C_1}\right\} . \end{aligned}$$
(21)

Plugging \(\hat{\delta }\le C_P/3C_1\) into (20), we have

$$\begin{aligned} \frac{1}{2}\partial _t\left\Vert h \right\Vert ^2 + \frac{2C_P}{3\eta } \left\Vert h \right\Vert ^2 \le \frac{C_P}{3\eta }, \quad \text {for }\forall t>T . \end{aligned}$$
(22)

Note that (21) also implies that \(2\eta (\infty )/3 \le \eta (t) \le 4\eta (\infty )/3\). Thus,

$$\begin{aligned} \frac{2C_P}{3\eta }\ge \frac{C_P}{2\eta (\infty )}, \quad \frac{C_P}{3\eta } \le \frac{C_P}{2\eta (\infty )}. \end{aligned}$$

Plugging back to (22), we arrive at

$$\begin{aligned} \frac{1}{2}\partial _t\left\Vert h \right\Vert ^2 + \frac{C_P}{2\eta (\infty )} \left\Vert h \right\Vert ^2 \le \frac{C_P}{2\eta (\infty )}, \quad \text {for }\forall t>T . \end{aligned}$$

Integrating the above equation from T to \(t>T\), we have

$$\begin{aligned} \left\Vert h(t) \right\Vert ^2 \le \left( \left\Vert h(T) \right\Vert ^2 + \frac{C_P}{\eta (\infty )}(t-T)\right) - \frac{C_P}{\eta (\infty )}\int _T^t \left\Vert h(s) \right\Vert ^2 \mathrm{{d}}s. \end{aligned}$$

By Gronwall’s Inequality, we finally get

$$\begin{aligned} \left\Vert h(t)\right\Vert ^2 \le \left( \frac{C_P}{\eta (\infty )}(t-T) + \left\Vert h(T) \right\Vert ^2\right) \exp \left( -\frac{C_P}{\eta (\infty )}(t-T)\right) . \end{aligned}$$

This completes the proof.

Appendix B.4: Quantification of T in Theorem 1

We quantify T by giving a condition that a minimum T should satisfy. From the proof in Sect. 3, it is clear that T should be large enough such that for all \(t>T\),

$$\begin{aligned} \left|\eta (t) - \eta (\infty ) \right|\le \min \left\{ \frac{\eta (\infty )}{3}, \frac{C_P}{3C_1}\right\} , \end{aligned}$$

where \(\eta (t)\) is defined in (14) and \(\eta (\infty ) = \lim _{t\rightarrow \infty }\eta (t)\), and

$$\begin{aligned} C_1 = \frac{M}{2}\max \{1 , \eta (\infty )\}\max \left\{ \int p_\infty ^{1/3}\mathrm{{d}}\mathbf{w}, 1+\frac{C_2}{2}\right\} , \end{aligned}$$

and \(C_2>0\) is an upper bound for \(\left|\nabla _\mathbf{w}\cdot \nabla _\mathbf{w}L - \eta (\infty )\left\Vert \nabla _\mathbf{w}L\right\Vert ^2\right|\) in the bounded domain \(\{\left\Vert \mathbf{w}\right\Vert <R\}\) such that

$$\begin{aligned} \nabla _\mathbf{w}\cdot \nabla _\mathbf{w}L - \eta (\infty )\left\Vert \nabla _\mathbf{w}L\right\Vert ^2 \le \left\{ \begin{aligned}&0, \quad \text {for }\forall \left\Vert \mathbf{w}\right\Vert > R,\\&C_2, \quad \text {for }\forall \left\Vert \mathbf{w}\right\Vert <R. \end{aligned} \right. \end{aligned}$$

Appendix B.5: Proof of Theorem 2

Denote by \(P_\epsilon (\check{\mathbf{w}}) = {\mathbb {P}}(\left\Vert \mathbf{W}(\infty ) - \check{\mathbf{w}}\right\Vert \le \epsilon )\) the probability that \(\mathbf{W}(\infty )\) is trapped in an \(\epsilon\)-neighborhood of the minimum \(\check{\mathbf{w}}\). Recall the probability density function of \(\mathbf{W}(\infty )\) is \(p_\infty (\mathbf{w})\). Then,

$$\begin{aligned} \begin{aligned} P_\epsilon (\check{\mathbf{w}})&= \int _{\left\Vert \mathbf{w}- \check{\mathbf{w}}\right\Vert ^2\le \epsilon ^2}\kappa \mathrm{{e}}^{-\eta (\infty ) L(\mathbf{w})}\mathrm{{d}}\mathbf{w}\\&= \int _{\left\Vert \mathbf{w}- \check{\mathbf{w}}\right\Vert ^2\le \epsilon ^2} \kappa \exp \left( -\eta (\infty )[L(\check{\mathbf{w}})+(\mathbf{w}-\check{\mathbf{w}})'\nabla ^2L(\check{\mathbf{w}})(\mathbf{w}-\check{\mathbf{w}})+o\{(\mathbf{w}- \check{\mathbf{w}})^2\}]\right) \mathrm{{d}}\mathbf{w}, \end{aligned} \end{aligned}$$

where \(\eta (t)\) is defined in (14) and \(\eta (\infty ) = \lim _{t\rightarrow \infty }\eta (t)\). Since \(\check{\mathbf{w}}\) is a local minimum of \(L(\mathbf{w})\), \(\nabla ^2L(\check{\mathbf{w}})\) is positive definite. There exists an orthogonal matrix O and diagonal matrix \(F\) such that \(\nabla ^2L = O' FO\). For simplicity, we assume that \(\nabla ^2L = F= \text {diag}(\lambda _{\min }, \cdots , \lambda _d)\). Then,

$$\begin{aligned} \begin{aligned}&\lim _{\epsilon \rightarrow 0}P_\epsilon (\check{\mathbf{w}})\\&= \lim _{\epsilon \rightarrow 0}\left[ \kappa \mathrm{{e}}^{-\eta (\infty ) L(\check{\mathbf{w}})} \int _{\left\Vert \mathbf{w}\right\Vert ^2\le \epsilon ^2} \prod _{j=1}^d \mathrm{{e}}^{-\eta (\infty ) \lambda _j w_j} \mathrm{{d}}\mathbf{w}\right] \mathrm{{e}}^{\eta (\infty )\epsilon ^2}\\&= \lim _{\epsilon \rightarrow 0}\left[ \kappa \mathrm{{e}}^{-\eta (\infty ) L(\check{\mathbf{w}})}\prod _{j=1}^d\frac{1}{\sqrt{\eta (\infty )\lambda _j}}\int _{-\epsilon \sqrt{\eta (\infty ) \lambda _j}}^{\epsilon \sqrt{\eta (\infty ) \lambda _j}}\mathrm{{e}}^{-w^2}\mathrm{{d}}w\right] \mathrm{{e}}^{\eta (\infty )\epsilon ^2}\\&=\lim _{\epsilon \rightarrow 0}\left[ \kappa \eta (\infty )^{-d/2} \mathrm{{e}}^{-\eta (\infty ) L(\check{\mathbf{w}})}\prod _{j=1}^d\frac{1}{\sqrt{\lambda _j}}\left( \Phi \left( \epsilon \sqrt{\eta (\infty ) \lambda _j}\right) - \Phi \left( -\epsilon \sqrt{\eta (\infty ) \lambda _j}\right) \right) \right] \mathrm{{e}}^{\eta (\infty )\epsilon ^2}, \end{aligned} \end{aligned}$$

where \(\Phi (\cdot )\) is the cumulative density function for standard normal distribution. The first equality is from the change of variable by writing \(\mathbf{w}- \check{\mathbf{w}}\) as \(\mathbf{w}\). The second equality is from changing \(\eta (\infty )\lambda _j\mathbf{w}_j\) to \(\mathbf{w}_j\). Using the approximation of the cumulative density function in Pólya [25], we can simplify the above equation as

$$\begin{aligned} \begin{aligned} \lim _{\epsilon \rightarrow 0}P_\epsilon (\check{\mathbf{w}}) &=\lim _{\epsilon \rightarrow 0}\left[ \frac{\kappa \mathrm{{e}}^{-2\eta (\infty ) L(\check{\mathbf{w}})}}{\eta (\infty )^{d/2}}\prod _{j=1}^d\sqrt{\frac{1-\mathrm{{e}}^{-\epsilon ^2\eta (\infty )\lambda _j/\pi }}{\lambda _j}}\right] \mathrm{{e}}^{\eta (\infty )\epsilon ^2} \\ &=\frac{\kappa \mathrm{{e}}^{-2\eta (\infty ) L(\check{\mathbf{w}})}}{\eta (\infty )^{d/2}|\nabla ^2L(\check{\mathbf{w}})|} \lim _{\epsilon \rightarrow 0}\left[ \mathrm{{e}}^{\eta (\infty )\epsilon ^2}\prod _{j=1}^d\sqrt{1-\mathrm{{e}}^{-\epsilon ^2\eta (\infty )\lambda _j/\pi }}\right] . \end{aligned} \end{aligned}$$

We complete the proof.

Appendix B.6: Proof of Equation 8

Denote by \(\lambda _j^k\)’s are eigenvalues of the Hessian \(\nabla ^2L(\check{\mathbf{w}}_k)\), \(k=1,2\) and \(j\ge 1\). By Theorem 2 and \(L(\check{\mathbf{w}}_1)=L(\check{\mathbf{w}}_2)\), we have that

$$\begin{aligned} \begin{aligned}&\lim _{\epsilon \rightarrow 0}\frac{{\mathbb {P}}(|\mathbf{W}(\infty ) - \check{\mathbf{w}}_1|\le \epsilon ) }{{\mathbb {P}}(|\mathbf{W}(\infty ) - \check{\mathbf{w}}_2|\le \epsilon )} = \frac{\left|\nabla ^2L(\check{\mathbf{w}}_2)\right|}{\left|\nabla ^2L(\check{\mathbf{w}}_1)\right|}\sqrt{\lim _{\epsilon \rightarrow 0}\prod _{j=1}^d\frac{1-\exp \left( -\frac{\epsilon ^2\eta (\infty )\lambda ^1_j}{\pi }\right) }{1-\exp \left( -\frac{\epsilon ^2\eta (\infty )\lambda ^2_j}{\pi }\right) }}\\ &=\frac{\left|\nabla ^2L(\check{\mathbf{w}}_2)\right|}{\left|\nabla ^2L(\check{\mathbf{w}}_1)\right|}\sqrt{\lim _{\epsilon \rightarrow 0}\prod _{j=1}^d\frac{\lambda ^1_j\exp \left( -\frac{\epsilon ^2\eta (\infty )\lambda ^2_j}{\pi }\right) }{\lambda ^2_j\exp \left( -\frac{\epsilon ^2\eta (\infty )\lambda ^2_j}{\pi }\right) }} = \frac{\left|\nabla ^2L(\check{\mathbf{w}}_2)\right|}{\left|\nabla ^2L(\check{\mathbf{w}}_1)\right|}\sqrt{\prod _{j=1}^d\frac{\lambda ^1_j}{\lambda ^2_j}} = \sqrt{ \frac{\left|\nabla ^2L(\check{\mathbf{w}}_2)\right|}{\left|\nabla ^2L(\check{\mathbf{w}}_1)\right|}}, \end{aligned} \end{aligned}$$

where \(\eta (t)\) is defined in (14) and \(\eta (\infty ) = \lim _{t\rightarrow \infty }\eta (t)\).

Appendix C: Proofs for Section 4

Appendix C.1: Derivation of SDE for MSGD

For constant learning rate and batch size: \(\gamma _k\equiv \gamma , M_k\equiv M\), we rewrite the MSGD as

$$\begin{aligned} \begin{aligned}&\frac{{\mathbf {z}}_{k+1}}{\sqrt{\gamma }} = \frac{{\mathbf {z}}_k}{\sqrt{\gamma }} + \sqrt{\gamma }\left( - \frac{1 - \xi }{\gamma } {\mathbf {z}}_k - \nabla L(\mathbf{w}_k)\right) + \sqrt{\gamma }\left( \nabla L(\mathbf{w}_k) - \left( \frac{1}{M}\sum _{n\in B_k}\nabla L_n(\mathbf{w}_k)\right) \right) \\&\mathbf{w}_{k+1} = \mathbf{w}_k + \frac{{\mathbf {z}}_{k+1}}{\sqrt{\gamma }}\sqrt{\gamma }. \end{aligned} \end{aligned}$$

Let \({\mathbf {v}}_k = {\mathbf {z}}_k/\sqrt{\gamma }\). We have the approximation for MSGD

$$\begin{aligned} \begin{aligned}&{\mathbf {v}}_{k+1} - {\mathbf {v}}_k = - \frac{1 - \xi }{\sqrt{\gamma }} {\mathbf {v}}_k \sqrt{\gamma } - \nabla L(\mathbf{w}_k) \sqrt{\gamma } + \frac{\gamma ^{1/4}}{\sqrt{M}} \sqrt{\beta }\nabla ^2B_t,\\&\mathbf{w}_{k+1} - \mathbf{w}_k = {\mathbf {v}}_{k+1}\sqrt{\gamma }, \end{aligned} \end{aligned}$$

where \(\beta (\mathbf{w})\) is the covariance function defined in (4). Hence, MSGD is approximated as the Euler–Maruyama discretization for the following SDE,

$$\begin{aligned} \left\{ \begin{aligned}&\mathrm{{d}}{\mathbf {V}}(t) = -\nabla L(\mathbf{W}(t)) \mathrm{{d}}t - \frac{1 - \xi }{\sqrt{\gamma }} {\mathbf {V}}(t) \mathrm{{d}}t + \frac{\gamma ^{1/4}}{\sqrt{M}}\sqrt{\beta (\mathbf{W}(t))} \mathrm{{d}}{\mathbf {B}}(t),\\&\mathrm{{d}} \mathbf{W}(t) = {\mathbf {V}}(t)\mathrm{{d}}t, \end{aligned} \right. \end{aligned}$$

where \({\mathbf {v}}_k \approx {\mathbf {V}}({k\sqrt{\gamma }})\), \(\mathbf{w}_k \approx \mathbf{W}({k\sqrt{\gamma }})\).

Appendix C.2: Proof of Lemma 4

We give a formal derivation, which is similar to the procedure in Pavliotis [24]. Let \(\phi (\cdot ,\cdot )\) be any bivariate function in \(C^\infty\) with a compact support. Using the It\(\hat{o}\)’s formula,

$$\begin{aligned} \begin{aligned}&\mathrm{{d}}\phi (\mathbf{W}(t), {\mathbf {V}}(t)) = \frac{\gamma ^{1/4}}{\sqrt{M}}\sqrt{\beta } \nabla _\mathbf{w}\cdot \nabla _\mathbf{w}\phi \mathrm{{d}}{\mathbf {B}}(t)\\&\quad \quad + \left( {\mathbf {V}}(t)\cdot \nabla _\mathbf{w}\phi + \left( -\nabla L(\mathbf{W}(t)) - \frac{1 - \xi }{\sqrt{\gamma }} {\mathbf {V}}(t)\right) \cdot \nabla _{\mathbf {v}}\phi + \frac{\gamma ^{1/2}}{2M}\beta (\mathbf{W}(t)) \nabla _{\mathbf {v}}\cdot \nabla _{\mathbf {v}}\phi \right) \mathrm{{d}}t. \end{aligned} \end{aligned}$$

By taking the expectation of the above equation and integrating it over the range \([t, t+h]\), we obtain that

$$\begin{aligned} \begin{aligned}&\frac{1}{h}\mathbb {E}\left( \phi (\mathbf{W}(t+h), {\mathbf {V}}(t+h)) - \phi (\mathbf{W}(t), {\mathbf {V}}(t))\right) \\ &=\frac{1}{h}\int _t^{t+h} \mathbb {E}\left( {\mathbf {V}}(s)\cdot \nabla _\mathbf{w}\phi + \left( -\nabla L(\mathbf{W}(s)) - \frac{1 - \xi }{\sqrt{\gamma }} {\mathbf {V}}(s)\right) \cdot \nabla _{\mathbf {v}}\phi + \frac{\gamma ^{1/2}\beta (\mathbf{W}(s))}{2M} \nabla _{\mathbf {v}}\cdot \nabla _{\mathbf {v}}\phi \right) \mathrm{{d}}s . \end{aligned} \end{aligned}$$

Let \(\psi (\mathbf{w},{\mathbf {v}},t)\) be the joint probability density function of \((\mathbf{W}(t), {\mathbf {V}}(t))\). The above equation can also be written as

$$\begin{aligned} \begin{aligned}&\frac{1}{h}\int \phi (\mathbf{w}, {\mathbf {v}}) \left( \psi (\mathbf{w},{\mathbf {v}},t+h) - \psi (\mathbf{w},{\mathbf {v}},t)\right) \,\mathrm{{d}}\mathbf{w}\,\mathrm{{d}}{\mathbf {v}}\\ &=\frac{1}{h}\int _t^{t+h} \int \left( {\mathbf {v}}\cdot \nabla _\mathbf{w}\phi + \left( -\nabla L(\mathbf{w}) - \frac{1 - \xi }{\sqrt{\gamma }} {\mathbf {v}}\right) \cdot \nabla _{\mathbf {v}}\phi + \frac{\gamma ^{1/2}\beta (\mathbf{w})}{2M} \nabla _{\mathbf {v}}\cdot \nabla _{\mathbf {v}}\phi \right) \psi (\mathbf{w},{\mathbf {v}},s)\,\mathrm{{d}}\mathbf{w}\,\mathrm{{d}}{\mathbf {v}}\, \mathrm{{d}}s. \end{aligned} \end{aligned}$$

Then, using the integration by parts and letting \(h\rightarrow 0\) gives

$$\begin{aligned} \begin{aligned}&\int \phi (\mathbf{w}, {\mathbf {v}}) \partial _t\psi \,\mathrm{{d}}\mathbf{w}\,\mathrm{{d}}{\mathbf {v}}\\&= \int \phi \left( -{\mathbf {v}}\cdot \nabla _\mathbf{w}\psi +\nabla L(\mathbf{w})\cdot \nabla _{\mathbf {v}}\psi +\nabla _{\mathbf {v}}\cdot \left( \frac{1 - \xi }{\sqrt{\gamma }} {\mathbf {v}}\psi \right) + \frac{\gamma ^{1/2}\beta (\mathbf{w})}{2M}\nabla _{\mathbf {v}}\cdot \nabla _{\mathbf {v}}\psi \right) \,\mathrm{{d}}\mathbf{w}\,\mathrm{{d}}{\mathbf {v}}, \end{aligned} \end{aligned}$$

which is satisfied for any test functions. Therefore, the density function \(\psi (\mathbf{w},{\mathbf {v}},t)\) satisfies

$$\begin{aligned} \partial _t\psi + {\mathbf {v}}\cdot \nabla _\mathbf{w}\psi -\nabla L(\mathbf{w})\cdot \nabla _{\mathbf {v}}\psi = \nabla _{\mathbf {v}}\cdot \left( \frac{1 - \xi }{\sqrt{\gamma }} {\mathbf {v}}\psi + \frac{\gamma ^{1/2}\beta (\mathbf{w})}{2M} \nabla _{\mathbf {v}}\psi \right) , \end{aligned}$$

which agrees with (9).

Next, we can verify that \(\psi _\infty (\mathbf{w},{\mathbf {v}})\) is a stationary solution of the Vlasov-Fokker–Planck equation (9) by a direct calculation as “Appendix B.2”.

Appendix C.3: Discussion on Assumption (A.4)

We show that Assumption (A.4) holds for the squared loss and the regularized mean cross-entropy loss. Denote by \(\{({\mathbf {x}}_n,y_n),1\le n\le N\}\) the set of training data. Without loss of generality, let \(\text {Var}[y_n|{\mathbf {x}}_n] = 1\). For the squared loss,

$$\begin{aligned} \tilde{L}(\mathbf{w}) = \big (\mathbf{w}-\mathbf{w}^0\big )^\top \mathbb {E}[{\mathbf {x}}_n{\mathbf {x}}_n^\top ]\big (\mathbf{w}-\mathbf{w}^0\big )+1 - \frac{1}{2}C_L^2\Vert \mathbf{w}\Vert ^2, \end{aligned}$$

where \(\mathbf{w}^0\) is the true parameter vector. By a direct calculation,

$$\begin{aligned} \nabla ^2\tilde{L}(\mathbf{w}) = 2\mathbb {E}[{\mathbf {x}}_n{\mathbf {x}}_n^\top ] - C^2_L. \end{aligned}$$

Since the eigenvalues of the design matrix \(\mathbb {E}[{\mathbf {x}}_n{\mathbf {x}}_n^\top ]\) are bounded, the eigenvalues of \(\nabla ^2\tilde{L}(\mathbf{w})\) are bounded for any \(C_L\). Hence, Assumption (A.4) holds for the squared loss.

Nest, we consider the regularized mean cross-entropy loss for the logistic regression. Similar to “Appendix B.1,” letting \(C_L = \sqrt{2\lambda }\) yields that

$$\begin{aligned} \tilde{L}(\mathbf{w}) = \mathbb {E}[-y_n\log \widehat{y}_n -(1-y_n)\log (1-\widehat{y}_n)]. \end{aligned}$$

The (ij)th entry of the Hessian \(\nabla ^2L(\mathbf{w})\) is

$$\begin{aligned} (\nabla ^2L(\mathbf{w}))_{ij} = \mathbb {E}\left[ x_{ni}x_{nj}\frac{\mathrm{{e}}^{-\mathbf{w}\cdot {\mathbf {x}}_n}}{(1+\mathrm{{e}}^{-\mathbf{w}\cdot {\mathbf {x}}_n})^2}\right] , \end{aligned}$$

where \(x_{ni}\) is the ith element of \({\mathbf {x}}_n\). Then,

$$\begin{aligned} (\nabla ^2L(\mathbf{w}))_{ij} \rightarrow 0\quad \text {as } \Vert \mathbf{w}\Vert \rightarrow \infty , \end{aligned}$$

which implies that there exists finite constant \(b_{ij}>0\) such that \(\Vert (\nabla ^2L(\mathbf{w}))_{ij} \Vert _{\infty }\le b_{ij}\) and the largest row sum of the matrix \(\{\Vert (\nabla ^2L)_{ij}\Vert _{\infty }\}_{1\le i,j\le d}\) is upper bounded by \(b \equiv \max _i(\sum _{j}b_{ij})\). Since the largest eigenvalue of a non-negative matrix is upper bounded by its largest row sum, the eigenvalues of \(\{\Vert (\nabla ^2L)_{ij}\Vert _{\infty }\}_{1\le i,j\le d}\) are bounded by b. Hence, Assumption (A.4) also holds for the regularized mean cross-entropy loss.

Appendix C.4: Proof of Theorem 3

Recall the function defined in Theorem 3:

$$\begin{aligned} h(\mathbf{w}, {\mathbf {v}},t) \equiv \frac{\psi (t,\mathbf{w},{\mathbf {v}}) - \psi _\infty (\mathbf{w},{\mathbf {v}})}{\ \psi _\infty (\mathbf{w},{\mathbf {v}})}, \end{aligned}$$

which is the weighted fluctuation function around the stationary solution \(\psi _\infty (\mathbf{w},{\mathbf {v}})\). Then, \(h(\mathbf{w}, {\mathbf {v}},t)\) satisfies the following partial differential equation,

$$\begin{aligned} \partial _th + Th = Fh, \end{aligned}$$
(23)

where

$$\begin{aligned} \begin{aligned}&T= {\mathbf {v}}\cdot \nabla _\mathbf{w}- \nabla L(\mathbf{w})\cdot \nabla _{\mathbf {v}}\quad \text { is the transport operator};\\&F= \frac{\gamma ^{1/2}\beta }{2M}\frac{1}{\psi _\infty }\nabla _{\mathbf {v}}\cdot \left( \psi _\infty \nabla _{\mathbf {v}}\right) \quad \text { is the Fokker Planck operator}.\\ \end{aligned} \end{aligned}$$

Also recall the norm \(\left\Vert \cdot \right\Vert _*\) defined in Theorem 3:

$$\begin{aligned} \begin{aligned} \text {For any }h(\mathbf{w},{\mathbf {v}},t), g(\mathbf{w},{\mathbf {v}},t):\quad&\left\langle h,g \right\rangle _* = \int hg \psi _\infty \,\mathrm{{d}}\mathbf{w}d{\mathbf {v}}, \quad \left\Vert h \right\Vert ^2_* = \int \left|h\right|^2 \psi _\infty \,\mathrm{{d}}\mathbf{w}d{\mathbf {v}}, \end{aligned} \end{aligned}$$

Lemma 5

One have the following properties for the operator \(T, F\):

  1. (1)

    \(\displaystyle \left\langle Tf, g \right\rangle _*= -\left\langle f, Tg \right\rangle _*\),

  2. (2)

    \(\displaystyle \left\langle Tf, f \right\rangle _*= 0\),

  3. (3)

    \(\displaystyle \left\langle Ff, g \right\rangle _*= -\frac{\gamma ^{1/2}\beta }{2M} \left\langle \nabla _{\mathbf {v}}f, \nabla _{\mathbf {v}}g \right\rangle _*\).

This lemma can be verified by direct calculations, and we omit the details. These properties of operators \(F, T\) will be frequently used later.

Lemma 6

For the positive definite matrix P defined in (10), the function \(h(t,\mathbf{w},{\mathbf {v}})\)satisfies

$$\begin{aligned} \begin{aligned}&\frac{1}{2}\frac{\mathrm{{d}}}{\mathrm{{d}}t}H(t) +\frac{1}{2}\int [ \nabla _\mathbf{w}h, \nabla _{\mathbf {v}}h] K [ \nabla _\mathbf{w}h, \nabla _{\mathbf {v}}h]^\top \psi _\infty \mathrm{{d}}\mathbf{w}\mathrm{{d}}{\mathbf {v}}\\&\quad \le \left\langle \nabla ^2\tilde{L}\nabla _{\mathbf {v}}h, \nabla _\mathbf{w}h \right\rangle _*+ \left\langle \nabla ^2\tilde{L}\nabla _{\mathbf {v}}h, \nabla _{\mathbf {v}}h \right\rangle _*\\ \end{aligned} \end{aligned}$$

where the modified risk function \(\tilde{L}\)is defined in Assumption (A.4), and

$$\begin{aligned} K \equiv \left[ \begin{aligned}&2\hat{C}I_d&(C - C_L^2+\gamma \hat{C})I_d\\&(C - C_L^2+\gamma \hat{C})I_d&(2\gamma C - 2C_L^2\hat{C})I_d \end{aligned}\right] . \end{aligned}$$
(24)

Proof

Taking the gradient \(\nabla _\mathbf{w}\) to (23) and multiplying it by \(\nabla _\mathbf{w}h \psi _\infty\) gives

$$\begin{aligned} \frac{1}{2}\partial _t\left\Vert \nabla _\mathbf{w}h \right\Vert ^2_*- \left\langle T\nabla _\mathbf{w}h , \nabla _\mathbf{w}h \right\rangle _*- \left\langle \nabla ^2L \nabla _{\mathbf {v}}h, \nabla _\mathbf{w}h \right\rangle _*= \left\langle F\nabla _\mathbf{w}h, \nabla _\mathbf{w}h \right\rangle _*\end{aligned}$$

Them, applying Lemma 5 yields,

$$\begin{aligned} \frac{1}{2}\partial _t\left\Vert \nabla _\mathbf{w}h \right\Vert ^2_*- \left\langle \nabla ^2L \nabla _{\mathbf {v}}h, \nabla _\mathbf{w}h \right\rangle _*= -\frac{\gamma ^{1/2}\beta }{2M}\sum _{i = 1}^d \left\Vert \partial _{v_i}\nabla _\mathbf{w}h \right\Vert ^2_*. \end{aligned}$$

By Assumption (A.4), we have

$$\begin{aligned} \frac{1}{2}\partial _t\left\Vert \nabla _\mathbf{w}h \right\Vert ^2_*- C_L^2\left\langle \nabla _{\mathbf {v}}h, \nabla _\mathbf{w}h \right\rangle _*= -\frac{\gamma ^{1/2}\beta }{2M}\sum _{i = 1}^d \left\Vert \partial _{v_i}\nabla _\mathbf{w}h \right\Vert ^2_*+ \left\langle \nabla ^2\tilde{L}\nabla _{\mathbf {v}}h, \nabla _\mathbf{w}h \right\rangle _*. \end{aligned}$$
(25)

Similarly, taking the gradient \(\nabla _{\mathbf {v}}\) to (23), multiplying it by \(\nabla _{\mathbf {v}}h \psi _\infty\) and applying Lemma 5 gives,

$$\begin{aligned} \frac{1}{2}\partial _t\left\Vert \nabla _{\mathbf {v}}h \right\Vert ^2_*+ \left\langle \nabla _\mathbf{w}h, \nabla _{\mathbf {v}}h \right\rangle _*= -\frac{\gamma ^{1/2}\beta }{2M}\sum _{i = 1}^d \left\Vert \partial _{v_i}\nabla _{\mathbf {v}}h \right\Vert ^2_*- \frac{1 - \xi }{\sqrt{\gamma }}\left\Vert \nabla _{\mathbf {v}}h \right\Vert ^2_*. \end{aligned}$$
(26)

Taking the gradient \(\nabla _{\mathbf {v}}\) to (23) and multiply it by \(\nabla _\mathbf{w}h \psi _\infty\), then taking the gradient \(\nabla _\mathbf{w}\) to (23) and multiply it by \(\nabla _{\mathbf {v}}h \psi _\infty\), and combine the results gives,

$$\begin{aligned} \begin{aligned}&\partial _t\left\langle \nabla _\mathbf{w}h, \nabla _{\mathbf {v}}h \right\rangle _*-C_L^2 \left\langle \nabla _\mathbf{w}h, \nabla _{\mathbf {v}}h \right\rangle _*+ \left\Vert \nabla _\mathbf{w}h \right\Vert ^2_*\\&\quad = -\frac{\gamma ^{1/4}\sqrt{\beta }}{\sqrt{M}}\sum _{i = 1}^d \left\langle \partial _{v_i}\nabla _{\mathbf {v}}h , \partial _{v_i}\nabla _\mathbf{w}h \right\rangle _*- \frac{1 - \xi }{\sqrt{\gamma }}\left\langle \nabla _{\mathbf {v}}h, \nabla _\mathbf{w}h \right\rangle _*+ \left\langle \nabla ^2\tilde{L}\nabla _{\mathbf {v}}h, \nabla _{\mathbf {v}}h \right\rangle _*. \end{aligned} \end{aligned}$$
(27)

Finally, (25) \(+\)\(C\cdot\)(26) \(+\) 2\(\hat{C}\cdot\) (25) yields

$$\begin{aligned} \begin{aligned}&\frac{1}{2}\partial _tH(t) + \frac{\gamma ^{1/2}\beta }{2M}\sum _{i = 1}^d \int [ \partial _{v_i}\nabla _\mathbf{w}h, \partial _{v_i}\nabla _{\mathbf {v}}h]^\top P [ \partial _{v_i}\nabla _\mathbf{w}h, \partial _{v_i}\nabla _{\mathbf {v}}h] \mathrm{{d}}\mathbf{w}d{\mathbf {v}}\\&+ \frac{1}{2}\int [\nabla _\mathbf{w}h, \nabla _{\mathbf {v}}h]^\top K [\nabla _\mathbf{w}h, \nabla _{\mathbf {v}}h] \mathrm{{d}}\mathbf{w}d{\mathbf {v}}= \left\langle \nabla ^2\tilde{L}\nabla _\mathbf{w}h, \nabla _{\mathbf {v}}h \right\rangle _*+ \left\langle \nabla ^2\tilde{L}\nabla _{\mathbf {v}}h, \nabla _{\mathbf {v}}h \right\rangle _*, \end{aligned} \end{aligned}$$
(28)

where function H(t) and the positive definite matrix P are defined in (10). The positive definite property of P implies that

$$\begin{aligned} \frac{\gamma ^{1/2}\beta }{2M}\sum _{i = 1}^d \int [ \partial _{v_i}\nabla _\mathbf{w}h, \partial _{v_i}\nabla _{\mathbf {v}}h]^\top P [ \partial _{v_i}\nabla _\mathbf{w}h, \partial _{v_i}\nabla _{\mathbf {v}}h] \mathrm{{d}}\mathbf{w}d{\mathbf {v}}\ge 0, \end{aligned}$$

which together with (28) complete the proof. \(\square\)

Lemma 7

For PK defined in (10) and (24), respectively, there exists \(\mu\), C, and \(\hat{C}\)such that

$$\begin{aligned} K \ge 2 \mu P\ge 0, \end{aligned}$$

where value of \(\mu\), C, \(\hat{C}\)can be quantifies as follows:

$$\begin{aligned} \left\{ \begin{aligned}&\text {when } \frac{1 - \xi }{\sqrt{\gamma }} < 2C_L: \mu \equiv \frac{1 - \xi }{\sqrt{\gamma }}, \ C \equiv C_L^2,\ \hat{C}\equiv \frac{1 - \xi }{2\sqrt{\gamma }};\\&\text {when } \frac{1 - \xi }{\sqrt{\gamma }} \ge 2C_L: \mu \equiv \frac{1 - \xi }{\sqrt{\gamma }} - \sqrt{ \frac{(1 - \xi )^2}{\gamma } - 4C_L^2}, \ C \equiv \frac{(1 - \xi )^2}{2\gamma }-C_L^2, \ \hat{C}\equiv \frac{1 - \xi }{2\sqrt{\gamma }}. \end{aligned} \right. \end{aligned}$$

This lemma can be verified by direct calculations and we omit the details. We now go back to the proof of Theorem 3.

Proof of Theorem 3

By Lemmas 6, 7, and Assumption (A.4), we obtain

$$\begin{aligned} \begin{aligned}&\frac{1}{2}\frac{\mathrm{{d}}}{\mathrm{{d}}t}H(t) + \mu H(t) \le \frac{1+\sqrt{2}}{2}b(\left\Vert \nabla _\mathbf{w}h \right\Vert ^2_*+ \left\Vert \nabla _{\mathbf {v}}h \right\Vert ^2_*)\\ \end{aligned} \end{aligned}$$

Let \(\lambda _{\min }\) be the smallest eigenvalue of the positive definite matrix P, we have

$$\begin{aligned} \begin{aligned}&\lambda _{\min } (\left\Vert \nabla _\mathbf{w}h \right\Vert ^2_*+ \left\Vert \nabla _{\mathbf {v}}h \right\Vert ^2_*) \le H(t), \end{aligned} \end{aligned}$$
(29)

which implies

$$\begin{aligned} \begin{aligned}&\frac{1}{2}\frac{\mathrm{{d}}}{\mathrm{{d}}t}H(t) +( \mu - \hat{\mu }) H(t) \le 0,\\ \end{aligned} \end{aligned}$$

where \(\displaystyle \hat{\mu }= \frac{1+\sqrt{2}}{2}\frac{b}{\lambda _{\min }}\). Solving the above inequality yields,

$$\begin{aligned} \begin{aligned}&H(t)\le \mathrm{{e}}^{-2(\mu - \hat{\mu })t} H(0).\\ \end{aligned} \end{aligned}$$

Inserting this inequality to (29) gives

$$\begin{aligned} \left\Vert \nabla _\mathbf{w}h \right\Vert ^2_*+ \left\Vert \nabla _{\mathbf {v}}h \right\Vert ^2_*\le \frac{1}{\lambda _{\min }}\mathrm{{e}}^{-2(\mu - \hat{\mu })t} H(0). \end{aligned}$$
(30)

Besides, the Poincaré inequality w.r.t. the measure \(\psi _\infty (\mathbf{w},{\mathbf {v}})\) is

$$\begin{aligned} \begin{aligned}&\left\Vert \nabla _\mathbf{w}h \right\Vert ^2_*+ \left\Vert \nabla _{\mathbf {v}}h \right\Vert ^2_*\ge \frac{2M(1 - \xi )}{\gamma \beta }\min \{ C_P, d\} \left\Vert h \right\Vert ^2_*. \end{aligned} \end{aligned}$$

Inserting it back to (30) leads to,

$$\begin{aligned} \begin{aligned}&\left\Vert h \right\Vert ^2_*\le \frac{\gamma \beta }{2M(1 - \xi )\min \{ C_P, d\}} \frac{1}{\lambda _{\min }}\mathrm{{e}}^{-2(\mu - \hat{\mu })t} H(0) \\ \end{aligned} \end{aligned}$$

\(\square\)

Appendix D: Networks and Dataset Used in Sect. 5.1

The N1 network is a shallow convolutional network, which is a modified AlexNet configuration (Krizhevsky et al. [19]). Let \(n\times [a,b,c,d]\) denote a stack of n convolution layers of a filters and a Kernel size of \(b\times c\) with stride length of d. Then, N1 network uses two sets of [65, 5, 5, 2]–MaxPool(3) and two dense layers of sizes (384, 192), and finally an output layer of size 10. We use ReLU activations.

The N2 network is a deep convolutional network, which is a modified VGG configuration (Simonyan and Zisserman [27]). The N2 network uses the configuration: \(2\times [64, 3, 3, 1]\), \(2\times [128, 3, 3, 1]\), \(3\times [256, 3, 3, 1]\), \(3\times [512, 3, 3, 1]\), \(3\times [512, 3, 3, 1]\) and a MaxPool(2) after each stack. This stack is followed by a 512-dimensional dense layer and finally, a ten-dimensional output layer. We use ReLU activations.

The MNIST dataset (LeCun et al. [20]) contains 60,000 training images and 10,000 testing images, where each image is black and white and normalized to fit into a \(28\times 28\) pixel bounding box and it belongs to one of total ten classes of handwritten digits (i.e., \(0,1,2,\ldots ,10\)).

The CIFAR-10 dataset consists of 50,000 training data and 10,000 testing data, where each data is a color image with \(32 {\times } 32\) features and it belongs to one of total ten classes representing airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks.

Rights and permissions

Reprints and Permissions

About this article

Verify currency and authenticity via CrossMark

Cite this article

Dai, X., Zhu, Y. On Large Batch Training and Sharp Minima: A Fokker–Planck Perspective. J Stat Theory Pract 14, 53 (2020). https://doi.org/10.1007/s42519-020-00120-9

Download citation

Keywords

  • Large batch training
  • Sharp minima
  • Fokker–Planck equation
  • Stochastic gradient algorithm
  • Deep neural network

Mathematics Subject Classification

  • 90C15
  • 35Q62
  • 65K05