Erroneous pixel prediction for semantic image segmentation

We consider semantic image segmentation. Our method is inspired by Bayesian deep learning which improves image segmentation accuracy by modeling the uncertainty of the network output. In contrast to uncertainty, our method directly learns to predict the erroneous pixels of a segmentation network, which is modeled as a binary classification problem. It can speed up training comparing to the Monte Carlo integration often used in Bayesian deep learning. It also allows us to train a branch to correct the labels of erroneous pixels. Our method consists of three stages: (i) predict pixel-wise error probability of the initial result, (ii) redetermine new labels for pixels with high error probability, and (iii) fuse the initial result and the redetermined result with respect to the error probability. We formulate the error-pixel prediction problem as a classification task and employ an error-prediction branch in the network to predict pixel-wise error probabilities. We also introduce a detail branch to focus the training process on the erroneous pixels. We have experimentally validated our method on the Cityscapes and ADE20K datasets. Our model can be easily added to various advanced segmentation networks to improve their performance. Taking DeepLabv3+ as an example, our network can achieve 82.88% of mIoU on Cityscapes testing dataset and 45.73% on ADE20K validation dataset, improving corresponding DeepLabv3+ results by 0.74% and 0.13% respectively.


Introduction
The goal of semantic image segmentation is to obtain a high-level representation of an image by assigning each pixel a semantic class label. Semantic image segmentation can be used in video surveillance, medical imaging, autonomous driving, etc. Recently, deep convolutional neural networks (DCNN) trained on large scale image segmentation datasets such as PASCAL VOC 2012 [1], Cityscapes [2], and ADE20K [3] have significantly improved the accuracy of image segmentation.
While end-to-end training a DCNN can effectively learn multi-scale features for various vision tasks, the down-sampling operations in the encoder designed to enlarge the receptive field are likely to lose detailed information required for pixel-level image segmentation [4]. Thus, atrous convolution and skip-connections are used to balance down-sampling operations and learning of multi-scale features [5,6]. It has also been shown that fusing global context and multi-scale features can effectively improve the accuracy of image segmentation [7][8][9]. However, even with state-of-the-art image segmentation algorithms, we can still see a large number of pixels with wrong labels in regions with indistinct RGB information, at object boundaries and in small-scale objects. We call these erroneous pixels. While hard-mining methods exist that train the network using gradient information back-propagated from the erroneous pixels, these methods rely on ground-truth data to detect erroneous pixels, which is not available during inferencing. The difficulty-aware method in Ref. [10] is a layer-cascading method (LC) that focuses on those pixels whose largest label probabilities are less than a threshold in a layer-by-layer manner. However, the erroneous pixels whose largest label probabilities in one layer are greater than the threshold, which we refer to as hard erroneous pixels, are simply accepted as part of the result and overlooked in subsequent layers.
In this paper, we study how to learn to predict the erroneous pixels for a segmentation network, so that a cascaded detailed branch can be used to handle erroneous pixels to improve segmentation accuracy. It runs as a model cascading strategy during inferencing: using an existing image segmentation network as a front-end semantic branch, we first predict error pixels in its segmentation result, then redetermine semantic labels for those error pixels, and finally fuse them to obtain the final segmentation result. The difference of our strategy to that of Ref. [10] is that we add an error-prediction branch to the network to improve the accuracy of error pixel prediction. Thus, it is possible, in our method, to predict the overlooked hard erroneous pixels as erroneous pixels to be corrected. The error-pixel prediction is similar to uncertainty modeling in Bayesian deep learning for computer vision. Our method can speed up training by modeling error-pixel prediction as a binary classification problem, as an alternative to Monte Carlo integration used to evaluate the objective function in Ref. [11]. It implicitly assumes that the aleatoric uncertainty can be learned through the difference between the segmentation result and the ground-truth labeling in training. To correct the detected erroneous pixels, we employ another independent sub-network, the detail branch, trained to focus on the segmentation of such pixels.
Since using an independent branch to learn to predict the erroneous pixels does not affect the pixels that the front-end segmentation network can handle well, the error-prediction branch and detail branch can be used to improve the accuracy of a variety of segmentation networks due to its cascading design. Our network trained on Cityscapes can achieve mIoU at 82.88% on the testing dataset when using DeepLabv3+ as the semantic branch [12], which is 0.74% higher than the original network.

Related work
In the following, we mainly review image segmentation methods using deep neural networks, which are mostly related to our work. Please see Ref. [13] for a comprehensive survey.
The encoder-decoder is the fully convolutional neural network structure most used for pixel-wise segmentation for high-resolution images [4,14]. A common technique in DNN-based image segmentation algorithms is to fuse multi-scale features to improve segmentation accuracy. U-Net [5] exploits skip connections to augment high-level features with lowlevel features in the decoder so as to improve the accuracy of localization, and is widely used in many works [9,[15][16][17]. ParseNet [18] adopts a simple global branch to add global context, while Refs. [8,19] use the global feature to guide feature fusion. PSP-Net [7] proposes a pyramid pooling module to aggregate representative context features. Atrous spatial pyramid pooling (ASPP) in Ref. [20] uses atrous convolution filters [6,21] at multiple dilation rates to capture multi-scale image contexts. In order to handle small objects in the image, EncNet [22] utilizes a context encoding module to explicitly enforce learning of global scene context. A recent contribution [23] proposed HRNet to improve segmentation accuracy; it gradually adds high-to-low resolution subnetworks and fuses the learned multi-scale features in parallel. Neural architecture search (NAS) is a new method which aims to find the optimal neural architecture and weights simultaneously. Ref. [24] explores the construction of meta-learning techniques for recurrently searching. Ref. [25] introduces auxiliary cells that provide an intermediate supervisory signal for architecture parameterization. Auto-DeepLab [26] proposes a hierarchical architecture search, searching at cell level and network level.
Our work is also related to the popular cascading structure used in computer vision. In object detection, successive classifiers are combined in a cascading structure, which allows the background regions of an image to be quickly discarded while spending more computation on promising regions [27][28][29][30]. The cascading structure can also be applied to segmentation. A layer-cascading (LC) method is introduced in Ref. [10], but our network can capture hard erroneous pixels overlooked in LC to further improve the segmentation accuracy.

Approach
In the following, we first introduce the overall framework of our method, and then provide details of the error-prediction branch and the detail branch of our network. Training strategy is also described. Figure 1 provides an overview of our method, which consists of three modules: (i) a pre-trained segmentation network, the semantic branch, which is used to obtain initial segmentation results and semantic features (see Section 3.2), (ii) errorprediction and detail branches to find erroneous pixels and predict new labels for them respectively (see Sections 3.3 and 3.4), and (iii) a module to combine the initial segmentation result and the newly predicted labels, providing a more accurate segmentation result (see Section 3.5).

Overview
More concretely, given an input image I and a segmentation network f sb (·), we obtain the initial segmentation probability map P sb = f sb (I). For the i-th pixel, P i sb ∈ R C×1 gives the probabilities of this pixel belonging to each of C categories. The errorprediction branch f ep (·) yields a probability map P ep = f ep (·) with the same size as the initial result P sb . A pixel with a high probability in P ep is likely to be wrongly labelled in the initial segmentation. After error prediction, those erroneous pixels should be relabelled. The detail branch, denoted f db (·), is responsible for predicting new labels for erroneous pixels and predicts a new probability map P db = f db (·). Finally, labels of erroneous pixels in the initial label map are replaced by the new labels generated by the detail branch, giving a more reliable semantic segmentation result.

Semantic branch
We directly use a pre-trained segmentation network as the semantic branch. More concretely, we mainly used DeepLabv3+ [12], PSP-Net [7], and the DPC network [24] as our semantic branch in the following experiments. The pre-trained segmentation network gives the initial segmentation probability map P sb and corresponding low-level and high-level features that are used in the training of the error-prediction branch and the detail branch.

Error-prediction branch
The error-prediction branch aims to predict whether the initial labels given by the semantic branch are erroneous. Specifically, this branch predicts a probability map P ep in which each pixel value represents the probability that the semantic branch prediction is mislabeled. The inputs of this branch consist of (i) the probability map P sb generated by the semantic branch, (ii) the feature maps from the direct convolution of the input RGB image, and (iii) the feature maps from the semantic branch. We exploit the global attention upsampling (GAU) module from Ref. [8], as illustrated in Fig. 2, to provide channelwise attention in this branch.
In detail, we firstly apply convolutions to P sb , the probability map output by the semantic branch, and the input RGB image I separately. The obtained features are then concatenated as the input low-level features. Afterwards, we use the high-level features from the semantic branch as the input to GAU. For example, the features generated by ASPP in DeepLabv3+ and the pyramid pooling module in PSP-Net are used as the high-level features input to the error-prediction branch.
The loss function for this branch is formulated as Fig. 1 Architecture of our network. We use a pre-trained segmentation network (for example, DeepLabv3+ [12]) as the semantic branch. We add two branches: the error-prediction branch predicts an error probability map to find error pixels, and the the detail branch predicts the correct labels for the mislabeled pixels. a pixel-wise cross-entropy loss to classify each pixel as mislabelled or not, which is a binary classification problem. The ground-truth error map M err for training is obtained by checking whether the initial segmentation from the semantic branch is inconsistent with ground-truth or not. We use 1 to denote a mislabelled pixel and 0 otherwise. Specifically, the loss function is where S i sb and S i gt are the predicted semantic label and the ground-truth label of pixel i, respectively.
Since the number of the erroneous pixels is usually much smaller than the number of correct pixels, we adopt a balanced version cross-entropy to deal with the imbalance in training data. In addition, the erroneous pixels are categorized into two types with different weights counted into the cross-entropy loss: (1) "easy" erroneous pixels. Inspired by Ref. [10], we define the erroneous pixels with classification scores smaller than a threshold ρ as the "easy" erroneous pixels (i.e., max(P i sb ) ρ). These "easy" erroneous pixels are easy to detect from the input of the initial prediction P sb ; (2) "hard" erroneous pixels. The rest erroneous pixels with classification scores larger than ρ are defined as "hard" erroneous pixels (i.e., max(P i sb ) > ρ). These pixels are misclassified with high confidence which are hard to detect. Hence we add a larger loss weight to the "hard" erroneous pixels. In summary, the balanced cross-entropy loss is formulated as where M +e err and M +h err are the "easy" erroneous pixels and "hard" erroneous pixels respectively. M − err are the negatively labelled pixels. We set w 1 = 1.0 and w 2 = 1.5. The value of weight w 3 is 0.04 on average, which can be computed according to the proportion of the erroneous pixels for an image.

Detail branch
Once we know which pixels are likely to be mislabeled by the semantic branch, we wish to correct the errors with the detail branch. Thus, the detail branch is trained to predict the correct labels for the mislabeled pixels. This branch is designed to be a decoder branch to obtain a pixel-wise segmentation result using features from the semantic branch as input, where the low-level features are fed into the corresponding decoder stages using skip connections. Specifically, we use 3 successive decoder blocks as shown in Fig. 2 to build the decoder with GAU.
During training, we require the detail branch to achieve higher accuracy for the erroneous pixels so that it can correct errors in the initial segmentation results from the semantic branch. To this end, we design the loss function to enforce the training to focus on the erroneous pixels captured by the errorprediction branch. Specifically, a pixel-wise weight E ep derived from P ep is used in the loss function: where P i,c db is the probability of the i-th pixel belonging to the c-th category, and S i,c gt equals to 1 if the i-th pixel belongs to the c-th category, and 0 otherwise. The pixel-wise loss weight E ep is a binary map generated from the probability map P ep which is predicted by the error-prediction branch using a binarization threshold t: With this binary loss weight, the pixels that are classified as mislabeled, i.e., with probabilities larger than t in P ep , will contribute to the loss. Thus, our network is also designed as a cascading architecture: the semantic branch is able to classify most of the easy erroneous pixels correctly, and the other hard erroneous pixels which are highly likely to be mislabeled are passed to the detail branch.

Fusion
During the fusion stage, we combine the segmentation results from the semantic branch and the detail branch. The final segmentation result is computed as a pixel-wise linear combination according to the binary error mask E ep : Since the hard erroneous pixels are also trained as erroneous pixels in the error-prediction branch, they can also be corrected in the detail branch if they are correctly classified as erroneous pixels during inferencing after the fusion step. This is superior to the LC method [10] in which hard erroneous pixels are simply ignored.

Branch training
Because our method aims to improve a given segmentation network, we keep the semantic branch fixed during the whole training procedure, i.e., the parameters are frozen and both batch normalization [31] layers and dropout [32] layers in the semantic branch are always in inferencing mode. We first train the error-prediction branch with the loss function defined in Eq. (2) for 60k iterations. After the error-prediction branch has converged, we fix it and update the detail branch using Eq. (3) for 90k iterations.

Optimizer and learning rate
We adopt a poly learning rate policy similar to Ref. [21] where the initial learning rate is multiplied by (1 − iter/max iter) power with power = 0.9. We then employ Adam [33] as the optimizer during training.

Group normalization
In general, the performance of the batch normalization layer is related to the batch size. However, in practice, the batch size is constrained by the limited GPU memory. To improve stability during optimization, we adopt group normalization [34] in both errorprediction and detail branches; the channels are divided into 32 groups in our implementation.

Data augmentation
Following the training protocol of Refs. [7,12], we randomly crop patches from the image during training, with a crop size of 769 (DeepLabv3+ based model) or 713 (PSP-Net based model) for the Cityscapes dataset, and 513 for the ADE20K dataset. For data augmentation, random scaling (from 0.5 to 2 with a step size of 0.25), random leftright flipping, and random rotation between −10 • and 10 • are applied.

Datasets
We evaluated our network on an urban scene dataset, Cityscapes [2], and a diverse scenes dataset, ADE20K [3]. These two datasets provide densely annotated images, which are important to recover segmentation details when training our method. The Cityscapes dataset contains high-quality dense annotations with 19 object classes for 5000 images (2975, 500, and 1525 for the training, validation, and testing sets, respectively) and 20,000 coarsely annotated images. ADE20K is a more challenging dataset with 150 object classes, withe 20,210, 2000, and 3000 images for the training, validation, and testing sets, respectively.

Evaluation of branches
In this section, we consider experiments to analyze the performance of the proposed branches in our network. We employed DeepLabv3+ as our semantic branch and kept it fixed in the experiments for ease of interpretation. The network was trained using the Cityscapes training set and all outcomes are reported for the validation set.

Error-prediction branch
The error-prediction branch is just a classifier to predict the pixel-wise error probability. We illustrate a predicted error probability map and ground-truth error map in Fig. 3. The ground-truth error map is computed as the difference between the semantic label map output by the semantic branch and the ground-truth. Given the error probability map, we consider pixels with error probability larger than the threshold t to be erroneous pixels, from E ep in Eq. (3). Thus, the mean intersection over union (mIoU) between the predicted error mask E ep and ground-truth error mask M err , defined as error-pixels mIoU, can be computed, as reported in the 2nd column in Table 1 for different values of the threshold t. In order to show how many hard erroneous pixels are captured by the error-prediction branch, we also compute the recall values, the percentage of hard erroneous pixels classified as erroneous pixels, and report them in the 3rd column in Table 1 as hard recall. The hard erroneous pixels are those mislabelled pixels whose largest class probability is larger than ρ = 0.95, which is consistent with the definition in the LC.
It can be seen that with increasing value of binarization threshold t, the mIoU of the erroneous pixels increases while the hard recall drops. Since a small mIoU indicates that a large number of correct pixels are classified as erroneous pixels, which will distract the subsequent training of the detail branch, we need to balance the mIoU and hard recall so as to achieve high segmentation accuracy. The influence of the choice of different threshold values on the final segmentation accuracy on the Cityscapes validation set is reported in Table 2. As a result of the experiments, we set the binarization threshold to 0.7 for training the detail branch.

Detail branch
The detail branch is trained using the cross-entropy loss for the erroneous pixels predicted by the errorprediction branch. As reported in Table 3, fusion of the segmentation results from the detail branch and the semantic branch can improve the mIoU of DeepLabv3+, used as the semantic branch, by 1.11%. The designed detail branch has 3 decoder stages requiring an additional 11.6 MB memory for its parameters, and is more complex than the lightweight decoder of DeepLabv3+. It is thus worthwhile to verify whether the performance gains come from the additional parameters or the cascading of the error-prediction and the detail branch. We thus conducted an experiment to directly replace the original 1-stage decoder in DeepLabv3+ with our detail branch. We trained this network variant, called DeepLabv3+-GAU-Decoder, with the same training strategy as DeepLabv3+, in which the cross-entropy loss is equally weighted over all pixels. Its mIoU (78.89%) is reported in the row for DeepLabv3+-GAU-Decoder in Table 3: it is slightly higher than the original DeepLabv3+ network, but still inferior to our overall network. We also report the mIoU of the network proposed by Tian et al. [35], which improves the decoder of DeepLabv3+ by data-dependent upsampling and improves the mIoU by 0.27% (the 3rd row in Table 3). Thus, increasing the complexity of the decoder is not as effective as our method of allocating resources to erroneous pixels.

Running time
The average running time of our full network is 0.71 s for an input image with resolution 1025 × 2049 (the output segmentation result is 1/4 of the input resolution), the semantic branch taking 0.38 s, the error-prediction branch 0.05 s, and the detail branch 0.28 s.

Layer-cascading
To compare with LC [10], we adopt layer-wise cascading in the decoder of the DeepLabv3+-GAU-Decoder network discussed above; we denote this variant model as LC-GAU-Decoder and illustrate it in Fig. 4. As in LC, stage 1 predicts a segmentation result, and pixels with classification score smaller than a threshold ρ are propagated to stage 2. Stage 2 follows the same propagation procedure. We set ρ = 0.95 to be consistent with the definition of the hard erroneous pixels to test how simply discarding hard erroneous pixels in LC influences the segmentation results. As reported in Table 3, our method can outperform the LC-GAU-Decoder by a 0.17% gain.

Hard-mining
For a fair comparison, we employed loss-rank mining from Ref. [36] as a hard-mining method to train the DeepLabv3+-GAU-Decoder network, where the decoder of DeepLabv3+ is replaced by our proposed decoder in the detail branch. In this method, the cross-entropy loss is calculated for each pixel and then all pixels are ranked in order of descending loss. Only a proportion of pixels with the highest loss (20% in our experiment) contribute to the training process. Although the hard-mining method enhances the training of hard examples, it is still inferior to our network in terms of mIoU: see the 5th row of Table 3.

Error-pixel based fusion vs. bagging
A bagging result was obtained by training the detail branch by setting every pixel as erroneous and then averaging its result with the result from the initial DeepLabv3+ network, which leads to an mIoU of 79.38% (the 6th row in Table 3). Instead of directly averaging the results of the semantic branch and the detail branch, we combine the two branch results guided by error probability, which gives a 0.52% gain with respect to average bagging. Additional visual comparisons using the methods introduced above are shown in the Electronic Supplementary material (ESM).

Integration with other segmentation networks
In this section, we report how the error-prediction and detail branch can be cascaded with PSP-Net [7] and the DPC network [24] to improve segmentation accuracy. Specifically, for PSP-Net, we concatenate features generated by pyramid pooling as high-level features to provide global attention for the errorprediction branch. For DPC, we use the features generated by the dense prediction cell as the input to GAU. The error-pixel mIoU, hard recall, and the final mIoU of segmentation results are reported in Table 4. The binarization threshold is again set to t = 0.7, and the threshold for hard erroneous pixels is set to 0.95. The results suggest that our approach can correct errors and boost mIoU for various advanced segmentation models. Various visual results are shown in Fig. 5. Our method achieves more detailed segmentation results for some difficult classes like "pole".

Comparison with other state-of-the-art methods
In this section, we further evaluate our method on the Cityscapes benchmark testing dataset (1525 images) and the ADE20K validation dataset (2000 images), which have more images than the Cityscapes validation dataset (500 images) used in the experiments reported so far. The binarization threshold is set to 0.7 below.

Cityscapes benchmark testing dataset
We used the state-of-the-art method Xception71-DPC as the semantic branch, and trained the detail branch on the trainval fine set because the finely annotated images in this set can provide valid training data for segmentation details. Our proposed method achieves an mIoU of 82.88% on the test set, as reported in Table 5. It improves upon DeepLabv3+ by 1.69% and the original DPC network by 0.22%. More visual results are illustrated in Fig. 6.

ADE20K
We selected Xception65-DeepLabv3+ as the semantic branch and trained our network using the ADE20K training set. Our network improves the accuracy of Xception65-DeepLabv3+ as reported in Table 6. A qualitative comparison of the segmentation results is shown in Fig. 7.

Fig. 6
Visual results selected from the Cityscapes testing dataset. Semantic branch: DPC network [24].

Conclusions
We have proposed a method to improve semantic image segmentation results by predicting erroneous pixels and re-estimating the semantic label for these pixels. Our method can improve the segmentation mIoU for state-of-the-art segmentation networks. The experimental results have demonstrated that cascading error-prediction and detail branches can improve segmentation results. In future, we would like to investigate how to improve the mIoU of the erroneous pixels with attention techniques and layer-wise cascading of error-prediction and image segmentation.