1 Introduction

Machine learning problems with large output spaces are common in a variety of domains such as face recognition (Schroff et al., 2015), open-domain question answering (Lee et al., 2019), and language models (Bengio & Senécal, 2008). For multilabel classification problems, consisting of hundreds of thousand or even millions of labels, this is referred to as Extreme Multilabel Classification (XMC) (Bhatia et al., 2016). Beyond automatic tagging of web-scale corpora, such as Wikipedia (Partalas et al., 2015), the framework of XMC can be leveraged to address search, recommendation and web-advertising (Dahiya et al., 2021a; Jain et al., 2019).

The main challenge in XMC is the enormous computation needed for training a model due to the large label space. In early approaches for XMC, to partially mitigate this issue, the training of binary linear classifiers was done in parallel by exploiting the fact that under a one-vs-rest setting, learning a binary classifier for every label is independent of the rest of the labels (Babbar & Schölkopf, 2017, 2019; Schultheis & Babbar, 2022). Furthermore, several methods employed label trees to achieve computational complexity which is logarithmic in the number of labels (Khandagale et al., 2020; Prabhu et al., 2018).

For training using deep neural networks, computing the loss and its gradient over a mini-batch requires a pass over all the labels meaning that the computations grow linearly with the number of labels at each iteration of the training algorithm. A common approach to make training feasible in deep XMC models is to employ negative sampling methods, discussed next, to approximate the loss and its gradient.

1.1 Negative sampling in XMC

Negative sampling methods use the fact that there are only a few positive labels for each training point, while the rest of them, also called the negative labels, is extremely large. Therefore, if only a small fraction of negative labels along with all the positive labels are involved in computing the loss, the computational complexity of training the model will decrease significantly. There are three main approaches for negative sampling: static, meta-classifier-based, and adaptive methods.

In static negative sampling, the negative labels are drawn from a fixed distribution independent of the model (Mikolov et al., 2013). While static methods are fast, their approximation of the full loss is known to be biased (Blanc & Rendle, 2018; Rawat et al., 2021), and hence the resulting models perform poorly unless the number of sampled labels is large.

Meta-classifier-based methods have recently become popular for training deep XMC models. These methods train an additional classifier, known as meta classifier, in addition to the extreme classifier, which is responsible for suggesting confusing negative labels to the extreme classifier (Dahiya et al., 2021b; Jiang et al., 2021; Kharbanda et al., 2021, 2022; You et al., 2019; Zhang et al., 2021).

While meta-classifier-based methods can lead to high accuracy, this is achieved at the cost of training and storing a meta classifier which can be as large as the extreme classifier. Also, no matter how accurate the meta classifier is, the suggested negative labels may not be the most confusing ones for the extreme classifier since the meta classifier is usually fully (You et al., 2019; Zhang et al., 2021) or partially (Jiang et al., 2021; Kharbanda et al., 2021) independent of the extreme classifier.

Adaptive negative sampling methods leverage the fact that sampling from a distribution proportional to the scores of the labels leads to an unbiased estimate of the full loss (Blanc & Rendle, 2018; Rawat et al., 2019). In models where the scores of the labels are computed by an inner product operation such as neural networks, a close approach to this is to use a Maximum Inner Product Search (MIPS) for finding the labels with the highest scores and restricting the negative part of the loss to these labels.

However, an exact MIPS needs an exhaustive search over all the labels, which is expensive. An alternative way is to employ approximate MIPS methods, which are initially designed to accelerate inference in machine learning models. They involve building an index over the weights of the classifier as a pre-processing step, and during inference, the index is utilized to approximately identify the labels with high scores (Auvolat et al., 2015; Johnson et al., 2019; Shrivastava & Li 2014).

For training, on the other hand, applying approximate MIPS for negative sampling during training classifiers is not straightforward due to its high computational complexity. This is mainly because, unlike inference where a fixed index is used, the weights of the classifier are changing constantly during training so the index built by the MIPS module needs to be updated frequently. To mitigate this issue, a common practice is to perform pre-processing only a few times during training. Furthermore, two studies aimed to speed up approximate MIPS methods during query time when employed for negative sampling (Daghaghi et al., 2021; Yen et al., 2018). In Daghaghi et al. (2021), this is achieved by utilizing local sensitivity hashing (LSH) as a sampling mechanism instead of the top-k search, and Yen et al. (2018) proposed to replace the high dimensional MIPS with several lower dimensional ones.

In contrast to meta-classifier-based methods, adaptive negative sampling methods do not need to train and store any extra model. However, to the best of our knowledge, the existing adaptive methods are inferior to meta-classifier-based methods in terms of prediction accuracy and their performance in deep neural networks, like transformer-based models, has remained under-explored.

1.2 Meta classifier free negative sampling

In this paper, we highlight difficulties in training deep XMC models using MIPS-based adaptive negative sampling methods when the negative part of the loss is restricted to only the labels suggested by the MIPS module. We hypothesise two reasons for this.

First, when negative labels are restricted to the hardest ones from the beginning of training, the gradient with respect to the embedding vector is dominated by the term related to the negative labels, which prevents the embeddings from becoming similar enough to the weights corresponding to the positive labels. To mitigate this issue, one can incorporate a warm-up phase into training by using only random negative labels for a few iterations and switching to the labels found by the MIPS module for the rest of the training (Yen et al., 2018).

Second, we show that MIPS-based adaptive negative sampling methods are highly sensitive to the length of the intervals between the successive (weights) pre-processing steps. If these intervals are large, the weights related to a few negative labels frequently suggested by the MIPS module will be penalized over and over. This turns those negative labels from hard negatives to easy negatives, while these continue to be among the labels suggested by the MIPS module until the next pre-processing step.

To avoid the aforementioned issues in training deep models using MIPS-based methods, in this paper, we propose to combine top candidate labels found by MIPS with several labels drawn from a static distribution. Specifically, to build the set of negative labels for an instance, we pick approximately \(\#positives\) labels through the MIPS module and select the rest (\(\approx 0.01 \times L\)) from a uniform distribution, where \(\#positives\) indicates the average number of positive labels per sample and L is the number of labels in our dataset. Such a combination not only stabilizes the gradient in the initial steps of the training but also mitigates the problem of having only easy negatives even when the intervals for pre-processing the weights in the MIPS module are less frequent. In the experimental section, we compare the proposed method with a MIPS-based negative sampling method in which all the negative labels used in approximating the loss are those predicted by the MIPS module. The results on three extreme text classification datasets show that the proposed method achieves significantly higher precision, particularly in the presence of large pre-processing intervals.

Finally, we apply the proposed method to a transformer-based model, demonstrating that the architecture of this model, which utilizes a clustering-based MIPS module, is very similar to LightXML (Jiang et al., 2021), a negative sampling approach based on meta classifiers. Specifically, in our method, the MIPS module can be seen as the counterpart to LightXML’s meta classifier. However, since the MIPS module works directly on the weights of the extreme classifier, it does not require extra training and does not add additional parameters to the model.

The results on two extreme classification datasets show that the proposed negative sampling method is similar to LightXML in terms of prediction performance, while being up to 33% smaller in terms of model size. This suggests that if the negative labels are engineered well in adaptive negative sampling methods, these methods can reach meta-classifier-based ones in terms of prediction performance, while they have smaller model sizes and can lead to lower training time as there is no need to train and store any meta classifier.

1.3 Contributions

To summarize, our key contributions are as follows:

  • We highlight difficulties in training deep models using MIPS-based negative sampling methods. We show that when the negative labels are restricted to only hard negatives from the beginning of training the deep model fails to train properly. Also, when the intervals for pre-processing the weights are large, the labels used for approximating the loss may not consist informative negative labels.

  • We propose to pick only a few labels from the MIPS and select the rest of the labels from a uniform distribution, which mitigates the aforementioned issues in training deep models using MIPS-based negative sampling.

  • We compare the proposed method with LightXML, a meta-classifier-based negative sampling approach, on a very similar architecture. The results show that our proposed MIPS-based negative sampling can achieve accuracy similar to LightXML, while there is no need to train and store a relatively large meta classifier.

2 Negative sampling

Assume a training set \(\{\text {x}_i, \text {y}_i\}_{i=1}^N\) is given, where \(\text {x}_i\in \mathbb {R}^D\) are the features of the i-th sample and \(\text {y}_i\in \{1,0\}^L\) are the labels. We denote the embedded features for sample \(\text {x}\) by \(E_\text {x} \in \mathbb {R}^d\). In extreme multilabel classification, the goal is to learn a score function \(f:\mathbb {R}^d \rightarrow \mathbb {R}^L\) over embeddings \(E_{\text {x}_i}\) such that for a pair \((\text {x}, \text {y})\), \(f_j(E_\text {x})\) has a high value when \(y_j=1\). The score function f(.) is usually a linear function, which can be formulated as follows:

$$\begin{aligned} f(E_\text {x}) = W^T E_\text {x} \end{aligned}$$
(1)

where \(W\in \mathbb {R}^{d\times L}\) are the parameters of the score function that are learned during training and \(E_{\text {x}}\) is produced by an encoder, which can be for instance a deep neural network.

In XMC, most of the loss functions are the sum of binary losses such as BCE or hinge loss, and are computed as follows:

$$\begin{aligned} l(f(E_\text {x}), \text {y}) = \sum _{j:y_j=1} l_+(f_j(E_\text {x})) + \sum _{j:y_j=0} l_-(f_j(E_\text {x})) \end{aligned}$$
(2)

where \(l_+, l_-: \mathbb {R}\rightarrow \mathbb {R_+}\) are the positive and negative parts of the binary loss.

In extreme classification problems, computing the negative part of the loss in Eq. (2) becomes intractable as the set \(\mathcal {N} = \{j |y_j=0\}\) is extremely large. To overcome this problem, negative sampling methods compute the negative part of the loss over a subset of all possible negative labels, \(\hat{\mathcal {N}} \subset \mathcal {N}\), where \(|\hat{\mathcal {N}} |\ll |\mathcal {N} |\approx L\). This leads to the following formulation for the loss function:

$$\begin{aligned} \hat{l}(f(E_\text {x}), \text {y}) = \sum _{j: y_j=1} l_+(f_j(E_\text {x})) + \sum _{j\in \hat{\mathcal {N}}} l_-(f_j(E_\text {x})) \end{aligned}$$
(3)

Negative sampling methods are different in terms of the choice of \(\hat{\mathcal {N}}\). In order to minimize \(|l-\hat{l} |\) as much as possible, several works suggested choosing \(\hat{\mathcal {N}}\) in a way that it contains hard negatives meaning that \(f_j(E_\text {x})\) is higher than a specific threshold for all \(j\in \hat{\mathcal {N}}\). One way to efficiently find these hard negatives during training is to use an approximate Maximum Inner Product Search (MIPS). However, this may lead to difficulties in training deep models, which are discussed in the next section.

3 Difficulties in training by MIPS-based negative sampling methods

In this section we hypothesise two reasons which can lead to poor results in ordinary MIPS-based negative sampling methods. All the experiments conducted in this section are on Eurlex, a document classification dataset with \(\approx\) 4000 labels (Bhatia et al., 2016), using either a linear classifier or a neural network with one hidden layer.

3.1 Exact hard negative mining from scratch

Fig. 1
figure 1

A comparison of precision at 5 (\(\text {P}@5\)) in a linear classifier and a neural network trained using hard negatives on the Eurlex dataset. The figure shows the linear classifier can be trained properly using the hardest negatives, while the output of the neural network (NN-top) is just a random guess. Using higher learning rates for the positive part of the loss (NN-top-re) or picking a few hard negatives and combining them with randomly sampled labels (NN-top-uniform) can fix the issue

Figure 1 illustrates a comparison of a linear model and a neural network when the negatives are restricted to be the top-100 labels returned by an exact MIPS method. The results show that, with these as the negatives comprising the set \(\hat{\mathcal {N}}\), the linear model (Linear-top) can be trained properly, while the output of the neural network (NN-top) is not better than a random guess.

We hypothesise that the reason for stable training in linear models is the independence of the parameters learned for each label. Specifically, the only learnable parameters in the linear model are the weights of the classifier, and no matter what other negative labels are included in computing the loss, the gradient will be the same for the weights connected to a particular label. However, when the embeddings are learned during training, such as in neural networks, the gradient with respect to an embedding will be affected by the negative labels used in computing the loss. In this case, the negative sampling method has a big impact on learning the embeddings.

More specifically, without loss of generality, assume the loss in Eq. 2 is the BCE loss:

$$\begin{aligned}&l_+(f_i(E_\text{x})) = -\log (\sigma (f_i(E_\text{x}))), \\&l_-(f_i(E_\text{x})) = -\log (1 - \sigma (f_i(E_\text{x}))) \end{aligned}$$

Then, the gradient of the loss with respect to \(E_\text{x}\) is as follows:

$$\begin{aligned} \nabla _{E_\text{x}}l = \sum _{j:y_j=1} (o_j - 1)\text{w}_j + \sum _{j\in \mathcal {N}} o_j \text{w}_j \end{aligned}$$
(4)

where \(o_j :=\sigma (f_j(E_\text{x})\)) and \(\text{w}_j\) is the parameter vector of the extreme classifier corresponding to label j. Then a single step of the stochastic gradient descent algorithm for updating the embedding will be as follows:

$$\begin{aligned} E_\text{x} \leftarrow E_\text{x} + \eta ^+ \sum _{j:y_j=1} (1 - o_j)\text{w}_j - \eta ^- \sum _{j\in \mathcal {N}} o_j \text{w}_j \end{aligned}$$
(5)

where \(\eta ^+\) and \(\eta ^-\) are the learning rates and typically \(\eta ^+ = \eta ^-\). This equation suggests that the optimal embedding for a sample should be similar to the weights of the extreme classifier for the positive labels and be dissimilar to those for the negative labels. However, if the summation over the negative weights is restricted to only a few hardest negatives (those with the highest \(o_j\)), then this term will have a high magnitude and the total gradient will be only determined by this term. This will prevent the embedding to become close enough to the weights corresponding to the positive labels while it may also not become dissimilar to the other negative weights which are not included in computing the loss. To mitigate this problem, one can force the optimizer to take a larger step towards the positive weights by choosing \(\eta ^+ > \eta ^-\). Figure 1 shows that in case \(\frac{\eta ^+}{\eta ^-}=50\), the model will be able to train properly when all the negative labels are the hardest ones. However, the optimal ratio of \(\eta ^+\) and \(\eta ^-\) will add a hyperparameter to the model.

Another possible way to stabilize training when hard negatives are used, which has been employed by Xiong et al. (2020) and Yenet al. (2018), is to do negative sampling from a static distribution for a few iterations and then switch to using hard negatives for the rest of training. The initial steps of negative sampling from a static distribution allow the embeddings to become similar to the positive weights of the extreme classifier.

In this paper, to have the benefits of using hard negatives while avoiding an unstable gradient, we propose to combine only a few hard negatives with several negative labels sampled from a static distribution. Using randomly sampled negative labels mitigates the problem of taking large steps towards the direction that is suggested by taking the hard negatives.

3.2 Impact of the length of pre-processing intervals

Most of the adaptive negative sampling methods using MIPS perform pre-processing of the weights within specific intervals during training since the ideal case, which is pre-processing the weights at every iteration, is highly time-consuming. Assuming that we are using an exact MIPS method, this is equivalent to storing the weights of the classifier at specific checkpoints and using them for querying until the next checkpoint.

Fig. 2
figure 2

a A comparison of P@5 for different numbers of negative labels selected by MIPS (#MIPS) as well as different numbers of checkpoints for pre-processing the weights per epoch (PPE). For the case that all the selected negative labels are determined by MIPS and pre-processing intervals are large (#MIPS=100, PPE=1), P@5 has an unstable behaviour, while restricting the negative labels from MIPS to 15 and sampling the rest from a uniform distribution (#MIPS=15, PPE=1) mitigate the problem. b Distribution of the hardest negative labels after switching to hard negative mining from uniform negative sampling

Fig. 3
figure 3

Frequency of the true ranks (based on the model’s predictions) of the negative labels used for approximating the loss for the case that all the 100 negative labels are selected by the MIPS method (MIPS-a) as well as restricting the labels from the MIPS method to 15 and sampling the rest from a uniform distribution. Histograms were computed in the first iteration after pre-processing the weights (a), in the middle iteration (b), and in the last iteration before the next checkpoint for pre-processing the weights (c). The histograms show that for the case that the distance from the pre-processing checkpoint is large, there are only labels with high ranks (easy negatives) among the labels of MIPS-a, while the negative labels of MIPS-s show a uniform distribution

In this subsection, we investigate the effect of having relatively large intervals for preprocessing the weights on training a neural network. Figure 2a shows precision at 5 (P@5) for a neural network trained by 100 negative labels for approximating the loss when the pre-processing is done only once versus doing it 8 times per epoch. For mitigating issues related to using hard negatives from scratch mentioned in the previous subsection, we train the models only by uniform negative sampling for one initial epoch. The figure shows that for the case that there is only one pre-processing checkpoint per epoch, P@5 has an unstable behaviour especially in the few first epochs after switching to hard negative mining. However, P@5 looks more stable by restricting the negative labels suggested by MIPS to 15 and sampling the rest from a uniform distribution.

To answer the question of why the performance declines after switching to hard negative mining when the pre-processing intervals are large and all the negative label are selected through MIPS, we investigate the negative labels retrieved by the MIPS. Figure 2b shows that, after training the model with uniform negative sampling for one epoch, for a mini-batch of size 50 randomly drawn from the training set, the negative labels retrieved by the MIPS have an imbalanced distribution. Since the weights in the MIPS module are fixed until the next checkpoint for pre-processing, if the embeddings do not change much during training, most of the labels suggested by the MIPS module will be the same for all the training samples, and the weights corresponding to these labels will be penalized frequently within that epoch of training.

Figure 3 illustrates the true rank of the negative labels used for approximating the loss when only one pre-processing is done per epoch for the case that all the labels are selected by the MIPS method (MIPS-a) as well as the case that the labels from the MIPS method are restricted to 15 and the rest are sampled uniformly (MIPS-s). The histograms were computed three times after switching to using the MIPS method: right after the checkpoint for pre-processing the weights, in the middle, and also right before the next checkpoint for pre-processing. Figure 3c shows that for the case that all the negative labels are selected by the MIPS method, in the last iteration before the next checkpoint, contrary to what is expected, the labels retrieved by the MIPS method are only easy negatives and therefore do not contribute much in training the model. However, for the MIPS-s method, the negative labels approach a uniform distribution.

4 Proposed method

Our proposed method for performing hard negative mining while avoiding unstable training in deep models is to find and select a few labels with high scores and sample the rest of the needed labels from a static distribution. To find the labels with high scores in a reasonable time during training, we use a clustering-based approximate MIPS (Auvolat et al., 2015; Johnson et al., 2019). More precisely, as the pre-processing step, we cluster the weights of the extreme classifier using spherical K-means, which partitions the weights according to their directions, at specific checkpoints. Then, for each training sample, to approximate the labels with high scores at the query time, we search a few clusters whose centroids maximize the inner product with the embedded features obtained by the encoder for that sample. The details of the pre-processing and querying steps of the approximate MIPS used in our method are described in the following.

Pre-processing The pre-processing is done by clustering the weights of the extreme classifier using spherical K-means (Auvolat et al., 2015; Zhong, 2005). More precisely, we randomly select K unit-length vectors, \(C = \{\text {c}_1,...,\text {c}_K\}\), and repeat the following two steps until meeting the convergence criteria (Auvolat et al., 2015):

  1. 1.

    For each weight vector \(\text {w}_j\), we compute the index of the cluster to which this weight vector belongs by \(a_j = {{\,\textrm{argmax}\,}}_{i\in \{1,...,K\}} \text {w}_j^T \text {c}_i\).

  2. 2.

    For each cluster i, we estimate its centroid by \(\text {c}_i = \frac{\sum _{j|a_j=i} \text {w}_j}{\Vert \sum _{j|a_j=i} \text {w}_j\Vert }\)

Since the pre-processing step is time-consuming, we perform it only at specific checkpoints during training.

Querying Let the weights of the extreme classifier at checkpoint h be \(W^h_{d \times L}\), and \(C^h = \{\text {c}^h_1,...,\text {c}^h_K\}\) be the set of centroids obtained by clustering these weights. Also, let \(I_j\) be the label ids of the weights inside the j-th cluster. As the first step of querying, a search over the centroids is done to find \(\tau\) centroids which maximize the inner product with the encoded sample \(E_\text {x}\):

$$\begin{aligned} C^* = \mathop {\tau\hbox {-argmax}}\limits_{{{{\text {c}} \in C^h}}}E_\text {x}^T \text {c} \end{aligned}$$
(6)

where \(\tau\) is a hyperparameter which indicates how many clusters to be searched. Let \(I^* = \bigcup \nolimits _{i \in C^*} I_i\), the next step is to search the weights suggested by \(I^*\) as follows:

$$L_{M} = \mathop {k{\text{-argmax}}}\limits_{{i \in I^{*} }} E_{{\text{x}}}^{T} {\text{w}}_{i}^{h}$$
(7)

where \(\text {w}^h_i\) is the i-th column of \(W^h\) and k is the number of labels retrieved by the approximate MIPS. The last step for using the labels suggested by MIPS as negative labels is to exclude positive labels from \(L_M\):

$$\begin{aligned} L_M \leftarrow L_M \setminus \{j |y_j=1\} \end{aligned}$$
(8)

Finally, we combine the hard negatives suggested by the MIPS with several labels drawn from a static distribution. The static distribution should be model-independent and fixed during training to be possible to easily sample from it, for which we use a uniform distribution in our experiments. Assume q is a uniform distribution defined over the set of negative labels, and the set \(L_S\) denotes the labels drawn from q. The negative part of the loss in Eq. 2 is computed only over the union of \(L_M\) and \(L_S\):

$$\begin{aligned} \hat{l}(f(E_\text {x}), \text {y}) = \sum _{j:y_j=1} l_+(f_j(E_\text {x})) + \sum _{j\in L_M \cup L_S} l_-(f_j(E_\text {x})) \end{aligned}$$
(9)
Fig. 4
figure 4

A forward pass and the loss computation in our proposed negative sampling method which restricts the negative part of the loss to hard negatives suggested by a clustering-based MIPS as well as negative labels sampled from a static distribution

The proposed negative sampling method is summarized in Algorithm 1 as well as Fig. 4.

Algorithm 1
figure a

Proposed negative sampling method

(Dis)similarities with meta-classifier-based methods Looking at Fig. 4, one may notice similarities between the proposed method and meta-classifier-based methods such as LightXML (Jiang et al., 2021) and AttentionXML (You et al., 2019). Both approaches use clustering to find hard negatives for the extreme classifier. More precisely, first, the clusters are scored based on their similarities with the encoded samples. Then, the union of the labels in the clusters with the highest scores are used as negative labels for the extreme classifier. The module which is responsible for finding hard negatives is called a meta classifier in meta-classifier-based methods, while it is denoted as clustering-based MIPS in our method in Fig. 4.

Despite similarities between the forward pass of the two approaches, there are three main differences between our method and meta-classifier-based methods. (I) In meta-classifier-based methods, the representations of the labelsFootnote 1 for clustering are independent of the extreme classifier, which means that no matter how accurate the meta classifier is, the labels suggested by the meta classifier may not be the hardest ones for the extreme classifier. However, in our method, the representations of the labels are completely aligned with the extreme classifier as they are the columns of the weight matrix \(W^h\). (II) As pointed out in Kharbanda et al. (2021), since the meta and extreme classifiers operate on different resolutions, it may be difficult to learn data representations which are simultaneously suitable for both tasks in the meta-classifier-based methods. On the other hand, our scheme only requires intermittent re-clustering of the weight vectors. (III) Also, in meta-classifier-based methods, the meta classifier needs to be trained and stored in addition to the extreme classifier, which can add millions of more parameters to the overall model. While in our method, no extra space, other than the extreme classifier, is required.

Complexity and hyperparameters: There are three hyperparameters involved in the MIPS part of our method: K, \(\tau\), k, which are the total number of clusters, the number of clusters to be searched, and the number of labels retrieved by the MIPS, respectively. Among these hyperparameters, K and \(\tau\) have a significant role in the number of computations. Assuming the clusters are balanced, a large K will lead to centroids which are better representatives of the parameters, but it also increases the computations for comparing the embeddings with the centroids. Using a small K decreases those computations but will lead to having large cluster sizes which will again lead to high computations when the clusters need to be searched. As stated in Zhang et al. (2021), \(K=\sqrt{L}\) is a good choice, which is also used in all the experiments in our work. In this case, with the assumption that the clusters are balanced, the computational complexity of querying the MIPS will be \(\mathcal {O}(\tau \sqrt{L} d)\). If \(\tau =\sqrt{L}\), which is equal to the number of clusters, then the computational complexity of the MIPS will be the same as an exhaustive search, while the lowest computational complexity is for the case \(\tau =1\). The rest of the computations in the forward pass will be only drawing labels from the predefined distribution and computing the loss over them, which is the same as static negative sampling methods.

Since the preprocessing step of the MIPS is done only within large intervals, the computational complexity of this step is negligible compared to the whole training time.

In the backward pass, the number of parameters needed to be updated is exactly proportional to the number of negative labels used in computing the loss. This is contrary to meta-classifier-based methods, like LightXML, in which, in addition to the extreme classifier, all the parameters of a large meta classifier should be updated at each iteration of the training algorithm.

5 Experiments

In this section, we evaluate the proposed negative sampling method of Sect. 4, which combines a few negative labels suggested by a clustering-based MIPS with several labels sampled from a uniform distribution. The purpose of the experiments is twofold: firstly, to compare the proposed method with the case that all the negative labels are selected by MIPS, and secondly, to compare it with meta-classifier-based negative sampling methods, for which we use the LightXML framework of Jiang et al. (2021), and see if the proposed adaptive negative sampling can reach the performance of meta-classifier-based methods.

The results show that the proposed method achieves significantly higher precision compared to the case that all the negative labels are selected by MIPS or the case that the labels are sampled only from a uniform distribution. Moreover, the proposed method consistently achieves high precision even in the presence of large intervals for pre-processing the weights, while in this setup, it is difficult to train the model using only the negative labels suggested by the MIPS module. Compared to LightXML, while both the proposed method and LightXML have very similar architectures, the proposed method can achieve precision near to that of LightXML with significantly smaller model size and less training time.

5.1 Setup

Architectures and hyperparameters: We use two architectures as the encoders for the proposed method (and the corresponding baselines): a shallow neural network with a single hidden layer and a BERT model.

In the case of a shallow neural network as the encoder, a dropout with \(p=0.2\) is used over the input features, the number of neurons in the hidden layer is set to 128, ReLU is the activation function of the neurons in the hidden layer, the optimizer is Adam and the learning rate is chosen according to a validation set split from the training data. Also, the number of negative labels for approximating the loss is around \(1\%\) of the total number of labels.Footnote 2 In the MIPS-based negative sampling methods, we perform pre-processing of the weights only once per epoch.

For the BERT model as the encoder, all the hyperparameters, including the number of negative labels for approximating the loss in the extreme classifier and the learning rate, are the same as those used by Jiang et al. (2021). In MIPS-based negative sampling methods for training the BERT model, we preform pre-processing of the weights after each 1000 iterations.

For both the shallow neural network and the BERT model, the number of negative labels retrieved by the MIPS procedure in the MIPS-based methods is 5, which is approximately equal to the average number of positive labels per data point in the datasets we used in our experiments (Table 1). Also, for both models, the loss function is the BCE loss that is used.

We use the GPU-implemented MIPS of Johnson et al. (2019) for performing maximum inner product search in the MIPS-based negative sampling methods. Two hyperparameters related to the MIPS are K, which is the number of clusters, and \(\tau\), which is the number of clusters to be searched during querying. In our experiments, K is always set to the square root of the number of labels, as discussed in Sect. 4, and \(\tau\) is set to 64.

For evaluating the BERT models trained by MIPS-based methods, to avoid performing an exhaustive search over all the labels, we approximate the labels with the highest scores using an approximate MIPS. Contrary to the training phase, for evaluation, the pre-processing of weights needs to be done only once as the weights are fixed. Similar to the training phase, for evaluation using MIPS, K is set to the square root of the number of labels and \(\tau\) is chosen in a way that the evaluation time be the same as that of LightXML.

Baselines: We compare the proposed method with some other strategies for negative sampling. These methods are as follows:

  • Uniform: This method samples negative labels uniformly from the set of all possible negative labels.

  • MIPS-s: This method refers to the proposed method, in which only a few negative labels are selected by the clustering-based MIPS and the rest are sampled uniformly.

  • MIPS-a: In this method, after training by only uniform negative sampling for a couple of iterations, all the needed negative labels selected by MIPS.

  • LightXML: This method refers to the meta-classifier-based negative sampling proposed by Jiang et al. (2021).

Datasets: We use three textual multilabel datasets from the extreme classification repository (Bhatia et al., 2016). The statistics of these datasets are given in Table 1. For the shallow neural network, we use the TF-IDF representations of the data, while for the BERT model, the raw texts are used. Since the raw text is not available for WikiLSHTC-325K, the experiments using BERT were only done on the two other datasets.

Table 1 The statistics of three datasets used in our experiments (Bhatia et al., 2016)

Evaluation metrics: In the multilabel setting of our experiments, the goal is to predict the correct labels among the top-k labels. Therefore, we evaluate the models using precision at k (P@k) and its unbiased counterpart, propensity scored precision at k (PSP@k). Formally, these two metrics are as follows:

$$\begin{aligned}{} & {} \text {P@k}(\text {y},\hat{\text {y}}) :=\frac{1}{k} \sum _{l \in topk(\hat{\text {y}})} y_l \end{aligned}$$
(10)
$$\begin{aligned}{} & {} \text {PSP@k}(\text {y},\hat{\text {y}}) :=\frac{1}{k} \sum _{l \in topk(\hat{\text {y}})} \frac{y_l}{p_l} \end{aligned}$$
(11)

where \(topk(\hat{\text {y}})\) is the set of top-k labels predicted by the model, and \(p_l\) is the propensity score of label l, which indicates the probability that label l is present (Jain et al., 2016). The propensity scores for each dataset are set according to Jain et al. (2016).

5.2 Results on the shallow neural network as the encoder

In this subsection, we compare the proposed method with the baselines using the shallow neural network. The results are given in Table 2. As the results show, all the metrics for MIPS-s are significantly higher than those of MIPS-a and Uniform, which shows that using only a few labels suggested by MIPS combined with randomly sampled labels can boost the performance significantly compared to using only labels from MIPS or only uniformly sampled labels.

Also, we should note that the results for the MIPS-a method without the initial training by uniform negative sampling are not higher than a random guess even with very high values for \(\tau\) and very short intervals for preprocessing the weights, which makes it close to the setting discussed in Sect. 3.1.

Table 2 A comparison of the proposed method (MIPS-s) with uniform negative sampling (Uniform) and the case that all the needed negative labels are defined by MIPS (MIPS-a) for training a neural network with one hidden layer on three extreme classification datasets

5.3 Results on BERT as the encoder

Table 3 compares different negative sampling methods for training BERT on Amazon-670K and Wikipedia-500K. As the results show, the MIPS-s method outperforms Uniform on both datasets. Compared to MIPS-a, the MIPS-s method achieves superior performance across the majority of the metrics on Wikipedia-500K. While on Amazon-670K both methods exhibit similar efficacy, in the next subsection, we show sensitivity of the MISP-a method to the pre-processing intervals of the MIPS module. Specifically, we show MIPS-a fails to train properly in the presence of larg intervals for pre-processing the weights, highlighting the limitations of this approach.

Compared to LightXML, precision and ps-precision achieved by MIPS-s are within 1% of that obtained by LightXML. Unfortunately, given the large scale of out datasets, it is difficult to verify the statistical significance of the results. However, we ran MIPS-s and LightXML on Amazon-670K using four different sets of initial weights, the detail of which are given in Table 5 in Appendix A. The results show the standard deviation of all the metrics for both MIPS-s and LightXML are less than 0.07% and 0.08%, respectively, which indicates that the performance of these models shows minimal variance on the Amazon-670K dataset, suggesting a consistent performance.

We should note that similar to the experiments on the shallow neural network, when the negative labels are restricted to the hardest ones from the beginning of training in the BERT model, the output is not better than a random guess.

Table 3 A comparison of the proposed method (MIPS-s) with uniform negative sampling (Uniform), the case that all the needed negative labels are defined by the MIPS (MIPS-a), and LightXML for training a BERT model on two extreme text classification datasets

The main difference between the architecture of LightXML and the proposed method is that there is no need to train and store any meta classifier in the latter one. Table 4 compares the model size of MIPS-s with LightXML as well as the training time. The results show by excluding the meta classifier in LightXML, we can save up to 34% space on the hard disk and train the model a few hours faster.

Table 4 A comparison of the model size and training time of MIPS-s with those of LightXML on Amazon-670K and Wikipedia-500K
Fig. 5
figure 5

The effect of the number of negative labels retrieved by the MIPS in MIPS-s on P@5. In both datasets, P@5 peaks when the number of negative labels retrieved by the MIPS is almost 10. Furthermore, the figures show that the MIPS-s method at it’s peak can achieve P@5 near to that of LightXML. Also, there is a large gap between MIPS-s and Uniform even when only 1 negative label is taken from the MIPS

5.4 Hyperparameters analysis

In this subsection, we investigate the effects of three hyperparameters, namely the number of negative labels taken from the MIPS module, the number of clusters to be searched (\(\tau\)), and pre-processing intervals of the MIPS module. All the experiments are done using Bert as the encoder. Other hyperparameters than the ones which are being analysed are the same as Sect. 5.3.

Number of negatives taken from the MIPS: Fig. 5 compares P@5 in the MIPS-s method when the number of negative labels retrieved by the MIPS is different. The results show that P@5 reaches its peak when the number of negative labels obtained by the MIPS module is around 10. Also, it can be seen in the figure that the maximum P@5 achieved by MIPS-s is very close to LightXML. Moreover, even using one label found by the MIPS can boost the performance significantly when the results are compared with Uniform.

We should note that since the number of labels taken from the MIPS module has minimal impact on the training time in MIPS-s, the training times of all the experiments with MIPS-s in Fig. 5 are approximately the same as that mentioned for MIPS-s in Table 4.

The number of clusters to be searched (\(\tau\)): Fig. 6 demonstrates P@5 and training time with different values for \(\tau\) during training MIPS-s. Smaller values for \(\tau\) indicate a smaller search space and, therefore, faster training, while larger values approach an exact search, which is computationally costly. However the figure shows that even with small values for \(\tau\), MIPS-s can achieve high P@5, with the difference of only around 1% compared to LightXML, while increasing \(\tau\) only marginally improves P@5.

Fig. 6
figure 6

The effect of \(\tau\) (the number of clusters to be searched in the MIPS method during training) on P@5 (top row) and training time (bottom row) in the MIPS-s method. Even small values for \(\tau\) can achieve high P@5. Increasing \(\tau\) will increase the training time, while there is no significant gain in P@5

Pre-processing intervals: Fig. 7 shows the effect of the length of pre-processing intervals of the weights in the MIPS module in MIPS-based methods on P@5 and training time. Consistent with the analysis of Sect. 3.2, while MIPS-s achieves high P@5 even with large pre-processing intervals (more than 30K on Amazon-670K and 50K on Wikipedia-500K), the MIPS-a method can be trained properly only in the high training time region with short intervals for pre-processing.

Fig. 7
figure 7

The effect of pre-processing intervals of the MIPS module in MIPS-based methods on P@5 (top row) and training time (bottom row). Since both MIPS-based methods have approximately similar training times, only the training time of MIPS-s is reported. Both MIPS-s and MIPS-a methods achieve high P@5 when the intervals for pre-processing are short and therefore the training time is high. However, the MIPS-a method fails to train properly when the pre-processing intervals are large, while MIPS-s consistently achieves high P@5

In addition to the above hyperparameters, we also investigated the effect of using the conventional K-means equipped with Euclidean distance instead of spherical K-means on our proposed method. The results presented in Table 6 in Appendix A show that both clustering methods achieve similar performance highlighting the robustness of our approach to the clustering method.

6 Other related work

In this section, we review some related work on adaptive and meta-classifier-based negative sampling methods as well as the work on the drawbacks of hard negative mining.

MIPS-based adaptive negative sampling—In MIPS-based extreme classification, to find hard negative labels, maximum inner product search methods are utilized to approximate the labels which maximize the inner product between an embedding and the labels’ weights (Chen et al., 2020; Daghaghi et al., 2021; Vijayanarasimhan et al., 2014; Yen et al., 2018). Among these works, Daghaghi et al. (2021) proposed to use Local Sensitive Hashing (LSH) for performing the maximum inner product search. Yen et al. (2018), to reduce the computational complexity, replaced the MIPS over the high dimensional embeddings with several lower dimensional ones by decoupling the dimensions using dual decomposition. Chen et al. (2020) proposed to approximate the neurons with high responses in every layer of neural networks using an approximate MIPS and backpropagate the error only through those neurons.

Meta-classifier-based negative sampling—In meta-classifier-based methods, a meta classifier is trained in addition to the extreme classifier to suggest the confusing labels to the extreme classifier. In the recent state-of-the-art meta-classifier-based methods, first, the labels are categorized into a bunch of clusters based on their similarities. Then, a meta classifier is trained in which for any sample, any cluster containing at least one positive label is considered positive and the rest negative labels. Finally, the meta classifier is queried for each sample and the union of the labels in a few clusters with the highest scores are used as the negative labels for approximating the loss of the extreme classifier.

Among these works, You et al. (2019) and Zhang et al. (2021) use hierarchical architectures where one model is trained at each level which acts as the meta classifier for the model at the next level. In these methods, the models are trained sequentially, which means the hard negatives retrieved by their meta-classifier are fixed during training. To make the process of the learning the embeddings in the meta classifier more aligned to the extreme task, Kharbanda et al. (2021) propose to increase the number of clusters for the meta-classfier.

Two-stage negative sampling—Two-stage negative sampling methods have recently become the state-of-the-art in short-text extreme classification (Dahiya et al., 2021a, 2021b, 2022; Mittal et al., 2021). In most of these methods, in the first stage, a model is trained to learn representations for the documents and the labels. Then in the second stage, by using an approximate nearest neighbor search over the learned embeddings, labels with small distances to each instance are determined and be used as hard negatives for training the extreme classifier.

Two-stage negative sampling methods are at the intersection of adaptive methods and meta-classifier-based ones since they are partially dependent on the model and also they use the model in the first stage as a means for finding confusing labels for the extreme classifier.

Issues in training using hard negatives—A few works have discussed issues in training using hard negative labels (Schroff et al., 2015; Wu et al., 2017). Specifically, in deep embedding learning using triplet losses, (Schroff et al., 2015) argued that selecting the hardest negative, which is the sample in the mini-batch with the closest distance to the anchor, will lead to a bad local minimum in the early stages of training. Wu et al. (2017) showed that in these methods, the gradient with respect to hard negatives have high variance and therefore susceptible to the noise in the training data. To this end, they proposed a distance weighted sampling, in which the probability of selecting a hard sample is equal to a constant and for the rest of the labels it is proportional to the inverse of the distance to the anchor. Our work is on the same line as the aforementioned works to investigate the issues related to using negative labels in MIPS-based adaptive methods to train feedforward neural networks.

7 Conclusion

In this paper, we highlighted two difficulties in training deep neural networks using MIPS-based adaptive negative sampling methods. We argued that when only hard negatives are employed from the beginning of training, the gradients of the loss with respect to the embeddings are unstable. Also, we showed that when the labels needed for approximating the loss are only determined by MIPS, training is sensitive to the pre-processing intervals of the weights in the MIPS method.

To overcome these problems, we proposed to only pick a few labels from those found by MIPS and select the rest from a uniform distribution. The results show that the proposed method leads to significantly higher precision in a shallow and a deep neural network. Furthermore, we highlighted the similarities of the architecture of our method with LightXML, a meta-classifier-based negative sampling. We showed that our approach, which only requires re-clustering the weights a few times during training, can reach the performance of LightXML on a BERT model, while there is no need to train and store any additional classifier. We hope that our work will spur further research in exploring meta-classifier free negative sampling methods in extreme multi-label classification.