Keywords

1 Introduction

Reinforcement learning (RL) embodies a learning paradigm inspired by biological systems and demonstrates significant potential in various domains such as computer games [1, 2], robotics [3, 4], and autonomous driving [5,6,7]. However, RL often requires extensive data gathering, making distributed training techniques like synchronous and asynchronous stochastic gradient descent (SGD) essential.

Synchronous SGD, as used in OpenAI’s advantageous synchronous advantage actor-critic (A2C), requires servers to wait for gradient updates from all workers before updating models. While enabling linear scaling, this approach suffers from delays caused by slower workers that can undermine intended acceleration [8]. Asynchronous SGD methods like DeepMind’s asynchronous advantage actor-critic (A3C) mitigate this issue by allowing more frequent model updates without synchronization [9, 10]. However, asynchronous training faces the challenge of potentially using outdated model gradients, which can disrupt training stability [11]. Mitigating gradient staleness is thus a critical challenge. Dutta et al. [12] proposed an adaptive algorithm to balance straggling and staleness by regulating synchronous workers. Chen et al. [11] suggested using backup workers to handle stragglers and prevent staleness. However, these methods merely aim to strike a delicate balance between the challenges of gradient straggling and gradient staleness, without effectively addressing the underlying systemic issues that lead to these problems in distributed training.

To tackle this issue, this paper proposes a gradient correction algorithm designed specifically for asynchronous SGD in RL, with a primary focus on addressing the stale gradient issue in asynchronous parallelism. Our primary contributions can be summarized as follows:

  1. 1.

    We propose a gradient correction algorithm that leverages second-order information within the worker process and incorporates current parameters from both worker and server processes. This approach yields an improved gradient closer to the target value, effectively harnessing parallel resources while ensuring model convergence and stability.

  2. 2.

    We present an asynchronous training scheme incorporating gradient correction within the generalized policy iteration framework. The simulation results on an autonomous vehicle demonstrate that this innovative approach exhibits significantly improved convergence speed and policy performance compared to the current asynchronous update scheme.

2 Preliminaries

Reinforcement Learning. RL involves an agent interacting with the environment to learn behaviors that maximize the objective. Given state \(x_t\in \mathcal {X}\) at time t, the agent takes action \(u_t\in \mathcal {U}\) based on its policy \(\pi _\theta :\mathcal {X}\rightarrow \mathcal {U}\) parameterized by \(\theta \). It then receives utility \(l_t\) and transitions to the next state \(x_{t+1}\). The primary goal of RL is to find the optimal policy \(\pi _{\theta }^*\) that minimizes expected accumulated utilities, i.e., \(\min J(\theta )=\mathbb {E} \sum _{t=0}^{T} l(x_t, \pi _\theta (x_t))\). The generalized policy iteration (GPI) framework is commonly employed to iteratively solve RL problems. Specifically, the policy evaluation step utilizes the current policy to forecast accumulated future utilities, while the policy improvement step updates the policy parameters \(\theta \) in order to minimize the objective J [7].

Stochastic Gradient Descent Optimization. RL commonly employs stochastic gradient descent (SGD) optimization. At iteration k, given the stochastic gradient \(\nabla _\theta J(\theta ^k)\) obtained by utilizing a batch of randomly sampled data, the parameter updating rule is as follows:

$$\begin{aligned} \begin{array}{lll} \theta ^{k+1} = \theta ^{k} - \alpha \nabla _\theta J(\theta ^k), \end{array} \end{aligned}$$
(1)

where \(\alpha \) is the learning rate that determines the extent of parameter adjustments.

Synchronous Training. In synchronous training, a parameter server aggregates gradients from all worker nodes. Upon accumulating gradients from all workers, the parameter server updates the policy parameters and communicates the updated parameters back to the workers, ensuring all workers maintain consistent and up-to-date parameters. However, this synchronization can introduce delays from waiting for all workers to synchronize, potentially leading to the slow straggler issue where updates only occur after the slowest worker completes the batch assembly and gradient computation.

Asynchronous Training. In the asynchronous approach, workers independently update shared policy parameters without waiting for synchronization, enabling immediate parameter updates once each worker finishes gradient computation. However, this naive asynchronous approach is vulnerable to the stale gradient issue, where workers compute gradients based on outdated networks. For instance, while one worker is still processing data, other workers may have already updated the network multiple times due to variations in their computing intervals.

3 Asynchronous Training with Gradient Correction

In this section, we propose a gradient correction algorithm to mitigate the stale gradient issue that arises from the lack of synchronization during the asynchronous training process. For simplicity, we consider a scenario where only two workers (A and B) are involved in the asynchronous training process. However, it is worth noting that our analysis also applies to more general cases.

Fig. 1.
figure 1

Gradient correct process.

The gradient correction process is shown in Fig. 1, where purple ellipses outline the objective contours. We assume that both workers A and B have identical policy parameters \(\theta _w\), and worker A completes the gradient calculation first. However, when the server uses worker A’s gradient to update parameters to \(\theta _s\), worker B’s gradient becomes stale as it lags behind the policy parameters on the server. Specifically, at point \(\theta _s\), the desired gradient \(\nabla J(\theta _s)\) should resemble the dotted purple line, deviating from the stale gradient \(\nabla J(\theta _w)\) represented by the solid blue line.

To correct the stale gradient by aligning it with the desired gradient, we first employ Taylor expansion on the objective, yielding:

$$\begin{aligned} J(\theta _s) = J(\theta _w) + \nabla J(\theta _w)^\top (\theta _s-\theta _w) + \dfrac{1}{2} (\theta _s - \theta _w)^{\top }H(\theta _s)(\theta _s - \theta _w). \end{aligned}$$
(2)

By deriving both sides and treating higher-order terms as infinitesimal, we further derive:

$$\begin{aligned} \begin{array}{lll} \nabla J(\theta _s) = \nabla J(\theta _w) + H(\theta _w)(\theta _s - \theta _w), \end{array} \end{aligned}$$
(3)

where H is the Hessian matrix, and the gradient correction can be defined as

$$\begin{aligned} \begin{array}{lll} \varDelta = H(\theta _w) (\theta _s - \theta _w). \end{array} \end{aligned}$$
(4)

Hence, the utilization of second-order information enables the correction of stale gradients caused by parameter discrepancies between the worker and server. Moreover, we seamlessly integrate this gradient correction mechanism into asynchronous training, where workers are required to transmit gradient, Hessian matrix, and local parameters to the server for computing the corrected gradient:

$$\begin{aligned} \begin{array}{lll} G_c = G_i + H_i(\theta _i)(\theta ^k - \theta _i), \end{array} \end{aligned}$$
(5)

where \(G_i, H_i, \theta _i\) denote the gradient, Hessian matrix, and policy parameters from worker i respectively, \(\theta ^k\) are the current parameters on the server, and \(G_c\) represents the corrected gradient. Figure 2 illustrates the proposed asynchronous training scheme with gradient correction within the GPI framework.

Fig. 2.
figure 2

Asynchronous training with gradient correction.

4 Simulation

In this section, the efficacy of our proposed asynchronous training scheme with gradient correction is evaluated in sinusoidal trajectory tracking tasks for an autonomous vehicle. The learning curves are depicted in Fig. 3, where the naive asynchronous scheme is labeled as “Async”, and our proposed approach is denoted as “Modify”.

Fig. 3.
figure 3

Learning curves.

Figure 3a illustrates loss curves plotted against training iterations. For a fair comparison, we employ identical random seeds and hyperparameters for two training schemes. It is evident that both loss curves initially decrease from the same level. However, our scheme exhibits a more rapid decline and achieves convergence in approximately only 1,000 iterations. Conversely, the naive scheme displays a slower reduction in loss and fails to converge within the specified maximum iterations set for the experiment. These findings indicate that our scheme demonstrates superior effectiveness in algorithm convergence without being affected by stale gradients.

Figure 3b demonstrates the comparison of gradient accuracy between the two distinct training schemes during the training process. In both schemes, the server consistently retains and updates the latest policy parameters, which are considered as ground truth for computing desired gradients. On the other hand, the stale gradient is derived from gradients returned by workers. Therefore, we quantify gradient accuracy through error analysis between these two gradients. As shown in Fig. 3b, our proposed scheme initially exhibits higher gradient errors but rapidly reduces this error to ultimately converge to zero. In contrast, the naive scheme demonstrates a relatively slow decline in gradient error over iterations. These results highlight the effectiveness of our proposed training scheme in mitigating stale gradient issues by incorporating second-order information, further supporting experimental observations depicted in Fig. 3a.

Figure 3c presents the tracking performance on the sinusoidal curve. By utilizing the policy parameters saved at the maximal iteration and starting from the same initialization, our proposed training scheme successfully achieves precise trajectory tracking, closely following the reference trajectory. Conversely, employing the naive training scheme results in a policy that merely learns a rudimentary feedback control, demonstrating poor tracking capabilities.

5 Conclusion

This paper presents a gradient correction algorithm aimed at tackling the stale gradient issue in asynchronous RL training. By leveraging second-order information from the worker and considering the current parameters from both the worker and server, this algorithm refines the stale gradient to closely align with the desired one. Moreover, we incorporate this gradient correction mechanism into an asynchronous training scheme, offering a novel approach. Validation through sinusoidal trajectory tracking tasks of an autonomous vehicle demonstrates its accelerated convergence speed and effective resolution of gradient staleness. An avenue for future exploration involves addressing the challenge of generalizing neural network approximate functions while utilizing second-order information. Additionally, our method holds promise for comprehensive investigation into its application in enhancing training speed and driving policy performance in autonomous vehicles.