LSALSA: accelerated source separation via learned sparse coding

Abstract

We propose an efficient algorithm for the generalized sparse coding (SC) inference problem. The proposed framework applies to both the single dictionary setting, where each data point is represented as a sparse combination of the columns of one dictionary matrix, as well as the multiple dictionary setting as given in morphological component analysis (MCA), where the goal is to separate a signal into additive parts such that each part has distinct sparse representation within an appropriately chosen corresponding dictionary. Both the SC task and its generalization via MCA have been cast as \(\ell _1\)-regularized optimization problems of minimizing quadratic reconstruction error. In an effort to accelerate traditional acquisition of sparse codes, we propose a deep learning architecture that constitutes a trainable time-unfolded version of the split augmented lagrangian shrinkage algorithm (SALSA), a special case of the alternating direction method of multipliers (ADMM). We empirically validate both variants of the algorithm, that we refer to as learned-SALSA (LSALSA), on image vision tasks and demonstrate that at inference our networks achieve vast improvements in terms of the running time and the quality of estimated sparse codes on both classic SC and MCA problems over more common baselines. We also demonstrate the visual advantage of our technique on the task of source separation. Finally, we present a theoretical framework for analyzing LSALSA network: we show that the proposed approach exactly implements a truncated ADMM applied to a new, learned cost function with curvature modified by one of the learned parameterized matrices. We extend a very recent stochastic alternating optimization analysis framework to show that a gradient descent step along this learned loss landscape is equivalent to a modified gradient descent step along the original loss landscape. In this framework, the acceleration achieved by LSALSA could potentially be explained by the network’s ability to learn a correction to the gradient direction of steeper descent.

Introduction

In the SC framework, we seek to efficiently represent data by using only a sparse combination of available basis vectors. We therefore assume that an M-dimensional data vector \(\mathbf {y}\in {\mathbb {R}}^M\) can be approximated as

$$\begin{aligned} \mathbf {y}\approx \mathbf {A}\mathbf {x}^*, \end{aligned}$$
(1)

where \(\mathbf {x}^*\in {\mathbb {R}}^N\) is sparse and \(\mathbf {A}\in {\mathbb {R}}^{M\times N}\) is a dictionary, sometimes referred to as the synthesis matrix, whose columns are the basis vectors. This paper focuses on the generalized SC problem of decomposing a signal into morphologically distinct components. A typical assumption for this problem is that the data is a linear combination of D source signals:

$$\begin{aligned} \mathbf {y}= \sum _{i=1}^D\mathbf {y}_i. \end{aligned}$$
(2)

The MCA framework (Starck et al. 2004) for addressing additive mixtures requires that each component \(\mathbf {y}_i\) admits a sparse representation within the corresponding dictionary \(\mathbf {A}_i\), leading to a generalized signal approximation model:

$$\begin{aligned} \mathbf {y}\approx \sum _{i=1}^D\mathbf {A}_i\mathbf {x}_i^*. \end{aligned}$$
(3)

We then seek to recover \(x_i^{*}\)s given y and dictionaries \(A_i\)s. We may trivially satisfy (3) by setting, for example, \(\mathbf {x}^*_i=0\) for all \(i\ne j\), and performing traditional SC using only dictionary \(\mathbf {A}_j\). Thus, MCA further assumes that the dictionaries \(\mathbf {A}_i\)s are distinct in the sense that each source-specific dictionary allows obtaining sparse representation of the corresponding source signal, while being highly inefficient in representing the other content in the mixture. This assumption is difficult to enforce on harder problems, i.e. when the components \(\mathbf {y}_i\) have similar characteristics and do not admit intuitive a priori sparsifying bases. In practice, the \(\mathbf {A}_i\)s often have significant overlap in sparse representation, making the problem of jointly recovering the \(\mathbf {x}_i\)s highly ill-conditioned.

There exist iterative optimization algorithms for performing SC and MCA. The bottleneck of these techniques is that at inference a sparse code has to be computed for each data point or data patch (as in case of high-resolution images). In the single dictionary setting, ISTA (Daubechies et al. 2004) and FISTA (Beck and Teboulle 2009) are classical algorithmic choices for this purpose. For the MCA problem, the standard choice is SALSA (Afonso et al. 2011), an instance of ADMM (Boyd et al. 2011). The iterative optimization process is prohibitively slow for high-throughput real-time applications, especially in the case of the ill-conditioned MCA setting. Thus our goal is to provide algorithms performing efficient inference, i.e. algorithms that find good approximations of the optimal codes in significantly shorter time than FISTA or SALSA.

The first key contribution of this paper is an efficient and accurate deep learning architecture that is general enough to well-approximate optimal codes for both classic SC in a single-dictionary framework and MCA-based signal separation. By accelerating SALSA via learning, we provide a means for fast approximate source separation. We call our deep learning approximator Learned SALSA (LSALSA). The proposed encoder is formulated as a time-unfolded version of the SALSA algorithm with a fixed number of iterations, where the depth of the deep learning model corresponds to the number of SALSA iterations. We train the deep model in the supervised fashion to predict optimal sparse codes for a given input and show that shallow architectures of fixed-depth, that correspond to only few iterations of the original SALSA, achieve superior performance to the classic algorithm.

The SALSA algorithm uses second-order information about the cost function, which gives it an advantage over popular comparators such as ISTA on ill-conditioned problems (Figueiredo et al. 2009). Our second key contribution is an empirical demonstration that this advantage carries over to the deep-learning accelerated versions LSALSA and LISTA (Gregor and LeCun 2010), while preserving SALSA’s applicability to a broader class of learning problems such as MCA-based source separation (LISTA is used only in the single dictionary setting). To the best of our knowledge, our approach is the first one to utilize an instance of ADMM unrolled into a deep learning architecture to address a source separation problem

Our third key contribution is a theoretical framework that provides insight into how LSALSA is able to surpass SALSA, namely describing how the learning procedure can enhance the second-order information that is characteristically exploited by SALSA. In particular, we show that the forward-propagation of a signal through the LSALSA network is equivalent to the application of truncated-ADMM to a new, learned cost function, and present a theoretical framework for characterizing this function in relation to the original Augmented Lagrangian. To the best of our knowledge, our work is the first to attempt to analyze a learning-accelerated ADMM algorithm.

To summarize, our contributions are threefold:

  1. 1.

    We achieve significant acceleration in both SC and MCA: classic SALSA takes up to \(100\times \) longer to achieve LSALSA’s performance. This opens up the MCA framework to potentially be used in high-throughput, real-time applications.

  2. 2.

    We carefully compare an ADMM-based algorithm (SALSA) with our proposed learnable counterpart (LSALSA) and with popular baselines (ISTA and FISTA). For a large variety of computational constraints (i.e. fixed number of iterations), we perform comprehensive hyperparameter testing for each encoding method to ensure a fair comparison.

  3. 3.

    We present a theoretical framework for analyzing the LSALSA network, giving insight as to how it uses information learned from data to accelerate SALSA.

This paper is organized as follows: Sect. 2 provides literature review, Sect. 3 formulates the SC problem in detail, Sect. 4 shows how to derive predictive single dictionary SC and multiple dictionary MCA from their iterative counterparts and explains our approach (LSALSA). Section 5 elaborates our theoretical framework for analyzing LSALSA and provides insight into its empirically demonstrated advantages. Section 6 shows experimental results for both the single dictionary setting and MCA. Finally, Sect. 7 concludes the paper. We provide an open-source implementation of the sparse coding and source separation experiments presented herein.

Related work

A sparse code inference aims at computing sparse codes for given data and is most widely addressed via iterative schemes such as aforementioned ISTA and FISTA. Predicting approximations of optimal codes can be done using deep feed-forward learning architectures based on truncated convex solvers. This family of approaches lies at the core of this paper. A notable approach in this family known as LISTA (Gregor and LeCun 2010) stems from earlier predictive sparse decomposition methods (Kavukcuoglu et al. 2010; Jarrett et al. 2009), which however were obtaining approximations to the sparse codes of insufficient quality. LISTA improves over these techniques and enhances ISTA by unfolding a fixed number of iterations to define a fixed-depth deep neural network that is trained with examples of input vectors paired with their corresponding optimal sparse codes obtained by conventional methods like ISTA or FISTA. LISTA was shown to provide high-quality approximations of optimal sparse codes with a fixed computational cost. Unrolling methodology has since been applied to algorithms solving SC with \(\ell _0\)-regularization (Wang et al. 2016) and message passing schemes (Borgerding and Schniter 2016). In other prior works, ISTA was recast as a recurrent neural network unit giving rise to a variant of LSTM (Gers et al. 2003; Zhou et al. 2018). Recently, theoretical analysis has been provided for LISTA (Chen et al. 2018; Moreau and Bruna 2016), in which the authors provide convergence analyses by imposing constraints on the LISTA algorithm. This analysis does not apply to the MCA problem as it cannot handle multiple dictionaries. In other words, they would approach the MCA problem by casting it as a SC problem with access to a single dictionary that is a concatenation of source-specific dictionaries, e.g. \([\mathbf {A}_1,\mathbf {A}_2,\dots ,\mathbf {A}_D]\). Furthermore these analyses do not address the saddle-point setting as required for ADMM-type methods such as SALSA.

MCA has been used successfully in a number of applications that include decomposing images into textures and cartoons for denoising and inpainting (Elad et al. 2005; Peyré et al. 2007, 2010; Shoham and Elad 2008; Starck et al. 2005a, b), detecting text in natural scene images (Liu et al. 2017), as well as other source separation problems such as separating non-stationary clutter from weather radar signals (Uysal et al. 2016), transients from sustained rhythmic components in EEG signals (Parekh et al. 2014), and stationary from dynamic components of MRI videos (Otazo et al. 2015). The MCA problem is frequently solved via SALSA algorithm, which constitutes a special case of the ADMM method.

There exist a few approaches in the literature utilizing highly specialized trainable ADMM algorithms. One such framework (Yang et al. 2016) was demonstrated to improve the reconstruction accuracy and inference speed over a variety of state-of-the-art solvers for the problem of compressive sensing Magnetic Resonance Imaging. A variety of papers followed up on this work for various image reconstruction tasks, such as the Learned Primal-dual Algorithm (Adler and Öktem 2017). However, these approaches do not give a detailed iteration-by-iteration comparison of the baseline method versus the learned method, making it difficult to understand the accuracy/speed tradeoff. Another related framework (Sprechmann et al. 2013) was applied to efficiently learn task-specific (reconstruction or classification) sparse models via sparsity-promoting convolutional operators. None of the above methods were applied to the MCA or other source separation problems and moreover it is non-trivial to obtain such extensions of these works. An unrolled nonnegative matrix factorization (NMF) algorithm (Roux et al. 2015) was implemented as a deep network for the task of speech separation. In another work (Wisdom et al. 2017), the NMF-based speech separation task was solved with an ISTA-like unfolded network.

Problem formulation

This paper focuses on the inference problem in SC: given data vector \(\mathbf {y}\) and dictionary matrix \(\mathbf {A}\), we consider algorithms for finding the unique coefficient vector \(\mathbf {x}^*\) that minimizes the \(\ell _1\)-regularized linear least squares cost function:

$$\begin{aligned} \mathbf {x}^* = \text {arg}\min _{\mathbf {x}}\left\{ E_{\mathbf {A}}(\mathbf {x};\mathbf {y}) = \tfrac{1}{2}\left\| \mathbf {y}-\mathbf {A}\mathbf {x}\right\| _2^2 + \alpha \left\| \mathbf {x}\right\| _1\right\} , \end{aligned}$$
(4)

where the scalar constant \(\alpha \ge 0\) balances sparsity with data fidelity. Since this problem is convex, \(\mathbf {x}^*\) is unique and we refer to it as the optimal code for \(\mathbf {y}\) with respect to \(\mathbf {A}\). The dictionary matrix \(\mathbf {A}\) is usually learned by minimizing a loss function given below (Olshausen and Field 1996)

$$\begin{aligned} {\mathcal {L}}_{\text {Dict}}(\mathbf {A}) = \frac{1}{P}\sum _{p=1}^P E_{\mathbf {A}}(\mathbf {x}^{*,p}; \mathbf {y}^p) \end{aligned}$$
(5)

with respect to \(\mathbf {A}\) using stochastic gradient descent (SGD), where P is the size of the training data set, \(\mathbf {y}^p\) is the \(p\mathrm{th}\) training sample, and \(\mathbf {x}^{*,p}\) is the corresponding optimal sparse code. The optimal sparse codes in each iteration are obtained in this paper with FISTA. When training dictionaries, we require the columns of \(\mathbf {A}\) to have unit norm, as is common practice for regularizing the dictionary learning process (Olshausen and Field 1996), however this is not necessary for code inference.

In the MCA framework, a generalization of the cost function from Eq. 4 is minimized to estimate \(\mathbf {x}_1^*,\mathbf {x}_2^*,\dots ,\mathbf {x}_D^*\) from the model given in Eq. 3. Thus one minimizes

$$\begin{aligned} E_{\mathbf {A}}(\mathbf {x};\mathbf {y}) = \tfrac{1}{2}\left\| \mathbf {y}-\mathbf {A}\mathbf {x}\right\| _2^2 + \sum _{i=1}^D\alpha _i\left\| \mathbf {x}_i\right\| _1, \end{aligned}$$
(6)

using \(\mathbf {A}:=[\mathbf {A}_1,\mathbf {A}_2,\dots ,\mathbf {A}_D]\in {\mathbb {R}}^{M\times N}\) and

$$\begin{aligned} \mathbf {x}:=\left[ \begin{array}{c} \mathbf {x}_1\\ \mathbf {x}_2 \\ \vdots \\ \mathbf {x}_D \end{array} \right] \in {\mathbb {R}}^N, \end{aligned}$$
(7)

where \(\mathbf {x}_i \in {\mathbb {R}}^{N_i}\) for \(i = \{1,2,\dots ,D\}\), \(N = \sum _{i=1}^D N_i\), and \(\alpha _i\)s are the coefficients controlling the sparsity penalties. We denote the concatenated optimal codes with \(\mathbf {x}^* = \text {arg}\min _{\mathbf {x}}E_{\mathbf {A}}(\mathbf {x},\mathbf {y})\). To recover the single dictionary case, simply set \(\alpha _i=\alpha _j,\ \forall i,j=1,\ldots ,D\) and set \(\mathbf {A}_i\) to be partitions of \(\mathbf {A}\).

In the classic MCA works, the dictionaries \(\mathbf {A}_i\)s are selected to be well-known filter banks with explicitly designed sparsification properties. Such hand-designed transforms have good generalization abilities and help to prevent overfitting. Also, MCA algorithms often require solving large systems of equations involving \(\mathbf {A}^{\text {T}}\mathbf {A}\) or \(\mathbf {A}\mathbf {A}^{\text {T}}\). An appropriate constraining of \(\mathbf {A}_i\) leads to a banded system of equations and in consequence reduces the computational complexity of these algorithms, e.g. Parekh et al. (2014). More recent MCA works use learned dictionaries for image analysis (Shoham and Elad 2008; Peyré et al. 2007). Some extensions of MCA consider learning dictionaries \(\mathbf {A}_i\)s and sparse codes jointly (Peyré et al. 2007, 2010).

Remark 1

(Learning dictionaries) In our paper, we learn dictionaries \(\mathbf {A}_is\) independently. In particular, for each i we minimize

$$\begin{aligned} {\mathcal {L}}_{\text {Dict}}(\mathbf {A}_i) = \frac{1}{P}\sum _{p=1}^P E_{\mathbf {A}_i}(\mathbf {x}_i^{*,p}; \mathbf {y}_i^p) \end{aligned}$$
(8)

with respect to \(\mathbf {A}_i\) using SGD, where \(\mathbf {y}_i^p\) is the \(i\mathrm{th}\) mixture component of the \(p\mathrm{th}\) training sample and \(\mathbf {x}_i^{*,p}\) is the corresponding optimal sparse code. The columns are constrained to have unit norm. The sparse codes in each iteration are obtained with FISTA.

From iterative to predictive SC and MCA

Split augmented lagrangian shrinkage algorithm (SALSA)

The objective functions used in SC (Eq. 4) and MCA (Eq. 6) are each convex with respect to \(\mathbf {x}\), allowing a wide variety of optimization algorithms with well-studied convergence results to be applied (Bauschke and Combettes 2011). Here we describe a popular algorithm that is general enough to solve both problems called SALSA (Afonso et al. 2010), which is an instance of ADMM. ADMM (Boyd et al. 2011) addresses an optimization problem with the form

$$\begin{aligned} \min _{\mathbf {x}} f_1(\mathbf {x}) + f_2(\mathbf {x}) \end{aligned}$$
(9)

by re-casting it as the equivalent, constrained problem

$$\begin{aligned} \min _{\mathbf {u},\mathbf {x}} f_1(\mathbf {x}) + f_2(\mathbf {u})\,\,\, \text {such that }\, \mathbf {x}=\mathbf {u}. \end{aligned}$$
(10)

ADMM then optimizes the corresponding scaled Augmented Lagrangian,

$$\begin{aligned} {\mathcal {L}}_A= f_1(\mathbf {x}) + f_2(\mathbf {u})+\frac{\mu }{2}\left\| \mathbf {u}-\mathbf {x}-\mathbf {d}\right\| _2^2 - \frac{\mu }{2}\left\| d\right\| _2^2, \end{aligned}$$
(11)

where \(\mathbf {d}\) correspond to Lagrangian multipliers, one variable at a time until convergence.

SALSA, proposed in Afonso et al. (2010), addresses an instance of the general optimization problem from Eq. 10 for which convergence has been proved in Eckstein and Bertsekas (1992). Namely, SALSA requires that (1) \(f_1\) is a least-squares term, and (2) the proximity operator of \(f_2\) can be computed exactly. For our most general cost function in Eq. 6, requirement (1) is clearly satisfied, and our \(f_2\) is the weighted sum of \(\ell _1\) norms. In Supplemental Section A, we show that the the proximity operator of \(f_2\) reduces to element-wise soft thresholding for each component, which in scalar form is given by

$$\begin{aligned} \text {soft}(z;\alpha ) = {\left\{ \begin{array}{ll} z-\alpha , &{}\quad z>\alpha \\ 0, &{}\quad |z|\le \alpha \\ z+\alpha , &{}\quad z<-\alpha \end{array}\right. }. \end{aligned}$$
(12)

When applied to a vector, \(\text {soft}(\mathbf {z};\alpha )\) performs soft thresholding element-wise. Thus, SALSA is guaranteed to converge for the multiple-dictionary sparse coding problem.

figurea
figureb

SALSA is given in Algorithms 1 and 2 for the single-dictionary case and the MCA case involving two dictionaries,Footnote 1 respectively. Note that in Algorithm 2, the \(\mathbf {u}\) and \(\mathbf {d}\) updates can be performed with element-wise operations. The \(\mathbf {x}\)-update, however, is non-separable with respect to components \(\{\mathbf {x}_i\}_{i=1}^D\) for general \(\mathbf {A}\); the system of equations in the \(\mathbf {x}\)-update cannot be broken down into D sub-problems, one for each component (in contrast, 1st order methods such as FISTA update components independently). We call this the splitting step.

As mentioned in Sect. 3, the \(\mathbf {x}\)-update is often simplified to element-wise operations by constraining matrix \(\mathbf {A}\) to have special properties. For example: requiring \(\mathbf {A}\mathbf {A}^{\text {T}}=\rho \mathbf {I}\), \(\rho \in {\mathbb {R}}_+\), reduces the \(\mathbf {x}\)-update step to element-wise division (after applying the matrix inverse lemma). In Yang et al. (2016), \(\mathbf {A}\) is set to be the partial Fourier transform, reducing the system of equations of the \(\mathbf {x}\)-update to be a series of convolutions and element-wise operations. In our work, as is typical in the case of SC, \(\mathbf {A}\) is a learned dictionary without any imposed structure.

Fig. 1
figure1

A block diagram of SALSA. The one-time initialization \(\mathbf {x}= \mathbf {A}^{\text {T}}\mathbf {y}\) is represented by a gate on the left

Note that one way to solve for \(\mathbf {x}\) in Algorithms 1 and 2 is to compute the inverse of regularized Hessian matrix \(\mu I + \mathbf {A}^{\text {T}}\mathbf {A}\). This however needs to be done just once, at the very beginning, as this matrix remains fixed during the entire run of SALSA. We abbreviate the inverted matrix as

$$\begin{aligned} \mathbf {S}= (\mu \mathbf {I} + \mathbf {A}^{\text {T}}\mathbf {A})^{-1}. \end{aligned}$$
(13)

We call this matrix a splitting operator. Note that the inversion process couples together the dictionary elements (and hence also the dictionaries) in a non-linear fashion. This is an advanced utilization of prior knowledge not seen in the comparator methods of Sect. 6. The recursive block diagram of SALSA is depicted in Fig. 1.

Learned SALSA (LSALSA)

Fig. 2
figure2

The deep learning architecture of LSALSA for \(T=3\). The soft-thresholding function, defined in Eq. 12, is an activation function found in each layer of the network and at the end

We now describe our proposed deep encoder architecture that we refer to as Learned SALSA (LSALSA). Consider truncating the SALSA algorithm to a fixed number of iterations T and then time-unfolding it into a deep neural network architecture that matches the truncated SALSA’s output exactly. The obtained architecture is illustrated in Fig. 2 for \(T=3\), and the formulas for the \(t\mathrm{th}\) layer w.r.t. the \((t-1)\mathrm{th}\) iterates are described via pseudocode in Algorithms 3 and 4 for the single-dictionary and MCA cases, respectively. Note that Algorithms 2 and 4 are the most general algorithms considered by us whereas Algorithms 1 and 3 are their special, i.e. single-dictionary, cases.

The LSALSA model has two matrices of learnable parameters: \(\mathbf {S}\) and \(\mathbf {W_e}\). We initialize these to achieve an exact correspondence with SALSA:

$$\begin{aligned} \mathbf {W_e}= \mathbf {A}^{\text {T}}\in {\mathbb {R}}^{N\times M}\,\,\,\text {and}\,\,\,\mathbf {S}= \left( \mu \mathbf {I} + \mathbf {A}^{\text {T}}\mathbf {A}\right) ^{-1} \in {\mathbb {R}}^{N\times N}, \end{aligned}$$
(14)

where \(N=N_1+N_2\) in the MCA case. All splitting operators \(\mathbf {S}\) share parameters across the network. LSALSA’s two matrices of parameters can be trained with standard backpropagation. Let \(\mathbf {x}= f_e(\mathbf {W}_e,\mathbf {S},\mathbf {y})\) denote the output of the LSALSA architecture after a forward propagation of \(\mathbf {y}\). The cost function used for training the model is defined as

$$\begin{aligned} {\mathcal {L}}(\mathbf {W}_e,\mathbf {S}) = \frac{1}{2P}\sum _{p=1}^P\left\| \mathbf {x}^{*,p}-f_e(\mathbf {W}_e,\mathbf {S},\mathbf {y}^p)\right\| _2^2. \end{aligned}$$
(15)
figurec
figured

To summarize, LSALSA extends SALSA. SALSA is meant to run until convergence, where LSALSA is meant to run for T iterations, where T is the depth of the network. Intuitively, the backpropagation steps applied during training in LSALSA fine-tune the “splitting step” so that T iterations can be sufficient to achieve good-quality sparse codes (those are obtained due to the existence of nonlinearities). The SALSA algorithm relies on cumulative Lagrange Multiplier updates to “explain away” code components, while separating sources. This is especially important in MCA, where similar atoms from different dictionaries will compete to represent the same segment of a mixed signal. The Lagrange Multiplier updates translate to a cross-layer connectivity pattern in the corresponding LSALSA network (see the d-updates in Fig. 2), which has been shown to be a beneficial architectural feature in e.g. (Greff et al. 2016; Liao and Poggio 2016; Orhan and Pitkow 2018). During training, LSALSA is fine-tuning the splitting operator \(\mathbf {S}\) so that it need not rely on a large number of cumulative updates. However, we show in Sect. 5 that even after training, forward propagation through an LSALSA network is equivalent to the application of a truncated ADMM algorithm applied to a new, learned cost function that generalizes the original problem.

Analysis of LSALSA

Optimality property for LSALSA

Typically, analyses of ADMM-like algorithms rely on the optimality of each primal update, e.g. that \(\mathbf {x}^{(k+1)}=\text {arg}\min _{\mathbf {x}}{\mathcal {L}}_A(\mathbf {x},\mathbf {u}^{(k+1)};\mathbf {d}^{(k)})\) (Boyd et al. 2011; Goldstein et al. 2014; Wang et al. 2019). In Theorem 1 we show that LSALSA provides optimal primal updates with respect to a generalization of the Augmented Lagrangian (11) parameterized by \(\mathbf {S}\). The proof is provided in Supplemental Section C.

Theorem 1

(LSALSA Optimality) Given a neural network with the LSALSA architecture as described in Sect. 4.2, there exists an Augmented Lagrangian for which the LSALSA network provides optimal primal updates. In particular, for learned matrices \(\mathbf {S}\) and \(\mathbf {W_e}\), we have

$$\begin{aligned} \hat{\mathcal {L}}_A= \hat{f_1}(\mathbf {x};\mathbf {S})+\ell _1(\mathbf {u}) + \frac{\mu }{2}\left\| \mathbf {u}-\mathbf {x}-\mathbf {d}\right\| ^2-\frac{\mu }{2}\left\| \mathbf {d}\right\| ^2, \end{aligned}$$
(16)

where

$$\begin{aligned} \hat{f_1}(\mathbf {x};\mathbf {S}) = \frac{1}{2}\mathbf {x}^{\text {T}}\left[ \mathbf {S}^{-1}-\mu I\right] \mathbf {x}-(\mathbf {W_e}\mathbf {y})^{\text {T}}\mathbf {x}+ \frac{1}{2}\mathbf {y}^{\text {T}}\mathbf {y}, \end{aligned}$$
(17)

and \(\ell _1(\mathbf {u})\) represents a sum of L1-terms as in (6).

Remark 2

(LSALSA as an instance of ADMM) Note that by plugging in the initializations of \(\mathbf {S}\) and \(\mathbf {W_e}\), given in Eq. 14, we recover the original Augmented Lagrangian. Then, from the perspective of Theorem 1, LSALSA at inference is equivalent to applyingTiterations of ADMM on a new, learned cost function that generalizes the original problem in Eq. 11.

Remark 3

(LSALSA provides sparse solutions) Since \(\hat{\mathcal {L}}_A\) employs the \(\ell _1\)-norm in the usual way and LSALSA’s \(\mathbf {u}\)-update is standard soft-thresholding, we can expect LSALSA to enforce sparsity given sufficient iterations.

We show in Sect. 5.2 that the optimal direction for \(\hat{\mathcal {L}}_A\) is related to the optimal direction for \(\mathcal {L}_A\), and in Sect. 5.3 we show that gradient descent along \(\hat{\mathcal {L}}_A\) is equivalent to a modified gradient descent along \(\mathcal {L}_A.\) For simplicity, we consider the case of learned, symmetric \(\mathbf {S}\) while holding fixed \(\mathbf {W_e}\equiv \mathbf {A}^{\text {T}}\).

Modified descent direction: deterministic framework

Though \(\hat{\mathcal {L}}_A\)’s dependence on \(\mathbf {u}\) and \(\mathbf {d}\) is standard in ADMM settings (Boyd et al. 2011), the learned data-fidelity term \(\hat{f_1}\) that commands \(\mathbf {x}\)-directions is now a data-driven quadratic form that relies on the weight matrix \(\mathbf {S}\) that parameterizes LSALSA. We will next rewrite the new cost function in terms of the original Augmented Lagrangian:

$$\begin{aligned} \hat{\mathcal {L}}_A(\mathbf {x},\mathbf {u},\mathbf {d}) = \mathcal {L}_A(\mathbf {x},\mathbf {u},\mathbf {d}) + \hat{f_1}(\mathbf {x};\mathbf {S}) - \frac{1}{2}\left\| \mathbf {y}-\mathbf {A}\mathbf {x}\right\| ^2_2. \end{aligned}$$
(18)

The optimality condition for \(\hat{\mathcal {L}}_A\) can be written

$$\begin{aligned} 0&=\nabla _{\mathbf {x}}\hat{\mathcal {L}}_A(\mathbf {x}^*, \mathbf {u}, \mathbf {d})\\&=\nabla _{\mathbf {x}}\left( \mathcal {L}_A(\mathbf {x}^*, \mathbf {u}, \mathbf {d}) +\hat{f_1}(\mathbf {x};\mathbf {S}) - \frac{1}{2}\left\| \mathbf {y}-\mathbf {A}\mathbf {x}\right\| ^2_2\right) \\&=\nabla _{\mathbf {x}}\mathcal {L}_A(\mathbf {x}^*, \mathbf {u}, \mathbf {d}) +\left[ \mathbf {S}^{-1}-\mu I - \mathbf {A}^{\text {T}}\mathbf {A}\right] \mathbf {x}^*. \end{aligned}$$

Then, using \(\nabla _{\mathbf {x}}^2\mathcal {L}_A=\mu I + \mathbf {A}^{\text {T}}\mathbf {A}\) we can write the LSALSA update as

$$\begin{aligned} 0&=\nabla _{\mathbf {x}}\mathcal {L}_A(\mathbf {x}^*, \mathbf {u}, \mathbf {d}) + \left[ \mathbf {S}^{-1} -\nabla _{\mathbf {x}}^2\mathcal {L}_A \right] \mathbf {x}^* \end{aligned}$$
(19)
$$\begin{aligned} \Rightarrow&\left[ \mathbf {S}^{-1} - \nabla _{\mathbf {x}}^2\mathcal {L}_A \right] \mathbf {x}^* = -\nabla _{\mathbf {x}}\mathcal {L}_A(\mathbf {x}^*, \mathbf {u}, \mathbf {d}). \end{aligned}$$
(20)

The root-finding problem posed in (19) and equivalent system of equations in (20) resemble a Newton-like update, but using a learned modification of the original Lagrangian’s Hessian matrix. Note that at initialization (using Formula 14), the left-hand-side cancels to zero, recovering the optimality condition for the original problem. This also admits an intuition that LSALSA is incorporating prior knowledge, learned from the training data, that could be made to balance between optimality of the original problem while maintaining some relationship with the training data distribution.

Modified descent direction: stochastic framework

We will next look at (L)SALSA through the prism of worst-case analysis, i.e. by replacing the optimal primal steps with stochastic gradient descent. This effectively enables us to analyze (L)SALSA as a stochastic alternated optimization approach solving a general saddle point problem, and we show that LSALSA leads to faster convergence under certain assumptions that we stipulate. Our analysis is a direct extension of that in Choromanska et al. (2019). We provide the final statement of the theorem below and defer all proofs to the supplement.

Problem formulation

Consider the following general saddle-point problem:

$$\begin{aligned} \max _{\phi _1,\ldots ,\phi _{K_2}}\min _{\theta _1,\ldots ,\theta _{K_1}}&\mathcal {L}_{}(\theta _1,\ldots ,\theta _{K_1};\phi _1,\ldots ,\phi _{K_2})\end{aligned}$$
(21)
$$\begin{aligned}&\Updownarrow \nonumber \\ \max _{\varvec{\phi }}\min _{\varvec{\theta }}&\mathcal {L}_{}(\varvec{\theta };\varvec{\phi }), \end{aligned}$$
(22)

using \(\varvec{\theta }= [\theta _1,\ldots ,\theta _{K_1}]\) to denote the collection of variables to be minimized, and \(\varvec{\phi }= [\phi _1,\ldots ,\phi _{K_2}]\) the variables to be maximized. We denote the entire collection of variables as \(\mathbf {x}=[\varvec{\theta }, \varvec{\phi }]\in {\mathbb {R}}^{K},\) where \(K=K_1+K_2\) is the total number of arguments. We denote with \(x_d\) the \(d\mathrm{th}\) entry in \(\mathbf {x}\). For theoretical analysis we consider a smooth function \(\mathcal {L}_{}\) as is often done in the literature (especially for \(\ell _1\) problems, as discussed in Lange et al. 2014; Schmidt et al. 2007).

Let \((x_1^*,\ldots ,x_K^*)\) be the optimal solution of the saddle point problem in (22), where \(\mathcal {L}_{}\) is computed over global data population (i.e. averaged over an infinite number of samples). For each variable \(x_d\), we assume a lower bound on the radii of convergence \(r_d>0\). Let \(\nabla _d^1 \mathcal {L}_{}\) denote the gradient of \(\mathcal {L}_{}\) with respect to the \(d\mathrm{th}\) argument evaluated on a single data sample (stochastic gradient), and \(\nabla _d \mathcal {L}_{}\) to be that with respect to the global data population (i.e. an “oracle gradient”).

We analyze an Alternating Optimization algorithm that, at the \(d\mathrm{th}\) step, optimizes \(\mathcal {L}_{}\) with respect to \(x_d\) while holding all other \(x_{i\ne d}\) fixed:

$$\begin{aligned} x_d^{t+1} = \varPi _d\left( x_d^t \pm \eta ^t\nabla _d^1\mathcal {L}_{x_d}^t\right) , \end{aligned}$$
(23)

using the ± symbol to denote gradient descent for \(d\le K_1\) and gradient ascent for \(d>K_1\). \(\varPi _d\) is the projection onto the Euclidean-ball \(B_2(\frac{r_d}{2},x_d^*),\) with radius \(\frac{r_d}{2}\) and centered around the optimal value \(x_d^*\): this ensures that for each d, all iterates of \(x_d\) remain within the \(r_d\)-ball around \(x_d^*\).Footnote 2

Assumptions

The following assumptions are necessary for the Theorems in Sect. 5.3.3. The mathematical definitions of strong-convexity, strong-concavity, and smoothness follow the standards from Nesterov (2013).

Assumption 1

(Convex–Concave) For each \(d\le K_1\), \(\mathcal {L}_{x_d}^*\) is \(\beta _d\)-convex, and for each \(d>K_1\), \(\mathcal {L}_{x_d}^*\) is \(\beta _d\)-concave within a ball around the solution \(x_d*\) of radius \(r_d\).

Assumption 2

(Smoothness) For all \(d\in \{1,\ldots ,K\}\), the function \(\mathcal {L}_{x_d}^*\) is \(\alpha _d\)-smooth.

In summary, for every \(d=1,\ldots ,K\), \(\mathcal {L}_{x_d}^*\) is either \(\beta _d\)-convex or concave in a neighborhood around the optimal point, and \(\alpha _d\)-smooth. Next we assume two standard properties on the gradient of the cost function.

Assumption 3

(Gradient Stability\(GS(\gamma _d)\)) We assume that for each \(d=1,\ldots ,K,\) the following gradient stability condition holds for \(\gamma _d\ge 0\) over the Euclidean ball \(x_d\in B_2(r_d,x_d^*)\):

$$\begin{aligned} \left\| \nabla _d\mathcal {L}_{x_d}^* - \nabla _d\mathcal {L}_{x_d}\right\| \le \gamma _d\sum _{i\ne d}\left\| x_i-x_i^*\right\| . \end{aligned}$$
(24)

Assumption 4

(Assumption A.6: Bounded Gradient) We assume that the expected value of the gradient of our objective function \(\mathcal {L}\) is bounded by \(\sigma = \sqrt{\sum _{d=1}^K \sigma _d^2}\), where:

$$\begin{aligned} \sigma _d = \sup \left\{ {\mathbb {E}}\left[ \left\| \nabla _d\mathcal {L}_{x_d}\right\| ^2\right] : x_d\in B_2(r_d,x_d^*),\ \forall d=1,\ldots ,K\right\} . \end{aligned}$$
(25)

Convergence statement

Denote with \(\varDelta _d^t=x_d^t-x_d^*\) the error of the \(t\mathrm{th}\) estimate of \(d\mathrm{th}\) element of the global optimizer \(\mathbf {x}^*\). Define the following:

$$\begin{aligned} {\mathcal {E}}_{\textsf {SALSA}}(\beta )=\left( \frac{2}{t+3}\right) ^{\frac{3}{2}}{\mathbb {E}}\left[ \sum _{d=1}^K\left\| \varDelta _d^0\right\| ^2\right] + \frac{9\sigma ^2}{[2\xi (\beta )-\gamma (2K-1)]^2(t+3)}, \end{aligned}$$
(26)

where \(\xi (\beta )\) increases monotonically with increasing \(\beta .\)

Theorem 2

(Convergence of SALSA and LSALSA) Suppose that cost functions underlying SALSA \(\mathcal {L}_A\) and LSALSA \(\hat{\mathcal {L}}_A\) satisfy the Assumptions in Sect. 5.3.2 with convexity modulii \(\beta \) and \({\hat{\beta }}\) (the latter is implicitly learned from the data). Assume also that the deep model representing LSALSA had enough capacity to learn \({\hat{\beta }}\) such that \({\hat{\beta }}>\beta ,\) while keeping the same location of the global optimal fixed point, \(\mathbf {x}^*\).Footnote 3

Then, using the Stochastic Alternating Optimization scheme in Eq. 23 on \(\mathcal {L}_A\) and \(\hat{\mathcal {L}}_A\) such that the requirements from Theorem 4 are satisfied, starting from the same initial point, the error satisfies the following:

for SALSA:

$$\begin{aligned} \sum _{d=1}^K\left\| \varDelta _d^{t+1}\right\| ^2 \le {\mathcal {E}}_{\textsf {SALSA}}(\beta ), \end{aligned}$$
(27)

and for LSALSA:

$$\begin{aligned} \sum _{d=1}^K\left\| \varDelta _d^{t+1}\right\| ^2 \le {\mathcal {E}}_{\textsf {LSALSA}}({\hat{\beta }}) = {\mathcal {E}}_{\textsf {SALSA}}(\beta ) - \varDelta _{\beta }, \end{aligned}$$
(28)

where

$$\begin{aligned} \varDelta _{\beta } = {\mathcal {O}} \left( \frac{{\hat{\beta }}^2 - \beta ^2}{(2\beta {\hat{\beta }})^2}\right) . \end{aligned}$$
(29)

The above theorem states that, given enough capacity of the deep model, LSALSA can learn steeper descent direction than SALSA. We provide below an intuition for that. Consider the gradient descent step (or its stochastic approximation) for \(\hat{\mathcal {L}}_A\) in the \(\mathbf {x}\)-direction as given below

$$\begin{aligned} \mathbf {x}^{(k+1)}&=\mathbf {x}^{(k)}-\eta ^k\nabla _{\mathbf {x}}\hat{\mathcal {L}}_A(\mathbf {x}^{(k)}, \mathbf {u}^{(k+1)}, \mathbf {d}^{(k)}) \nonumber \\&=\mathbf {x}^{(k)}-\eta ^k\nabla _{\mathbf {x}}\left( \mathcal {L}_A^k +\phi (\mathbf {x};\mathbf {S}) - \frac{1}{2}\left\| \mathbf {y}-\mathbf {A}\mathbf {x}\right\| ^2_2\right) \nonumber \\&=\mathbf {x}^{(k)}-\eta ^k\nabla _{\mathbf {x}}\mathcal {L}_A^k -\eta ^k\left[ \mathbf {S}^{-1}-\mu I - \mathbf {A}^{\text {T}}\mathbf {A}\right] \mathbf {x}^{(k)} \nonumber \\&=\underbrace{\mathbf {x}^{(k)}-\eta ^k\nabla _{\mathbf {x}}\mathcal {L}_A^k}_{\text {unlearned descent step}} -\eta ^k\left[ \mathbf {S}^{-1} -\nabla _{\mathbf {x}}^2\mathcal {L}_A \right] \mathbf {x}^{(k)} \nonumber \\&=\left[ I- \eta ^k P\right] \mathbf {x}^{(k)}-\eta ^k\nabla _{\mathbf {x}}\mathcal {L}_A^k, \end{aligned}$$
(30)

where \(P:=\mathbf {S}^{-1} -\nabla _{\mathbf {x}}^2\mathcal {L}_A\).

This update can be seen as taking first a gradient descent step and then pushing the optimizer further in the learned direction, which we empirically show is a faster direction of decent.

Numerical experiments

We now present a variety of sparse coding inference tasks to evaluate our algorithm’s speed, accuracy, and sparsity trade-offs. For each task (including both SC and MCA), we consider a variety of settings of T, i.e. the number of iterations, and do a full hyperparameter grid search for each setting. In other words, we ask “how well can each encoding algorithm approximate the optimal codes, given a fixed number of stages?”. We compare LSALSA, truncated SALSA, truncated FISTA, and LISTA (Gregor and LeCun 2010) in terms of their RMSE proximity to optimal codes, sparsity levels, and performance on classification tasks. Both LSALSA and LISTA are implemented as feedforward neural networks. For MCA experiments, we run FISTA and LISTA using the concatenated dictionary \(\mathbf {A}\).

We focus on the inference problem and thus learn the dictionaries off-line as described in Sect. 3. Dictionary learning is performed only once for each data set, and the resulting dictionaries are held constant across all methods and experiments herein (visualization of the atoms of the obtained dictionaries can be found in Section F in the Supplement). For MCA, the independently-learned dictionaries are still used, creating difficult ill-conditioned problems (because each dictionary is able to at least partially represent both components).

To train the encoders, we minimize Eq. 15 with respect to \(\mathbf {W_e}\) and \(\mathbf {S}\) using vanilla Stochastic Gradient Descent (SGD). We considered the optimization complete after a fixed number of epochs, or when the relative change in cost function fell below a threshold of \(10^{-6}\). During hyperparameter grid searches, only 10 epochs through the training data were allowed; for testing, 100 epochs of training were allowed (usually the tolerance was reached before 100 epochs). The optimal codes are determined prior to training by solving the convex inference problem with fixed \(\alpha ^*\) and \(\mu ^*\), e.g. by running FISTA or SALSA to convergence (details are discussed in each section). In order to set the \(\alpha ^*,\mu ^*\), we fix \(\mu ^*=10\) and tune \(\alpha ^*\) to yield an average sparsity of at least 89%. We then slowly increase \(\alpha *\)s until just before the optimal sparse codes’ fail to provide recognizable image reconstructions. We take the simplest approach to image reconstruction: simply multiplying the sparse code with its corresponding dictionary. No additional learning was performed to achieve reconstruction: i.e. for LSALSA we have \(\mathbf {A}_i\cdot (f_e(\mathbf {W}_e,\mathbf {S},\mathbf {y}))_i\), where \(f_e(\mathbf {W}_e,\mathbf {S},\mathbf {y}))_i\) represents the i-th component of the encoder’s output.

We implemented the experiments in Lua using Torch7, and executed the experiments on a 64-bit Linux machine with 32GB RAM, i7-6850K CPU at 3.6 GHz, and GTX 1080 8GB GPU. The hyperparameters were selected via a grid search with specific values listed in the Supplement, Section E.

Single dictionary (SC) case

We run SC experiments with four data sets: Fashion MNIST (Xiao et al. 2017) (10 classes), ASIRRA (Elson et al. 2007) (2 classes), MNIST (LeCun et al. 2009) (10 classes), and CIFAR-10 (Krizhevsky and Hinton 2009) (10 classes). The ASIRRA data set is a collection of natural images of cats and dogs. We use a subset of the whole data set: 4000 training images and 1000 testing images as commonly done (Golle 2008). The results for MNIST and CIFAR-10 are reported in Section G in the Supplement.

The \(32\times 32\) Fashion MNIST images were first divided into \(10\times 10\) non-overlapping patches (ignoring extra pixels on two edges), resulting in 9 patches per image. Then, optimal codes were computed for each vectorized patch by minimizing the objective from Eq. 4 with FISTA for 200 iterations. The ASIRRA images come in varying sizes. We resized them to the resolution of \(224\times 224\) via Torch7’s bilinear interpolation and converted each image to grayscale. Then we divided them into \(16\times 16\) non-overlapping patches, resulting in 196 patches per image. Optimal codes were computed patch-wise as for Fashion MNIST, but taking 700 iterations to ensure convergence on this more difficult SC problem. For Fashion MNIST we selected \(\alpha ^*=0.15\) and for ASIRRA, \(\alpha ^*=0.5.\) using criteria mentioned earlier in the Section.

The data sets were then separated into training and testing sets. The training patches were used to produce the dictionaries. Visualizations of the dictionary atoms are provided in Section F in the Supplement. An exhaustive hyper-parameter searchFootnote 4 was performed for each encoding method and for each number of iterations T, to minimize RMSE between obtained and optimal codes. The hyper-parameter search included \(\alpha \) for all methods, \(\mu \) for SALSA and LSALSA, as well as SGD learning rates and learning rate decay schedules for LSALSA and LISTA training.

The obtained encoders were used to compute sparse codes on the test set. Those were then compared with the optimal codes via RMSE. The results for Fashion MNIST are shown both in terms of the number of iterations and the wallclock time in seconds used to make the prediction (Fig. 3). It takes FISTA more than 15 iterations and SALSA more than 5 to reach the error achieved by LSALSA in just one. Near \(T=100\), both FISTA and SALSA are finally converging to the optimal codes. LISTA outperforms FISTA at first, but does not show much improvement after \(T>10\). Similar results for ASIRRA are shown in the same figure. On this more difficult problem, it takes FISTA more than 50 iterations and SALSA more than 20 to catch up with LSALSA with a single iteration. LISTA and LSALSA are comparable for \(T\le 5\), after which LSALSA dramatically improves its optimal code prediction and, similarly as in case of Fashion MNIST, shows advantage in terms of the number of iterations, inference time, and the quality of the recovered sparse codes over other methods.

Fig. 3
figure3

Code prediction error as a function a iteration count, and b inference wallclock time for Fashion MNIST (a, b) and ASIRRA (c, d)

We also investigated which method yields better codes in terms of a classification task. We trained a logistic regression classifier to predict the label from the corresponding optimal sparse code, then ask: “can the classifier still recognize a fast encoder’s estimate to the optimal code?”. For Fashion MNIST each image is associated with 9 optimal codes (one for each patch), yielding a total feature length of \(9\times 10\times 10=900\). The Fashion MNIST classifier was trained until it achieved \(0\%\) classification error on the optimal codes. For ASIRRA, each concatenated optimal code had length \(196\times 16\times 16=50{,}176\); to reduce the dimensionality we applied a random Gaussian projection \({\mathcal {G}}:{\mathbb {R}}^{50{,}176}\rightarrow {\mathbb {R}}^{500}\) before inputting the codes into the classifier. The classifier was trained on the optimal projected codes of length 500 until it achieved \(0.5\%\) error. The results for Fashion MNIST and ASIRRA are shown in Table 3 and 4, respectively, in Section G in the Supplement. Note: The classifier was trained on the target test codes so that the resulting classification error is only due to the difference between the optimal and estimated codes. In conclusion, although the FISTA, LISTA, or SALSA codes may not look that much worse than LSALSA in terms of RMSE, we see in the Tables that the expert classifiers cannot recognize the extracted codes, despite being trained to recognize the optimal codes which the algorithms seek to approximate.

MCA: two-dictionary case

Data preparation

We now describe the dataset that we curated for the MCA experiments. We address the problem of decoupling numerals (text) from natural images, a topic closely related to text detection in natural scenes (Liu et al. 2017; Tian et al. 2015; Yan et al. 2018). Following the notation introduced previously in the paper, we set \(\mathbf {y}_1^p\)s to be the whole \(32\times 32\) MNIST images and \(\mathbf {y}_2^p\)s to be non-overlapping \(32\times 32\) patches from ASIRRA (thus we have 49 patches per image). We obtain 196 k training and 49 k testing patches from ASIRRA, and 60 k training and 10 k testing images from MNIST. We add together randomly selected MNIST images and ASIRRA patches to generate 588 k mixed training images and 49 k mixed testing images. Optimal codes were computed using SALSA (Algorithm 2) for 100 iterations, ensuring that each component had a sparsity level greater than \(89\%\), while retaining visually recognizable reconstructions. The values selected were \(\alpha _1=0.125^*,\)\(\alpha _2^*=0.2\), \(\mu ^*=10\). We also performed MCA experiments on additive mixtures of CIFAR-10 and MNIST images. Those results can be found in Section H in the Supplement.

Results

An exhaustive hyper-parameter search was performed for each encoding method and each number of iterations T. The hyper-parameters search included \(\alpha \) for FISTA and LISTA, \(\alpha _1,\alpha _2,\mu \) for SALSA and LSALSA, as well as SGD learning rates for LSALSA and LISTA training. The code prediction error curves are presented in Fig. 4. LSALSA steadily outperforms the others, until SALSA catches up around \(T=50\). FISTA and LISTA, without a mechanism for distinguishing two dictionaries, struggle to estimate the optimal codes (Fig. 5).

Fig. 4
figure4

MCA experiment using MNIST + ASIRRA data set. (left) Code prediction errors for varying numbers of iterations. (right) Code prediction error versus inference wallclock time

Fig. 5
figure5

MCA experiment separating MNIST + ASIRRA components: the trade-off between the sparse codes classification error versus their inference time is captured for different network lengths on (left) for MNIST (right) for ASIRRA

Fig. 6
figure6

Sparsity/accuracy trade-off analysis for ASIRRA obtained for the source separation experiment with MNIST + ASIRRA data set. Each method corresponds to a colored point cloud, where each point corresponds to one sample from the ASIRRA test data set. LSALSA (black) achieves the higher sparsity and/or lower code estimation error than the other methods for each T

In Fig. 6 we illustrate each method’s sparsity/accuracy trade-off on the ASIRRA test data set, while varying T (Supplemental Section I contains a similar plot for MNIST). For each data point in the test set, we plot its sparsity versus RMSE code-error, resulting in a point-cloud for each algorithm. For example, a sparsity value of 0.6 corresponds to 60% of the code elements being equal to zero. These point clouds represent the tradeoff between sparsity and fidelity to the original targets (e.g. proximity to the global solution as defined in original the convex problem). For each T, the (black) LSALSA point-cloud is generally further to the right and/or located below the other point-clouds, representing higher sparsity and/or lower error, respectively. For example, while FISTA achieves some mildly sparser solutions for \(T=10, 20\), it significantly sacrifices RMSE. In this sense, we argue that LSALSA enjoys the best sparsity-accuracy trade-off from among the four methods.

Table 1 MNIST classification error obtained after source separation (10 classes). The best performer is in bold
Table 2 ASIRRA classification error obtained after source separation (2 classes). The best performer is in bold

Similarly as before, we performed an evaluation on the classification task. A separate classifier was trained for each data set using the separated optimal codes \(\mathbf {x}_1^{*,p}\) and \(\mathbf {x}_2^{*,p}\), respectively. As before, a random Gaussian projection was used to reduce the ASIRRA codes to the length 500 before inputting to the classifier. The classification results are depicted in Table 1 for MNIST and Table 2 for ASIRRA.

Fig. 7
figure7

MCA experiment using MNIST + ASIRRA. Image reconstructions obtained by SALSA, LSALSA, FISTA, LISTA for \(T = 1,5\). Top row: original data (components and mixed)

Finally, in Fig. 7 we present exemplary reconstructed images obtained by different methods when performing source separation (more reconstruction results can be found in Section J in the Supplement). FISTA and LISTA are unable to separate components without severely corrupting the ASIRRA component. LSALSA has visually recognizable separations even at \(T=1\), and the MNIST component is almost gone by \(T=5\). Recall that no additional learning is employed to generate reconstructions, they are simply codes multiplied by corresponding dictionary matrices.

Conclusions

In this paper we propose a deep encoder architecture LSALSA, obtained from time-unfolding the split augmented lagrangian shrinkage algorithm (SALSA). We empirically demonstrate that LSALSA inherits desired properties from SALSA and outperforms baseline methods such as SALSA, FISTA, and LISTA in terms of both the quality of predicted sparse codes, and the running time in both the single and multiple (MCA) dictionary case. In the two-dictionary MCA setting, we furthermore show that LSALSA obtains the separation of image components faster, and with better visual quality than the separation obtained by SALSA. The LSALSA network can tackle the general single and multiple dictionary coding problems without extension, unlike common competitors.

We also present a theoretical framework to analyze LSALSA. We show that the forward propagation of a signal through the LSALSA network is equivalent to a truncated ADMM algorithm applied to a new, learned cost function that generalizes the original problem. We show via the optimality conditions for this new cost function that the LSALSA update is related to a “learned pseudo-Newton” update down the original loss landscape, whose descent direction is corrected by a learned modification of the Hessian of the original cost function. Finally, we extend a very recent Stochastic Alternating Optimization analysis framework to show that a gradient descent step down the learned loss landscape is equivalent with taking a modified gradient descent step along the original loss landscape. In this framework we provide conditions under which LSALSA’s descent direction modification can speed up convergence.

Notes

  1. 1.

    In this paper we consider the MCA framework with two dictionaries. Extensions to more than two dictionaries are straightforward.

  2. 2.

    this assumption can be potentially eliminated with carefully selected initial stepsizes.

  3. 3.

    LSALSA is trained to keep the same global fixed point, see Eq. 15.

  4. 4.

    The parameter settings that we explored in all our experiments are provided in the Supplement.

References

  1. Adler, J., & Öktem, O. (2017). Learned primal-dual reconstruction. CoRR arXiv:1707.06474.

  2. Afonso, M., Bioucas-Dias, J., & Figueiredo, M. (2010). Fast image recovery using variable splitting and constrained optimization. IEEE Transactions on Image Processing, 19(9), 2345–2356.

    MathSciNet  MATH  Article  Google Scholar 

  3. Afonso, M., Bioucas-Dias, J., & Figueiredo, M. (2011). An augmented Lagrangian approach to the constrained optimization formulation of imaging inverse problems. IEEE Transactions on Image Processing, 20(3), 681–695.

    MathSciNet  MATH  Article  Google Scholar 

  4. Bauschke, H. H., & Combettes, P. L. (2011). Convex analysis and monotone operator theory in Hilbert spaces (1st ed.). Berlin: Springer.

    Google Scholar 

  5. Beck, A., & Teboulle, M. (2009). A fast iterative shrinkage-thresholding algorithm for linear inverse problems. SIAM: SIAM Journal on Imaging Sciences, 2(1), 183–202.

    MathSciNet  MATH  Google Scholar 

  6. Borgerding, M., & Schniter, P. (2016) Onsager-corrected deep learning for sparse linear inverse problems. In GlobalSIP.

  7. Boyd, S., Parikh, N., Chu, E., Peleato, B., & Eckstein, J. (2011). Distributed optimization and statistical learning via the alternating direction method of multipliers. Foundations and Trends® in Machine Learning, 3(1), 1–122.

    MATH  Article  Google Scholar 

  8. Chen, X., Liu, J., Wang, Z., & Yin, W. (2018). Theoretical linear convergence of unfolded ISTA and its practical weights and thresholds. arXiv preprint arXiv:1808.10038.

  9. Choromanska, A., Cowen, B., Kumaravel, S., Luss, R., Rish, I., Kingsbury, B., Tejwani, R., & Bouneffouf, D. (2019). Beyond backprop: Alternating minimization with co-activation memory. arXiv preprint arXiv:1806.09077v3.

  10. Daubechies, I., Defrise, M., & De Mol, C. (2004). An iterative thresholding algorithm for linear inverse problems with a sparsity constraint. Communications on Pure and Applied Mathematics, 57(11), 1413–1457.

    MathSciNet  MATH  Article  Google Scholar 

  11. Eckstein, J., & Bertsekas, D. (1992). On the Douglas–Rachford splitting method and the proximal point algorithm for maximal monotone operators. Mathematical Programming, 55, 293–318.

    MathSciNet  MATH  Article  Google Scholar 

  12. Elad, M., Starck, J. L., Querre, P., & Donoho, D. L. (2005). Simultaneous cartoon and texture image inpainting using morphological component analysis (MCA). Applied and Computational Harmonic Analysis, 19(3), 340–358.

    MathSciNet  MATH  Article  Google Scholar 

  13. Elson, J., Douceur, J., Howell, J., & Saul, J. (2007). Asirra: A CAPTCHA that exploits interest-aligned manual image categorization. In ACM CCS.

  14. Figueiredo, M., Bioucas-Dias, J., & Afonso, M. (2009). Fast frame-based image deconvolution using variable splitting and constrained optimization. In Proceedings of IEEE workshop on statistical signal processing (pp. 109–112).

  15. Gers, F. A., Schraudolph, N. N., & Schmidhuber, J. (2003). Learning precise timing with LSTM recurrent networks. Journal of Machine Learning Research, 3, 115–143.

    MathSciNet  MATH  Google Scholar 

  16. Goldstein, T., O’Donoghue, B., & Setzer, S. (2014). Fast alternating direction optimization methods. SIAM Journal on Imaging Sciences, 7, 1588–1623.

    MathSciNet  MATH  Article  Google Scholar 

  17. Golle, P. (2008). Machine learning attacks against the Asirra CAPTCHA. In ACM CCS.

  18. Greff, K., Srivastava, R. K., & Schmidhuber, J. (2016). Highway and residual networks learn unrolled iterative estimation. arXiv preprint arXiv:1612.07771.

  19. Gregor, K., & LeCun, Y. (2010). Learning fast approximations of sparse coding. In ICML.

  20. Jarrett, K., Kavukcuoglu, K., Koray, M., & LeCun, Y. (2009). What is the best multi-stage architecture for object recognition? In ICCV.

  21. Kavukcuoglu, K., Ranzato, M. A., & LeCun, Y. (2010). Fast inference in sparse coding algorithms with applications to object recognition. CoRR arXiv:1010.3467.

  22. Krizhevsky, A., & Hinton, G. (2009). Learning multiple layers of features from tiny images (Vol. 1, No. 4, p. 7). Technical report, University of Toronto.

  23. Lange, M., Zühlke, D., Holz, O., Villmann, T. (2014). Applications of LP-norms and their smooth approximations for gradient based learning vector quantization. In ESANN.

  24. Le Roux, J., Hershey, J. R., & Weninger, F. (2015). Deep NMF for speech separation. In ICASSP.

  25. LeCun, Y., Bottou, L., Bengio, Y., & Haffner, P. (2009). Gradient-based learning applied to document recognition. In Proceedings of the IEEE.

  26. Liao, Q., & Poggio, T. (2016). Bridging the gaps between residual learning, recurrent neural networks and visual cortex. arXiv preprint arXiv:1604.03640.

  27. Liu, S., Xian, Y., Li, H., & Yu, Z. (2017). Text detection in natural scene images using morphological component analysis and laplacian dictionary. IEEE/CAA Journal of Automatica Sinica, PP(99), 1–9.

    Article  Google Scholar 

  28. Moreau, T., & Bruna, J. (2016). Understanding trainable sparse coding with matrix factorization. arXiv preprint arXiv:1609.00285.

  29. Nesterov, Y. (2013). Introductory lectures on convex optimization: A basic course (Vol. 87). Berlin: Springer.

    Google Scholar 

  30. Olshausen, B., & Field, D. (1996). Emergence of simple-cell receptive field properties by learning a sparse code for natural images. Nature, 381, 607–609.

    Article  Google Scholar 

  31. Orhan, E., & Pitkow, X. (2018). Skip connections eliminate singularities. In International conference on learning representations.

  32. Otazo, R., Candès, E., & Sodickson, D. K. (2015). Low-rank and sparse matrix decomposition for accelerated dynamic MRI with separation of background and dynamic components. Magnetic Resonance in Medicine, 73(3), 1125–36.

    Article  Google Scholar 

  33. Parekh, A., Selesnick, I., Rapoport, D., & Ayappa, I. (2014). Sleep spindle detection using time-frequency sparsity. In IEEE SPMB.

  34. Peyré, G., Fadili, J., & Starck, J. L. (2007). Learning adapted dictionaries for geometry and texture separation. In SPIE Wavelets.

  35. Peyré, G., Fadili, J., & Starck, J. L. (2010). Learning the morphological diversity. SIAM Journal on Imaging Sciences, 3(3), 646–669.

    MathSciNet  MATH  Article  Google Scholar 

  36. Schmidt, M., Fung, G., & Rosales, R. (2007). Fast optimization methods for l1 regularization: A comparative study and two new approaches. In J. N. Kok, J. Koronacki, R. L. D. Mantaras, S. Matwin, D. Mladenič, A. Skowron (Eds.), ECML.

  37. Selesnick, I. (2014). L1-norm penalized least squares with salsa. Connexions (p. 66). Retrieved March 1, 2017 from http://cnx.org/contents/e980d3cd-f201-4ef6-8992-d712bf0a88a3@5.

  38. Shoham, N., & Elad, M. (2008). Algorithms for signal separation exploiting sparse representations, with application to texture image separation. In Proceedings of the IEEE 25th convention of electrical and electronics engineers in Israel.

  39. Sprechmann, P., Litman, R., Yakar, T., Bronstein, A., & Sapiro, G. (2013). Efficient supervised sparse analysis and synthesis operators. In NIPS.

  40. Starck, J. L., Elad, M., & Donoho, D. (2004). Redundant multiscale transforms and their application for morphological component separation. Advances in Imaging and Electron Physics, 132, 287–348.

    Article  Google Scholar 

  41. Starck, J. L., Elad, M., & Donoho, D. (2005a). Image decomposition via the combination of sparse representations and a variational approach. IEEE Transactions on Image Processing, 14(10), 1570–1582.

    MathSciNet  MATH  Article  Google Scholar 

  42. Starck, J. L., Moudden, Y., Bobina, J., Elad, M., Donoho, D. (2005b). Morphological component analysis. In Proceedings of SPIE Wavelets.

  43. Tian, S., Pan, Y., Huang, C., Lu, S., Yu, K., & Lim Tan, C. (2015). Text flow: A unified text detection system in natural scene images. In Proceedings of the IEEE international conference on computer vision (pp. 4651–4659).

  44. Uysal, F., Selesnick, I., & Isom, B. (2016). Mitigation of wind turbine clutter for weather radar by signal separation. IEEE Transactions on Geoscience and Remote Sensing, 54(5), 2925–2934.

    Article  Google Scholar 

  45. Wang, Y., Yin, W., & Zeng, J. (2019). Global convergence of ADMM in nonconvex nonsmooth optimization. Journal of Scientific Computing, 78(1), 29–63. https://doi.org/10.1007/s10915-018-0757-z.

    MathSciNet  MATH  Article  Google Scholar 

  46. Wang, Z., Ling, Q., & Huang, T. (2016). Learning deep L0 encoders. In AAAI.

  47. Wisdom, S., Powers, T., Pitton, J., & Atlas, L. (2017). Deep recurrent NMF for speech separation by unfolding iterative thresholding. In IEEE workshop on applications of signal processing to audio and acoustics (WASPAA) (pp. 254–258).

  48. Xiao, H., Rasul, K., & Vollgraf, R. (2017). Fashion-MNIST: A novel image dataset for benchmarking machine learning algorithms. CoRR arXiv:1708.07747.

  49. Yan, C., Xie, H., Liu, S., Yin, J., Zhang, Y., & Dai, Q. (2018). Effective Uyghur language text detection in complex background images for traffic prompt identification. IEEE Transactions on Intelligent Transportation Systems, 19(1), 220–229.

    Article  Google Scholar 

  50. Yang, Y., Sun, J., Li, H., & Xu, Z. (2016). Deep ADMM-net for compressive sensing MRI. In NIPS.

  51. Zhou, J., Di, K., Du, J., Peng, X., Yang, H., Pan, S.J., Tsang, I. W., Liu, Y., Qin, Z., & Goh, R. (2018). SC2Net: Sparse LSTMs for sparse coding. In AAAI.

Download references

Author information

Affiliations

Authors

Corresponding author

Correspondence to Benjamin Cowen.

Additional information

Publisher's Note

Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Editors: Karsten Borgwardt, Po-Ling Loh, Evimaria Terzi, Antti Ukkonen.

Electronic supplementary material

Below is the link to the electronic supplementary material.

Supplementary material 1 (pdf 4165 KB)

Rights and permissions

Reprints and Permissions

About this article

Verify currency and authenticity via CrossMark

Cite this article

Cowen, B., Saridena, A.N. & Choromanska, A. LSALSA: accelerated source separation via learned sparse coding. Mach Learn 108, 1307–1327 (2019). https://doi.org/10.1007/s10994-019-05812-3

Download citation

Keywords

  • Sparse coding
  • Morphological component analysis
  • Deep learning