1 Introduction

Deep neural networks (DNNs) have been the main driving force for the recent wave in artificial intelligence (AI). They have achieved remarkable success in a number of domains including computer vision [14, 19], reinforcement learning [18, 23] and natural language processing [4], to name a few. However, due to the huge number of model parameters, the deployment of DNNs can be computationally and memory intensive. As such, it remains a great challenge to deploy DNNs on mobile electronics with low computational budget and limited memory storage.

Recent efforts have been made to the quantization of weights and activations of DNNs while in the hope of maintaining the accuracy. More specifically, quantization techniques constrain the weights or/and activation values to low-precision arithmetic (e.g., 4-bit) instead of using the conventional floating-point (32-bit) representation [2, 12, 17, 31,32,33]. In this way, the inference of quantized DNNs translates to hardware-friendly low-bit computations rather than floating-point operations. That being said, quantization brings three critical benefits for AI systems: energy efficiency, memory savings and inference acceleration.

The approximation power of weight quantized DNNs was investigated in [6, 8], while the recent paper [22] studies the approximation power of DNNs with discretized activations. On the computational side, training quantized DNNs typically calls for solving a large-scale optimization problem, yet with extra computational and mathematical challenges. Although people often quantize both the weights and activations of DNNs, they can be viewed as two relatively independent subproblems. Weight quantization basically introduces an additional set constraint that characterizes the quantized model parameters, which can be efficiently carried out by projected gradient-type methods [5, 10, 15, 16, 28, 30]. Activation quantization (i.e., quantizing ReLU), on the other hand, involves a staircase activation function with zero derivative almost everywhere (a.e.) in place of the subdifferentiable ReLU. Therefore, the resulting composite loss function is piecewise constant and cannot be minimized via the (stochastic) gradient method due to the vanished gradient.

To overcome this issue, a simple and hardware-friendly approach is to use a straight-through estimator (STE) [1, 9, 26]. More precisely, one can replace the a.e. zero derivative of quantized ReLU with an ad hoc surrogate in the backward pass, while keeping the original quantized function during the forward pass. Mathematically, STE gives rise to a biased first-order oracle computed by an unusual chain rule. This first-order oracle is not the gradient of the original loss function because there exists a mismatch between the forward and backward passes. Throughout this paper, this STE-induced type of “gradient” is called coarse gradient. While coarse gradient is not the true gradient, in practice it works as it miraculously points toward a descent direction (see [26] for a thorough study in the regression setting). Moreover, coarse gradient has the same computational complexity as standard gradient. Just like the standard gradient descent, the minimization procedure of training activation quantized networks simply proceeds by repeatedly moving one step at current point in the opposite of coarse gradient with some step size. The performance of the resulting coarse gradient method, e.g., convergence property, naturally relies on the choice of STE. How to choose a proper STE so that the resulting training algorithm is provably convergent is still poorly understood, especially in the nonlinear classification setting.

1.1 Related works

The idea of STE dated back to the classical perceptron algorithm [20, 21] for binary classification. Specifically, the perceptron algorithm attempts to solve the empirical risk minimization problem:

$$\begin{aligned} \min _{{{\varvec{w}}}} \; \sum _{i=1}^N (\text{ sign }({{\varvec{x}}}_i^{\top }{{\varvec{w}}}) - y_i)^2, \end{aligned}$$
(1.1)

where \(({{\varvec{x}}}_i, y_i)\) is the \(i^{\mathrm {th}}\) training sample with \(y_i\in \{\pm 1\}\) being a binary label; for a given input \({{\varvec{x}}}_i\), the single-layer perceptron model with weights \({{\varvec{w}}}\) outputs the class prediction \(\text{ sign }({{\varvec{x}}}_i^{\top }{{\varvec{w}}})\). To train perceptrons, Rosenblatt [20] proposed the following iteration for solving (1.1) with the step size \(\eta >0\):

$$\begin{aligned} {{\varvec{w}}}^{t+1} = {{\varvec{w}}}^{t} - \eta \sum _{i=1}^N (\text{ sign }({{\varvec{x}}}_i^{\top }{{\varvec{w}}}^t) - y_i)\cdot {{\varvec{x}}}_i \end{aligned}$$
(1.2)

We note that the above perceptron algorithm is not the same as gradient descent algorithm. Assuming the differentiability, the standard chain rule computes the gradient of the \(i^{\mathrm {th}}\) sample loss function by

$$\begin{aligned} (\text{ sign }({{\varvec{x}}}_i^{\top }{{\varvec{w}}}^t) - y_i)\cdot (\text{ sign})^\prime ({{\varvec{x}}}_i^{\top }{{\varvec{w}}}^t)\cdot {{\varvec{x}}}_i. \end{aligned}$$
(1.3)

Comparing (1.3) with (1.2), we observe that the perceptron algorithm essentially uses a coarse (and fake) gradient as if \((\text{ sign})^\prime \) composited in the chain rule was the derivative of identity function being the constant 1.

The idea of STE was extended to train deep networks with binary activations [9]. Successful experimental results have demonstrated the effectiveness of the empirical STE approach. For example, [1] proposed a STE variant which uses the derivative of sigmoid function instead of identity function. [11] used the derivative of hard tanh function, i.e., \(1_{\{|x|\le 1\}}\), as an STE in training binarized neural networks. To achieve less accuracy degradation, STE was later employed to train DNNs with quantized activations at higher bit-widths [2, 3, 12, 29, 32], where some other STEs were proposed including the derivatives of standard ReLU (\(\max \{x, 0\}\)) and clipped ReLU (\(\min \{\max \{x, 0\}, 1\}\)).

Regarding the theoretical justification, it has been established that the perceptron algorithm in (1.2) with identity STE converges and perfectly classifies linearly separable data; see, for example, [7, 25] and references therein. Apart from that, to our knowledge, there had been almost no theoretical justification of STE until recently: [26] considered a two-linear-layer network with binary activation for regression problems. The training data are assumed to be instead linearly non-separable, being generated by some underlying model with true parameters. In this setting, [26] proved that the working STE is actually non-unique and that the coarse gradient algorithm is descent and converges to a valid critical point if choosing the STE to be the proxy derivative of either ReLU (i.e., \(\max \{x, 0\}\)) or clipped ReLU function (i.e., \(\min \{\max \{x, 0\}, 1\}\)). Moreover, they proved that the identity STE fails to give a convergent algorithm for learning two-layer networks, although it works for single-layer perception.

Fig. 1
figure 1

Quantized activation functions. \(\tau \) is a value determined in the network training; see Sect. 8.2

1.2 Main contributions

Figure 1 shows examples of 1-bit (binary) and 2-bit (ternary) activations. We see that a quantized activation function zeros out any negative input, while being increasing on the positive half. Intuitively, a working surrogate of the quantized function used in backward pass should also enjoy this monotonicity, as conjectured by [26] which proved the effectiveness of coarse gradient for two specific STEs: derivatives of ReLU and clipped ReLU, and for binarized activation. In this work, we take a further step toward understanding the convergence of coarse gradient methods for training networks with general quantized activations and for classification of linearly non-separable data. A major analytical challenge we face here is that the network loss function is not in closed analytical form, in sharp contrast to [26]. We present more general results to provide meaningful guidance on how to choose STE in activation quantization. Specifically, we study multi-category classification of linearly non-separable data by a two-linear-layer network with multi-bit activations and hinge loss function. We establish the convergence of coarse gradient methods for a broad class of surrogate functions. More precisely, if a function \(g:{\mathbb {R}}\rightarrow {\mathbb {R}}\) satisfies the following properties:

  • \(g(x) = 0\) for all \(x\le 0\),

  • \(g'(x) \ge \delta >0\) for all \(x>0\) with some constant \(\delta \),

then with proper learning rate, the corresponding coarse gradient method converges and perfectly classifies the nonlinear data when \(g^\prime \) serves as the STE during the backward pass. This gives the affirmation of a conjecture in [26] regarding good choices of STE for a classification (rather than regression) task under weaker data assumptions, e.g., allowing non-Gaussian distributions.

1.3 Notations

We have Table 1 for notations used in this paper.

Table 1 Frequently used notations

2 Problem setup

2.1 Data assumptions

In this section, we consider the n-ary classification problem in the d-dimensional space \({\mathcal {X}}={\mathbb {R}}^{d}\). Let \({\mathcal {Y}}=[n]\) be the set of labels, and for \(i\in [n]\) let \({\mathcal {D}}_i\) be probabilistic distributions over \({\mathcal {X}}\times {\mathcal {Y}}\). Throughout this paper, we make the following assumptions on the data:

  1. 1.

    (Separability) There are n orthogonal subspaces \(V_i \subseteq {\mathcal {X}}\), \(i\in [n]\) where \(\dim V_i=d_i\), such that

    $$\begin{aligned} \mathop {{\mathbb {P}}}_{\{{\varvec{x}},y\}\sim {\mathcal {D}}_i}\left[ {\varvec{x}}\in {\mathcal {V}}_i\text { and }y=i\right] =1, \; \text{ for } \text{ all } i \in [n]. \end{aligned}$$
  2. 2.

    (Boundedness of data) There exist positive constants m and M, such that

    $$\begin{aligned} \mathop {{\mathbb {P}}}_{\{{\varvec{x}},y\}\sim {\mathcal {D}}_i}\left[ m<\left| {\varvec{x}}\right| <M\right] =1, \; \text{ for } \text{ all } i \in [n]. \end{aligned}$$
  3. 3.

    (Boundedness of p.d.f.) For \(i\in [n]\), let \(p_i\) be the marginal probability distribution function of \({\mathcal {D}}_i\) on \({\mathcal {V}}_i\). For any \({{\varvec{x}}}\in {\mathcal {V}}_i\) with \(m<\left| {\varvec{x}}\right| <M\), it holds that

    $$\begin{aligned} 0< p_i( {{\varvec{x}}})<p_{\text {max}}<\infty . \end{aligned}$$

Later on, we denote \({\mathcal {D}}\) to be the evenly mixed distribution of \({\mathcal {D}}_i\) for \(i\in [n]\).

Remark 1

The orthogonality of subspaces \({\mathcal {V}}_i\)’s in the data assumption (1) above is technically needed for our proof here. However, the convergence in Theorem 3.1 to a perfect classification with random initialization is observed in more general settings when \({\mathcal {V}}_i\)’s form acute angles and contain a certain level of noise. We refer to Sect. 8.1 for supporting experimental results.

Remark 2

Assumption (3) can be relaxed to the following, while the proof remains basically the same.

\({\mathcal {D}}_i\) is a mixture of \(n_i\) distributions, namely \({\mathcal {D}}_{i,j}\) for \(j\in [n_i]\). There exists a linear decomposition of \({\mathcal {V}}_i=\bigoplus _{j=1}^{n_i}{\mathcal {V}}_{i,j}\) and \({\mathcal {D}}_{i,j}\); each has a marginal probability distribution function \(p_{i,j}\) on \({\mathcal {V}}_{i,j}\). For any \({{\varvec{x}}}\in {\mathcal {V}}_{i,j}\) and \(<m<|{{\varvec{x}}}|<M\), it holds that

$$\begin{aligned} 0<p_{i,j}({{\varvec{x}}})\le p_{\text {max}}<\infty . \end{aligned}$$

2.2 Network architecture

We consider a two-layer neural architecture with k hidden neurons. Denote by \({{\varvec{W}}}=\left[ {\varvec{w}}_1,\ldots ,{\varvec{w}}_{k}\right] \in {\mathbb {R}}^{d\times k}\) the weight matrix in the hidden layer. Let

$$\begin{aligned} h_j=\left\langle {\varvec{w}}_j,{\varvec{x}}\right\rangle \end{aligned}$$

the input to the activation function, or the so-called pre-activation. Throughout this paper, we make the following assumptions:

Assumption 1

The weight matrix in the second layer \({\varvec{V}}=[{\varvec{v}}_1,\ldots ,{\varvec{v}}_n]\) is fixed and known in the training process and satisfies:

  1. 1.

    For any \(i\in [n]\), there exists some \(j\in [k]\) such that \(v_{i,j}>0\).

  2. 2.

    If \(v_{i,j}>0\), then for any \(r\in [n]\) and \(r\not =i\), we have \(v_{r,j}=0\).

  3. 3.

    For any \(i\in [n]\) and \(j\in [k]\), we have \(v_{i,j}<1\).

One can easily show that as long as \(k\ge n\), such a matrix \({\varvec{V}}=(v_{i,j})\) is ubiquitous.

For any input data \({\varvec{x}}\in {\mathcal {X}}={\mathbb {R}}^{d}\), the neural net output is

$$\begin{aligned} f({{\varvec{W}}};{{\varvec{x}}})=[o_1,\ldots ,o_n], \end{aligned}$$
(2.1)

where

$$\begin{aligned} o_i = \left\langle {\varvec{v}}_i,\sigma \left( {\varvec{h}}\right) \right\rangle =\sum _{j=1}^kv_{i,j}\sigma (h_j). \end{aligned}$$

The \(\sigma (\cdot )\) is the quantized ReLU function acting element-wise; see Fig. 1, for example, of binary and ternary activation functions. More general quantized ReLU function of the bit-width b can be defined as follows:

$$\begin{aligned} \sigma (x)={\left\{ \begin{array}{ll} 0 &{} \text {if} \quad x\le 0, \\ \text {ceil}(x) &{} \text {if} \quad 0<x<2^b-1, \\ 2^b -1 &{} \text {if} \quad x\ge 2^b-1. \\ \end{array}\right. } \end{aligned}$$

The prediction is given by the network output label

$$\begin{aligned} {\hat{y}}({{\varvec{W}}},{{\varvec{x}}})=\mathop {\text {argmax}}_{r\in [n]}o_r, \end{aligned}$$

ideally \({\hat{y}}({{\varvec{x}}})=i\) for all \({{\varvec{x}}}\in {\mathcal {V}}_i\). The classification accuracy in percentage is the frequency that this event occurs (when network output label \({\hat{y}}\) matches the true label) on a validation data set.

Given the data sample \(\{{{\varvec{x}}}, y\}\), the associated hinge loss function reads

$$\begin{aligned} l({{\varvec{W}}}; \{ {{\varvec{x}}}, y \}) := \max \left\{ 0, 1 - f_y\right\} :=\max \left\{ 0, 1 - \left( o_y-\max _{i\not =y}o_i\right) \right\} . \end{aligned}$$
(2.2)

To train the network with quantized activation \(\sigma \), we consider the following population loss minimization problem

$$\begin{aligned} \min _{{{\varvec{W}}}\in {\mathbb {R}}^{d\times k}}\; l\left( {{\varvec{W}}}\right) := \mathop {{\mathbb {E}}}_{\{ {{\varvec{x}}}, y \}\sim {\mathcal {D}}}\left[ l\left( {{\varvec{W}}}; \{{{\varvec{x}}}, y\}\right) \right] , \end{aligned}$$
(2.3)

where the sample loss \(l\left( {{\varvec{W}}}; \{{{\varvec{x}}}, y\}\right) \) is defined in (2.2). Let \(l_i\) be the population loss function of class i with the label \(y=i\), \(i \in [n]\). More precisely,

$$\begin{aligned} \begin{aligned} l_i({{\varvec{W}}})=&\mathop {{\mathbb {E}}}_{\{{\varvec{x}},y\}\sim {\mathcal {D}}_i}\left[ \max \left\{ 0, 1 - f_i\right\} \right] \\ =&\mathop {{\mathbb {E}}}_{\{{\varvec{x}},y\}\sim {\mathcal {D}}_i}\left[ \max \left\{ 0, 1 - \left( o_i - \max _{r\not =i}o_r\right) \right\} \right] . \end{aligned} \end{aligned}$$

Thus, we can rewrite the loss function as

$$\begin{aligned} l({{\varvec{W}}})=\frac{1}{n}\sum _{i=1}^nl_i({{\varvec{W}}}). \end{aligned}$$

Note that the population loss

$$\begin{aligned} l_i({{\varvec{W}}})=\mathop {{\mathbb {E}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left[ l({{\varvec{W}}};\{{{\varvec{x}}},y\})\right] \end{aligned}$$

fails to have simple closed-form solution even if \(p_i\) are constant functions on their supports. We do not have closed-form formula at hand to analyze the learning process, which makes our analysis challenging.

For notational convenience, we define:

$$\begin{aligned} \Omega _{{{\varvec{W}}}}&=\left\{ {{\varvec{x}}}\in {\mathcal {X}}: l({{\varvec{W}}};\{{{\varvec{x}}},y\})>0\right\} ,\\ \Omega _{{\varvec{v}}}^a&=\left\{ {{\varvec{x}}}\in {\mathcal {X}}:\left\langle {\varvec{v}},{\varvec{x}}\right\rangle >a\right\} , \end{aligned}$$

and

$$\begin{aligned} \Omega _{{{\varvec{W}}}}^j=\Omega _{{{\varvec{W}}}}\cap \Omega _{{\varvec{w}}_j}^0. \end{aligned}$$

2.3 Coarse gradient methods

We see that derivative of quantized ReLU function \(\sigma \) is a.e. zero, which gives a trivial gradient of sample loss function with respect to (w.r.t.) \({\varvec{w}}_j\). Indeed, differentiating the sample loss function with respect to \({\varvec{w}}_j\), we have

$$\begin{aligned} \nabla _{{\varvec{w}}_j} l({{\varvec{W}}};\{{{\varvec{x}}},y\})=-\left( v_{y,j}-v_{\xi ,j}\right) \,\mathbb {1}_{\Omega _{{{\varvec{W}}}}}({{\varvec{x}}})\,\sigma '\left( h_j\right) {{\varvec{x}}}= {\mathbf {0}}, \text{ a.e. }, \quad 1\le j\le k \end{aligned}$$

where \(\xi =\mathop {\text {argmax}}_{i\not =y}o_i\).

The partial coarse gradient w.r.t. \({{\varvec{w}}}_j\) associated with the sample \(\{{{\varvec{x}}}, y\}\) is given by replacing \(\sigma '\) with a straight-through estimator (STE) which is the derivative of function g, namely

$$\begin{aligned} {\tilde{\nabla }}_{{\varvec{w}}_j}l({{\varvec{W}}};\{{{\varvec{x}}},y\}) := -\left( v_{y,j}-v_{\xi ,j}\right) \,\mathbb {1}_{\Omega _{{{\varvec{W}}}}}({{\varvec{x}}})\,g'(h_j){{\varvec{x}}}. \end{aligned}$$
(2.4)

The sample coarse gradient \({\tilde{\nabla }}l({{\varvec{W}}};\{{{\varvec{x}}},y\})\) is just the concatenation of \({\tilde{\nabla }}_{{\varvec{w}}_j}l({{\varvec{W}}};\{{{\varvec{x}}},y\})\)’s. It is worth noting that coarse gradient is not an actual gradient, but some biased first-order oracle which depends on the choice of g.

Throughout this paper, we consider a class of surrogate functions during the backward pass with the following properties:

Assumption 2

\(g:{\mathbb {R}}\rightarrow {\mathbb {R}}\) satisfies

  1. 1.

    \(g(x) = 0\) for all \(x\le 0\).

  2. 2.

    \(g'(x)\in [\delta ,{{\tilde{\delta }}}]\) for all \(x>0\) with some constants \(0<\delta<{{\tilde{\delta }}}<\infty \).

Such a g is ubiquitous in quantized deep networks training; see Fig. 2, for example, of g(x) satisfying Assumption 2. Typical examples include the classical ReLU \(g(x) = \max (x, 0)\) and log-tailed ReLU [2]:

$$\begin{aligned} g(x)=\left\{ \begin{array}{lll} 0&{}\quad \text {if} &{} x\le 0 ,\\ x&{}\quad \text {if} &{} 0<x\le q_b ,\\ q_b+\log (x-q_b+1)&{}\quad \text {if} &{} x>q_b ,\\ \end{array}\right. \end{aligned}$$

where \(q_b := 2^b-1\) is the maximum quantization level. In addition, if the input of the activation function is bounded by a constant, one also can use \(g(x)=\max \{0,q_b (1- e^{-x/q_b})\}\), which we call reverse exponential STE.

Fig. 2
figure 2

Different choices of g(x) for the straight-through estimator

To train the network with quantized activation \(\sigma \), we use the expectation of coarse gradient over training samples:

$$\begin{aligned} {\tilde{\nabla }} l({{\varvec{W}}}): = \mathop {{\mathbb {E}}}_{\{ {{\varvec{x}}}, y \}\sim {\mathcal {D}}} {\tilde{\nabla }}l({{\varvec{W}}};\{{{\varvec{x}}},y\}), \end{aligned}$$

where \({\tilde{\nabla }}l({{\varvec{W}}};\{{{\varvec{x}}},y\})\) is given by (2.4). In this paper, we study the convergence of coarse gradient algorithm for solving the minimization problem (2.3), which takes the following iteration with some learning rate \(\eta >0\):

$$\begin{aligned} {{\varvec{W}}}^{t+1}= {{\varvec{W}}}^{t}-\eta \,{\tilde{\nabla }} l({{\varvec{W}}}^{t}). \end{aligned}$$
(2.5)

3 Main result and outline of proof

We show that if the iterates \(\{{{\varvec{W}}}^t\}\) are uniformly bounded in t, coarse gradient decent with the proxy function g under Assumption 2 converges to a global minimizer of the population loss, resulting in a perfect classification.

Theorem 3.1

Suppose data assumptions (1)–(3) and STE Assumptions  12 hold. If the network initialization satisfies \({{\varvec{w}}}_{j,i}^0\not =0\) for all \(j\in [k]\) and \(i\in [n]\) and \({{\varvec{W}}}^t\) is uniformly bounded by R in t, then for all \(v_{i,j}>0\) we have

$$\begin{aligned} \lim _{t\rightarrow \infty }\left| {\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}}^t)\right| =0. \end{aligned}$$

Furthermore, if \({{\varvec{W}}}^\infty \) is an accumulation point of \(\{{{\varvec{W}}}^t\}\) and all nonzero unit vectors \(\tilde{{\varvec{w}}}_{j,i}^\infty \)’s are distinct for all \(j\in [k]\) and \(i\in [n]\), then

$$\begin{aligned} \mathop {{\mathbb {P}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}}\left( {\hat{y}}\left( {{\varvec{W}}}^{\infty },{{\varvec{x}}}\right) \ne y\right) =0. \end{aligned}$$

We outline the major steps in the proof below.

Step 1: Decompose the population loss into n components. Recall the definition of \(l_i\) which is population loss functions for \(\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i\). In Sect. 4, we show under certain decomposition of \({{\varvec{W}}}\), the coarse gradient decent of each one of them only affects a corresponding component of \({{\varvec{W}}}\).

Step 2: Bound the total increment of weight norm from above. Show that for all \(v_{i,j}>0\), \(|{\varvec{w}}_{j,i}|\)’s are monotonically increasing under coarse gradient descent. Based on boundedness on \({{\varvec{W}}}\), we further give an upper bound on the total increment of all \(|{\varvec{w}}_j|\)’s, from which the convergence of coarse gradient descent follows.

Step 3: Show that when the coarse gradient vanishes, so does the population loss. In Sect. 6, we show that when the coarse gradient vanishes toward the end of training, the population loss is zero which implies a perfect classification.

4 Space decomposition

With \({\mathcal {V}}=\bigoplus _{i=1}^n {\mathcal {V}}_i\), we have the orthogonal complement of \({\mathcal {V}}\) in \({\mathcal {X}}={\mathbb {R}}^d\), namely \({\mathcal {V}}_{n+1}\). Now, we can decompose \({\mathcal {X}}={\mathbb {R}}^d\) into \(n+1\) linearly independent parts:

$$\begin{aligned} {\mathbb {R}}^d={\mathcal {V}}\bigoplus {\mathcal {V}}_{n+1}=\bigoplus _{i=1}^{n+1}{\mathcal {V}}_i \end{aligned}$$

and for any vector \({\varvec{w}}_j\in {\mathbb {R}}^d\), we have a unique decomposition of \({\varvec{w}}_j\):

$$\begin{aligned} {{\varvec{w}}}_j=\sum _{i=1}^{n+1}{{\varvec{w}}}_{j,i}, \end{aligned}$$

where \({{\varvec{w}}}_{j,i}\in {\mathcal {V}}_i\) for \(i\in [n+1]\). To simply notation, we let

$$\begin{aligned} {{\varvec{W}}}_i=\left[ {{\varvec{w}}}_{1,i},\ldots ,{{\varvec{w}}}_{k,i}\right] . \end{aligned}$$

Lemma 4.1

For any \({{\varvec{W}}}\in {\mathbb {R}}^{k\times d}\) and \(i\in [n]\), we have

$$\begin{aligned} l_i\left( {{\varvec{W}}}\right) =l_i\left( \sum _{r=1}^n{{\varvec{W}}}_r\right) =l_i({\varvec{W}}_i). \end{aligned}$$

Proof

Note that for any \({{\varvec{x}}}\in {\mathcal {V}}_i\) and \(j\in [k]\), we have \({{\varvec{x}}}\in {\mathcal {V}}\), so

$$\begin{aligned} \left\langle {{\varvec{w}}}_{j,n+1},{{\varvec{x}}}\right\rangle =0 \end{aligned}$$

and

$$\begin{aligned} h_j=\left\langle {{\varvec{w}}}_{j},{{\varvec{x}}}\right\rangle =\left\langle \sum _{j=1}^k{{\varvec{w}}}_{j,i},{{\varvec{x}}}\right\rangle =\left\langle {\varvec{w}}_{j,i},{{\varvec{x}}}\right\rangle . \end{aligned}$$

Hence,

$$\begin{aligned} f\left( {{\varvec{W}}};{{\varvec{x}}}\right) =f\left( \sum _{j=1}^k{{\varvec{W}}}_i;{{\varvec{x}}}\right) =f\left( {\varvec{W}}_i\right) \end{aligned}$$

for all \({{\varvec{W}}}\in {\mathbb {R}}^{d\times k}\), \({{\varvec{x}}}\in {\mathcal {V}}_i\). The desired result follows. \(\square \)

Lemma 4.2

Running the algorithm (2.5) on \(l_i\) only does not change the value of \({{\varvec{W}}}_r\) for all \(r\not =i\). More precisely, for any \({{\varvec{W}}}\in {\mathbb {R}}^{d\times k}\), let

$$\begin{aligned} {{\varvec{W}}}'={{\varvec{W}}}-\eta {\tilde{\nabla }}l_i({{\varvec{W}}}), \end{aligned}$$

then for any \(r\in [n]\) and \(r\not =i\)

$$\begin{aligned} {{\varvec{W}}}_r'={{\varvec{W}}}_r. \end{aligned}$$

Proof of Lemma 4.2

Assume \(i,r\in [n]\) and \(i\not =r\). Note that

$$\begin{aligned} {{\varvec{w}}}_j'={{\varvec{w}}}_j-\eta {\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}}) \end{aligned}$$

and

$$\begin{aligned} {\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}})=-\mathop {{\mathbb {E}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left[ \left( v_{y,j}-v_{\xi ,j}\right) \,\mathbb {1}_{\Omega _{{{\varvec{W}}}}}({{\varvec{x}}})\,g'(h_j){{\varvec{x}}}\right] \in V_i. \end{aligned}$$

Since \({\mathcal {V}}_i\)’s are linearly independent, we have

$$\begin{aligned} {{\varvec{w}}}_{j,i}'={{\varvec{w}}}_{j,i}-\eta {\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}}) \end{aligned}$$

and

$$\begin{aligned} {\varvec{w}}_{j,r}'={\varvec{w}}_{j,r}. \end{aligned}$$

\(\square \)

By the above result, we know (2.5) is equivalent to

$$\begin{aligned} {{\varvec{W}}}_i^{t+1} = {{\varvec{W}}}_i^t - \frac{\eta }{n} {\tilde{\nabla }} l_i\left( {{\varvec{W}}}^t\right) . \end{aligned}$$
(4.1)

5 Learning dynamics

In this section, we show that some components of the weight iterates have strictly increasing magnitude whenever coarse gradient does not vanish, and it quantifies the increment during each iteration.

Lemma 5.1

Assume

$$\begin{aligned} {\hat{v}}_j=\max _{i_1,i_2\in [n]}v_{i_1,j}-v_{i_2,j}\,, \end{aligned}$$

we have the following estimate:

$$\begin{aligned} \mathop {{\mathbb {P}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left( \Omega _W^j\right) \ge \frac{1}{{\hat{v}}_j{\tilde{\delta }}M}\left| {\tilde{\nabla }}_{{\varvec{w}}_j}l_i\left( {{\varvec{W}}}\right) \right| . \end{aligned}$$

Proof of Lemma 5.1

$$\begin{aligned} \left| {\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}})\right| =&\left| \mathop {{\mathbb {E}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left[ \left( v_{y,j}-v_{\xi ,j}\right) \,\mathbb {1}_{\Omega _{{{\varvec{W}}}}}({{\varvec{x}}})\,g'(h_j){{\varvec{x}}}\right] \right| \\ \le&{\hat{v}}_j{\tilde{\delta }}M\mathop {{\mathbb {E}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left[ \mathbb {1}_{\Omega ^j_{{{\varvec{W}}}}}({{\varvec{x}}})\right] \\ =&{\hat{v}}_j{\tilde{\delta }}M\mathop {{\mathbb {P}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left( \Omega _{{{\varvec{W}}}}^j\right) \end{aligned}$$

\(\square \)

Lemma 5.2

For any \(j\in [k]\), if

$$\begin{aligned} {\tilde{v}}_{i,j}:=v_{i,j}-\max _{r\not =i}v_{r,j}>0 \end{aligned}$$

we have

$$\begin{aligned} \left\langle {\tilde{{{\varvec{w}}}}}_{j,i},-{\tilde{\nabla }}_{{{\varvec{w}}}_j}l_i({{\varvec{W}}})\right\rangle \ge \frac{{\tilde{v}}_{i,j}\delta }{2C_p}\mathop {{\mathbb {P}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left( \Omega _W^j\right) ^2, \end{aligned}$$

where

$$\begin{aligned} C_p=\max _{{\varvec{v}}\in V_i,a\in {\mathbb {R}}}\int _{\langle {\varvec{v}}, {\varvec{x}}\rangle =a}p_i({{\varvec{x}}})\;d\,{\mathcal {H}}^{d_i-1}({{\varvec{x}}}). \end{aligned}$$

Proof of Lemma 5.2

First, we prove an inequality which will be used later. Recall that \(|{{\varvec{x}}}|\le M\), and that \({\tilde{\nabla }}_{{\varvec{w}}_j}l({{\varvec{W}}},\{{{\varvec{x}}},y\})\not =0\) only when \({\varvec{x}}\in \Omega _{{{\varvec{W}}}}^j\). Hence, we have \(\left\langle \tilde{{\varvec{w}}}_{j,i},{{\varvec{x}}}\right\rangle >0\). We have

$$\begin{aligned} \mathop {{\mathbb {P}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left( \Omega _{{{\varvec{W}}}}^j\cap \left\{ {{\varvec{x}}}:\left\langle {\tilde{{{\varvec{w}}}}}_{j,i},{{\varvec{x}}}\right\rangle<t\right\} \right) =&\int _{\Omega _{{\varvec{W}}}^j}\mathbb {1}_{\left\{ \left\langle {\tilde{{{\varvec{w}}}}}_{j,i},{{\varvec{x}}}\right\rangle <t\right\} }({{\varvec{x}}})p_i({{\varvec{x}}})\;d\,{{\varvec{x}}}\\ =&\int _0^t\int _{\left\langle {\tilde{{{\varvec{w}}}}}_{j,i},{{\varvec{x}}}\right\rangle =s}p_i({{\varvec{x}}})\;d\,{\mathcal {H}}^{d_i-1}({{\varvec{x}}})\;d\,s \\ \le&t\ C_p. \end{aligned}$$

Now, we use Fubini’s theorem to simplify the inner product:

$$\begin{aligned} \left\langle \tilde{{\varvec{w}}}_{j,i},-{\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}})\right\rangle =&\mathop {{\mathbb {E}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left[ \left( v_{y,j}-v_{\xi ,j}\right) \mathbb {1}_{\Omega _{{{\varvec{W}}}}^j}({{\varvec{x}}})\,g'(h_j)\,\langle \tilde{{\varvec{w}}}_{j,i},{{\varvec{x}}}\rangle \right] \\ \ge&{\tilde{v}}_{i,j}\,\delta \int _{\Omega _{{{\varvec{W}}}}^j\cap V_i}\langle \tilde{{\varvec{w}}}_{j,i},{{\varvec{x}}}\rangle p_i({{\varvec{x}}})\;d\,{{\varvec{x}}}\\ =&{\tilde{v}}_{i,j}\,\delta \int _{\Omega _{{{\varvec{W}}}}^j\cap V_i}\int _0^\infty \mathbb {1}_{\left\{ \langle {\tilde{{{\varvec{w}}}}}_{j,i},{{\varvec{x}}}\rangle>t\right\} }\;d\,t\;p_i({{\varvec{x}}})\;d\,{{\varvec{x}}}\\ =&{\tilde{v}}_{i,j}\,\delta \int _0^\infty \int _{\Omega _{{{\varvec{W}}}}^j\cap V_i}\mathbb {1}_{\left\{ \langle {\tilde{{{\varvec{w}}}}}_{j,i},{{\varvec{x}}}\rangle>t\right\} }\;p_i({{\varvec{x}}})\;d\,{{\varvec{x}}}\;d\,t\\ =&{\tilde{v}}_{i,j}\,\delta \int _0^\infty \mathop {{\mathbb {P}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left( \Omega _{{{\varvec{W}}}}^j\cap \left\{ {{\varvec{x}}}:\langle \tilde{{\varvec{w}}}_{j,i},{{\varvec{x}}}\rangle >t\right\} \right) d\,t. \end{aligned}$$

Now, using the inequality just proved above, we have

$$\begin{aligned}&\mathop {{\mathbb {P}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left( \Omega _{{{\varvec{W}}}}^j\cap \left\{ {{\varvec{x}}}:\langle \tilde{{\varvec{w}}}_{j,i},{{\varvec{x}}}\rangle >t\right\} \right) \\ =&\mathop {{\mathbb {P}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left( \Omega _{{\varvec{W}}}^j\right) -\mathop {{\mathbb {P}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left( \Omega _{{{\varvec{W}}}}^j\cap \left\{ {{\varvec{x}}}:\langle \tilde{{\varvec{w}}}_{j,i},{{\varvec{x}}}\rangle <t\right\} \right) \\ \ge&\max \left\{ \mathop {{\mathbb {P}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left( \Omega _{{\varvec{W}}}^j\right) -t\;C_p,0\right\} . \end{aligned}$$

Combining the above two inequalities, we have

$$\begin{aligned} \left\langle \tilde{{\varvec{w}}}_{j,i},-{\tilde{\nabla }}_{{\varvec{w}}_{j}}l_i({{\varvec{W}}})\right\rangle \ge&{\tilde{v}}_{i,j}\,\delta \int _0^\infty \max \left\{ \mathop {{\mathbb {P}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left( \Omega _{{\varvec{W}}}^j\right) -t\;C_p,0\right\} \;d\,t\\ \ge&\frac{{\tilde{v}}_{i,j}\,\delta }{2C_p}\mathop {{\mathbb {P}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left( \Omega _{{{\varvec{W}}}}^j\right) ^2. \end{aligned}$$

\(\square \)

Lemma 5.3

If \({\tilde{v}}_{i,j}>0\) in Lemma 5.2, then \(\{|{\varvec{w}}_{j,i}^{t}|\}\) in Eq. (2.1) is non-decreasing with coarse gradient decent (2.5). Moreover, under the same assumption, we have

$$\begin{aligned} \left| {\varvec{w}}_{j,i}^{t+1}\right| -\left| {\varvec{w}}_{j,i}^t\right| \ge \frac{ \eta {\tilde{v}}_{i,j}\delta }{2nC_p {\hat{v}}_j^2{\tilde{\delta }}^2M^2}\left| {\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}}^t)\right| ^2, \end{aligned}$$

where \(C_p\) is defined as in Lemma 5.2 and \({\hat{v}}_j\) as in Lemma 5.1.

Proof of Lemma 5.3

Since \({\varvec{w}}_{j,i}^{t+1}={\varvec{w}}_{j,i}^{t}-\frac{\eta }{n}{\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}}^t)\), we have

$$\begin{aligned} \left| {\varvec{w}}_{j,i}^{t+1}\right| -\left| {\varvec{w}}_{j,i}^t\right| \ge \left\langle {\varvec{w}}_{j,i}^{t+1}-{\varvec{w}}_{j,i}^t,\tilde{{\varvec{w}}}_{j,i}^{t}\right\rangle =\left\langle -\frac{\eta }{n}{\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}}^t),\tilde{{\varvec{w}}}_{j,i}^t\right\rangle . \end{aligned}$$

Hence, it follows from Lemmas 5.1 and 5.2 that

$$\begin{aligned} \left| {\varvec{w}}_{j,i}^{t+1}\right| -\left| {\varvec{w}}_{j,i}^t\right| \ge \frac{ \eta {\tilde{v}}_{i,j}\delta }{2nC_p {\hat{v}}_j^2{\tilde{\delta }}^2M^2}\left| {\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}}^t)\right| ^2, \end{aligned}$$
(5.1)

which is the desired result. \(\square \)

Note that one component of \({\varvec{w}}_j\) is increasing but the weights are bounded by assumption, hence, summation of the increments over all steps should also be bounded. This gives the following proposition:

Proposition 1

Assume \(\{|{\varvec{w}}_j^t|\}\) is bounded by R, then if \({\tilde{v}}_{i,j}>0\) in Lemma 5.2, then

$$\begin{aligned} \sum _{t=1}^\infty \left| {\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}}^t)\right| ^2\le \frac{2nC_p {\hat{v}}_j^2{\tilde{\delta }}^2M^2R}{ \eta {\tilde{v}}_{i,j}\delta }<\infty , \end{aligned}$$

where \(C_p\) is as defined in Lemma 5.2 and \({\hat{v}}_j\) defined in Lemma 5.1. This implies that

$$\begin{aligned} \lim _{t\rightarrow \infty } \left| {\tilde{\nabla }}_{{{\varvec{w}}}_j}l_i({{\varvec{W}}}^t)\right| = 0 \end{aligned}$$

as long as \({\tilde{v}}_{i,j}>0\).

Remark 3

Lemmas 5.15.25.3 and Proposition 1 were proved without Assumption 1. Under Assumption 1, we have \({\hat{v}}_j=\max _{i\in [n]}v_{i,j}\) in Lemma 5.1 and \({\tilde{v}}_{i,j}={\hat{v}}_j\) if \(v_{i,j}>0\) and \({\tilde{v}}_{i,j}=-{\hat{v}}_j\) if \(v_{i,j}=0\) in Lemma 5.2.

6 Landscape properties

We have shown that under boundedness assumptions, the algorithm will converge to some point where the coarse gradient vanishes. However, this does not immediately indicate the convergence to a valid point because coarse gradient is a fake gradient. We will need the following lemma to prove Proposition 2, which confirms that the points with zero coarse gradient are indeed global minima.

Lemma 6.1

Let \(\Omega =\left\{ {{\varvec{x}}}\in {\mathbb {R}}^l:m<|{{\varvec{x}}}|<M\right\} \), where \(0<m<M<\infty \). For \(j\in [k]\), let \(\Omega _j=\left\{ {{\varvec{x}}}:\langle {{\varvec{w}}}_j,{{\varvec{x}}}\rangle >a\right\} \), where \(a\ge 0\) and \(\Omega _i\not =\Omega _j\) for all \(i\not =j\). If for \(i\in [k]\) and \({{\varvec{x}}}\in \Omega _i\cap \Omega \), there exists some \(j\not =i\) such that \({{\varvec{x}}}\in \Omega _j\), then

$$\begin{aligned} \left( \mathop {\cup }_{j=1}^k\Omega _{j}\right) \cap \Omega =\emptyset \ \text { or }\ \Omega . \end{aligned}$$

Proof (Proof of Lemma 6.1)

Define \({\tilde{\Omega }}=\bigcup _{j=1}^k\Omega _j\), by De Morgan’s law, we have

$$\begin{aligned} {\tilde{\Omega }}^c=\left( \mathop {\cup }_{j=1}^k\Omega _j\right) ^c=\mathop {\cap }_{j=1}^k\Omega _j^c. \end{aligned}$$

Note that k is finite and \({\varvec{0}}\in \Omega _j^c\) for all \(j\in [k]\), we know \({\tilde{\Omega }}^c\) is a generalized polyhedron and hence either

$$\begin{aligned} \left( \partial {\tilde{\Omega }}\right) \cap \Omega =\emptyset \end{aligned}$$

or

$$\begin{aligned} {\mathcal {H}}^{l-1}\left( \left( \partial {\tilde{\Omega }}\right) \cap \Omega \right) >0. \end{aligned}$$

The first case is trivial. We show that the second case contradicts our assumption. Note that

$$\begin{aligned} \partial {\tilde{\Omega }}=\partial \left( \mathop {\cup }_{j=1}^k\Omega _j\right) \subseteq \mathop {\cup }_{j=1}^k\partial \Omega _j, \end{aligned}$$

we know there exists some \(j^\star \in [k]\) such that \({\mathcal {H}}^{l-1}\left( \partial \Omega _{j^\star }\cap \Omega \right) >0.\) It follows from our assumption that \({\tilde{\Omega }}=\mathop {\cup }_{j=1}^k\Omega _j=\mathop {\cup }_{j\not =j^\star }\Omega _j\), and hence,

$$\begin{aligned} {\mathcal {H}}^{l-1}\left( \partial \Omega _{j^\star }\cap \partial \Omega _j\right) >0. \end{aligned}$$

Note that \(\partial \Omega _j\)’s are hyperplanes. Therefore, \(\Omega _j=\Omega _{j^\star }\), contradicting with our assumption that all \(\Omega _j\)’s are distinct. \(\square \)

The following result shows that the coarse gradient vanishes only at a global minimizer with zero loss, except for some degenerate cases.

Proposition 2

Under Assumption 1, if \({\tilde{\nabla }}_{{\varvec{w}}_j} l_i({{\varvec{W}}})={\varvec{0}}\) for all \({\tilde{v}}_{i,j}>0\) and \(\tilde{{\varvec{w}}}_{j,i}\)’s are distinct, then \(l_i({{\varvec{W}}})=0\).

Proof of Proposition 2

For quantized ReLU function, let \(q_b :=\max \limits _{x\in {\mathbb {R}}}\sigma (x)\) be the maximum quantization level, so that

$$\begin{aligned} \sigma (x)=\sum _{a=0}^{q_b-1}\mathbb {1}_{\{x>a\}}(x). \end{aligned}$$

Note that

$$\begin{aligned} f_i\left( {{\varvec{W}}};{{\varvec{x}}}\right) =o_i-o_\xi =\sum _{j=1}^k\left( v_{i,j}-v_{\xi ,j}\right) \sigma (h_j) =\sum _{j=1}^k\left( v_{i,j}- v_{\xi ,j}\right) \sum _{a=0}^{q_b}\mathbb {1}_{\Omega _{{\varvec{w}}_j}^a}({{\varvec{x}}}). \end{aligned}$$

By assumption, \({\tilde{\nabla }}_{{\varvec{w}}_j} l_i({{\varvec{W}}})={\varvec{0}}\) for all \({\tilde{v}}_{i,j}>0\) which implies \(\mathbb {1}_{\Omega _{{{\varvec{W}}}}}({{\varvec{x}}})\mathbb {1}_{\Omega _{{\varvec{w}}_j}^a}({{\varvec{x}}})=0\) for all \({\tilde{v}}_{i,j}>0\) and \(a\in [n]\) almost surely. Now, for any \({{\varvec{x}}}\in \Omega _{{\varvec{w}}_j}^a\) we have \({{\varvec{x}}}\not \in \Omega _{{{\varvec{W}}}}\). Note that \({{\varvec{x}}}\in \Omega _{{{\varvec{W}}}}\) if and only if \(o_i-o_\xi \ge 1\), then for any \({\varvec{x}}\in \Omega _{{\varvec{w}}_j}^a\), since \(v_{i,j}-v_{\xi ,j}<1\), there exist \(j'\not =j\) and \(a'\in [n]\) such that \(v_{i,j'}>0\) and \({{\varvec{x}}}\in \Omega _{{\varvec{w}}_{j'}}^{a'}\). By Lemma 6.1, \(\mathop {{\mathbb {P}}}_{\{{{\varvec{x}}},y\}\sim {\mathcal {D}}_i}\left[ \Omega _{{{\varvec{W}}}}\right] =0\) is empty, and thus, \(l_i({{\varvec{W}}})=0\). \(\square \)

The following lemma shows that the expected coarse gradient is continuous except at \({\varvec{w}}_{j,i}={\varvec{0}}\) for some \(j\in [k]\).

Lemma 6.2

Consider the network in (2.1). \({\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}})\) is continuous on

$$\begin{aligned} \left\{ {{\varvec{W}}}\in {\mathbb {R}}^{k\times d}:|{\varvec{w}}_{j,i}|>0\text { for all }j\in [k],i\in [n]\right\} . \end{aligned}$$

Proof of Lemma 6.2

It suffices to prove the result for \(j\in [k]\). Note that

$$\begin{aligned} {\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}}) =\mathop {{\mathbb {E}}}_{\{{\varvec{x}},y\}\sim {\mathcal {D}}_i}\left[ -\left( v_{y,j}-v_{\xi ,j}\right) \,\mathbb {1}_{\Omega _{{{\varvec{W}}}}}({{\varvec{x}}})\,g'(h_j){{\varvec{x}}}\right] \end{aligned}$$

For any \({{\varvec{W}}}^0\) satisfying our assumption, we know

$$\begin{aligned} \lim _{{{\varvec{W}}}\rightarrow {{\varvec{W}}}^0}\mathbb {1}_{\Omega _{{\varvec{W}}}}({{\varvec{x}}})g'(h_j)=\mathbb {1}_{\Omega _{{{\varvec{W}}}^0}}({{\varvec{x}}})g'(h_j^0), \text{ a.e. } \end{aligned}$$

The desired result follows from the dominant convergence theorem. \(\square \)

7 Proof of main results

Equipped with the technical lemmas, we present:

Proof of Theorem 3.1

It is easily noticed from Assumption 1 that \(v_{i,j}>0\) if and only if \({\tilde{v}}_{i,j}>0\). By Lemma 5.3, if \(v_{i,j}>0\) and \(|{\varvec{w}}_{j,i}^0|>0\), then \(|{\varvec{w}}_{j,i}^t|>0\) for all t. Since \({{\varvec{W}}}\) is randomly initialized, we can ignore the possibility that \({\varvec{w}}_{j,i}^0={\varvec{0}}\) for some \(j\in [k]\) and \(i\in [n]\). Moreover, Proposition 1 and Eq. (2.5) imply for all \(v_{i,j}>0\)

$$\begin{aligned} \lim _{t\rightarrow \infty }\left| {\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}}^t)\right| =0. \end{aligned}$$

Suppose \({{\varvec{W}}}^\infty \) is an accumulation point and \({\varvec{w}}_{j,r}^{\infty }\not ={\varvec{0}}\) for all \(j\in [k]\) and \(r\in [n]\), we know for all \(v_{i,j}>0\)

$$\begin{aligned} {\tilde{\nabla }}_{{\varvec{w}}_j}l_i\left( {{\varvec{W}}}^\infty \right) ={\varvec{0}}. \end{aligned}$$

Next, we consider the case when \({\varvec{w}}_{j,r}={\varvec{0}}\) for some \(j\in [k]\) and \(r\in [n]\). Lemma 5.2 implies \(v_{r,j}=0\). We construct a new sequence

$$\begin{aligned} \hat{{\varvec{w}}}_{j,r}^t=\left\{ \begin{aligned} {\varvec{w}}_{j,r}^t&\;\; \text { if }{\varvec{w}}_{j,r}^\infty \not =0 ,\\ {\varvec{0}}&\;\; \text { if }{\varvec{w}}_{j,r}^\infty =0 ,\\ \end{aligned} \right. \end{aligned}$$

and

$$\begin{aligned} \hat{{\varvec{W}}}_r^t=\left[ \hat{{\varvec{w}}}_{1,r}^t,\ldots ,\hat{{\varvec{w}}}_{k,r}^t\right] . \end{aligned}$$

With

$$\begin{aligned} {\hat{o}}_r=\sum _{j=1}^kv_{r,j}\sigma ({\hat{h}}_j)=\sum _{j=1}^kv_{r,j}\sigma \left( \left\langle \hat{{\varvec{w}}}_{j,r},{{\varvec{x}}}\right\rangle \right) , \end{aligned}$$

we know \({\hat{o}}_r=o_r\) for all \(r\in [n]\). Hence, we have

$$\begin{aligned} l\left( \hat{{\varvec{W}}}^t,\{{{\varvec{x}}},i\}\right) =\text {ReLU}\left( 1-{\hat{o}}_i+{\hat{o}}_\xi \right) = l\left( {\varvec{W}}^t,\{{{\varvec{x}}},i\}\right) . \end{aligned}$$

This implies that \(\Omega _{{\hat{{{\varvec{W}}}}}^t}=\Omega _{{{\varvec{W}}}^t}\), so we have for all \(j\in [k]\),

$$\begin{aligned} \left| \left\langle {\tilde{\nabla }}_{{\varvec{w}}_j}l_i(\hat{{\varvec{W}}}_1^t),{\tilde{{{\varvec{w}}}}}_{j,i}^t\right\rangle \right| \le \left| \left\langle {\tilde{\nabla }}_{{\varvec{w}}_j}l_i({\varvec{W}}_i^t),{\tilde{{{\varvec{w}}}}}_{j,i}^t\right\rangle \right| \le \left| {\tilde{\nabla }}_{{\varvec{w}}_j}l_i({\varvec{W}}_i^t)\right| . \end{aligned}$$

Letting t go to infinity on both side, we get

$$\begin{aligned} \left| \left\langle {\tilde{\nabla }}_{{\varvec{w}}_j}l_i(\hat{{\varvec{W}}}^\infty ),{\tilde{{{\varvec{w}}}}}_{j,i}^\infty \right\rangle \right| =0. \end{aligned}$$

By Lemmas 5.1 and 5.2, we know

$$\begin{aligned} {\tilde{\nabla }}_{{\varvec{w}}_j}l_i({{\varvec{W}}}^\infty )={\tilde{\nabla }}_{{\varvec{w}}_j}l_i({\varvec{W}}_i^\infty )=0, \end{aligned}$$

so \({\tilde{\nabla }}_{{{\varvec{W}}}}l_i({{\varvec{W}}}^\infty )=0.\) By Proposition 2, \(l_i({{\varvec{W}}}^t)=0\), which completes the proof. \(\square \)

8 Experiments

In this section, we conduct experiments on both synthetic and MNIST data to verify and complement our theoretical findings. Experiments on larger networks and data sets will left for a future work.

8.1 Synthetic data

Let \(\left\{ {\varvec{e}}_1,{\varvec{e}}_2,{\varvec{e}}_3,{\varvec{e}}_4\right\} \) be orthonormal basis of \({\mathbb {R}}^4\), \(\theta \) be an acute angle and \({\varvec{v}}_1={\varvec{e}}_1\), \({\varvec{v}}_2=\sin \theta \,{\varvec{e}}_2+\cos \theta \,{\varvec{e}}_3\), \({\varvec{v}}_3={\varvec{e}}_3\), \({\varvec{v}}_4={\varvec{e}}_4\). Now, we have two linearly independent subspaces of \({\mathbb {R}}^4\), namely \({\mathcal {V}}_1=\text {Span}\left( \left\{ {\varvec{v}}_1,{\varvec{v}}_2\right\} \right) \) and \({\mathcal {V}}_2=\text {Span}\left( \left\{ {\varvec{v}}_3,{\varvec{v}}_4\right\} \right) \). We can easily calculate that the angle between \({\mathcal {V}}_1\) and \({\mathcal {V}}_2\) is \(\theta \). Next, with

$$\begin{aligned} S_r=\left\{ \frac{j}{10}:j\in [20]-[9]\right\} , \; S_\varphi =\left\{ \frac{j\pi }{40}:j\in [80]\right\} , \end{aligned}$$

we define

$$\begin{aligned} \hat{{\mathcal {X}}}_1=\left\{ r\left( \cos \varphi \,{\varvec{v}}_1+\sin \varphi \,{\varvec{v}}_2\right) :r\in S_r,\varphi \in S_\varphi \right\} \end{aligned}$$

and

$$\begin{aligned} \hat{{\mathcal {X}}}_2=\left\{ r\left( \cos \varphi \,{\varvec{v}}_3+\sin \varphi \,{\varvec{v}}_4\right) :r\in S_r,\varphi \in S_\varphi \right\} . \end{aligned}$$

Let \(\hat{{\mathcal {D}}}_i\) be uniform distributed on \(\hat{{\mathcal {X}}}_i\times \{i\}\) and \(\hat{{\mathcal {D}}}\) be a mixture of \(\hat{{\mathcal {D}}}_1\) and \(\hat{{\mathcal {D}}}_2\). Let \(\hat{{\mathcal {X}}}=\hat{{\mathcal {X}}}_1\cup \hat{{\mathcal {X}}}_2\). The activation function \(\sigma \) is 4-bit quantized ReLU:

$$\begin{aligned} \sigma (x)=\left\{ \begin{array}{lll} 0&{} \text {if} &{}x<0,\\ \text {ceil}(x)&{} \text {if} &{}0\le x<15,\\ 15&{} \text {if} &{}x\ge 15.\\ \end{array}\right. \end{aligned}$$

For simplicity, we take \(k=24\) and \(v_{i,j}=\frac{1}{2}\) if \(j-12(i-1)\in [12]\) for \(i\in [2]\) and \(j\in [24]\) and 0 otherwise. Now, our neural network becomes

$$\begin{aligned} f_i=\frac{(-1)^{i-1}}{2}\left[ \sum _{j=1}^{12}\sigma (h_j)-\sum _{j=1}^{12}\sigma (h_{j+12})\right] \end{aligned}$$

where \(h_j=\langle {\varvec{w}}_j,{{\varvec{x}}}\rangle \) and \({{\varvec{x}}}\in {\mathbb {R}}^4\). The population loss is given by

$$\begin{aligned} l({{\varvec{W}}})&=\mathop {{\mathbb {E}}}_{\{{{\varvec{x}}},y\}\sim \hat{{\mathcal {D}}}}\left[ l({{\varvec{W}}};\{{{\varvec{x}}},y\})\right]&=\mathop {{\mathbb {E}}}_{\{{{\varvec{x}}},y\}\sim \hat{{\mathcal {D}}}}\left[ \max \left\{ 1-f_i\right\} \right] . \end{aligned}$$

We choose the ReLU STE (i.e., \(g(x) = \max \{0,x\}\)) and use the coarse gradient

$$\begin{aligned} \begin{aligned}&{\tilde{\nabla }}_{{{\varvec{W}}}}l({{\varvec{W}}})=\mathop {{\mathbb {E}}}_{\{{{\varvec{x}}},y\}\sim \hat{{\mathcal {D}}}}\left[ {\tilde{\nabla }}_{{{\varvec{W}}}}l\left( {{\varvec{W}}},\{{{\varvec{x}}},y\}\right) \right] \\ =&\frac{1}{|\hat{{\mathcal {X}}}|}\left[ \sum _{{{\varvec{x}}}\in \hat{{\mathcal {X}}}_1}{\tilde{\nabla }}_{{{\varvec{W}}}}l\left( {{\varvec{W}}};\{{{\varvec{x}}},1\}\right) +\sum _{{{\varvec{x}}}\in \hat{{\mathcal {X}}}_2}{\tilde{\nabla }}_{{{\varvec{W}}}}l\left( {{\varvec{W}}};\{{{\varvec{x}}},2\}\right) \right] . \end{aligned} \end{aligned}$$

Taking learning rate \(\eta =1\), Eq. 2.5 becomes

$$\begin{aligned} {{\varvec{W}}}^{t+1}={{\varvec{W}}}^t-{\tilde{\nabla }}_{{{\varvec{W}}}}l\left( {{\varvec{W}}}^t\right) . \end{aligned}$$

We find that the coarse gradient method converges to a global minimum with zero loss. As shown in box plots of Fig. 3, the convergence still holds when the subspaces \({\mathcal {V}}_1\) and \({\mathcal {V}}_2\) form an acute angle, and even when the data come from two levels of Gaussian noise perturbations of \({\mathcal {V}}_1\) and \({\mathcal {V}}_2\). The convergence is faster and with a smaller weight norm when \(\theta \) increases toward \(\frac{\pi }{2}\) or \({\mathcal {V}}_2\) are orthogonal to each other. This observation clearly supports the robustness of Theorem 1 beyond the regime of orthogonal classes.

Fig. 3
figure 3

Left: Iterations to convergence v.s. \(\theta \), Right: Norm of weights v.s. \(\theta \)

Fig. 4
figure 4

Validation accuracies in training LeNet-5 with quantized (2-bit and 4-bit) ReLU activation

Fig. 5
figure 5

2D projections of MNIST features from a trained convolutional neural network [24] with quantized activation function. The 10 classes are color coded, and the feature points cluster near linearly independent subspaces

8.2 MNIST experiments

Our theory works for a board range of STEs, while their empirical performances on deeper networks may differ. In this subsection, we compare the performances of the three types of STEs in Fig. 2.

As in [2], we resort to a modified batch normalization layer [13] and add it before each activation layer. As such, the inputs to quantized activation layers always follow unit Gaussian distribution. Then, the scaling factor \(\tau \) applied to the output of quantized activation layers can be pre-computed via k-means approach and get fixed during the whole training process. The optimizer we use to train quantized LeNet-5 is the (stochastic) coarse gradient method with momentum = 0.9. The batch size is 64, and learning rate is initialized to be 0.1 and then decays by a factor of 10 after every 20 epochs. The three backward pass substitutions g for the straight-through estimator are (1) ReLU \(g(x) = \max \{x,0\}\), (2) reverse exponential \(g(x)=\max \{0,q_b(1- e^{-x/q_b})\}\) and (3) log-tailed ReLU. The validation accuracy for each epoch is shown in Fig. 4. The validation accuracies at bit-widths 2 and 4 are listed in Table 2. Our results show that these STEs all perform very well and give satisfactory accuracy. Specifically, reverse exponential and log-tailed STEs are comparable, both of which are slightly better than ReLU STE. In Fig. 5, we show 2D projections of MNIST features at the end of 100 epoch training of a 7 layer convolutional neural network [24] with quantized activation. The features are extracted from input to the last fully connected layer. The data points cluster near linearly independent subspaces. Together with Sect. 8.1, we have numerical evidence that the linearly independent subspace data structure (working as an extension of subspace orthogonality) occurs for high-level features in a deep network for a nearly perfect classification, rendering support to the realism of our theoretical study. Enlarging angles between linear subspaces can improve classification accuracy, see [27] for such an effort on MNIST and CIFAR-10 data sets via linear feature transform.

Table 2 Validation accuracy (%) on MNIST with LeNet-5

8.3 CIFAR-10 experiments

In this experiment, we train VGG-11/ResNet-20 with 4-bit activation function on CIFAR-10 data set to numerically validate the boundedness assumption upon the \(\ell _2\)-norm of weight. The optimizer is momentum SGD with no weight decay. We used initial learning rate \(=0.1\), with a decay factor of 0.1 at the 80-th and 140-th epoch.

We see from Fig. 6 that the \(\ell _2\) norm of weights is bounded during the training process. This figure also shows that the norm of weights is generally increasing in epochs which coincides with our theoretical finding shown in Lemma 5.3.

Fig. 6
figure 6

CIFAR-10 experiments for VGG-11 and ResNet-20: weight \(\ell _2\)-norm vs epoch

9 Summary

We studied a novel and important biased first-order oracle, called coarse gradient, in training quantized neural networks. The effectiveness of coarse gradient relies on the choice of STE used in backward pass only. We proved the convergence of coarse gradient methods for a class of STEs bearing certain monotonicity in nonlinear classification using one-hidden-layer networks. In experiments on LeNet and MNIST data set, we considered three different proxy functions satisfying the monotonicity condition for backward pass: ReLU, reverse exponential function and log-tailed ReLU for training LeNet-5 with quantized activations. All of them exhibited good performance which verified our theoretical findings. In future work, we plan to expand theoretical understanding of coarse gradient descent for deep activation quantized networks.