Skip to main content

Stabilize deep ResNet with a sharp scaling factor \(\tau\)

Abstract

We study the stability and convergence of training deep ResNets with gradient descent. Specifically, we show that the parametric branch in the residual block should be scaled down by a factor \(\tau =O(1/\sqrt{L})\) to guarantee stable forward/backward process, where L is the number of residual blocks. Moreover, we establish a converse result that the forward process is unbounded when \(\tau >L^{-\frac{1}{2}+c}\), for any positive constant c. The above two results together establish a sharp value of the scaling factor in determining the stability of deep ResNet. Based on the stability result, we further show that gradient descent finds the global minima if the ResNet is properly over-parameterized, which significantly improves over the previous work with a much larger range of \(\tau\) that admits global convergence. Moreover, we show that the convergence rate is independent of the depth, theoretically justifying the advantage of ResNet over vanilla feedforward network. Empirically, with such a factor \(\tau\), one can train deep ResNet without normalization layer. Moreover for ResNets with normalization layer, adding such a factor \(\tau\) also stabilizes the training and obtains significant performance gain for deep ResNet.

This is a preview of subscription content, access via your institution.

Fig. 1
Fig. 2
Fig. 3
Fig. 4
Fig. 5

Notes

  1. We use \(\tilde{O}(\cdot )\) to hide logarithmic factors.

  2. In (He et al., 2016), there is a ReLU after the building block \(y=x+F(x)\) (please refer to Figure 2 in He et al. (2016)), and hence a whole residual block is \(h_l = \phi (h_{l-1} + F(h_{l-1}))\) (if using the notations in our paper).

  3. Throughout the paper, the naming rule of ResNet is as follows.“ResNet" is referred to the model defined in Sect. 2, “ResNet#" is referred to the models in He et al. (2016) with removing all the BN layers, e.g., ResNet1202, “ResNet#\(+\)BN" corresponds to the original model in He et al. (2016), “\(+\)Fixup" corresponds to initializing the model with Fixup, and “\(+\tau\)" is referred to adding \(\tau\) on the output of the parametric branch in each residual block.

  4. GD exhibits the same phenomenon. We use SGD due to the expensive per-iteration cost of GD.

References

  • Allen-Zhu, Z., & Li, Y. (2019). What can ResNet learn efficiently, going beyond kernels? Advances in Neural Information Processing Systems.

  • Allen-Zhu, Z., Li, Y., & Song, Z. (2018). A convergence theory for deep learning via over-parameterization. arXiv preprint arXiv:1811.03962.

  • Allen-Zhu, Z., Li, Y., & Liang, Y. (2019a). Learning and generalization in overparameterized neural networks, going beyond two layers. Advances in Neural Information Processing Systems, pp.6155–6166.

  • Allen-Zhu, Z., Li, Y., & Song, Z. (2019b). On the convergence rate of training recurrent neural networks. Advances in Neural Information Processing Systems.

  • Arora, S., Du, S. S., Hu, W., Li, Z., Salakhutdinov, R., & Wang, R. (2019a). On exact computation with an infinitely wide neural net. Advances in Neural Information Processing Systems.

  • Arora, S., Du, S. S., Hu, W., Li, Z., & Wang, R. (2019b). Fine-grained analysis of optimization and generalization for overparameterized two-layer neural networks. International Conference on Machine Learning (ICML).

  • Arpit, D., Campos, V., & Bengio, Y. (2019). How to initialize your network? robust initialization for weightnorm & resnets. Advances in Neural Information Processing Systems.

  • Balduzzi, D., Frean, M., Leary, L., Lewis, J. P., Wan-Duo Ma, K., & McWilliams, B. (2017). The shattered gradients problem: If resnets are the answer, then what is the question? In International Conference on Machine Learning (ICML), pp. 342–350.

  • Brutzkus, A., Globerson, A., Malach, E., & Shalev-Shwartz, S. (2018). SGD learns over-parameterized networks that provably generalize on linearly separable data. In Proceedings of the 6th international conference on learning representations (ICLR 2018).

  • Cao, Y., & Gu, Q. (2019). A generalization theory of gradient descent for learning over-parameterized deep ReLU networks. arXiv preprint arXiv:1902.01384.

  • Cao, Y., & Gu, Q. (2020). Generalization bounds of stochastic gradient descent for wide and deep neural networks. Advances in Neural Information Processing Systems (NeurIPS).

  • Chen, Z., Cao, Y., Zou, D., & Gu, Q. (2021). How much over-parameterization is sufficient to learn deep ReLU networks? In Proceedings of the international conference on learning representations (ICLR 2021).

  • Chizat, L., & Bach, F. (2018a). On the global convergence of gradient descent for over-parameterized models using optimal transport. Advances in Neural Information Processing Systems 31.

  • Chizat, L., & Bach, F. (2018b). A note on lazy training in supervised differentiable programming. arXiv preprint arXiv:1812.07956, 8.

  • Chizat, L., Oyallon, E., & Bach, F. (2019). On lazy training in differentiable programming. Advances in Neural Information Processing Systems.

  • Devlin, J., Chang, M.-W., Lee, K., & Toutanova, K. (2019). Bert: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of NAACL-HLT.

  • Du, S. S., Lee, J. D., Li, H., Wang, L., & Zhai, X. (2019a). Gradient descent finds global minima of deep neural networks. In: International Conference on Machine Learning (ICML).

  • Du, S. S., Zhai, X., Poczos, B., & Singh, A. (2019b). Gradient descent provably optimizes over-parameterized neural networks. In: International Conference on Learning Representations (ICLR).

  • Fang, C., Dong, H., & Zhang, T. (2019a). Over parameterized two-level neural networks can learn near optimal feature representations. arXiv preprint arXiv:1910.11508.

  • Fang, C., Gu, Y., Zhang, W., & Zhang, T. (2019b). Convex formulation of overparameterized deep neural networks. arXiv preprint arXiv:1911.07626.

  • Frei, S., Cao, Y., & Gu, Q. (2019). Algorithm-dependent generalization bounds for overparameterized deep residual networks. Advances in Neural Information Processing Systems, pages 14769–14779.

  • Ghorbani, B., Mei, S., Misiakiewicz, T., Montanari, A. (2019). Limitations of lazy training of two-layers neural networks. Advances in Neural Information Processing Systems.

  • Haber, E., & Ruthotto, L. (2017). Stable architectures for deep neural networks. Inverse Problems, 34(1), 014004.

    Article  MathSciNet  Google Scholar 

  • Hardt, M., & Ma, T. (2016). Identity matters in deep learning. In: International Conference on Learning Representations (ICLR).

  • He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In: The IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

  • Ioffe, S., & Szegedy, C. (2015). Batch normalization: accelerating deep network training by reducing internal covariate shift. In: International Conference on Machine Learning (ICML), pp. 448–456.

  • Jacot, A., Gabriel, F, & Hongler, C. (2018). Neural tangent kernel: Convergence and generalization in neural networks. Advances in Neural Information Processing Systems, pp. 8571–8580.

  • Ji, Z., & Telgarsky, M. (2020). Polylogarithmic width suffices for gradient descent to achieve arbitrarily small test error with shallow ReLU networks. In Proceedings of the international conference on learning representations (ICLR 2020).

  • Krizhevsky, A., & Hinton, G. (2009). Learning multiple layers of features from tiny images.

  • Laurent, B., & Massart, P. (2000). Adaptive estimation of a quadratic functional by model selection. Annals of Statistics, pp. 1302–1338.

  • LeCun, Y., Bottou, L., Bengio, Y., & Haffner, P. (1998). Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11), 2278–2324.

    Article  Google Scholar 

  • Li, Y., & Liang, Y. (2018). Learning overparameterized neural networks via stochastic gradient descent on structured data. Advances in Neural Information Processing Systems, pp. 8168–8177.

  • Mei, S., Montanari, A., & Nguyen, P.-M. (2018). A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences, 115(33), E7665–E7671.

    Article  MathSciNet  Google Scholar 

  • Mei, S., Misiakiewicz, T., & Montanari, A. (2019). Mean-field theory of two-layers neural networks: dimension-free bounds and kernel limit. In Proceedings of the thirty-second conference on learning theory (pp. 2388–2464).

  • Neyshabur, B., Li, Z., Bhojanapalli, S., LeCun, Y., & Srebro, N. (2019). The role of over-parametrization in generalization of neural networks. In: International Conference on Learning Representations (ICLR).

  • Nguyen. P.-M. (2019). Mean field limit of the learning dynamics of multilayer neural networks. arXiv preprint arXiv:1902.02880.

  • Orhan, A. E., & Pitkow, X. (2018). Skip connections eliminate singularities. In: International Conference on Learning Representations (ICLR).

  • Oymak, S., & Soltanolkotabi, M. (2019). Overparameterized nonlinear learning: Gradient descent takes the shortest path? In: International Conference on Machine Learning (ICML).

  • Spielman, D. A., & Teng, S-H. (2004). Smoothed analysis of algorithms: Why the simplex algorithm usually takes polynomial time. Journal of the ACM (JACM), 51(3):385–463.

  • Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. Advances in Neural Information Processing Systems.

  • Veit, A., Wilber, M. J., & Belongie, S. (2016). Residual networks behave like ensembles of relatively shallow networks. Advances in Neural Information Processing Systems, pp 550–558.

  • Vershynin, R. (2012). Introduction to the non-asymptotic analysis of random matrices (pp. 210–268). Theory and Applications: Compressed Sensing.

    Google Scholar 

  • Yang, G., and Schoenholz, S. (2017). Mean field residual networks: On the edge of chaos. Advances in Neural Information Processing Systems, pp 7103–7114.

  • Yehudai, G., & Shamir, O. (2019). On the power and limitations of random features for understanding neural networks. Advances in Neural Information Processing Systems.

  • Zhang, H., Dauphin, Y. N., & Ma, T. (2019a). Fixup initialization: Residual learning without normalization. In: International Conference on Learning Representations (ICLR).

  • Zhang, H., Chen, W., & Liu, T.-Y. (2018). On the local hessian in back-propagation. In Advances in Neural Information Processing Systems, pp. 6521–6531.

  • Zhang, J., Han, B., Wynter, L., Low, K. H., & Kankanhalli, M. (2019b). Towards robust resnet: A small step but a giant leap. In: International Joint Conferences on Artificial Intelligence (IJCAI).

  • Zou, D., & Gu, Q. (2019). An improved analysis of training over-parameterized deep neural networks. Advances in Neural Information Processing Systems.

  • Zou, D., Cao, Y., Zhou, D., & Gu, Q. (2020). Stochastic gradient descent optimizes over-parameterized deep ReLU networks. Machine Learning, 109(3), 467–492.

Download references

Author information

Authors and Affiliations

Authors

Corresponding authors

Correspondence to Huishuai Zhang or Wei Chen.

Additional information

Editor: Paolo Frasconi.

Publisher's Note

Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Appendices

A Useful Lemmas

First we list several useful bounds on Gaussian distribution.

Lemma 1

Suppose \(X\sim \mathcal {N}(0,\sigma ^{2})\), then

$$\begin{aligned} \begin{aligned}&\mathbb {P}\{|X|\le x\} \ge 1-\exp \left( -\frac{x^{2}}{2\sigma ^{2}}\right) ,\\&\mathbb {P}\{|X|\le x\} \le \sqrt{\frac{2}{\pi }}\frac{x}{\sigma }. \end{aligned} \end{aligned}$$
(6)

Another bound is on the spectral norm of random matrix ((Vershynin, 2012), Corollary 5.35).

Lemma 2

Let \(\varvec{A}\in \mathbb {R}^{N\times n}\), and entries of \(\varvec{A}\) are independent standard Gaussian random variables. Then for every \(t\ge 0\), with probability at least \(1-\exp (-t^{2}/2)\) one has

$$\begin{aligned} s_{\max }(\varvec{A})\le \sqrt{N}+\sqrt{n}+t, \end{aligned}$$
(7)

where \(s_{\max }(\varvec{A})\) are the largest singular value of \(\varvec{A}\).

B Spectral norm bound at initialization

Next we present a spectral norm bound related to the forward process of ResNet with \(\tau\).

Proof

Without introducing ambiguity, we drop the superscript \(^{(0)}\) for notation simplicity. We first build the claim for one fixed sample \(i\in [n]\) and drop the subscript i,  for convenience. Let \(g_l = h_{l-1}+\tau {\varvec{W}}_lh_{l-1}\) and \(h_l = {\varvec{D}}_l g_{l}\) for \(l=\{a,..., b\}\). We will show for a vector \(h_{a-1}\) with \(\Vert h_{a-1}\Vert =1\), we have \(\Vert h_b\Vert \le 1+c\) with high probability, where

$$\begin{aligned} h_b = {\varvec{D}}_b (\varvec{I}+\tau {\varvec{W}}_{b}){\varvec{D}}_{b-1}\cdots {\varvec{D}}_{a}(\varvec{I}+\tau {\varvec{W}}_{a}) h_{a-1}. \end{aligned}$$
(8)

Then we have \(\Vert g_l\Vert \ge \Vert h_l\Vert\) due to the assumption \(\Vert {\varvec{D}}_l\Vert \le 1\). Hence we have

$$\begin{aligned} \Vert h_{b}\Vert ^2 = \frac{\Vert h_{b}\Vert ^2}{\Vert h_{b-1}\Vert ^2} \cdots \frac{\Vert h_{a}\Vert ^2}{\Vert h_{a-1}\Vert ^2}\Vert h_{a-1}\Vert ^2\le \frac{\Vert g_{b}\Vert ^2}{\Vert h_{b-1}\Vert ^2} \cdots \frac{\Vert g_{a}\Vert ^2}{\Vert h_{a-1}\Vert ^2}\Vert h_{a-1}\Vert ^2. \end{aligned}$$

Taking logarithm at both side, we have

$$\begin{aligned} \log {\Vert h_{b}\Vert ^2}\le \sum _{l=a}^{b}\log \Delta _{l},\quad \quad \text {where } \Delta _{l} := \frac{\Vert g_{l}\Vert ^2}{\Vert h_{l-1}\Vert ^2}. \end{aligned}$$
(9)

If letting \(\tilde{h}_{l-1} := \frac{h_{l-1}}{\Vert h_{l-1}\Vert }\), then we obtain that

$$\begin{aligned} \begin{aligned} \log {\Delta _{l}}&= \log \left( 1 + 2\tau \left\langle \tilde{h}_{l-1},{\varvec{W}}_{l}\tilde{h}_{l-1} \right\rangle + \tau ^{2}\Vert {\varvec{W}}_{l}\tilde{h}_{l-1}\Vert ^{2}\right) \\&\le 2\tau \left\langle \tilde{h}_{l-1},{\varvec{W}}_{l}\tilde{h}_{l-1} \right\rangle + \tau ^{2}\Vert {\varvec{W}}_{l}\tilde{h}_{l-1}\Vert ^{2}, \end{aligned} \end{aligned}$$

where the inequality is due to the fact \(\log (1+x) \le x\) for all \(x>-1\). Let \(\xi _{l} := 2\tau \left\langle \tilde{h}_{l-1},{\varvec{W}}_{l}\tilde{h}_{l-1} \right\rangle\) and \(\zeta _{l}:= \tau ^{2}\Vert {\varvec{W}}^{(0)}_{l}\tilde{h}_{l-1}\Vert ^{2}\), then given \(h_{l-1}\) we have \(\xi _{l}\sim \mathcal {N}\left( 0, \frac{4\tau ^2}{m}\right)\), \(\zeta _{l}\sim \frac{\tau ^{2}}{m}\chi _{m}^2\) because of the random initialization of \({\varvec{W}}_l\). We see that

$$\begin{aligned} \begin{aligned} \mathbb {P}\left( \sum \limits _{l=a}^{b}\log \Delta _{l}\ge c_1\right)&\le \mathbb {P}\left( \sum \limits _{l=a}^{b}\xi _{l}\ge \frac{c_1}{2}\right) + \mathbb {P}\left( \sum \limits _{l=a}^{b}\zeta _{l}\ge \frac{c_1}{2}\right) . \end{aligned} \end{aligned}$$
(10)

Next we bound the two terms on the right hand side one by one. For the first term we have

$$\begin{aligned} \begin{aligned} \mathbb {P}\left( \sum \limits _{l=a}^{b}\xi _{l}\ge \frac{c_1}{2}\right) =\mathbb {P}\left( \exp \left( \lambda \sum \limits _{l=a}^{b}\xi _{l}\right) \ge \exp \left( \frac{\lambda c_1}{2}\right) \right) \le \mathbb {E}\left[ \exp \left( \lambda \sum \limits _{l=a}^{b}\xi _{l} - \frac{\lambda c_1}{2}\right) \right] , \end{aligned} \end{aligned}$$
(11)

where \(\lambda\) is any positive number and the last inequality uses the Markov’s inequality. Moreover,

$$\begin{aligned} \begin{aligned} \mathbb {E}\left[ \exp \left( \lambda \sum \limits _{l=a}^{b}\xi _{l}\right) \right]&= \mathbb {E}\left[ \exp \left( \lambda \sum \limits _{l=a}^{b-1}\xi _{l}\right) \mathbb {E}\left[ \exp \left( \lambda \xi _{b}\right) \right] \Big | \mathcal {F}_{b-1}\right] \\&= \exp \left( \frac{4\tau ^{2}\lambda ^{2}}{m}\right) \mathbb {E}\left[ \exp \left( \lambda \sum \limits _{l=a}^{b-1}\xi _{l}\right) \right] \\&=\cdots =\exp \left( \frac{4\tau ^{2}\lambda ^{2}(b - a + 1)}{m}\right) . \end{aligned} \end{aligned}$$
(12)

Hence we obtain

$$\begin{aligned} \mathbb {P}\left( \sum \limits _{l=a}^{b}\xi _{l}\ge \frac{c_1}{2}\right) \le \exp \left( \frac{4m^2c_1^2\tau ^{2}(b - a + 1)}{256m\tau ^4L^2} -\frac{mc_1^{2}}{32\tau ^{2}L}\right) = \exp \left( -\frac{mc_1^{2}}{64\tau ^{2}L}\right) , \end{aligned}$$
(13)

by choosing \(\lambda = \frac{mc_1}{16\tau ^{2}L}\) and using \(b-a+1 \le L\). Due to the symmetry of \(\sum _{l=a}^{b}\xi _{l}\), the conclusion can be generalized to the quantity \(|\sum _{l=a}^{b}\xi _{l}|\) that \(\mathbb {P}\left( \left| \sum \limits _{l=a}^{b}\xi _{l}\right| \ge \frac{c_1}{2}\right) \le 2\exp \left( -\frac{mc_1^{2}}{64\tau ^{2}L}\right)\).

Then, for the second term, we follow the above procedure but for a \(\chi _m^2\) variable. We note that the generate moment function of \(\chi _{m}^{2}\) is \((1-2t)^{-m/2}\) for \(t<1/2\). We will use an inequality that \((1-\frac{x}{m})^{-m}\le e^{x}\) for \(x\ge 0\). By using the Markov’s inequality, we first have for any \(\lambda >0\),

$$\begin{aligned} \begin{aligned} \mathbb {P}\left( \sum \limits _{l=a}^{b}\zeta _{l}\ge \frac{c_1}{2}\right) =\mathbb {P}\left( \exp \left( \lambda \sum \limits _{l=a}^{b}\zeta _{l}\right) \ge \exp \left( \frac{\lambda c_1}{2}\right) \right) \le \mathbb {E}\left[ \exp \left( \lambda \sum \limits _{l=a}^{b}\zeta _{l} - \frac{\lambda c_1}{2}\right) \right] . \end{aligned} \end{aligned}$$
(14)

Then we have

$$\begin{aligned} \begin{aligned} \mathbb {E}\left[ \exp \left( \lambda \sum \limits _{l=a}^{b}\zeta _{l}\right) \right]&= \mathbb {E}\left[ \exp \left( \lambda \sum \limits _{l=a}^{b-1}\zeta _{l}\right) \mathbb {E}\left[ \exp \left( \lambda \zeta _{b}\right) \right] \Big | \mathcal {F}_{b-1}\right] \\&=\left( 1-\frac{\lambda \tau ^2}{m/2}\right) ^{-m/2}\mathbb {E}\left[ \exp \left( \lambda \sum \limits _{l=a}^{b-1}\zeta _{l}\right) \right] \\&\le \exp (\lambda \tau ^2) \mathbb {E}\left[ \exp \left( \lambda \sum \limits _{l=a}^{b-1}\zeta _{l}\right) \right] \\&\le \cdots \le \exp \left( \lambda \tau ^2(b - a + 1)\right) . \end{aligned} \end{aligned}$$
(15)

Hence we obtain

$$\begin{aligned} \begin{aligned} \mathbb {P}\left( \sum \limits _{l=a}^{b}\zeta _{l}\ge \frac{c_1}{2}\right) \le \exp \left( \lambda \tau ^2(b - a + 1)- \frac{\lambda c_1}{2}\right) \le \exp \left( -\frac{mc_1^2}{2\tau ^2L}\left( 1-\frac{2\tau ^2L}{c_1}\right) \right) , \end{aligned} \end{aligned}$$
(16)

by choosing \(\lambda = \frac{mc_1}{\tau ^2L}\) and using \(b-a+1 \le L\). If further setting \(\tau\) such that \(\tau ^2 L\le \frac{c_1}{4}\), we have

$$\begin{aligned} \mathbb {P}\left( \sum \limits _{l=a}^{b}\zeta _{l}\ge \frac{c_1}{2}\right) \le \exp \left( -\frac{mc_1^2}{4\tau ^2L}\right) . \end{aligned}$$
(17)

Combining (13) and (17), we obtain \(\mathbb {P}\left( \sum \limits _{l=a}^{b}\log \Delta _{l}\ge c_1\right) \le 3\exp \left( -\frac{mc_1^2}{64\tau ^2L}\right)\) under the condition \(\tau ^2 L\le \frac{c_1}{4}\). Hence we have \(\mathbb {P}\left( \Vert h_b\Vert \ge 1+c\right) \le \mathbb {P}\left( \sum \limits _{l=a}^{b}\log \Delta _{l}\ge 2\log (1+c)\right) \le 3\exp \left( -\frac{m\log ^2(1+c)}{16\tau ^2L}\right)\) under the condition that \(\tau ^2 L\le \frac{1}{2} \log (1+c)\). We next use \(\epsilon\)-net argument to prove the claim for all m-dimensional vectors of \(h_{a-1}\). Let \(\mathcal {N}_{\epsilon }\) be an \(\varepsilon\)-net over the unit ball in \(\mathbb {R}^m\) with \(\epsilon < 1\), then we have the cardinality \(|\mathcal {N}_{\epsilon }|\le (1+2/\epsilon )^m\). Taking the union bound over all vectors \(h_{a-1}\) in the net \(\mathcal {N}_{\epsilon }\), we obtain

$$\begin{aligned} \mathbb {P}\left\{ \max _{h_{a-1}\in \mathcal {N}_{\epsilon }}\Vert h_b\Vert > 1+c \right\}&\le (1+2/\epsilon )^m \cdot 3\exp \left( -\frac{m\log ^2(1+c)}{16\tau ^2L}\right) \\&= 3\exp \left( -m \left( \frac{\log ^2(1+c)}{16\tau ^2L} - \log (1+2/\epsilon )\right) \right) \le 3\exp \left( -m\right) , \end{aligned}$$

where the last equality is obtained by choosing \(\tau\) appropriately to make \(\frac{\log ^2 (1+c)}{16\tau ^2L} -\log (1+2/\epsilon )>1\). Then we have the spectral norm bound

$$\begin{aligned} \left\| {\varvec{D}}_b\left( \varvec{I}+\tau {\varvec{W}}_{b}^{(0)}\right) {\varvec{D}}_{b-1}\cdots {\varvec{D}}_{a}\left( \varvec{I}+\tau {\varvec{W}}_{a}^{(0)}\right) \right\| \le (1-\epsilon )^{-1} \max _{h_{a-1}\in \mathcal {N}_\epsilon } \Vert h_b\Vert . \end{aligned}$$

This is because of the following argument. For a matrix \(\varvec{M}\), \(v_i\) is a vector in the net which is closest to a unit vector v, then \(\Vert \varvec{M}v\Vert \le \Vert \varvec{M}v_i \Vert + \Vert \varvec{M}(v-v_i)\Vert \le \Vert \varvec{M}v_i \Vert + \epsilon \Vert \varvec{M}\Vert\), and hence taking the supremum over v, one obtains \((1-\epsilon ) \Vert \varvec{M}\Vert \le \max _i \Vert \varvec{M}v_i\Vert\).

Finally taking a union bound over a and b with \(1\le a\le b <L\) and a union bound over all samples \(i\in [n]\), we have the claimed result. \(\square\)

C Bounded forward/backward process

1.1 C.1 Proof at initialization

Proof

We ignore the subscript \(^{(0)}\) for simplicity. First we have

$$\begin{aligned} \Vert h_{i,l}\Vert = \Vert h_{i,0}\Vert \frac{\Vert h_{i,1}\Vert }{\Vert h_{i,0}\Vert }\cdots \frac{\Vert h_{i,l}\Vert }{\Vert h_{i,l-1}\Vert }. \end{aligned}$$
(18)

Then we see

$$\begin{aligned} \begin{aligned} \log {\Vert h_{i,l}\Vert ^{2}}&= \log {\Vert h_{i,0}\Vert ^{2}} + \sum \limits _{a = 1}^{l}\log {\frac{\Vert h_{i, a}\Vert ^{2}}{\Vert h_{i,a-1}\Vert ^{2}}}= \log {\Vert h_{i,0}\Vert ^{2}} + \sum \limits _{a = 1}^{l}\log \left( 1 + \frac{\Vert h_{i, a}\Vert ^{2} - \Vert h_{i,a-1}\Vert ^{2}}{\Vert h_{i,a-1}\Vert ^{2}}\right) . \end{aligned} \end{aligned}$$
(19)

We introduce notation \(\Delta _{a}:=\frac{\Vert h_{i,a}\Vert ^{2} - \Vert h_{i,a-1}\Vert ^{2}}{\Vert h_{i,a-1}\Vert ^{2}}\). We next give a lower bound on \(\Delta _{a}\). Let S be the set \(\{k: k\in [m] \text { and } (h_{i,a-1})_{k} +\tau ({\varvec{W}}_{a}h_{i,a-1})_{k} >0\}\). We have that

$$\begin{aligned} \Delta _{a}&= \frac{1}{\Vert h_{i,a-1}\Vert ^{2}}\sum \limits _{k\in S}\left[ (h_{i,a-1})_{k}^{2} + 2\tau (h_{i,a-1})_{k}({\varvec{W}}_{a}h_{i,a-1})_{k} + (\tau {\varvec{W}}_{a}h_{i,a-1})_{k}^{2}\right] - \frac{1}{\Vert h_{i,a-1}\Vert ^{2}}\sum \limits _{k=1}^{m}(h_{i,a-1})_{k}^{2} \nonumber \\&= -\frac{1}{\Vert h_{i,a-1}\Vert ^{2}}\sum \limits _{k\notin S}(h_{i,a-1})_{k}^{2} + \frac{1}{\Vert h_{i,a-1}\Vert ^{2}}\sum \limits _{k\in S}\tau ^{2}({\varvec{W}}_{a}h_{i,a-1})_{k}^{2} + \frac{2}{\Vert h_{i,a-1}\Vert ^{2}}\sum \limits _{k\in S}\tau (h_{i,a-1})_{k}({\varvec{W}}_{a}h_{i,a-1})_{k}\nonumber \\&\ge -\frac{1}{\Vert h_{i,a-1}\Vert ^{2}}\sum \limits _{k=1}^{m}(\tau {\varvec{W}}_{a}h_{i,a-1})^{2} + \frac{2}{\Vert h_{i,a-1}\Vert ^{2}}\tau \sum \limits _{k=1}^{m}(h_{i,a-1})_{k}({\varvec{W}}_{a}h_{i,a-1})_{k}\nonumber \\&= -\frac{\Vert \tau {\varvec{W}}_{a}h_{i,a-1}\Vert ^{2}}{\Vert h_{i,a-1}\Vert ^{2}} + \frac{2\tau \left\langle h_{i,a-1},{\varvec{W}}_{a}h_{i,a-1} \right\rangle }{\Vert h_{i,a-1}\Vert ^{2}}, \end{aligned}$$
(20)

where the inequality is due to the fact that for \(k\notin S\), \(|(h_{i,a-1})_{k}|<|(\tau {\varvec{W}}_{a}h_{i,a-1})_{k}|\) and \((h_{i,a-1})_{k}({\varvec{W}}_{a}h_{i,a-1})_{k}\le 0\). Let \(\xi _{a}:=\frac{2\tau \left\langle h_{i,a-1},{\varvec{W}}_{a}h_{i,a-1} \right\rangle }{\Vert h_{i,a-1}\Vert ^{2}}\) and \(\zeta _{a}:=\frac{\Vert \tau {\varvec{W}}_{a}h_{i,a-1}\Vert ^{2}}{\Vert h_{i,a-1}\Vert ^{2}}\), then \(\Delta _{a} \ge \xi _{a}- \zeta _{a}\). We note that given \(h_{i,a-1}\), \(\xi _{a}\sim \mathcal {N}\left( 0, \frac{4\tau ^2}{m}\right) \) and \(\zeta _{a}\sim \frac{\tau ^{2}}{m}\chi _{m}^2\). We use a tail bound for a \(\chi ^2_m\) variable X (see Lemma 1 in Laurent and Massart (2000))

$$\begin{aligned} \mathbb {P}\left( |X-m|\ge u\right) \le e^{-\frac{u^{2}}{4m}}. \end{aligned}$$
(21)

By applying the tail bound on Gaussian and Chi-square variables, for a constant \(c_0\) such that \(4\tau ^2\le c_0 \) we have

$$\begin{aligned} \begin{aligned} \mathbb {P}\left( \Delta _{a}< -c_0\right)&= \mathbb {P}\left( \Delta _{a}< -c_0\text { and } \xi _{a}< -\frac{c_0}{2}\right) + \mathbb {P}\left( \Delta _{a}< -c_0\text { and } \xi _{a} \ge -\frac{c_0}{2}\right) \\&\le \mathbb {P}\left( \xi _{a}< -\frac{c_0}{2}\right) + \mathbb {P}\left( \zeta _{a}>\frac{c_0}{2}\right) \\&=\frac{1}{2}\exp \left( -\frac{mc_0^2}{32\tau ^2}\right) + \exp \left( -\frac{mc_0^2}{16\tau ^4}\right) \\&<\exp \left( -\frac{mc_0^2}{32\tau ^2}\right) . \end{aligned} \end{aligned}$$
(22)

Thus, by choosing \(c_0 = 0.5\), we have \(\mathbb {P}\left( \Delta _a \ge -0.5, \forall a\in [L-1]\right) \ge 1- L\exp \left( -\frac{m}{128\tau ^2}\right) \). On the event \(\{\Delta _a \ge -0.5, \forall a \in [L-1]\}\), we can use the relation \(\log (1 + x)\ge x - x^{2}\) for \(x\ge -0.5\) and have

$$\begin{aligned} \begin{aligned} equation 19\ge \log {\Vert h_{i,0}\Vert ^{2}} + \sum \limits _{a = 1}^{l}\left( \Delta _{a} - \Delta _{a}^{2}\right) . \end{aligned} \end{aligned}$$
(23)

Due to (13) and (17), we have for any \(c_1>0\), and \(\tau ^2L\le c_1/4\),

$$\begin{aligned} \begin{aligned} \mathbb {P}\left( \sum \limits _{l=a}^{b}\xi _{l}\ge \frac{c_1}{2}\right) \le&\exp \left( -\frac{mc_1^{2}}{64\tau ^{2}L}\right) ,\; \mathbb {P}\left( \sum \limits _{l=a}^{b}\xi _{l}<- \frac{c_1}{2}\right) \le \exp \left( -\frac{mc_1^{2}}{64\tau ^{2}L}\right) ,\\ \mathbb {P}\left( \sum \limits _{l=a}^{b}\zeta _{l}\ge \frac{c_1}{2}\right) \le&\exp \left( -\frac{mc_1^2}{4\tau ^2L}\right) . \end{aligned} \end{aligned}$$
(24)

Thus we have for any \(c_1>0\), and \(\tau ^2L\le c_1/4\),

$$\begin{aligned} \begin{aligned} \mathbb {P}\left( \sum _{a=1}^{l}\Delta _{a}\le -c_1\right)&= \mathbb {P}\left( \sum _{a=1}^{l}\Delta _{a}\le -c_1, \sum _{a=1}^{l}\xi _{a}\ge -\frac{c_1}{2}\right) + \mathbb {P}\left( \sum _{a=1}^{l}\Delta _{a}\le -c_1, \sum _{a=1}^{l}\xi _{a}\le -\frac{c_1}{2}\right) \\&\le \mathbb {P}\left( \sum _{a=1}^{l}\zeta _{a}\ge \frac{c_1}{2}\right) + \mathbb {P}\left( \sum _{a=1}^{l}\xi _{a}\le -\frac{c_1}{2}\right) =2\exp \left( -\frac{mc_1^{2}}{64\tau ^{2}L}\right) . \end{aligned} \end{aligned}$$
(25)

We can derive a similar result that \(\mathbb {P}\left( \sum _{a=1}^{l}\Delta _{a}\ge c_1\right) \le \mathbb {P}\left( \sum _{a=1}^{l}\xi _{a}\ge c_1\right) \le \exp \left( -\frac{mc_1^{2}}{16\tau ^{2}L}\right) \). Let \(a = b\) in (24), we have obtained that for a single \(\Delta _{a}\), for a constant \(c_1\) such that \(4\tau ^2\le c_1 \),

$$\begin{aligned} \mathbb {P}\left( \left| \Delta _{a}\right| \ge c_1\right) \le 2\exp \left( -\frac{mc_1^{2}}{32\tau ^2}\right) . \end{aligned}$$
(26)

In addition, we see that for any \(16\tau ^4L\le c_1 \)

$$\begin{aligned} \mathbb {P}\left( \sum \limits _{a=1}^{l}\Delta ^{2}_{a}\ge c_1\right) \le \sum \limits _{a=1}^{l}\mathbb {P}\left( \Delta _{a}^{2}\ge \frac{c_1}{l}\right) = \sum \limits _{a=1}^{l}\mathbb {P}\left( |\Delta _{a}|\ge \sqrt{\frac{c_1}{l}}\right) \le 2l \exp \left( -\frac{mc_1}{32}\right) . \end{aligned}$$
(27)

Thus, similar to the (25), we obtain for any \(c_1>0\) and \(8\tau ^2L< c_1\),

$$\begin{aligned} \begin{aligned} \mathbb {P}\left( \sum \limits _{a=1}^{l}\left( \Delta _{a} - \Delta ^{2}_{a}\right) \le -c_1\right)&\le \mathbb {P}\left( \sum _{a=1}^{l}\Delta _{a}\le -\frac{c_1}{2}\right) + \mathbb {P}\left( \sum \limits _{a=1}^{l}\Delta ^{2}_{a}\ge \frac{c_1}{2}\right) \\&\le 2\exp \left( -\frac{mc_1^{2}}{256\tau ^{2}L}\right) + 2(L-1)\exp \left( -\frac{mc_1}{64}\right) \\&\le 2L\exp \left( -\frac{mc_1}{64}\right) \end{aligned} \end{aligned}$$
(28)

Thus on the event of \(\{\Delta _a \ge -0.5, \forall a\in [L-1]\}\), we have for any \(c_1>0\) and \(8\tau ^2L< c_1\),

$$\begin{aligned} \mathbb {P}\left( \log {\Vert h_{i,l}\Vert ^{2}} \le -c_1 \right) \le \mathbb {P}\left( \log {\Vert h_{i,0}\Vert ^{2}} + \sum \limits _{a = 1}^{l}\left( \Delta _{a} - \Delta _{a}^{2}\right) \le -c_1\right) \le 2L\exp \left( -\frac{mc_1}{64}\right) . \end{aligned}$$
(29)

Then we get the conclusion \(\mathbb {P}\left( \Vert h_{i,l}\Vert < 1-c\right) = \mathbb {P}\left( \log {\Vert h_{i,l}\Vert ^{2}} \le -2\log (1-c)^{-1}\right) \le 2L\exp \left( -\frac{1}{32}m\log (1-c)^{-1}\right) \). Taking union bound over \(i\in [n]\) and \(l\in [L-1]\), we get the claimed result with probability \(1-2nL^2\exp \left( -\frac{1}{32}m\log (1-c)^{-1}\right) \) under the condition \(\tau ^2L \le \frac{1}{4}\log (1-c)^{-1}\). \(\square \)

1.2 C.2 Lemmas and proofs after perturbation

We use \(\overrightarrow{{\varvec{W}}}^{(0)}\) to denote the weight matrices at initialization and use \(\overrightarrow{{\varvec{W}}}'\) to denote the perturbation matrices. Let \(\overrightarrow{{\varvec{W}}} = \overrightarrow{{\varvec{W}}}^{(0)} + \overrightarrow{{\varvec{W}}}'\). We define \(h_{i,l}^{(0)} = \phi ((\varvec{I}+\tau {\varvec{W}}_l^{(0)})h_{i,l-1}^{(0)})\) and \(h_{i,l} = \phi ((\varvec{I}+\tau {\varvec{W}}_l)h_{i,l-1})\) for \(l\in [L-1]\), and \(h_{i,L}^{(0)} = \phi ({\varvec{W}}_L^{(0)}h_{i,L-1}^{(0)})\) and \(h_{i,L} = \phi ({\varvec{W}}_L h_{i,L-1})\). Furthermore, let \(h'_{i,l} := h_{i,l}- h_{i,l}^{(0)}\) and \({\varvec{D}}'_{i,l} := {\varvec{D}}_{i,l} - {\varvec{D}}_{i,l}^{(0)}\). We note that \(\Vert \cdot \Vert _0\) is the number of nonzero entries in \(\cdot\). In the sequel, we will use notation O and \(\Omega\) to simplify the presentation. Then the spectral norm bound after perturbation is as follows.

Lemma 3

Suppose that \(\overrightarrow{{\varvec{W}}}^{(0)}\), \(\varvec{A}\) are randomly generated as in the initialization step, and \({\varvec{W}}'_{1},\dots ,{\varvec{W}}'_{L-1}\in \mathbb {R}^{m\times m}\) are perturbation matrices with \(\Vert {\varvec{W}}'_l\Vert <\tau \omega\) for all \(l\in [L-1]\) for some \(\omega <1\). Suppose \({\varvec{D}}_{i,0},\dots ,{\varvec{D}}_{i,L}\) are diagonal matrices representing the activation status of sample i. If \(\tau ^2L \le O(1)\), then with probability at least \(1-3nL^2\cdot \exp (-\Omega (m))\) over the initialization randomness we have

$$\begin{aligned} \Vert (\varvec{I}+\tau {\varvec{W}}_{b}^{(0)}+\tau {\varvec{W}}'_{b}){\varvec{D}}_{i,b-1}\cdots {\varvec{D}}_{i,a}(\varvec{I}+\tau {\varvec{W}}_{a}^{(0)}+\tau {\varvec{W}}'_{a})\Vert \le O(1). \end{aligned}$$
(30)

Proof

This proof is similar to the proof of Theorem 1. We first build the claim for one fixed sample \(i\in [n]\) and drop the subscript i,  for convenience. We will show for a vector \(h_{a-1}\) with \(\Vert h_{a-1}\Vert =1\), we have \(\Vert h_b\Vert \le 1+c\) with high probability, where

$$\begin{aligned} h_b = {\varvec{D}}_b (\varvec{I}+\tau {\varvec{W}}^{(0)}_{b}+\tau {\varvec{W}}'_{b}){\varvec{D}}_{b-1}\cdots {\varvec{D}}_{a}(\varvec{I}+\tau {\varvec{W}}^{(0)}_{a}+\tau {\varvec{W}}'_{a}) h_{a-1}. \end{aligned}$$
(31)

Let \(g_l = h_{l-1}+\tau {\varvec{W}}^{(0)}_lh_{l-1} +\tau {\varvec{W}}'_lh_{l-1}\) and \(h_l = {\varvec{D}}_l g_{l}\) for \(l=\{a,..., b\}\). Then we have \(\Vert g_l\Vert \ge \Vert h_l\Vert\) due to the fact \(\Vert {\varvec{D}}_l\Vert \le 1\). Hence we have

$$\begin{aligned} \Vert h_{b}\Vert ^2 = \frac{\Vert h_{b}\Vert ^2}{\Vert h_{b-1}\Vert ^2} \cdots \frac{\Vert h_{a}\Vert ^2}{\Vert h_{a-1}\Vert ^2}\Vert h_{a-1}\Vert ^2\le \frac{\Vert g_{b}\Vert ^2}{\Vert h_{b-1}\Vert ^2} \cdots \frac{\Vert g_{a}\Vert ^2}{\Vert h_{a-1}\Vert ^2}\Vert h_{a-1}\Vert ^2. \end{aligned}$$

Taking logarithm at both side, we have

$$\begin{aligned} \log {\Vert h_{b}\Vert ^2}\le \sum _{l=a}^{b}\log \Delta _{l},\quad \quad \text {where } \Delta _{l} := \frac{\Vert g_{l}\Vert ^2}{\Vert h_{l-1}\Vert ^2}. \end{aligned}$$
(32)

If letting \(\tilde{h}_{l-1} := \frac{h_{l-1}}{\Vert h_{l-1}\Vert }\), then we obtain that

$$\begin{aligned} \begin{aligned} \log {\Delta _{l}}&= \log \left( 1 + 2\tau \left\langle \tilde{h}_{l-1},{\varvec{W}}^{(0)}_{l}\tilde{h}_{l-1} \right\rangle + \tau ^{2}\Vert {\varvec{W}}^{(0)}_{l}\tilde{h}_{l-1}\Vert ^{2} + 2\tau \left\langle (\varvec{I}+ \tau {\varvec{W}}^{(0)}_{l}) \tilde{h}_{l-1}, {\varvec{W}}'_{l}\tilde{h}_{l-1} \right\rangle + \tau ^{2}\Vert {\varvec{W}}'_{l}\tilde{h}_{l-1}\Vert ^{2}\right) \\&\le 2\tau \left\langle \tilde{h}_{l-1},{\varvec{W}}^{(0)}_{l}\tilde{h}_{l-1} \right\rangle + \tau ^{2}\Vert {\varvec{W}}^{(0)}_{l}\tilde{h}_{l-1}\Vert ^{2} + 2\tau \left\langle (\varvec{I}+ \tau {\varvec{W}}^{(0)}_{l}) \tilde{h}_{l-1}, {\varvec{W}}'_{l}\tilde{h}_{l-1} \right\rangle + \tau ^{2}\Vert {\varvec{W}}'_{l}\tilde{h}_{l-1}\Vert ^{2}, \end{aligned} \end{aligned}$$

where the inequality is due to the fact \(\log (1+x) \le x\) for all \(x>-1\). We can bound the sum over layers of the first two terms as in the proof of Theorem 1. Next we control the last two terms related with \({\varvec{W}}'_l\), on a high probability event \(\{\Vert {\varvec{W}}^{(0)}_l \Vert \le 4, \text { for all } l \in [L-1]\}\)

$$\begin{aligned}&\sum _{l=a}^{b} 2\tau \left\langle (\varvec{I}+ \tau {\varvec{W}}^{(0)}_{l}) \tilde{h}_{l-1}, {\varvec{W}}'_{l}\tilde{h}_{l-1} \right\rangle \le \sum _{l=a}^{b}2 \tau \Vert \varvec{I}+ \tau {\varvec{W}}^{(0)}_l \Vert \Vert {\varvec{W}}'_l\Vert \Vert \tilde{h}_{l-1}\Vert ^2\le \sum _{l=a}^{b}2 \tau ^2 \omega (1+4\tau ),\nonumber \\&\quad \sum _{l=a}^{b} \tau ^{2}\Vert {\varvec{W}}'_{l}\tilde{h}_{l-1}\Vert ^{2} \le \sum _{l=a}^{b}2 \tau ^4 \omega ^2. \nonumber \end{aligned}$$
(33)

Hence given \(\tau ^2 L \le c_1/4\) as in proof of Theorem 1 and \(\omega\) being a small constant, the above two sum are well controlled. We can obtain a spectral norm bound as claimed. Here the theorem is built for one \({\varvec{W}}'_l\). At the end of the whole proof, we will see the number of iterations is \(\Omega (n^2)\). If we take union bound over all the \({\varvec{W}}'_l\) s running into in the optimization trajectory, the overall probability is still as high as \(1 - \Omega (n^3 L^2)\exp (-\Omega (m))\). \(\square\)

We also have small changes on the output vector of each layer after perturbation.

Lemma 4

Suppose that \(\omega \le O(1)\) and \(\tau ^2L\le O(1)\). If \(\Vert {\varvec{W}}_{L}'\Vert \le \omega\) and \(\Vert {\varvec{W}}_{l}'\Vert \le \tau \omega\) for \(l\in [L-1]\), then with probability at least \(1-\exp (-\Omega (m\omega ^{\frac {2}{3}}))\), the following bounds on \(h'_{i,l}\) and \({\varvec{D}}'_{i,l}\) hold for all \(i\in [n]\) and all \(l\in [L-1]\),

$$\begin{aligned}&\Vert h'_{i,l}\Vert \le O(\tau ^2 L\omega ),\;\;\Vert {\varvec{D}}'_{i,l}\Vert _{0}\le O\left( m(\omega \tau {L})^{\frac {2}{3}}\right) ,\;\; \Vert h'_{i,L}\Vert \le O(\omega ),\; \; \Vert {\varvec{D}}'_{i,L}\Vert _{0}\le O\left( m\omega ^{\frac {2}{3}}\right) . \end{aligned}$$

Proof

Fixing i and ignoring the subscript in i, by Claim 8.2 in Allen-Zhu et al. (2018), for \(l\in [L-1]\), there exists \({\varvec{D}}''_{l}\) such that \(|({\varvec{D}}''_{l})_{k,k}|\le 1\) and

$$\begin{aligned} h'_{l}&={\varvec{D}}''_{l}\left( (\varvec{I}+\tau {\varvec{W}}_{l}^{(0)}+\tau {\varvec{W}}'_{l})h_{l-1}-(\varvec{I}+\tau {\varvec{W}}_{l}^{(0)})h_{l-1}^{(0)}\right) \nonumber \\&={\varvec{D}}''_{l}\left( (\varvec{I}+\tau {\varvec{W}}_{l}^{(0)}+\tau {\varvec{W}}'_{l})h'_{l-1}+\tau {\varvec{W}}'_{l}h_{l-1}^{(0)}\right) \nonumber \\&={\varvec{D}}''_{l}(\varvec{I}+\tau {\varvec{W}}_{l}^{(0)}+\tau {\varvec{W}}'_{l}){\varvec{D}}''_{l-1}(\varvec{I}+\tau {\varvec{W}}_{l-1}+\tau {\varvec{W}}'_{l-1})h'_{l-2}\nonumber \\&\quad +\tau {\varvec{D}}''_{l}(\varvec{I}+\tau {\varvec{W}}_{l}^{(0)}+\tau {\varvec{W}}'_{l}){\varvec{D}}''_{l-1}{\varvec{W}}'_{l-1}h_{l-2}^{(0)}+\tau {\varvec{D}}''_{l}{\varvec{W}}'_{l}h_{l-1}^{(0)}\nonumber \\&=\cdots \nonumber \\&=\sum _{a=1}^{l}\tau {\varvec{D}}''_{l}(\varvec{I}+\tau {\varvec{W}}_{l}^{(0)}+\tau {\varvec{W}}'_{l})\cdots {\varvec{D}}''_{a+1}(\varvec{I}+\tau {\varvec{W}}_{a+1}+\tau {\varvec{W}}'_{a+1}){\varvec{D}}''_{a}{\varvec{W}}'_{a}h_{a}^{(0)}. \end{aligned}$$
(34)

We claim that

$$\begin{aligned} \Vert h'_{l}\Vert \le O(\tau ^2L\omega ) \end{aligned}$$
(35)

due to the fact \(\Vert {\varvec{D}}''_{l}\Vert \le 1\) and the assumption \(\Vert {\varvec{W}}'_{l}\Vert \le \tau \omega\) for \(l\in [L-1]\). This implies that \(\Vert h'_{i,l}\Vert ,\Vert g'_{i,l}\Vert \le O(\tau ^2L\omega )\) for all \(l\in [L-1]\) and for all i with probability at least \(1-O(nL)\cdot \exp (-\Omega (m))\). One step further, we have \(\Vert h'_{L}\Vert ,\Vert g'_{L}\Vert \le O(\omega )\).

As for the sparsity \(\Vert {\varvec{D}}'_{l}\Vert _{0}\), we have \(\Vert {\varvec{D}}'_{l}\Vert _{0}\le O(m(\omega \tau L)^{\frac {2}{3}})\) for every \(l=[L-1]\) and \(\Vert {\varvec{D}}'_{L}\Vert _{0}\le O(m\omega ^{\frac {2}{3}})\).

The argument is as follows (adapt from the Claim 5.3 in Allen-Zhu et al. (2018)).

We first study the case \(l\in [L-1]\). We see that if \(({\varvec{D}}'_{l})_{j,j}\ne 0\) one must have \(|(g'_{l})_{j}|>|(g_{l}^{(0)})_{j}|\).

We note that \((g_{l}^{(0)})_{j}=(h_{l-1}^{(0)}+\tau {\varvec{W}}_{l}^{(0)}h_{l-1}^{(0)})_{j}\sim \mathcal {N}\left( (h_{l-1}^{(0)})_{j},\frac{\tau ^{2}\Vert h_{l-1}^{(0)}\Vert ^{2}}{m}\right)\). Let \(\xi \le \frac{1}{\sqrt{m}}\)be a parameter to be chosen later. Let \(S_{1}\subseteq [m]\) be a index set satisfying \(S_{1}:=\{j:|(g_{l}^{(0)})_{j}|\le \xi \tau \}\). We have \(\mathbb {P}\{|(g_{l}^{(0)})_{j}|\le \xi \tau \}\le O(\xi \sqrt{m})\) for each \(j\in [m]\). By Chernoff bound, with probability at least \(1-\exp (-\Omega (m^{3/2}\xi ))\) we have

$$\begin{aligned} |S_{1}|\le O(\xi m^{3/2}). \end{aligned}$$

Let \(S_{2}:=\{j:j\notin S_{1},\ \text {and }({\varvec{D}}'_{l})_{j,j}\ne 0\}\). Then for \(j\in S_{2}\), we have \(|(g'_{l})_{j}|>\xi \tau\). As we have proved that \(\Vert g'_{l}\Vert \le O(\tau ^2L\omega )\), we have

$$\begin{aligned} |S_{2}|\le \frac{\Vert g'_{l}\Vert ^{2}}{(\xi \tau )^{2}}=O((\omega \tau L)^{2}/\xi ^{2}). \end{aligned}$$

Choosing \(\xi\) to minimize \(|S_{1}|+|S_{2}|\), we have \(\xi =(\omega \tau L)^{\frac {2}{3}}/\sqrt{m}\) and consequently, \(\Vert {\varvec{D}}'_{l}\Vert _{0}\le O(m(\omega \tau L)^{\frac {2}{3}})\). Similarly, we have \(\Vert {\varvec{D}}'_{L}\Vert _{0}\le O(m\omega ^{\frac {2}{3}})\). \(\square\)

We next prove that the norm of a sparse vector after the ResNet mapping.

Lemma 5

Suppose that \(s\ge \Omega (d/\log m), \tau ^2L\le O(1)\). If \({\varvec{W}}_l\) for \(l\in [L]\) satisfy the condition as in Lemma 3, then for all \(i\in [n]\) and \(a\in [L]\) and for all s-sparse vectors \(u\in \mathbb {R}^{m}\) and for all \(v\in \mathbb {R}^{d}\), the following bound holds with probability at least \(1-(nL)\cdot \exp (-\Omega (s\log m))\)

$$\begin{aligned} |v^T{\varvec{B}}{\varvec{D}}_{i, L}{\varvec{W}}_{L}{\varvec{D}}_{i, L-1}(\varvec{I}+\tau {\varvec{W}}_{L-1})\cdots {\varvec{D}}_{i, a}(\varvec{I}+\tau {\varvec{W}}_{a})u|\le O\left( \frac{\sqrt{s\log m}}{\sqrt{d}}\Vert u\Vert \Vert v\Vert \right) , \end{aligned}$$
(36)

where \({\varvec{D}}_{i, a}\) is diagonal activation matrix for sample i.

Proof

For any fixed vector \(u\in \mathbb {R}^{m}\), \(\Vert {\varvec{D}}_{i,L}{\varvec{W}}_{L}{\varvec{D}}_{i,L-1}(\varvec{I}+\tau {\varvec{W}}_{L-1})\cdots {\varvec{D}}_{i,a}(\varvec{I}+\tau {\varvec{W}}_{a})u\Vert \le 1.1 \Vert u\Vert\) holds with probability at least \(1-\exp (-\Omega (m))\) because of Lemma 3.

On the above event, for a fixed vector \(v\in \mathbb {R}^{d}\) and any fixed \({\varvec{W}}_{l}\) for \(l\in [L]\), the randomness only comes from \({\varvec{B}}\), then \(v^{T}{\varvec{B}}{\varvec{D}}_{i,L}{\varvec{W}}_{L}{\varvec{D}}_{i,L-1}(\varvec{I}+\tau {\varvec{W}}_{L-1})\cdots {\varvec{D}}_{i,a}(\varvec{I}+\tau {\varvec{W}}_{a})u\) is a Gaussian variable with mean 0 and variance no larger than \(1.1^2\Vert u\Vert ^2\cdot \Vert v\Vert ^2/d\). Hence

$$\begin{aligned}&\mathbb {P}\big \{ |v^{T}{\varvec{B}}{\varvec{D}}_{i,L}{\varvec{W}}_{L}{\varvec{D}}_{i,L-1}(\varvec{I}+\tau {\varvec{W}}_{L-1})\cdots {\varvec{D}}_{i,a}(\varvec{I}+\tau {\varvec{W}}_{a})u|\ge \sqrt{s\log m}\cdot \Omega (\Vert u\Vert \Vert v\Vert /\sqrt{d})\big \}\\&\quad =\text {erfc}(\Omega (\sqrt{s\log m}))\le \exp (-\Omega (s\log m)). \end{aligned}$$

Take \(\epsilon\)-net over all s-sparse vectors of u and all d-dimensional vectors of v, if \(s\ge \Omega (d/\log m)\) then with probability \(1-\exp (-\Omega (s\log m))\) the claim holds for all s-sparse vectors of u and all d-dimensional vectors of v. Further taking the union bound over all \(i\in [n]\) and \(a\in [L]\), the lemma is proved. \(\square\)

D Gradient lower/upper bounds and their proofs

Because the gradient is pathological and data-dependent, in order to build bound on the gradient, we need to consider all possible point and all cases of data. Hence we first introduce an arbitrary loss vector and then the gradient bound can be obtained by taking a union bound.

We define the \(\mathsf {BP}_{\overrightarrow{{\varvec{W}}}, i}(v, \cdot )\) operator. It back-propagates a vector v to the \(\cdot\) which could be the intermediate output \(h_l\) or the parameter \({\varvec{W}}_l\) at the specific layer l using the forward propagation state of input i through the network with parameter \(\overrightarrow{{\varvec{W}}}\). Specifically,

$$\begin{aligned} \mathsf {BP}_{\overrightarrow{{\varvec{W}}}, i}(v, h_l)&:= (\varvec{I}+\tau {\varvec{W}}_{l+1})^T {\varvec{D}}_{i,l+1} \cdots (\varvec{I}+\tau {\varvec{W}}_{L-1})^T{\varvec{D}}_{i,L-1}{\varvec{W}}_L^T{\varvec{D}}_{i,L} {\varvec{B}}^T v,\\ \mathsf {BP}_{\overrightarrow{{\varvec{W}}}, i}(v, {\varvec{W}}_l)&:=\tau \left( {\varvec{D}}_{i,l} (\varvec{I}+\tau {\varvec{W}}_{l+1})^T\cdots (\varvec{I}+\tau {\varvec{W}}_{L-1})^T{\varvec{D}}_{i,L-1}{\varvec{W}}_L^T{\varvec{D}}_{i,L} {\varvec{B}}^T v\right) h_{i, l-1}^T \quad \forall l\in [L-1],\\ \mathsf {BP}_{\overrightarrow{{\varvec{W}}}, i}(v, {\varvec{W}}_L)&:=\left( {\varvec{D}}_{i,L} {\varvec{B}}^T v\right) h_{i, L-1}^T. \end{aligned}$$

Moreover, we introduce

$$\begin{aligned} \mathsf {BP}_{\overrightarrow{{\varvec{W}}}}(\overrightarrow{v}, {\varvec{W}}_l):=\sum _{i=1}^n \mathsf {BP}_{\overrightarrow{{\varvec{W}}}, i}(v_i, {\varvec{W}}_l) \quad \forall l\in [L], \end{aligned}$$

where \(\overrightarrow{v}\) is composed of n vectors \(v_i\) for \(i\in [n]\). If \(v_i\) is the error signal of input i, then \(\nabla _{{\varvec{W}}_l} F_i(\overrightarrow{{\varvec{W}}}) = \mathsf {BP}_{\overrightarrow{{\varvec{W}}},i}({\varvec{B}}h_{i,L}-y_i^* , {\varvec{W}}_l)\).

1.1 D.1 Gradient upper bound

Proof

We ignore the superscript \(^{(0)}\) for simplicity. Then for an \(i\in [n]\) we have

$$\begin{aligned} \left\| \nabla _{{\varvec{W}}_{L}}F_i(\overrightarrow{{\varvec{W}}})\right\| _F =\left\| \left( {\varvec{D}}_{i,L}\partial h_{i,L}\right) h_{i,L-1}^T\right\| _F =\left\| \left( {\varvec{D}}_{i,L}\partial h_{i,L}\right) \right\| \left\| h_{i,L-1}^T\right\| \le \frac{1+c}{1-\epsilon } \Vert \partial h_{i,L}\Vert , \end{aligned}$$

because of Theorem 1. Similarly, we have for \(l\in [L-1]\),

$$\begin{aligned} \left\| \nabla _{{\varvec{W}}_{l}}F_i(\overrightarrow{{\varvec{W}}})\right\| _F&=\left\| \tau \left( {\varvec{D}}_{i,l} (\varvec{I}+\tau {\varvec{W}}_{l+1})^T\cdots (\varvec{I}+\tau {\varvec{W}}_{L-1})^T{\varvec{D}}_{i,L-1}{\varvec{W}}_L^T{\varvec{D}}_{i,L} \partial h_{i,L}\right) h_{i, l-1}^T\right\| _F \\&\le \tau \Vert {\varvec{D}}_{i,l} (\varvec{I}+\tau {\varvec{W}}_{l+1})^T\cdots {\varvec{D}}_{i,L-1}\Vert \cdot \Vert {\varvec{W}}_L^T{\varvec{D}}_{i,L}\Vert \cdot \Vert \partial h_{i,L}\Vert \cdot \Vert h_{i,l-1}\Vert \\&\le \frac{(1+c)^2}{(1-\epsilon )^2}(2\sqrt{2}+c)\tau \Vert \partial h_{i,L}\Vert , \end{aligned}$$

because of Theorem 1 and Lemma 2. \(\square\)

The above upper bounds hold for the initialization \(\overrightarrow{{\varvec{W}}}^{(0)}\) because of Theorem 1 and Theorem 2. They also hold for all the \(\overrightarrow{{\varvec{W}}}\) such that \(\Vert \overrightarrow{{\varvec{W}}}-\overrightarrow{{\varvec{W}}}^{(0)}\Vert \le \omega\) due to Lemma 3.

For the quadratic loss function, we have \(\Vert \partial h_{i,L}\Vert ^2= \Vert {\varvec{B}}^T({\varvec{B}}h_{i,L}-y_{i}^{*})\Vert ^2= O(m/d)F_i(\overrightarrow{{\varvec{W}}})\). We have the gradient upper bound as follows.

Theorem 6

Suppose \(\omega = O(1)\). For every input sample \(i\in [n]\) and for every \(l\in [L-1]\) and for every \(\overrightarrow{{\varvec{W}}}\) such that \(\Vert {\varvec{W}}_L-{\varvec{W}}_L^{(0)}\Vert \le \omega\) and \(\Vert {\varvec{W}}_l-{\varvec{W}}_l^{(0)}\Vert \le \tau \omega\), the following holds with probability at least \(1- O(nL^2)\cdot \exp (-\Omega (m))\) over the randomness of \(\varvec{A},{\varvec{B}}\) and \(\overrightarrow{{\varvec{W}}}^{(0)}\)

$$\begin{aligned} \Vert \nabla _{{\varvec{W}}_{l}}F_i(\overrightarrow{{\varvec{W}}})\Vert _{F}^{2}\le&O\left( \frac{\tau ^{2}m}{d} F_i(\overrightarrow{{\varvec{W}}})\right) , \\ \Vert \nabla _{{\varvec{W}}_{L}}F_i(\overrightarrow{{\varvec{W}}})\Vert _{F}^{2}\le&O\left( \frac{m}{d} F_i(\overrightarrow{{\varvec{W}}})\right) . \end{aligned}$$
(37)

1.2 D.2 Gradient lower bound

For the quadratic loss function, we have the following gradient lower bound.

Theorem 7

Let \(\omega =O\left( \frac{\delta ^{3/2}}{n^{3}\log ^{3}m}\right)\). With probability at least \(1-\exp (-\Omega (m\omega ^{\frac {2}{3}}))\) over the randomness of \(\overrightarrow{{\varvec{W}}}^{(0)},\varvec{A},{\varvec{B}}\), it satisfies for every \(\overrightarrow{{\varvec{W}}}\) with \(\Vert \overrightarrow{{\varvec{W}}}-\overrightarrow{{\varvec{W}}}^{(0)}\Vert \le \omega\),

$$\begin{aligned} \Vert \nabla _{{\varvec{W}}_{L}}F(\overrightarrow{{\varvec{W}}})\Vert _{F}^{2}\ge \Omega \left( \frac{F(\overrightarrow{{\varvec{W}}})}{dn/\delta }\times m\right) . \end{aligned}$$
(38)

This gradient lower bound on \(\Vert \nabla _{{\varvec{W}}_{L}}F(\overrightarrow{{\varvec{W}}})\Vert _{F}^{2}\) acts like the gradient dominance condition (Zou and Gu, 2019; Allen-Zhu et al., 2018) except that our range on \(\omega\) does not depend on the depth L.

Proof

The gradient lower-bound at the initialization is given by the Section 6.2 in (Allen-Zhu et al., 2018) and the Lemma 4.1 in (Zou and Gu, 2019) via the smoothed analysis (Spielman and Teng, 2004): with high probability the gradient is lower-bounded, although the worst case it might be 0. We adopt the same proof for the Lemma 4.1 in Zou and Gu (2019) based on two preconditioned results Theorem 2 and Lemma 6. We shall not repeat it here.

Now suppose that we have \(\Vert \nabla _{{\varvec{W}}_{L}}F(\overrightarrow{{\varvec{W}}}^{(0)})\Vert _{F}^{2}\ge \Omega \left( \frac{F(\overrightarrow{{\varvec{W}}}^{(0)})}{dn/\delta }\times m\right)\). We next bound the change of the gradient after perturbing the parameter. Recall that

$$\begin{aligned} \mathsf {BP}_{\overrightarrow{{\varvec{W}}}^{(0)}}(\overrightarrow{v}, {\varvec{W}}_L)-\mathsf {BP}_{\overrightarrow{{\varvec{W}}}}(\overrightarrow{v}, {\varvec{W}}_L)=\sum _{i=1}^{n}\Big ((v_{i}^{T}{\varvec{B}}{\varvec{D}}_{i,L}^{(0)})^{T}(h_{i,L-1}^{(0)})^{T}-(v_{i}^{T}{\varvec{B}}{\varvec{D}}_{i,L})^{T}(h_{i,L-1})^{T}\Big ) \end{aligned}$$

By Lemma 4 and Lemma 5, we know,

$$\begin{aligned} \Vert v_{i}^{T}{\varvec{B}}{\varvec{D}}_{i,L}^{(0)}-v_{i}^{T}{\varvec{B}}{\varvec{D}}_{i,L}\Vert \le O(\sqrt{m\omega ^{\frac {2}{3}}}/\sqrt{d})\cdot \Vert v_{i}\Vert . \end{aligned}$$

Furthermore, we know

$$\begin{aligned} \Vert v_{i}^{T}{\varvec{B}}{\varvec{D}}_{i,L}\Vert \le O(\sqrt{m/d})\cdot \Vert v_{i}\Vert . \end{aligned}$$

By Theorem 2 and Lemma 4, we have

$$\begin{aligned} \Vert h_{i,L-1}^{(0)}\Vert \le 1.1\quad \text {and }\quad \Vert h_{i,L-1}-h_{i,L-1}^{(0)}\Vert \le O(\omega ). \end{aligned}$$

Combing the above bounds together, we have

$$\begin{aligned} \Vert \mathsf {BP}_{\overrightarrow{{\varvec{W}}}^{(0)}}(\overrightarrow{v}, {\varvec{W}}_L)-\mathsf {BP}_{\overrightarrow{{\varvec{W}}}}(\overrightarrow{v}, {\varvec{W}}_L)\Vert _{F}^{2} \le n\Vert \overrightarrow{v}\Vert ^{2}\cdot O(\sqrt{m\omega ^{\frac {2}{3}}/d}+\omega \sqrt{m/d})^{2} \le n\Vert \overrightarrow{v}\Vert ^{2}\cdot O\left( \frac{m}{d}\omega ^{\frac {2}{3}}\right) \end{aligned}$$

Hence the gradient lower bound still holds for \(\overrightarrow{{\varvec{W}}}\) given \(\omega <O\left( \frac{\delta ^{3/2}}{n^{3}}\right)\).

Finally, taking \(\epsilon -\)net over all possible vectors \(\overrightarrow{v}=(v_{1},\dots ,v_{n})\in (\mathbb {R}^{d})^{n}\), we prove that the above gradient lower bound holds for all \(\overrightarrow{v}\). In particular, we can now plug in the choice of \(v_{i}={\varvec{B}}h_{i,L}-y_{i}^{*}\) and it implies our desired bounds on the true gradients. \(\square\)

The gradient lower bound requires the following property.

Lemma 6

For any \(\delta\) and any pair \((x_{i},x_{j})\) satisfying \(\Vert x_{i}-x_{j}\Vert \ge \delta\), then \(\Vert h_{i,l}-h_{j,l}\Vert \ge \Omega (\delta )\) holds for all \(l\in [L]\) with probability at least \(1-O(n^{2}L)\cdot \exp (-\Omega (\log ^{2}{m}))\) given that \(\tau \le O(1/(\sqrt{L}\log {m}))\) and \(m\ge \Omega (\tau ^2 L^2\delta ^{-2})\).

The proof of Lemma 6 follows the Appendix C in Allen-Zhu et al. (2018).

E Semi-smoothness for \(\tau \le O(1/\sqrt{L})\)

With the help of Theorem 6 and several other improvements, we can obtain a tighter bound on the semi-smoothness condition of the objective function.

Theorem 8

Let \(\omega = O\left( \frac{\delta ^{3/2}}{n^3L^{7/2}}\right)\) and \(\tau ^2L\le 1\). With high probability, we have for every \(\breve{\overrightarrow{{\varvec{W}}}}\in (\mathbb {R}^{m\times m})^{L}\) with \(\left\| \breve{\overrightarrow{{\varvec{W}}}}-\overrightarrow{{\varvec{W}}}^{(0)}\right\| \le \omega\) and for every \(\overrightarrow{{\varvec{W}}}'\in (\mathbb {R}^{m\times m})^{L}\) with \(\Vert \overrightarrow{{\varvec{W}}}'\Vert \le \omega\), we have

$$\begin{aligned} F(\breve{\overrightarrow{{\varvec{W}}}}+\overrightarrow{{\varvec{W}}}') \le&F(\breve{\overrightarrow{{\varvec{W}}}})+\langle \nabla F(\breve{\overrightarrow{{\varvec{W}}}}),\overrightarrow{{\varvec{W}}}'\rangle +O(\frac{nm}{d})\Vert \overrightarrow{{\varvec{W}}}'\Vert _F^2+O\left( \sqrt{\frac{m}{nd}}\omega ^{\frac{1}{3}}L^{\frac{7}{6}}\right) \Vert \overrightarrow{{\varvec{W}}}'\Vert _F\sqrt{F(\breve{\overrightarrow{{\varvec{W}}}})}. \end{aligned}$$

We will show the semi-smoothness theorem for a more general \(\omega \in \left[ \Omega \left( \left( d/(m\log m)\right) ^{\frac {3}{2}}\right) , O(1)\right]\) and the above high probability is at least \(1-\exp (-\Omega (m\omega ^{\frac {2}{3}}))\) over the randomness of \(\overrightarrow{{\varvec{W}}}^{(0)},\varvec{A},{\varvec{B}}\).

Before going to the proof of the theorem, we introduce a lemma.

Lemma 7

There exist diagonal matrices \({\varvec{D}}''_{i,l}\in \mathbb {R}^{m\times m}\) with entries in [-1,1] such that \(\forall i\in [n],\forall l\in [L-1]\),

$$\begin{aligned} h_{i,l}-\breve{h}_{i,l}=\sum _{a=1}^{l}(\breve{{\varvec{D}}}_{i,l}+{\varvec{D}}''_{i,l})(\varvec{I}+\tau \breve{{\varvec{W}}}_{l})\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{a+1})(\breve{{\varvec{D}}}_{i,a}+{\varvec{D}}''_{i,a})\tau {\varvec{W}}'_{a}h_{i, a-1}, \end{aligned}$$
(39)

and

$$\begin{aligned} h_{i,L}-\breve{h}_{i,L}=&(\breve{{\varvec{D}}}_{i,L}+{\varvec{D}}''_{i,L}){\varvec{W}}'_{L}h_{i, L-1}\nonumber \\&+\sum _{a=1}^{L-1}(\breve{{\varvec{D}}}_{i,L}+{\varvec{D}}''_{i,L})\breve{{\varvec{W}}}_{L}\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{a+1})(\breve{{\varvec{D}}}_{i,a}+{\varvec{D}}''_{i,a})\tau {\varvec{W}}'_{a}h_{i,a-1}. \end{aligned}$$
(40)

Furthermore, we then have \(\forall l\in [L-1],\Vert h_{i,l}-\breve{h}_{i,l}\Vert \le O(\tau ^2L\omega )\), \(\Vert {\varvec{D}}''_{i,l}\Vert _{0}\le O(m(\omega \tau L)^{\frac {2}{3}})\), and \(\Vert h_{i,L}-\breve{h}_{i,L}\Vert \le O((1+\tau \sqrt{L})\Vert {\varvec{W}}'\Vert _F)\), \(\Vert {\varvec{D}}''_{i,L}\Vert _{0}\le O(m\omega ^{\frac {2}{3}})\) and

$$\begin{aligned} \Vert {\varvec{B}}h_{i,L}-{\varvec{B}}\breve{h}_{i,L}\Vert \le O(\sqrt{m/d})\Vert {\varvec{W}}'\Vert _F \end{aligned}$$

hold with probability \(1-\exp (-\Omega (m\omega ^{\frac {2}{3}}))\) given \(\Vert {\varvec{W}}'_{L}\Vert \le \omega , \Vert {\varvec{W}}'_{l}\Vert \le \tau \omega\) for \(l\in [L-1]\) and \(\omega \le O(1), \tau \sqrt{L}\le 1\).

Proof of Theorem 8

First of all, we know that \(\breve{loss}_{i}:={\varvec{B}}\breve{h}_{i,L}-y_{i}^{*}\)

$$\begin{aligned} \frac{1}{2}\Vert {\varvec{B}}h_{i,L}-y_{i}^{*}\Vert ^{2}&=\frac{1}{2}\Vert \breve{loss}_{i}+{\varvec{B}}(h_{i,L}-\breve{h}_{i,L})\Vert ^{2}\nonumber \\&=\frac{1}{2}\Vert \breve{loss}_{i}\Vert ^{2}+\breve{loss}_{i}^{T}{\varvec{B}}(h_{i,L}-\breve{h}_{i,L})+\frac{1}{2}\Vert {\varvec{B}}(h_{i,L}-\breve{h}_{i,L})\Vert ^{2}, \end{aligned}$$
(41)

and

$$\begin{aligned} \nabla _{{\varvec{W}}_{l}}F(\overrightarrow{{\varvec{W}}})&=\sum _{i=1}^{n}(loss_{i}^{T}{\varvec{B}}{\varvec{D}}_{i,L}{\varvec{W}}_{L}\cdots {\varvec{D}}_{i,l+1}(\varvec{I}+\tau {\varvec{W}}_{l+1}){\varvec{D}}_{i,l})^{T}(\tau h_{i,l-1})^{T}.\\ \nabla _{{\varvec{W}}_{L}}F(\overrightarrow{{\varvec{W}}})&=\sum _{i=1}^{n}(loss_{i}^{T}{\varvec{B}}{\varvec{D}}_{i,L})^{T}(h_{i,l-1})^{T}. \end{aligned}$$
(42)

We use the relation that for two matrices AB, \(\langle A, B\rangle = \text {tr}(A^TB)\). Then, we can write

$$\begin{aligned} \langle \nabla _{{\varvec{W}}_{l}}F(\breve{\overrightarrow{{\varvec{W}}}}),{\varvec{W}}'_{l}\rangle = \sum _{i=1}^{n}(\breve{loss}_{i}^{T}{\varvec{B}}\breve{{\varvec{D}}}_{i,L}\breve{{\varvec{W}}}_{L}\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{l+1})\breve{{\varvec{D}}}_{i,l}{\varvec{W}}'_{l}(\tau \breve{h}_{i,l-1}). \end{aligned}$$
(43)

Then further by Lemma 7, we have

$$\begin{aligned}&F(\breve{\overrightarrow{{\varvec{W}}}}+\overrightarrow{{\varvec{W}}}')-F(\breve{\overrightarrow{{\varvec{W}}}})-\langle \nabla F(\breve{\overrightarrow{{\varvec{W}}}}),\overrightarrow{{\varvec{W}}}'\rangle \nonumber \\&\quad =-\langle \nabla F(\breve{\overrightarrow{{\varvec{W}}}}),\overrightarrow{{\varvec{W}}}'\rangle +\frac{1}{2}\sum _{i=1}^{n}\Vert {\varvec{B}}h_{i,L}-y_{i}^{*}\Vert ^{2}-\Vert {\varvec{B}}\breve{h}_{i,L}-y_{i}^{*}\Vert ^{2}\nonumber \\&\quad =-\sum _{l=1}^{L}\langle \nabla _{{\varvec{W}}_{l}}F(\breve{\overrightarrow{{\varvec{W}}}}),{\varvec{W}}'_{l}\rangle +\sum _{i=1}^{n}\breve{loss}_{i}^{T}{\varvec{B}}(h_{i,L}-\breve{h}_{i,L})+\frac{1}{2}\Vert {\varvec{B}}(h_{i,L}-\breve{h}_{i,L})\Vert ^{2}\nonumber \\&\quad {\mathop {=}\limits ^{(a)}}\frac{1}{2}\sum _{i=1}^{n}\Vert {\varvec{B}}(h_{i,L}-\breve{h}_{i,L})\Vert ^{2}+\sum _{i=1}^{n}\breve{loss}_{i}^{T}{\varvec{B}}\left( (\breve{{\varvec{D}}}_{i,L}+{\varvec{D}}''_{i,L}){\varvec{W}}'_{L}h_{i,L-1}-(\breve{{\varvec{D}}}_{i,L}){\varvec{W}}'_{L}\breve{h}_{i,L-1}\right) \nonumber \\&\qquad +\sum _{i=1}^{n}\sum _{l=1}^{L-1}\breve{loss}_{i}^{T}{\varvec{B}}\Big ((\breve{{\varvec{D}}}_{i,L}+{\varvec{D}}''_{i,L})\breve{{\varvec{W}}}_{L}\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{l+1})(\breve{{\varvec{D}}}_{i,l}+{\varvec{D}}''_{i,l})\tau {\varvec{W}}'_{l}h_{i,l-1}\nonumber \\&\qquad -\breve{{\varvec{D}}}_{i,L}\breve{{\varvec{W}}}_{L}\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{l+1})\breve{{\varvec{D}}}_{i,l}{\varvec{W}}'_{l}(\tau \breve{h}_{i,l-1})\Big ), \end{aligned}$$
(44)

where (a) is due to Lemma 7.

We next bound the RHS of (45). We first use Lemma 7 to get

$$\begin{aligned} \Vert {\varvec{B}}(h_{i,L}-\breve{h}_{i,L})\Vert \le O(\sqrt{m/d})\Vert {\varvec{W}}'\Vert _F. \end{aligned}$$
(45)

Next we calculate that for \(l=L\),

$$\begin{aligned}&\Big |\breve{loss}_{i}^{T}{\varvec{B}}\left( (\breve{{\varvec{D}}}_{i,L}+{\varvec{D}}''_{i,L}){\varvec{W}}'_{L}h_{i,L-1}-(\breve{{\varvec{D}}}_{i,L}){\varvec{W}}'_{L}\breve{h}_{i,L-1}\right) \Big |\nonumber \\&\quad \le \Big |\breve{loss}_{i}^{T}{\varvec{B}}\left( {\varvec{D}}''_{i,L}{\varvec{W}}'_{L}h_{i,L-1}\right) \Big |+\Big |\breve{loss}_{i}^{T}{\varvec{B}}\left( \breve{{\varvec{D}}}_{i,L}{\varvec{W}}'_{L}(h_{i,L-1}-\breve{h}_{i,L-1})\right) \Big |. \end{aligned}$$
(46)

For the first term, by Lemma 5 and Lemma 7, we have

$$\begin{aligned} \left| \breve{loss}_{i}^{T}{\varvec{B}}\left( {\varvec{D}}''_{i,L}{\varvec{W}}'_{L}h_{i,L-1}\right) \right|&\le O\left( \frac{\sqrt{m\omega ^{\frac {2}{3}}}}{\sqrt{d}}\right) \Vert \breve{loss}_{i}\Vert \cdot \Vert {\varvec{W}}'_{L}h_{i,L-1}\Vert \nonumber \\&\le O\left( \frac{\sqrt{m\omega ^{\frac {2}{3}}}}{\sqrt{d}}\right) \Vert \breve{loss}_{i}\Vert \cdot \Vert {\varvec{W}}'_{L}\Vert , \end{aligned}$$
(47)

where the last inequality is due to \(\Vert h_{i,L-1}\Vert \le O(1)\). For the second term, by Lemma 7 we have

$$\begin{aligned}&\left| \breve{loss}_{i}^{T}{\varvec{B}}\left( \breve{{\varvec{D}}}_{i,L}{\varvec{W}}'_{L}(h_{i,L-1}-\breve{h}_{i,L-1})\right) \right| \nonumber \\&\quad \le \Vert \breve{loss}_{i}\Vert \cdot \left\| {\varvec{B}}\breve{{\varvec{D}}}_{i,L}\right\| _{2}\cdot \Vert {\varvec{W}}'_{L}\Vert \Vert h_{i,L-1}-\breve{h}_{i,L-1}\Vert \nonumber \\&\quad \le \Vert \breve{loss}_{i}\Vert \cdot O\left( \frac{\omega \sqrt{m}}{\sqrt{d}}\right) \cdot \Vert {\varvec{W}}'_{L}\Vert , \end{aligned}$$
(48)

where the last inequality is due to the assumption \(\Vert {\varvec{W}}'_{L}\Vert \le \omega\). Similarly for \(l\in [L-1]\), we ignore the index i for simplicity.

$$\begin{aligned}&\Big |\sum _{l=1}^{L-1}\breve{loss}^{T}\Big ({\varvec{B}}(\breve{{\varvec{D}}}_{L}+{\varvec{D}}''_{L})\breve{{\varvec{W}}}_{L}\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{l+1})(\breve{{\varvec{D}}}_{l}+{\varvec{D}}''_{l})-{\varvec{B}}\breve{{\varvec{D}}}_{L}\breve{{\varvec{W}}}_{L}\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{l+1})\breve{{\varvec{D}}}_{l}\Big ){\varvec{W}}'_{l}(\tau h_{l-1})\Big |\nonumber \\&\quad =\Big | \sum _{l=1}^{L-1}\breve{loss}^{T}{\varvec{B}}{\varvec{D}}''_{L}\breve{{\varvec{W}}}_{L}({\varvec{D}}_{L-1}+{\varvec{D}}''_{L-1})(\varvec{I}+\tau \breve{{\varvec{W}}}_{L-1})\cdots ({\varvec{D}}_{l}+{\varvec{D}}''_{l})(\tau {\varvec{W}}'_l h_{l-1})\Big |\nonumber \\&\qquad +\Big | \sum _{l=1}^{L-1}\sum _{a=l}^{L-1}\breve{loss}^{T}{\varvec{B}}\breve{{\varvec{D}}}_{L}\breve{{\varvec{W}}}_{L}\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{a+1}){\varvec{D}}''_{a}(\varvec{I}+\tau \breve{{\varvec{W}}}_{a})\cdots ({\varvec{D}}_{l}+{\varvec{D}}''_{l})(\tau {\varvec{W}}'_l h_{l-1})\Big |\nonumber \\&\qquad + \Big |\sum _{l=1}^{L-1}\breve{loss}^{T}{\varvec{B}}\breve{{\varvec{D}}}_{L}\breve{{\varvec{W}}}_{L}\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{l+1})\breve{{\varvec{D}}}_{l}{\varvec{W}}'_{l}\tau (h_{l-1}-\breve{h}_{l-1})\Big | \end{aligned}$$
(49)

We next bound the terms in (50) one by one. For the first term, by Lemma 5 and Lemma 7, we have

$$\begin{aligned}&\left| \sum _{l=1}^{L-1}\breve{loss}^{T}{\varvec{B}}{\varvec{D}}''_{L}\breve{{\varvec{W}}}_{L}({\varvec{D}}_{L-1}+{\varvec{D}}''_{L-1})(\varvec{I}+\tau \breve{{\varvec{W}}}_{L-1})\cdots ({\varvec{D}}_{l}+{\varvec{D}}''_{l})(\tau {\varvec{W}}'_l h_{l-1})\right| \nonumber \\&\quad \le O\left( \frac{\sqrt{m\omega ^{\frac {2}{3}}}}{\sqrt{d}}\right) \left\| \breve{loss}\right\| \cdot \left\| \sum _{l=1}^{L-1}\breve{{\varvec{W}}}_{L}({\varvec{D}}_{L-1}+{\varvec{D}}''_{L-1})(\varvec{I}+\tau \breve{{\varvec{W}}}_{L-1})\cdots ({\varvec{D}}_{l}+{\varvec{D}}''_{l})(\tau {\varvec{W}}'_l h_{l-1})\right\| \nonumber \\&\quad {\mathop {\le }\limits ^{(a)}} O\left( \frac{\sqrt{m\omega ^{\frac {2}{3}}}}{\sqrt{d}}\right) \cdot \Vert \breve{loss}\Vert \cdot \tau \sqrt{L} \Vert {\varvec{W}}'_{L-1:1}\Vert _F, \end{aligned}$$
(50)

where \(\Vert {\varvec{W}}'_{L-1:1}\Vert _F=\sqrt{\sum _{l=1}^{L-1}\Vert {\varvec{W}}'_l\Vert _F^2}\) and (a) is due to the similar argument (56) in the proof Lemma 7 and the fact \(\left\| \breve{{\varvec{W}}}_{L}({\varvec{D}}_{L-1}+{\varvec{D}}''_{L-1})(\varvec{I}+\tau \breve{{\varvec{W}}}_{L-1})\cdots ({\varvec{D}}_{l}+{\varvec{D}}''_{l})\right\| = O(1)\) and \(\Vert h_{l-1}\Vert =O(1)\) holds with high probability. We note that the inequality (a) helps us save a \(\sqrt{L}\) factor in our main theorem.

We have similar bound for the second term of (50)

$$\begin{aligned}&\left| \sum _{l=1}^{L-1}\sum _{a=l}^{L-1}\breve{loss}^{T}{\varvec{B}}\breve{{\varvec{D}}}_{L}\breve{{\varvec{W}}}_{L}\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{a+1}){\varvec{D}}''_{a}(\varvec{I}+\tau \breve{{\varvec{W}}}_{a})\cdots ({\varvec{D}}_{l}+{\varvec{D}}''_{l})(\tau {\varvec{W}}'_l h_{l-1})\right| \nonumber \\&\quad \le O\left( \frac{\sqrt{m(\omega \tau {L})^{\frac {2}{3}}}}{\sqrt{d}}\right) \cdot \Vert \breve{loss}\Vert \cdot \tau \sum _{a=1}^{L-1} \sqrt{a}\Vert {\varvec{W}}'_{a:1}\Vert _F\nonumber \\&\quad \le O\left( \frac{\sqrt{m(\omega \tau {L})^{\frac {2}{3}}}}{\sqrt{d}}\right) \cdot \Vert \breve{loss}\Vert \cdot \tau L^{3/2}\Vert {\varvec{W}}'_{L-1:1}\Vert _F. \end{aligned}$$
(51)

For the last term in (50), we have

$$\begin{aligned}&\left| \sum _{l=1}^{L-1}\breve{loss}^{T}{\varvec{B}}\breve{{\varvec{D}}}_{L}\breve{{\varvec{W}}}_{L}\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{l+1})\breve{{\varvec{D}}}_{l}{\varvec{W}}'_{l}\tau (h_{l-1}-\breve{h}_{l-1})\right| \nonumber \\&\quad \le \Vert \breve{loss}\Vert \cdot O\left( \sqrt{m/d}\right) \cdot \sum _{l=1}^{L-1}\Vert {\varvec{W}}'_l\Vert \cdot \tau ^3L\omega \nonumber \\&\quad \le \Vert \breve{loss}\Vert \cdot O\left( \sqrt{m/d}\right) \cdot \Vert {\varvec{W}}'_{L-1:1}\Vert _{F}\cdot (\tau ^2 L)^{3/2}, \end{aligned}$$
(52)

where is the last inequality is due to the bound on \(\Vert h_{l-1}-\breve{h}_{l-1}\Vert\) in Lemma 7. Hence

$$\begin{aligned} equation 50&\le O\left( \frac{\sqrt{m(\omega \tau {L})^{\frac {2}{3}}}}{\sqrt{d}}\right) \cdot \Vert \breve{loss}\Vert \cdot \tau L^{3/2}\Vert {\varvec{W}}_{L-1:1}\Vert _F\nonumber \\&\le O\left( (\tau L)^{\frac {4}{3}}\frac{\sqrt{mL\omega ^{\frac {2}{3}}}}{\sqrt{d}}\right) \cdot \Vert \breve{loss}\Vert \cdot \Vert {\varvec{W}}'_{L-1:1}\Vert _{F}. \end{aligned}$$
(53)

Having all the above together and using triangle inequality, we have the result. \(\square\)

Proof of Lemma 7

The proof relies on the following lemma.

Lemma 8

(Proposition 8.3 in in Allen-Zhu et al. (2018)) Given vectors \(a, b\in \mathbb {R}^m\) and \({\varvec{D}}\in \mathbb {R}^{m\times m}\) the diagonal matrix where \({\varvec{D}}_{k,k}=\varvec{1}_{a_k\ge 0}\). Then, there exists a diagonal matrix \({\varvec{D}}''\in \mathbb {R}^{m\times m}\) with

  • \(|{\varvec{D}}_{k,k}+{\varvec{D}}''_{k,k}|\le 1\) and \(|{\varvec{D}}''_{k,k}|\le 1\) for every \(k\in [m]\),

  • \({\varvec{D}}''_{k,k}\ne 0\) only when \(\varvec{1}_{a_k\ge 0}\ne \varvec{1}_{b_k\ge 0}\),

  • \(\phi (a)-\phi (b) = ({\varvec{D}}+{\varvec{D}}'')(a-b)\).

Fixing index i and ignoring the subscript in i for simplicity, by Lemma 8, for each \(l\in [L-1]\) there exists a \({\varvec{D}}''_{l}\) such that \(|({\varvec{D}}''_{l})_{k,k}|\le 1\) and

$$\begin{aligned} h_l - \breve{h}_{l}&=\phi ((\varvec{I}+\tau \breve{{\varvec{W}}}_{l}+\tau {\varvec{W}}'_{l})h_{l-1})-\phi ((\varvec{I}+\tau \breve{{\varvec{W}}}_{l})\breve{h}_{l-1})\nonumber \\&=(\breve{{\varvec{D}}}_l+{\varvec{D}}''_{l})\left( (\varvec{I}+\tau \breve{{\varvec{W}}}_{l}+\tau {\varvec{W}}'_{l})h_{l-1}-(\varvec{I}+\tau \breve{{\varvec{W}}}_{l})\breve{h}_{l-1}\right) \nonumber \\&=(\breve{{\varvec{D}}}_l+{\varvec{D}}''_{l})(\varvec{I}+\tau \breve{{\varvec{W}}}_{l})(h_{l-1} - \breve{h}_{l-1}) + (\breve{{\varvec{D}}}_l+{\varvec{D}}''_{l}) \tau {\varvec{W}}'_{l}h_{l-1}\nonumber \\&= \sum _{a=1}^{l} (\breve{{\varvec{D}}}_{l}+{\varvec{D}}''_{l})(\varvec{I}+\tau \breve{{\varvec{W}}}_{l})\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{a+1})(\breve{{\varvec{D}}}_{a}+{\varvec{D}}''_{a})\tau {\varvec{W}}'_{a}h_{a-1}\nonumber \end{aligned}$$
(54)

Then we have following properties. For \(l\in [L-1]\), \(\Vert h_l - \breve{h}_{l}\Vert \le O(\tau ^2 L \omega )\). This is because \(\Vert (\breve{{\varvec{D}}}_{l}+{\varvec{D}}''_{l})(\varvec{I}+\tau \breve{{\varvec{W}}}_{l})\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{a+1})(\breve{{\varvec{D}}}_{a}+{\varvec{D}}''_{a})\Vert \le 1.1\) from Lemma 3; \(\Vert h_{a-1}\Vert \le O(1)\) from Theorem 2; and the assumption \(\Vert {\varvec{W}}'_l\Vert \le \tau \omega\) for \(l\in [L-1]\).

To have a tighter bound on \(\Vert h_L-\breve{h}_L\Vert\), let us introduce \({\varvec{W}}''_{b}:= \sum _{a=b}^{l} (\breve{{\varvec{D}}}_{l}+{\varvec{D}}''_{l})(\varvec{I}+\tau \breve{{\varvec{W}}}_{l})\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{a+1})(\breve{{\varvec{D}}}_{a}+{\varvec{D}}''_{a}){\varvec{W}}'_{a}\), for \(b=1, ..., l\). Then we have

$$\begin{aligned} h_L - \breve{h}_{L} =\left[ {\varvec{W}}''_L, {\varvec{W}}''_{L-1}, ..., {\varvec{W}}''_{1}\right] [ h_{L-1}^T, \tau h_{L-2}^T, ..., \tau h_{0}^T]^T. \end{aligned}$$
(55)

It is easy to get

$$\begin{aligned} \Vert [\tau h_{l-1}^T, \tau h_{l-2}^T, ..., \tau h_{0}^T]^T\Vert = \sqrt{\tau ^2 \sum _{a=0}^{l-1}\Vert h_{a}\Vert ^2}\le \tau \sqrt{L}\cdot O(1), \end{aligned}$$

where the inequality is because of \(\Vert h_{a-1}\Vert \le O(1)\) from Theorem 2. Next, we have

$$\begin{aligned} \left\| \left[ {\varvec{W}}''_l, {\varvec{W}}''_{l-1}, ..., {\varvec{W}}''_{1}\right] \right\|&= \left\| \left[ {\varvec{W}}''_l, {\varvec{W}}''_{l-1}, ..., {\varvec{W}}''_{1}\right] ^T\right\| \le \sqrt{\sum _{a=1}^l \Vert ({\varvec{W}}''_l)^T\Vert ^2} \le 1.1 \sqrt{\sum _{a=1}^l \Vert ({\varvec{W}}'_l)^T\Vert ^2} \le 1.1 \Vert {\varvec{W}}'_{l:1}\Vert _F, \end{aligned}$$
(56)

where the second inequality is from the definition of spectral norm, the third inequality is because of \(\Vert (\breve{{\varvec{D}}}_{l}+{\varvec{D}}''_{l})(\varvec{I}+\tau \breve{{\varvec{W}}}_{l})\cdots (\varvec{I}+\tau \breve{{\varvec{W}}}_{a+1})(\breve{{\varvec{D}}}_{a}+{\varvec{D}}''_{a})\Vert \le 1.1\) from Lemma 3.

Hence we have \(\Vert h_L- \breve{h}_{L}\Vert \le O\left( (1+\tau \sqrt{L})\Vert {\varvec{W}}'\Vert _F\right) = O\left( \Vert {\varvec{W}}'\Vert _F\right)\) because of the assumption \(\tau \sqrt{L}\le 1\).

For \(l\in [L]\), \(\Vert {\varvec{D}}''_l\Vert _0\le O(m\omega ^{\frac {2}{3}})\). This is because \(({\varvec{D}}''_l)_{k,k}\) is non-zero only at coordinates k where \((\breve{g}_l)_k\) and \((g_l)_k\) have opposite signs, where it holds either \(({\varvec{D}}_l^{(0)})_{k,k}\ne (\breve{{\varvec{D}}}_l)_{k,k}\) or \(({\varvec{D}}_l^{(0)})_{k,k}\ne ({\varvec{D}}_l)_{k,k}\). Therefore by Lemma 4, we have \(\Vert {\varvec{D}}''_l\Vert _0\le O(m(\omega \tau L)^{\frac {2}{3}})\) if \(\Vert {{\varvec{W}}}'_l\Vert \le \tau \omega\). \(\square\)

F Proof for Theorem 5

1.1 F.1 Convergence result for GD

Proof

Using Theorem 2 we have \(\Vert h^{(0)}_{i,L}\Vert \le 1.1\) and then using the randomness of \({\varvec{B}}\), it is easy to show that \(\Vert {\varvec{B}}h_{i,L}^{(0)}-y_{i}^{*}\Vert ^{2}\le O(\log ^{2}m)\) with probability at least \(1-\exp (-\Omega (\log ^{2}m))\), and therefore

$$\begin{aligned} F(\overrightarrow{{\varvec{W}}}^{(0)})\le O(n\log ^{2}m). \end{aligned}$$
(57)

Assume that for every \(t=0,1,\dots ,T-1\), the following holds,

$$\begin{aligned} \Vert {\varvec{W}}_{L}^{(t)}-{\varvec{W}}_{L}^{(0)}\Vert _{F}&\le \omega {\mathop {=}\limits ^{\Delta }}O\left( \frac{\delta ^{3/2}}{n^{3}L^{7/2}}\right) \end{aligned}$$
(58)
$$\begin{aligned} \Vert {\varvec{W}}_{l}^{(t)}-{\varvec{W}}_{l}^{(0)}\Vert _{F}&\le \tau \omega . \end{aligned}$$
(59)

We shall prove the convergence of GD under the assumption (58) holds, so that previous statements can be applied. At the end, we shall verify that (58) is indeed satisfied.

Letting \(\nabla _{t}=\nabla F(\overrightarrow{{\varvec{W}}}^{(t)})\), we calculate that

$$\begin{aligned} F(\overrightarrow{{\varvec{W}}}^{(t+1)})&\le F(\overrightarrow{{\varvec{W}}}^{(t)})-\eta \Vert \nabla _t\Vert _{F}^{2}+O(\eta ^{2}nm/d)\Vert \nabla _{t}\Vert _{F}^{2}+\quad \eta \sqrt{F(\overrightarrow{{\varvec{W}}}^{(t)})}\cdot O\left( \sqrt{\frac{mnL\omega ^{\frac {2}{3}}}{d}}(\tau L)^{\frac {4}{3}}\right) \cdot \Vert \nabla _{t}\Vert _{F} \nonumber \\&\le \left( 1-\Omega \left( \frac{\eta \delta m}{dn}\right) \right) F(\overrightarrow{{\varvec{W}}}^{(t)}), \end{aligned}$$
(60)

where the first inequality uses Theorem 4, the second inequality uses the gradient upper bound in Theorem 6 and the last inequality uses the gradient lower bound in Theorem 7 and the choice of \(\eta =O(d/(mn))\) and the assumption on \(\omega\) (58). That is, after \(T=\Omega (\frac{dn}{\eta \delta m})\log \frac{n\log ^{2}m}{\epsilon }\) iterations \(F(\overrightarrow{{\varvec{W}}}^{(T)})\le \epsilon\).

We need to verify for each t, (58) holds. Here we use a result from the Lemma 4.2 in Zou and Gu (2019) that states \(\Vert {\varvec{W}}_L^{(t)}-{\varvec{W}}_L^{(0)}\Vert _F\le O(\sqrt{\frac{n^2 d\log m }{m\delta }})\).

To guarantee the iterates fall into the region given by \(\omega\) (58), we obtain a bound \(m\ge n^8\delta ^{-4}dL^7\log ^2 m\). \(\square\)

1.2 F.2 Convergence result for SGD

Theorem 9

For the ResNet defined and initialized as in Sect. 2, the network width \(m\ge \Omega (n^{17}L^7b^{-4}\delta ^{-8}d\log ^2 m)\). Suppose we do stochastic gradient descent update starting from \(\overrightarrow{{\varvec{W}}}^{(0)}\) and

$$\begin{aligned} \overrightarrow{{\varvec{W}}}^{(t+1)} = \overrightarrow{{\varvec{W}}}^{(t)} - \eta \frac{n}{|S_t|}\sum _{i\in S_t} \nabla F_i(\overrightarrow{{\varvec{W}}}^{(t)}), \end{aligned}$$
(61)

where \(S_t\) is a random subset of [n] with \(|S_t|=b\). Then with probability at least \(1-\exp (-\Omega (\log ^{2}m))\), stochastic gradient descent (61) with learning rate \(\eta =\Theta (\frac{db\delta }{n^{3}m\log m})\) finds a point \(F(\overrightarrow{{\varvec{W}}})\le \epsilon\) in \(T=\Omega (n^{5}b^{-1}\delta ^{-2}\log m \log ^2 \frac{1}{\epsilon })\) iterations.

Proof

The proof of the case of SGD can be adapted from the proof of Theorem 3.8 in Zou and Gu (2019). \(\square\)

G Proofs of Theorem 4 and Proposition 1

Proof

By induction we can show for any \(k\in [m]\) and \(l\in [L-1]\),

$$\begin{aligned} (h_{l})_{k}\ge \phi \left( \sum _{a=1}^{l}\left( \tau {\varvec{W}}_{a}h_{a-1}\right) _{k}\right) . \end{aligned}$$
(62)

It is easy to verify \((h_1)_k = \phi \left( (h_0)_k+(\tau {\varvec{W}}_1h_0)_k\right) \ge \phi \left( (\tau {\varvec{W}}_1h_0)_k\right)\) because of \((h_0)_k\ge 0\).

Then assume \((h_l)_k \ge \phi \left( \sum _{a=1}^{l}\left( \tau {\varvec{W}}_{a}h_{a-1}\right) _{k}\right)\), we show it holds for \(l+1\).

$$\begin{aligned} (h_{l+1})_k&= \phi \left( (h_l)_k+(\tau {\varvec{W}}_{l+1}h_l)_k\right) \ge \phi \left( \phi \left( \sum _{a=1}^{l}\left( \tau {\varvec{W}}_{a}h_{a-1}\right) _{k}\right) +(\tau {\varvec{W}}_{l+1}h_l)_k\right) \ge \phi \left( \sum _{a=1}^{l+1}\left( \tau {\varvec{W}}_{a}h_{a-1}\right) _{k}\right) , \end{aligned}$$

where the last inequality can be shown by case study.

Next we can compute the mean and variance of \(\sum _{a=1}^{l}\left( \tau {\varvec{W}}_{a}h_{a-1}\right) _{k}\) by taking iterative conditioning. We have

$$\begin{aligned} \mathbb {E} \sum _{a=1}^{l}\left( \tau {\varvec{W}}_{a}h_{a-1}\right) _{k}=0,\quad \mathbb {E} \left( \sum _{a=1}^{l}\left( \tau {\varvec{W}}_{a}h_{a-1}\right) _{k}\right) ^2=\frac{\tau ^2}{m}\sum _{a=1}^l\mathbb {E}\Vert h_{a-1}\Vert ^2. \end{aligned}$$
(63)

Moreover, \((\tau {\varvec{W}}_a h_{a-1})_{k}\) are jointly Gaussian for all a with mean 0 because \({\varvec{W}}_a\)’s are drawn from independent Gaussian distributions. We use \(l=2\) as an example to illustrate the conclusion, it can be generalized to other l. Assume that \(h_{0}\) is fixed. First it is easy to verify that \((\tau {\varvec{W}}_1 h_{0})_{k}\) is Gaussian variable with mean 0 and \((\tau {\varvec{W}}_2 h_{1})_{k}\big |{\varvec{W}}_1\) is also Gaussian variable with mean 0. Hence \([(\tau {\varvec{W}}_1 h_{0})_{k}, (\tau {\varvec{W}}_2 h_{1})_{k}]\) follows jointly Gaussian with mean vector [0, 0]. Thus \((\tau {\varvec{W}}_1 h_{0})_{k}+(\tau {\varvec{W}}_2 h_{1})_{k}\) is Gaussian with mean 0. By induction, we have \(\sum _{a=1}^l(\tau {\varvec{W}}_a h_{a-1})_{k}\) is Gaussian with mean 0. Then we have

$$\begin{aligned} \mathbb {E} \Vert h_{l}\Vert ^2&\ge \sum _{k=1}^{m}\mathbb {E}\left( \phi \left( \sum _{a=1}^{l}\left( \tau {\varvec{W}}_{a}h_{a-1}\right) _{k}\right) \right) ^{2} =\sum _{k=1}^{m}\frac{1}{2}\mathbb {E}\left( \sum _{a=1}^{l}\left( \tau {\varvec{W}}_{a}h_{a-1}\right) _{k}\right) ^{2} \nonumber \\&=\frac{1}{2} \sum _{k=1}^{m}\frac{\tau ^{2}\sum _{a=1}^{l}\mathbb {E}\left[ \Vert h_{a-1}\Vert ^{2}\right] }{m} = \frac{\tau ^{2}}{2}\sum _{a=1}^{l}\mathbb {E}\Vert h_{a-1}\Vert ^{2}, \end{aligned}$$
(64)

where the first step is due to (62), the second step is due to the symmetry of Gaussian distribution and the third step is due to (66). Since \((h_{l})_{k} = \phi \left( (h_{l-1})_{k} + \left( {\varvec{W}}_{l}h_{l-1}\right) _{k}\right)\), we can show \(\mathbb {E}(h_l)_k^2 \ge (h_{l-1})_{k}^2\) given \(h_{l-1}\) by numerical integral of Gaussian variable over an interval. Hence we have \(\mathbb {E}\Vert h_{l}\Vert ^{2}\ge \mathbb {E} \Vert h_{l-1}\Vert ^2\ge \cdots \ge \mathbb {E} \Vert h_{0}\Vert ^2=1\) by iteratively taking conditional expectation. Then combined with (64) and the choice of \(\tau =L^{-\frac{1}{2}+c}\), we have \(\mathbb {E}\Vert h_{L-1}\Vert ^2 \ge \frac{1}{2}L^{2c}\). Because \(({\varvec{W}}_L)_{i,j}\sim \mathcal {N}(0,2/m)\) and \(h_L=\phi ({\varvec{W}}_L h_{L-1})\), we have \(\mathbb {E} \Vert h_L\Vert ^2 = \Vert h_{L-1}\Vert ^2\). Thus, the claim is proved. \(\square\)

Proof

From the inequality (62) in the previous proof, we know for any \(k\in [m]\) and \(l\in [L-1]\),

$$\begin{aligned} (h_{l})_{k}\ge \phi \left( \sum _{a=1}^{l}\left( \tilde{z}_{a}\right) _{k}\right) . \end{aligned}$$
(65)

Next we can compute the mean and variance of \(\sum _{a=1}^{l}\left( \tilde{z}_a\right) _{k}\) by taking iterative conditioning. We have

$$\begin{aligned} \mathbb {E} \sum _{a=1}^{l}\left( \tilde{z}_a\right) _{k}=0,\quad \mathbb {E} \left( \sum _{a=1}^{l}\left( \tilde{z}_a\right) _{k}\right) ^2=\sum _{a=1}^l\mathbb {E}((\tilde{z}_a)_k)^2= l. \end{aligned}$$
(66)

Then we have

$$\begin{aligned} \mathbb {E} \Vert h_{l}\Vert ^2\ge \sum _{k=1}^{m}\mathbb {E}\left( \phi \left( \sum _{a=1}^{l}\left( \tilde{z}_a\right) _{k}\right) \right) ^{2} = \frac{1}{2}\sum _{k=1}^{m}\mathbb {E}\left[ \sum _{a=1}^{l}\left( \tilde{z}_a\right) _{k}\right] ^2 = \frac{1}{2}ml, \end{aligned}$$
(67)

where the first step is due to (62), the second step is due to the symmetry of random variable \((\tilde{z}_a)_k\) and the third step is due to (66). The proposition is proved. \(\square\)

H More empirical studies

Fig. 6
figure 6

Validation accuracy on CIFAR10 of ResNets with different choices of \(\tau\) (\(\tau =1/L\), \(\tau =1/\sqrt{L}\), \(\tau =1/L^{1/4}\))

Table 3 Validation accuracy of ResNet110+\(\tau\) with different learning rates

We do more experiments to demonstrate the points in Sect. 5.

Besides the basic feedforward structure in Sect. 5.1, we do another experiment to demonstrate that \(\tau =1/\sqrt{L}\) is sharp with practical structures (see Fig. 6). We can see that for ResNet110 and ResNet1202, \(\tau =1/L^{1/4}\) cannot train the network effectively.

One may wonder if we can tune the learning rate for the case of \(\tau =1/L\) to achieve validation accuracy as well as the case of \(\tau =1/\sqrt{L}\). We do a new experiment to verify this (see Table 3). Specifically, for ResNet110 with fixed \(\tau =1/L\) and \(\tau =1/\sqrt{L}\) on CIFAR10 classification task, we tune the learning rate from 0.1 to 1.6 and record the validation accuracy in Table 3. We can see that ResNet110 with \(\tau =1/L\) performs inferior to that with \(\tau =1/\sqrt{L}\) even with grid search of learning rates. It is possible that we can achieve a bit better performance by adjusting the learning rate for \(\tau =1/L\). But this requires tuning for each depth. In contrast, we have shown that with \(\tau =1/\sqrt{L}\), one learning rate fits for all nets with different depths.

Rights and permissions

Reprints and Permissions

About this article

Verify currency and authenticity via CrossMark

Cite this article

Zhang, H., Yu, D., Yi, M. et al. Stabilize deep ResNet with a sharp scaling factor \(\tau\). Mach Learn 111, 3359–3392 (2022). https://doi.org/10.1007/s10994-022-06192-x

Download citation

  • Received:

  • Revised:

  • Accepted:

  • Published:

  • Issue Date:

  • DOI: https://doi.org/10.1007/s10994-022-06192-x