1 Introduction

Deep learning achieved excellent generalization on a wide variety of tasks (Jumper et al., 2021; Fawzi et al., 2022; Ramesh et al., 2021; He et al., 2015; Vaswani et al., 2017). However, deepnets are also known to be vulnerable to adversarial attacks (Papernot et al., 2017; Kurakin et al., 2017, 2017; Schönherr et al., 2019), due to the non-smooth nature of their associated loss landscapes. Adversarial training (AT) Goodfellow et al. (2015) is established as the most efficient mechanism for defending against these attacks and consists of training on adversarially perturbed samples instead of on the original ones. Yet, AT is also more prone to overfitting (Madry et al., 2018; Schmidt et al., 2018; Rice et al., 2020; Dong et al., 2022; Yu et al., 2022; Stutz, 2022) and therefore results in lower generalization than standard training. The reasons behind this are still under-explored. Nevertheless, several classical and modern deep learning remedies for overfitting, e.g., regularization (Zhang et al., 2022; Miyato et al., 2015; Gan et al., 2020; Chen et al., 2021) and data augmentation Rice et al. (2020); Wu et al. (2020); Gowal et al. (2020) have been proposed without a full understanding of the root causes.

Recently, Dong et al. (2022) show that overfitting is due to augmenting the training data with ‘hard’ to classify adversarial examples that incentivize a complicated adjustment of the decision boundaries. The authors argue that these ‘hard’ adversarial examples are generated from samples close to the decision boundary with noisy or inappropriate ground-truth one-hot encoded labels. They show that previously proposed techniques for label and weight smoothing (Pang et al., 2021; Huang et al., 2020; Chen et al., 2021) can alleviate this issue.

Fig. 1
figure 1

a In AT, adversarial examples generated from misclassified samples are always placed maximally away from the decision boundary, making the model more prone to overfitting; b A3T adjusts the adversarial objective to generate adversarial examples close to the decision boundary and thus encourages learning smoother loss landscapes which results in better generalization

In this paper, we identify another category of ‘hard’ adversarial examples. Specifically, we show that adversarial examples generated from misclassified samples in every iteration of the training are, counter-intuitively, always placed maximally away from the decision boundary, as illustrated in Fig. 1. Hence, they induce a large loss leading to substantial perturbations in the local region of the misclassified point. This problem is more exacerbated in high-capacity models, as they have the flexibility of making local changes to a decision boundary. To address this problem, we propose Accuracy Aware Adversarial Training (A3T) which controls the generation of adversarial examples in a manner that is aware of the current predicted label of the original training sample. While Wang et al. (2020) and Ding et al. (2020) proposed formulations that generate adversarial examples differently for misclassified samples, their insights were empirically driven, and they do not identify nor directly address the root cause for poor generalization, i.e., adversarial examples from misclassified samples are generated maximally away from the decision boundary. Differently, A3T proposes a simpler fix that directly addresses the root cause, i.e., places adversarial examples from misclassified samples closer to the decision boundary.

We show that our A3T improves upon previously adopted techniques for mitigating overfitting in AT and achieves better generalization while having comparable robustness to state-of-the-art models on both toy experiments as well as on computer vision, natural language processing, and tabular applications. This hence demonstrates that we have identified a root cause for overfitting that is unaddressed in the literature.

2 Related work

2.1 Adversarial training (AT)

AT Goodfellow et al. (2015) is a defense mechanism against adversarial attacks which aim at fooling a trained supervised machine learning model by adding imperceptible perturbations to the inputs to cause misclassification. Modern day deepnets are particularly vulnerable to such attacks. Due to the non-smooth nature of their loss-landscapes, it is possible to find, for any input sample, a direction of steep gradient (adversarial direction) along which perturbations can lead to high loss and potentially a different prediction (Szegedy et al., 2014). To address this, in AT, the model is trained to correctly classify perturbed versions of the inputs. Formally, AT is formulated as a min-max program, searching for the best parameters \(\theta\) of a classifier \(f_{\theta }\) under the worst-case perturbation \(\delta\) applied to an input x (Madry et al., 2018), i.e.,

$$\begin{aligned} \min _{\theta } {\mathbb {E}}_{(x,y)\sim {\mathcal {D}}}\big [ \max _{\delta \in \Delta } \ell (f_{\theta }(x+\delta ), y ) \big ], \end{aligned}$$
(1)

where \(\ell\) is a loss function (e.g., cross-entropy loss) and \(\Delta\) is a constrained set of imperceptible perturbations. The outer-optimization program is classically solved using gradient descent. Popular approaches for solving the constrained inner maximization program include Fast Gradient Sign Methods (FGSM) (Szegedy et al., 2014; Goodfellow et al., 2015) and Projected Gradient Descent (PGD) (Madry et al., 2018). In FGSM, the perturbation \(\delta\) is computed via:

$$\begin{aligned} \delta = \alpha \cdot sgn(\nabla _{x} \ell ( f_{\theta }(x), y ) ), \end{aligned}$$
(2)

with \(\alpha\) being the learning rate and sgn the sign operator. PGD is a multi-step variant of FGSM. It starts at randomly initialized perturbations in the feasible set \(\Delta\) and iteratively applies a gradient ascent update followed by a projecting onto \(\Delta\):

$$\begin{aligned} \delta ^{k}=\Pi _{\Delta } \left( \delta ^{k-1} + \alpha \nabla _{x}\ell (f_\theta (x+\delta ^{k-1}),y) \right) , \forall k\in [1,K] . \end{aligned}$$
(3)

Here \(\Pi _{\Delta }\) is the projection operator into \(\Delta\) and K is the number of steps.

AT is only defined for correctly classified samples. However, adversarial examples are generated following the same procedure for both correctly classified and misclassified samples. We will show that this makes the model more prone to overfitting.

2.2 Overfitting in AT

Overfitting in AT was first identified by Madry et al. (2018). After that, different variant to the standard AT objective [Eq. (1)] have been proposed to alleviate overfitting. These variants can be classified into: (1) Local Distributional Smoothness (LDS) (Miyato et al., 2019; Zhang et al., 2019; Gan et al., 2020), (2) Margin Maximization (MM) (Balaji et al., 2019; Cheng et al., 2022), (3) Label-Smoothing Regularization (LSR) (Cheng et al., 2022) and (4) Misclassification-Aware (MA) (Wang et al., 2020) methods. Objectives of notable methods from these different families are presented in Table 3. However, all these attempts try to fix overfitting without identifying the root causes behind it. Naturally, creating a truly robust model would require a deeper understanding of the sources of its vulnerability. Recently, Dong et al. (2022) have shown that overfitting is caused by generating ‘hard’ to classify adversarial examples that are memorized by the network. The authors argue that these ‘hard’ adversarial examples are generated from samples close to the boundaries and hence their associated one-hot encoding ground truth label is likely to be inaccurate or noisy. To fix this, the authors propose label smoothing of the adversarial examples. Yu et al. (2022) empirically show that small-loss data is responsible for overfitting and propose a minimum loss constrained AT procedure that increases the loss of small-loss data.

Local distributional smoothness (LDS) approaches extends the robust optimization problem in Eq. (1) with a regularization term that encourages the label distribution around each sample point to be locally smooth. This is achieved via minimizing the KL divergence between the label distribution of the original samples and the one of the corresponding adversarially generated example (Miyato et al., 2019; Zhang et al., 2019). Intuitively, this results in forcing the decision boundary to be sufficiently away from the training samples (Miyato et al., 2019; Zhang et al., 2019). AT objectives from this family include TRADES (Zhang et al., 2019) and VILLA (Gan et al., 2020) (Lines 2 and 3 in Table 1). A major drawback of this approach is that it equally reinforces correct and incorrect predictions, i.e., it encourages the adversarial examples from misclassified samples to have the same label as the original (misclassified) samples. This results in a drop in generalization.

Margin maximization (MM) methods propose to individually tune the perturbation strength for each sample. Intuitively, adopting a fixed level of perturbation for all the samples can potentially result in pushing adversarial examples further away from the boundary, causing them to be mixed with samples of other classes. This would force the model to learn a non-smooth and complex decision boundary. AT objectives from this class include MMA (Ding et al., 2020), IAAT (Balaji et al., 2019), and CAT (Cheng et al., 2022) (lines 4–6 in Table 1). Note that MM methods only generate adversarial examples from correctly classified samples.

Label-smoothing regularization (LSR) methods propose applying a label-smoothing regularization. Intuitively, one cause for overfitting is the use of one-hot encoded ground-truth labels for adversarial examples (Dong et al., 2022), which essentially forces the model to assign an overconfident probability of one to all samples of a given class. This can be noisy or inappropriate in case of adversarial examples as they are deliberately generated by pushing close-to-the-boundary samples over to boundary to the side of another class. To address this, Cheng et al. (2022) apply this depending on the perturbation tolerance of each sample.

Misclassification-aware (MA) approaches propose extending the standard training loss with a misclassification-aware regularization (line 7 in Table 1). Wang et al. (2020) imperially show that adversarial examples from misclassified samples result in a drop in robustness. However, this has the same drawback as LDS approaches. In this work, we provide a deeper understanding of the drawback of the current practice in generating adversarial examples from misclassified samples and propose an easier fix to dress this.

Table 1 Optimization objectives for AT methods and their main characteristics

3 Accuracy aware adversarial training (A3T)

Given a data set \({\mathcal {D}} = \{(x_i,y_i)\}\), we consider a K-class classification setup. The classifier is a \(\theta\)-parameterized score function \(f_{\theta }(x) = (f^{1}_{\theta }(x),\ldots ,f^{K}_{\theta }(x))\), where \(f^{i}_{\theta }(x)\) corresponds to the score of the i-th class. The predicted label of x is denoted as \({\hat{y}} = {{\,\mathrm{arg\,max}\,}}_{i}f^{i}_{\theta }(x)\). The classic AT problem is formulated as a min-max optimization program (Madry et al., 2018):

$$\begin{aligned} \min _{\theta }\sum \limits _{(x,y) \in {\mathcal {D}}}\big [ \max _{\delta \in \Delta } \ell (f_{\theta }(x+ \delta ), y ) \big ], \end{aligned}$$
(4)

where \(\ell\) is the loss function used for training the classifier, and \(\Delta\) is typically a \(L_{\infty }\) norm ball with radius bounded by \(\epsilon\). We denote by \(\delta ^{*}\) the optimal perturbation solving the inner maximization program i.e.,

$$\begin{aligned} \delta ^{*} = \mathop {\mathrm{arg~max}}\limits _{\delta \in \Delta } \ell (f_{\theta }(x+\delta ), y ). \end{aligned}$$

Classically, PGD or its variants (Madry et al., 2018) are used to solved the program above. In case of correctly classified training samples, the optimal perturbation would result in adversarial examples that are closer to the decision boundary than the original samples. However, in case of misclassified samples, the same optimal perturbation can result in adversarial examples that are further away from the decision boundary than the original samples. For high-capacity models, this can result in a highly non-smooth boundary with poor generalization as shown in Fig. 1.

A3T loss: to prevent overfitting, we propose A3T, an accuracy-aware AT loss, which aims at generating close to decision boundary adversarial examples. To achieve this, A3T generates adversarial examples differently for classified and misclassified samples, i.e., it entails different inner maximization objectives depending on the samples prediction accuracy:

$$\begin{aligned}{} & {} \min _{\theta } \left[ \sum _{(x,y)\in {\mathcal {D}_{\theta }^+}} \ell \left( f_{\theta }\left( x+\mathop {\mathrm{arg~max}}\limits _{\delta \in \Delta } \ell \left( f_{\theta }\left( x+\delta \right) , y \right) \right) ,y \right) \right. \nonumber \\{} & {} \quad +\, \left. \sum _{(x,y)\in {\mathcal {D}_{\theta }^-}}\ell \left( f_{\theta }\left( x+\mathop {\mathrm{arg~max}}\limits _{\delta \in \Delta } \ell \left( f_{\theta }\left( x+\delta \right) , {\hat{y}} \right) \right) ,y \right) \right] . \end{aligned}$$
(5)

Here, \({\mathcal {D}_{\theta }^+}\) and \({\mathcal {D}_{\theta }^-}\) are the set of correctly and misclassified samples at a given training iteration, respectively. Note that the loss terms of the two inner maximizations have different arguments, i.e., y vs \({\hat{y}}\). In particular, for the misclassified samples, the loss uses the predicted label \({\hat{y}}\) as its second argument. The optimal perturbation \(\delta ^{*}\) can be computed using PGD as follows:

$$\begin{aligned} \delta ^* = {\left\{ \begin{array}{ll} \Pi _{\Delta } \left( \delta + \alpha \nabla _{x}\ell (f_\theta (x+\delta ),y) \right) &{} \text {if } (x,y)\in {\mathcal {D}_{\theta }^+}\\ \Pi _{\Delta } \left( \delta + \alpha \nabla _{x}\ell (f_\theta (x+\delta ),{\hat{y}}) \right) &{} \text {otherwise} \end{array}\right. }. \end{aligned}$$
(6)

Since \({\hat{y}}=y\) holds for samples in \({\mathcal {D}_{\theta }^+}\), Eq. (6) above can be simplified to

$$\begin{aligned} \delta ^* = \Pi _{\Delta } \left( \delta + \alpha \nabla _{x}\ell \left( f_\theta (x+\delta ),{\hat{y}}\right) \right) . \end{aligned}$$
(7)

Combining Eqs. (5) and (7) results in the following alternative training objective to Eq. (5):

$$\begin{aligned} \min _{\theta } \left[ \sum _{(x,y)\in {\mathcal {D}}} \ell \left( f_{\theta }(x+ \delta ^*),y \right) \right] . \end{aligned}$$
(8)

Theorem 1

Let \(f_{\theta }(x)=\theta ^Tx + b\) be a linear model trained with a logistic loss \(\ell\). Assume that \((x_i,y_i)\) is a misclassified training sample and that \(\delta _1 = {{\,\mathrm{arg\,max}\,}}_{\delta \in \Delta }\ell (f_{\theta }(x_i), y_i)\) and \(\delta _2 = {{\,\mathrm{arg\,max}\,}}_{\delta \in \Delta }\ell (f_{\theta }(x_i), (1 - 2y_i))\) are the solutions to the inner maximization program using standard AT [Eq. (4)] and A3T [Eq. (5)], respectively. We prove that

$$\begin{aligned} \left|\theta ^T({\textbf{x}}_i + \varvec{\delta }_{2}) + b\right| \le \left|\theta ^T({\textbf{x}}_i + \varvec{\delta }_{1}) + b\right|. \end{aligned}$$

Proof

See Appendix. \(\square\)

Intuitively, Theorem 1 sates that adversarial examples from misclassified samples generated using A3T are closer to the decision boundary than the ones generated using AT. To provide better intuition, we will next compare our A3T approach to standard AT on two toy examples.

Example 1

Consider a 2-D binary linear classification setup, where the decision boundary is given by \(x_1 + x_2 = 1\) and labels are \(y\in \{-1,1\}\). The linear classifier weights are at convergence therefore \((\theta _1,\theta _2) = (1,1)\). The optimal adversarial perturbation is \(\varvec{\delta } = -y\epsilon {\text {sgn}}(\theta )\). Suppose a data point \((x_1,x_2)\) with label +1 is misclassified as -1. Then the perturbation based on the original label is \((x_1 - \epsilon , x_2 - \epsilon )\) while the one based on our approach is \((x_1 + \epsilon ,x_2 + \epsilon )\). Hence, close to convergence, our approach pushes the adversarial sample corresponding to \((x_1,x_2)\) towards the boundary while the classical approaches, counter-intuitively, place the adversarial examples maximally away from the boundary.

Example 2

Consider the linear classification setup presented in Fig. 2. Note, A3T decision boundary (green) is closer to the standard boundary (black) than the AT one (orange). This is due to the adversarial examples generated by A3T (\(A_{A3T}\), \(B_{A3T}\)) from misclassified samples A and B being placed closer to the boundary that the one generated by AT (\(A_{AT}\), \(B_{AT}\)). This also implies that A3T performance on clean test samples is closer to the original classifier.

Fig. 2
figure 2

Adversarial training of a linear classifier using standard AT [Eq. (3)] and A3T [Eq. (7)] on a two-dimensional synthetic dataset. The black, red, and green lines represent the decision boundaries of the model trained with standard training, AT and A3T, respectively. Note that, samples A and B are misclassified by all the models and then corresponding adversarial examples generated by A3T (\(A_{A3T}\), \(B_{A3T}\)) are closer to the decision boundary than the one generated by AT (\(A_{AT}\), \(B_{AT}\)). Also, A3T boundary is closer to the original than one from AT. This implies less sensitivity to AE and hence more robustness. We omit correctly classified training samples for better visualization

Our A3T approach is summarized in Algorithm 1. We also visualize one training iteration in Fig. 3.

Fig. 3
figure 3

Visualization of a training iteration of our A3T algorithm. A training iteration consists of three stages: (1) predicting the label of the input, (2) computing the optimal perturbation for the sample, and (3) updating the model parameters based on A3T loss

figure a

Since A3T addresses an orthogonal root cause of overfitting than the one mitigated by LSR and MM techniques, we propose A3T\(^+\) (line 8 in Table 1 and Algorithm 2) which combines A3T, LSR, and MM objectives.

3.1 Improvement over prior work

So far, two misclassification aware AT approaches have been proposed (Wang et al., 2020; Ding et al., 2020). Wang et al. (2020) were the first to explicitly recognize that misclassified and correctly classified samples should be treated differently and proposed a misclassificiation aware AT method (MART). Their training objective consists of two separate loss terms (line 7 in Table 1). The first term corresponds to the adversarial loss used in conventional AT, and does not distinguish between correctly and misclassified samples. The second term is a KL-divergence loss weighted by the confidence in classification measured by \((1-f_{\theta }(x))\) via proposed by Miyato et al. (2019). Intuitively, this leads to reducing the impact of correctly classified samples and encouraging label smoothness around misclassified samples. In the context of Fig. 1, MART populates the space around the misclassified samples by new samples of both the true class and the incorrectly predicted one. The true class samples will be located in a direction that is maximally away from the decision boundary and will serve as ‘hard’ examples for the model. In contrast, the regularization based label smoothing will add samples of the incorrect class isotropically. In effect, the latter group of samples neutralizes the impact of the ‘hard’ examples created by the conventional AT method and prevents the model from exhibiting even worse overfitting behavior.

Ding et al. (2020) use margin-maximization as an implicit treatment of misclassified samples. The perturbation margin \(\epsilon _{i}\) of each sample (\(x_i,y_i\)) is individually determined through increasing perturbation strength gradually until an adversarial examples is generated. Since misclassified samples are adversarial in nature, they yield the lowest possible margin. Thus, adversarial examples created for misclassified samples are not significantly perturbed.

Our approach, A3T, fundamentally differs from the above methods in that the adversarial examples created include ‘easy’ samples that are placed towards the decision boundary, as opposed to creating no samples or samples that are further away from the boundary. These ‘easy’ samples essentially serve to correct for any potential deformation in that part of the decision boundary. In other words, we attribute the cause of a misclassified example to the presence of opposing class examples in the immediate vicinity of that example, preventing the model from successfully fitting them. By creating samples that yield lower loss, our method reduces the effect of misclassified samples on the resulting decision boundary.

4 Experiments

We evaluate A3T on 12 datasets from 4 domains and compare its performance them against 6 baselines from 4 families of AT methods: (1) LDS, (2) MM, (3) LSR, and (4) MA.

4.1 Results on synthetic data

First, we evaluate A3T on a synthetic dataset for binary classification introduced by Gan et al. (2020). The dataset consists of 1,016 samples from two classes in a two-dimensional space over real numbers. Samples in each class are created on two trajectories each following the shape of a crescent moon as shown in Fig. 4. Each data point is then projected onto a 100-dimensional vector space.

We train a deepnet with 100 neurons using 16 randomly selected training samples per class. With such a small training sample size the model is prone to overfitting, thereby allowing better evaluation of the generalization capability. The model is trained for 400 epochs within a learning rate (\(\tau\)) equals 0.01 using Adam optimizer. For the first 100 epochs, we use standard training and for the remaining epochs we switched to AT. Adversarial examples are generated using a 5-step PGD attack [Eq. (3)]. A random perturbation is initialized with \(\sim {\mathcal {N}}(0,0.05)\) and updated with a learn rate \(\alpha =0.1\) by limiting \(\epsilon =0.4\).

Overall, 20 models are generated for each training approach using the same set of training samples and the fixed parameter values. Following the standard training phase (first 100 epochs) all models were found to correctly classify all but one (out of 32) training samples, generally exhibiting a good fitting capability. The decision boundaries are then drawn based on the positions of all data points classified with the least confidence, i.e., by averaging the locations of all training samples in all runs that were classified with confidence in the 0.49–0.51 probability range. Figure 4 shows the decision boundaries obtained after the first 100 epochs and at the end of the training, as gray and black lines for both AT and A3T, respectively. All adversarial examples created beyond epoch 100 are also displayed as orange and turquoise dots on the figure.

Fig. 4
figure 4

The orange and blue dots correspond to adversarial examples generated by AT and A3T methods. The original samples are demonstrated by \(\blacktriangle\) and \(\bullet\) . The gray decision boundary is obtained by standard training (end of epoch 100) and the black boundary corresponds to the AT one (end of epoch 400). A3T pushes adversarial examples of misclassified samples (orange) closer to the decision boundary

The difference between the standard AT and A3T can best be seen around the misclassified samples of the red class, at the upper end of the crescent. In the case of A3T, the adversarial examples, i.e., orange points around the misclassified samples, are generated between the misclassified example and the decision boundary. Whereas for AT, adversarial examples are created further away from the boundary. This further resulted in intermixing of red-class adversarial examples with the blue-class data points, which may force the model to learn a highly non-smooth decision boundary. In other respects, it can be seen that the final decision boundary (black) for AT (Fig. 4a) and A3T (Fig. 4b) are very similar and therefore the robustness behavior may be expected to be on par with each other.

To further examine the generalization behavior, we repeated the same experiment by randomly selecting the training points for each run and averaging results over 50 runs. Figure 5 shows all training points used as well as the corresponding decision boundaries. Note that training samples remain the same for both AT and A3T within a run, but they change across runs. Here, the generalization capability of A3T is more visible as it is able to correctly classify most of the blue-class points. In the case of red-class data points, both methods perform similarly. These observations are also reflected in the measured classification accuracies: A3T achieved 83.6% accuracy, i.e., 1.2% better than AT.

Fig. 5
figure 5

Decision boundaries obtained after 50 runs when training samples are selected at random for each run. The improved generalization capability of A3T is reflected in the measured accuracy

4.2 Results on real data

We now assess the robustness and generalization trade-off yielded by A3T on computer vision, natural language processing tasks as well as on tabular data. Models obtained by standard training and AT are evaluated under attack-free and attack setups. The models are attacked using adversarial examples produced by PGD (Madry et al., 2018) and AutoAttack (Croce et al., 2021).

4.2.1 Computer vision experiments

Most AT methods have been proposed to improve the robustness of image classifiers. The effectiveness of these methods has been evaluated on CIFAR-10 datasets using architectures such as WideResNet-34-10 and WideResNet-28-10 with varying attack and training parameters. In our evaluations, we consider a similar setting.

As a baseline, we use MART since it is the only other method that explicitly formulates a treatment for misclassified samples during AT. In their paper, MART uses boosted cross-entropy loss and report a significant improvement in accuracy (\(\sim\) 3%) compared to standard cross entropy loss.Footnote 1 Since boosted loss can be generically incorporated to all AT approaches, we only use cross-entropy loss to ensure a comparison on an equal footing.

To be in line with the previous methods, we use WideResNet-34-10 and train all models for a total of 100 epochs, and SGD with momentum 0.9, weight decay \(2\times 10^{-4}\), an initial learning rate of 0.1. We decay the learning rate by 90% at the 75th, 90th epoch.

Table 2 provides accuracy results for standard and under-attack settings (FGSM, PGD). A3T outperforms MART by 4.7% on natural accuracy. This implies better generalization, i.e., less overfitting. On robust accuracy, A3T inconsistently outperformes MART on 3 out of 4 experiment. When compared to conventional AT, A3T improves both natural and robust accuracy by around 2%.

Table 2 Comparison of misclassification-aware AT methods on the CIFAR-10 dataset using the WideResNet-34-10 model

Loss landscape In Fig. 6, we visualize the loss landscape of these three approaches by following the method proposed by Engstrom et al. (2018). In particular, we are interested in investigating the effect of the A3T loss on the smoothness of the loss landscape. Loss landscapes are generated by computing the loss around each test data point by adding two perturbations defined by the sign of the input gradient and a random Rademacher matrix (\({\mathfrak {R}}\)). The x- and y- axes represent the perturbation magnitudes, \(\lambda _r\) and \(\lambda _g\), and the z-axis represents the loss values calculated as:

$$\begin{aligned} \ell \left( f_\theta (x + \gamma ),y \right) \text{, } \text{ where } \gamma = \lambda _r{\mathfrak {R}} + \lambda _g sgn\left( \nabla _{x}\ell \left( f_\theta (x),y\right) \right) . \end{aligned}$$

Here \(\gamma\) is the total perturbation added to the input. The figure shows that while all the AT losses are robust against random perturbation, as expected, the A3T loss is the least affected by perturbation in the gradient direction.

Fig. 6
figure 6

Loss landscapes of the AT, MART, and A3T on CIFAR-10 dataset. A3T loss generates the smoothest landscape indicating better robustness

Next, we compare A3T and A3T\(^+\) with several AT methods. Here we want to highlight that the results reported by these methods include some inconsistencies, making a fair comparison challenging. This includes the learning-rate schedule (a larger number of training epochs are expected to report lower final accuracy values) and the PGD step size used during attacks (larger steps imply a strength attack). To circumvent these ambiguities, we decided to use the robustness results reported by the AutoAttack (AA) benchmarkFootnote 2 and evaluated A3T and A3T\(^+\) accordingly. We use the same hyperparameters from the previous experiments for A3T and A3T\(^+\). Table 3 provides corresponding results for several AT methods.

Table 3 The clean and AutoAttack (AA) (Wang et al., 2019) accuracy values of adversarially trained WideResNet models

A3T and A3T\(^+\) yield the highest natural accuracy after the Bilateral AT method (Wang & Zhang, 2019) which exhibited very limited robustness under AA attack. In terms of robust accuracy, CAT yields the best accuracy, performing only marginally better than A3T\(^+\) (\(+0.2\)%) and noticeably better than A3T (\(+3.7\)%). However, both A3T and A3T\(^+\) yield a higher natural accuracy compared to CAT (\(+2.8\)–1.6%). These results overall present a more granular view of the generalization vs. robustness trade-off: MM and LSR allow a shift in favor of robustness, while MA and LDS tilt the balance more towards improved generalization. When the average accuracy is evaluated, A3T\(^+\) offers a better trade-off between natural and AA accuracy as it allows an increase in the former while keeping the robust accuracy on par with other methods. This finding further strengthens the idea that MM, LSR, and MA methods are addressing different overfitting root causes and that they are indeed complementary.

4.2.2 Natural language processing experiments

For evaluation, we used the GLUE benchmark (Wang et al., 2019).Footnote 3 GLUE contains seven datasets for two sentiment analysis tasks, two similarity tasks and three inference tasks. Due to the discrete nature of text, generating adversarial examples in the input space is nontrivial as it involves projecting the continuous perturbation computed in the embedding space to the input space. Although several input-domain approaches are proposed to create adversarial examples, they are not efficient (Altinisik et al., 2022). An alternative approach is to perform AT in the latent space without the need for creating adversarial inputs (Gan et al., 2020). In our experiments, we also adopt this approach both during AT and when performing adversarial attacks. It must be noted that in the case of launching an adversarial attack this approach is impractical to implement, but it nevertheless corresponds to a worst-case attack setting and allows a better evaluation of the model’s robustness.

[Implementation] In our experiments, we fine-tuned the deBERTa-base model (He et al., 2021) from the HuggingFace library. For each dataset in the GLUE benchmark, we created a fine-tuned, task-specific model by first training the model for three epochs at the suggested learning rate of \(2\textrm{e}^{-5}\) using the AdamW optimizer (He et al., 2021). Then, we freeze the first five layers and fine-tune for another three epochs with either standard or AT. In all cases, adversarial examples are generated using 3-step PGD assuming \(\epsilon = 0.01\), the initial perturbation is sampled from a normal distribution \(\sim {\mathcal {N}}(0,0.005)\).

Table 4 AT results on GLUE Benchmark

Table 4 reports classification performance yielded by various models under both normal and adversarial attack scenarios. As compared to the baseline accuracy of the model, i.e., attack-free test setting, A3T yields a small drop in the average performance (\(-2.5\)%) but performs noticeably better (\(+3.6\)%) than AT as displayed in the last line of the table. Similarly, under an adversarial attack test setting, A3T yields an improvement of 1.5% over conventional AT. The improvement due to A3T is more noticeable when the results of CoLA task is removed from the average as it uses a different metric than accuracy (third to the last line). In that case, the use of A3T additionally results in a performance improvement of 2.8% over AT.

4.2.3 Tabular tasks

We evaluate our model on tabular data. Specifically, we consider two classification tasks in finance: (1) predicting the overall yearly default rate for subjects on the Matlab’s Retail Credit Panel dataset (Matlab, 2022) and (2) a binary classification task on two datasets. Our first tabular evaluation involves Matlab’s Retail Credit Panel dataset.

To predict what percent of subjects defaulted in a given year, a logistic regression model with two hidden layers consisting of 256 and 128 neurons, is trained over 500 epochs, with the Adam optimizer using a learning rate of 0.001. AT is only applied between epochs 100 and 500 using the attack parameters obtained after a grid search.Footnote 4 The obtained result is the average from the grid search is used for comparisons.

Fig. 7
figure 7

Yearly default rates predicted by robust and standard models averaged over 50 runs for each model. MSE values are computed between model predictions and realized, ground-truth rates. Computed values are normalized by \(\textrm{e}^{-5}\) for ease of viewing

Since the input data is very low-dimensional, no adversarial attack is performed. Instead, we investigated how AT impacts the model’s performance in the attack-free setting. Figure 7 provides the average default rates predicted by different models after 50 runs in comparison to the ground truth rates. Results show that A3T predictions follow the realized default rates more closely than other approaches. To better evaluate the fit of each model, mean squared errors (MSE) of predictions with respect to actual rates are computed across all years. A3T yields the lowest MSE among all approaches, even outperforming the non-robust model.

For the binary classification task, we evaluate A3T on two datasets, namely, the European card dataset (ECD) (Pozzolo et al., 2015) and the Adult dataset (Kohavi & Becker, 1996). The former dataset contains 492 fraud and 10 K genuine transactions randomly downsampled from more than 280 K transactions with 31 features. The goal is to identify the type of transaction. The latter involves 32 K samples with 9 features, and the objective is to predict whether the income of a subject is higher than $50 K or not.

All categorical values are represented by one-hot encodings. Logistic regression models on deepnet architecture with the same as above are trained. Rather than using a fixed AT setting, multiple robust models with the same grid search parameters used in Matlab’s Retail Credit Panel dataset are trained. Each model is then tested under 12 adversarial attack scenarios parameterized by the same possible parameter value configurations considered in the grid search, and the average of resulting accuracy values is the model’s under-attack prediction accuracy. The standard and under-attack accuracy achievable by a model are finally determined by averaging the corresponding values over five runs (see Fig. 8).

Fig. 8
figure 8

Model accuracy for normal (left) and under-attack (right) settings obtained by averaging over 5 runs. For both AT and A3T, 12 models are created by varying parameter values that govern the adversarial sample generation process. A3T yields the most favorable performance, exhibiting high accuracy in both attack-free and under attack scenarios

In Fig. 8, the upper right corner corresponds to a performance regime where a model performs well both in attack-free and under-attack settings. Thus, the best performing model can be identified based on how close it gets to that corner. Notice that several robust versions of A3T exhibit high under-attack performance with only a slight drop in performance in the attack-free setting. This result also demonstrates the importance of the choice of AT parameters on model accuracy.

5 Discussion and conclusions

AT uses adversarial examples to train machine learning models in order to achieve better robustness. However, improving robustness via AT implies hauling-off generalization. In this work, we propose a new AT method that yields a more favorable robustness-generalization trade-off. Next, we summarize the main insights from our paper.

Generation of non-adversarial samples The underlying idea of our misclassification-aware adversarial training approach (A3T) is to prevent the creation of ‘hard’ examples from misclassified samples, thereby reducing the risk of overfitting. A3T effectively achieves this by generating samples that are non-adversarial in nature. That is, instead of applying a perturbation that maximizes the loss, A3T computes a perturbation that minimizes the loss, defeating the purpose of an adversarial sample. Hence, the newly generated samples by A3T are maximally close to the decision boundary as opposed to adversarial examples that are maximally away from the boundary. From this perspective, imposing a bound on the extent of loss reduction may not seem meaningful. However, since samples with arbitrarily low-loss values will constitute ‘easy’ examples, they will likely be less informative for the model.

Incorporation of LDS with A3T Since MM, LSR, LDS, and MA are proposed to address different aspects of AT, a question to be answered is if and to what extent these design choices interfere with each other. In fact, our A3T\(^+\) results show that combining MM, LSR, and MA losses help achieve a better generalization-robustness trade-off. However, incorporating LDS with A3T i.e., using A3T during inner maximization and LDS as an additional regularization, is ineffective. Crucially, LDS applies label smoothing by placing samples of predicted class, which for a misclassified sample includes opposing class samples, around the misclassified sample. In contrast, A3T generates examples of the same class as the misclassified sample closer to the decision boundary. Hence, jointly optimizing these two losses can potentially contribute to the model’s overfitting as samples of opposing classes will be placed in the same vicinity.

Influence on robust overfitting Robust overfitting is another artifact of AT wherein a model test accuracy remains similar but the robust accuracy exhibits a drop as the number of training epochs increases (Dong et al., 2022; Rice et al., 2020). Our observations show that A3T also suffers from robust overfitting. This is in agreement with the findings of MART, where a larger variation between best and last accuracies is observed compared to A3T. Hence, we can deduce that the root cause for this phenomenon does not mainly relate to the adversarial examples from misclassified samples.