Entropy-SGD is a first-order optimization method which has been used successfully to train deep neural networks. This algorithm, which was motivated by statistical physics, is now interpreted as gradient descent on a modified loss function. The modified, or relaxed, loss function is the solution of a viscous Hamilton–Jacobi partial differential equation (PDE). Experimental results on modern, high-dimensional neural networks demonstrate that the algorithm converges faster than the benchmark stochastic gradient descent (SGD). Well-established PDE regularity results allow us to analyze the geometry of the relaxed energy landscape, confirming empirical evidence. Stochastic homogenization theory allows us to better understand the convergence of the algorithm. A stochastic control interpretation is used to prove that a modified algorithm converges faster than SGD in expectation.
The empirical loss is a sample approximation of the expected loss, \(\mathbb {E}_{x \sim P} f(x)\), which cannot be computed since the data distribution P is unknown. The extent to which the empirical loss (or a regularized version thereof) approximates the expected loss relates to generalization, i.e., the value of the loss function on (“test” or “validation”) data drawn from P but not part of the training set D.
For example, the ImageNet dataset [38] has \(N = 1.25\) million RGB images of size \(224\times 224\) (i.e., \(d \approx 10^5\)) and \(K=1000\) distinct classes. A typical model, e.g., the Inception network [65] used for classification on this dataset has about \(N = 10\) million parameters and is trained by running (7) for \(k \approx 10^5\) updates; this takes roughly 100 hours with 8 graphics processing units (GPUs).
AO is supported by a grant from the Simons Foundation (395980); PC and SS by ONR N000141712072, AFOSR FA95501510229 and ARO W911NF151056466731CS; SO by ONR N000141410683, N000141210838, N000141712162 and DOE DE-SC00183838. AO would like to thank the hospitality of the UCLA mathematics department where this work was completed.
