1 Introduction

Few-shot learning (FSL) is of great significance for at least the following two scenarios [1]: First, FSL can relieve the heavy needs for data gathering and labelling, which can boost the ubiquitous use of deep learning techniques, especially for users without enough resources. Second, FSL is an important solution for applications in which rare cases matter or image acquisition is costly because of high operation difficulty or ethical issues. Typical examples of such applications include computer-assisted diagnosis with medical imaging, and classification of endangered species.

An FSL task is typically formulated as follows: Given support images with corresponding labels and a query image without any label, it requires to finding the label of the query image based on the labels of support images. With this formulation, most FSL methods train the model on base (seen) classes and evaluate the model on novel (unseen) classes. It is assumed that knowledge can be well extracted from base classes and transferred to novel classes. However, this is not always the case. The knowledge in a pre-trained backbone convolutional neural networks (CNNs), which computes the features of an input image, may sometimes be useless when novel classes have significant visual differences from base class images [2]. For example, having sheep always on grass and cats mostly in indoor environments, FSL models may classify an image showing a cat on grass as the class of “sheep” because “cat” has a very large visual difference with all base classes while owning a similar background with one base class. What makes matters worse is that we even have no way to see if the visual differences between the base and novel classes are significant for an FSL model. This raised one essential question: Is there any way to see what is transferred from base classes to novel classes? Most research on FSL tasks do not pay attention to what is extracted from the backbone CNNs.

In this study, we redesign the mechanism of knowledge transfer for FSL tasks, offering an answer to the above question. Our approach is inspired by what humans seemingly do when trying to recognize a rarely seen object. That is, we usually try to find some patterns in the object and match them in a small number of previously seen examples in our memory. We mimic this process by designing a self-explainable attention module, and propose a new FSL method, named a match-them-up network (MTUNet), which consists of a pattern extractor (PE) and pairwise matching (PM).

The PE is designed to find discriminative patterns for image representation. The knowledge transferred from the base classes to the novel classes is thus the learned patterns. Owing to the explainability of the PE, the extracted patterns themselves can be easily visualized by exemplifying them in the images as shown in Fig. 1(a). This directly means that we have a way to see what is transferred in our FSL pipeline. The patterns extracted from each of the support and query images are aggregated to form discriminative image representation, which is shown as overall attention in Fig. 1(b) and is used for matching. In Fig. 1(b), the visualization of aggregated patterns collectively shows a consistent and meaningful clue for the images of the same class. For example, the PE shows strong attention on the neck of the goose in the second column, which is consistent in both support and query images (even for sub-images in the latter). Image representation based on the patterns learned from base classes makes matching between a pair of images much easier by incorporating only a small number of regions to pay attention to.

Fig. 1
figure 1

Few-shot learning using pair-matching with the pattern extractor (PE). Images are from the mini-ImageNet dataset [3]

On top of the PE, PM is adopted to determine whether image pairs belong to the same class. Each pair consists of one image from the support set and one image from the query set. The category of the support image that has the highest similarity score is regarded as the query image’s category. Together with the PE, MTUNet can provide a matching score to further relate the visualization and model decision.

The main contributions of our work include:

  • We propose a new explainable FSL model that achieves high classification accuracy, qualitatively and quantitatively showing its explainability.

  • We design the PE module to spatially filter the original image’s features provided by a backbone CNN, keeping only informative regions of specific patterns that contribute to better FSL classification performance. Visualization of these regions plays a central role in MTUNet’s explainability as it presents the model’s basis of prediction.

  • A PM mechanism that can relate the visual explanations with the model decision using matching scores, which may help find potential prediction failures.

  • Our method combines several techniques and concepts, e.g., FSL, attention, feature representation, and explainable AI, which can inspire future research.

This paper is an extension of a four-page CVPR2021 workshop paper [4]. In addition to more detailed description of our method (Section 3), extensive literature review (Section 2), discussion based on our experimental results (Section 4.5), and limitations and future works (Section 4.5). The extension includes technical contributions as follows: (1) We introduce the PE pre-trainig, which allows better FSL classification performance. We also redesigned and detailed our methodology (e.g. in Section 3), re-did all experiments with redesigned method (Section 4.3), and additional figures (e.g. Fig. 1) are added for easier understanding. (2) We add new experimental results over another dataset, CIFAR-FS, which show superior classification performance than existing methods and validate the generalizability of our method to different datasets. (3) To compare with previous XAI methods, we design an experiment using existing XAI metrics in Section 4.4.2. The results quantitatively demonstrate the explainability of the proposed method. (4) We add a discussion based on our experimental results (Sections 4.5.1 and 4.5.2).

2 Related work

2.1 Few-shot learning

Recently, due to the availability of a sufficient number of labelled images, deep neural networks have achieved outstanding performance on various classification tasks. Such large datasets usually require a large amount of effort for their creation, and some tasks, such as medical tasks [5, 6], may not inherently have enough supervising signals. For these tasks, we require a new paradigm that allows training a model with a small number of labelled images. The popular FSL model [3, 7, 8] serve as a testbed for certain aspects of such small tasks. Recent efforts toward FSL are summarized as follows.

Image embedding and metric learning

Many works focus on transforming images into vectors in embedding space, in which the distance between a pair of vectors represents the conceptual dissimilarity. A Siamese network [9] uses a shared feature extractor to produce image embeddings for both support and query images. The weighted 1 distance is used for the classification criterion. Metric learning [3, 7] can offer a better way to train the mapping into the embedding space. Some works try to improve the discriminatory power of image embeddings. Simple Shot [10] applies an 2 normalization and a central method to make the distance calculation easier. Instead of physical distance calculation, some works use a multi-layer perceptron (MLP) to parameterize and train similarity metrics [11,12,13]. A recent work [14] uses a two-stream network for better feature representation, which improves the FSL performance.

Meta-learning

Another major approach to FSL is to optimize models so they can rapidly adapt to novel classes. The method in [15] fine-tuned the feature extractor using support images of novel classes. However, due to very few support samples, overfitting limited the model’s success. MAML [16] and its extensions [17, 18] train initial parameters, and through one or more gradient adjustment steps from the initial parameters, they can be easily adapted to a target task with only a small amount of data. Besides training good initial parameters, Meta-SGD [19] trains the update direction and step size. UDS [20] adopted an unsupervised meta-learning algorithm to localize and select semantically meaningful regions in feature maps, which enables better FSL performance. A recent work [21] extends FSL into a multi-label scenario, which is meaningful to real-world applications.

Data augmentation

Solving an FSL problem by augmenting training data is straightforward and easy to understand. Data augmentation aims at introducing immutability to models to capture information at both image and feature levels [22, 23]. There are also some works that try to use samples that are weakly labelled or unlabelled [24, 25]. ICI [26] introduces a judgment mechanism to enhance the training set by utilizing unlabelled data with confidently predicted labels.

Transductive or Semi-supervised Paradigm

Transductive or semi-supervised approaches [27, 28] have made great progress in the past few years. They use the statistics of query examples or statistics across FSL tasks, assuming that all novel images for classification are accessible. We only employ the original inductive paradigm to explore explainable feature extraction, but our idea can be easily adapted to a transductive paradigm.

2.2 Zero-shot learning

Zero-shot learning (ZSL) is another challenging task as there is no sample available for the unseen classes. An early attempt [29] proposed an attribute-based classification using human-specified high-level labels. The unseen classes can be predicted based on the combination of detected attributes, without training with the classes. Some methods were developed to utilize inter-class relationships through graph neural networks [30, 31]. Wang et al. [30] adopt a graph to use both semantic embeddings and categorical relationships to generate classifiers. OCITN [32] is designed to deal with the situation where training data with only one class. The target is to determine if the input data is seen class or unseen class. Recently, a cluster-based ZSL method [33] was proposed, which expands the idea of ZSL tasks to multivariate binary classification problem.

Our method employs a similar idea to attribute-based classification. PE is designed to learn and extract a certain set of patterns that can describe all possible classes in episodes of the FSL classification task.

2.3 Explainable AI

Deep neural networks are considered black-box technology, and explainable artificial intelligence (XAI) is a series of attempts to unveil them. Most XAI methods for classification tasks are based on back-propagation [34,35,36] or perturbation [37]. These methods are post-hoc, which can only provide explanations outside model training. There are also intrinsic methods that aim to explain the model decision spontaneously. A new type of intrinsic XAI, coined SCOUTER [38], has been proposed, which applies a self-attention mechanism [39] to the classifier. This method can extract the attention for each class during training, which makes classification results explainable.

XAI methods have been widely applied to many deep learning tasks [40], however, a few works [4, 41,42,43] have tried XAI for FSL tasks. Geng et al. [42] uses a knowledge graph to make an explanation for zero-shot tasks. Sun et al. [41] adopt layer-wise relevance propagation (LRP) [44] to explain the output of a classifier. StarNet [43] realizes visualization through heat maps derived from back-projection. These methods are based on the idea of XAI for general classification tasks, which are not suitable for the training rule of FSL (sampling support and query [3]). Most of them are not evaluated on FSL benchmark datasets, which make these methods not comparable. Thus, an FSL model which has both high classification accuracy and interpretability is important.

In this study, we adopted the intrinsic approach of XAI to explore a new explainable FSL paradigm. Compared to previous FSL methods, MTUNet has PE, which is based on the self-attention mechanism [45], that can extract informative regions to improve FSL classification performance. Another difference from previous FSL is MTUNet’s explainability. Through the combination of PE and PM, MTUNet can provide insight into why the model classified a query image into a certain unseen class (refer to Section 4.4). Our experiments showed that explanation by MTUNet can help find potential prediction failures, which is important for some risk-sensitive domains like medical applications.

3 Material and methods

3.1 Problem definition

This study addresses an inductive FSL task (c.f., and a transductive task [27, 28]), in which we are given two disjointed sets \(\mathcal {D}_{\text {base}}\) and \(\mathcal {D}_{\text {novel}}\) of samples. The former is a base set of many labelled base class images whereas the latter is a novel set of a few labelled novel class images, where the disjointed sets of base and novel classes are denoted by \(\mathcal {C}_{\text {base}}\) and \(\mathcal {C}_{\text {novel}}\), respectively. The FSL task is to find a mapping from a novel image \(x \in \mathcal {D}_{\text {novel}}\) to the corresponding class \(y \in \mathcal {C}_{\text {novel}}\), with the images in \(\mathcal {D}_{\text {base}}\) and the corresponding labels available in training.

The literature typically uses the K-way N-shot episodic paradigm for training/evaluating FSL models. For each episode in training, we sample a support set \(\mathcal {S} = \{(x_{kn}, y_{kn}) \mid k=1,\dots ,K, n = 1, 2, \dots , N\}\) and a query image xq from query set \(\mathcal {Q}\). The support set contains N images for each of K classes in \(\mathcal {C}_{\text {base}}\) and serves as the basis for classification of a query image into the same K classes.

Our FSL model is trained to find a match between a query image and a support image in \(\mathcal {S}\), i.e., the query image is classified with the class of the matched image in \(\mathcal {S}\). Evaluation can be performed within the same paradigm by sampling query and support sets from \(\mathcal {D}_{\text {novel}}\).

3.2 Overview

The overall process is illustrated in Fig. 2. In each episode, we extract feature map \(F=f_{\theta }(x)\in \mathbb {R}^{c \times h\times w}\) from each image x in \(\mathcal {S}\) and Query image using the CNN backbone f𝜃, where 𝜃 is the set of learnable parameters. F is then fed into the pattern extractor (PE) module, fϕ, with learnable parameter set ϕ. This module provides attention \(A = f_{\phi }(F) \in \mathbb {R}^{z \times l}\) over F. Our pairwise matching (PM) module uses an MLP to compute a score that indicates how likely query image xq is to belong to one of the K classes in \(\mathcal {S}\).

Fig. 2
figure 2

Overall structure of MTUNet. One query is processed by the CNN backbone and pattern extractor (PE) to provide exclusive patterns and then turned into overall attention. The query is concatenated to each support to make a pair for final discrimination through pairwise matching (PM). The dotted line represents each support image undergoing the same calculation as the query

The PE plays a major role in the learning of FSL tasks. It is designed to learn a transferable attention mechanism, which finds common patterns that are shared among different episodes sampled from \(\mathcal {D}_{\text {base}}\). Consequently the patterns are more likely to be shared among \(\mathcal {D}_{\text {novel}}\) given that \(\mathcal {D}_{\text {base}}\) and \(\mathcal {D}_{\text {novel}}\) are from similar domains.

3.3 Pattern extractor

Figure 3 shows the structure of our PE module. The input feature map F is first fed into a 1 × 1 convolution layer followed by a ReLU nonlinearity to squeeze the dimensionality of F from c to d. The spatial dimensions of the squeezed features are flattened to form \(F^{\prime } \in \mathbb {R}^{d \times l}\), where l = hw. To maintain the spatial information, position embedding P [38, 46, 47] is added to the features, i.e., \(\tilde {F} = F^{\prime } + P\).

Fig. 3
figure 3

The structure of our pattern extractor module

The self-attention [45] mechanism provides the attention over F for the spatial dimension using the dot-product similarity between a set of z patterns and \(\tilde {F}\) after nonlinear transformations. The PE repeats this process T times by updating the patterns with a gated recurrent unit (GRU) to refine the attention. That is, let \(W^{(t)} \in \mathbb {R}^{z \times d}\) denote the patterns in the t-th repetition, where \(t = 1, 2, \dots , T\) and W(1) = W is the learnable parameters. The nonlinear transformations for W(t) and \(\tilde {F}\) are given by

$$ g_{\text{Q}}(W^{(t)}) \in \mathbb{R}^{z \times d}, \quad g_{\text{K}}(\tilde{F}) \in \mathbb{R}^{d \times l}. $$
(1)

The attention is given using a normalization function ξ as

$$ \begin{array}{@{}rcl@{}} \bar{A}^{(t)} &=& g_{\text{Q}}(W^{(t)}) g_{\text{K}}(\tilde{F}) \end{array} $$
(2)
$$ \begin{array}{@{}rcl@{}} A^{(t)} &=& \xi(\bar{A}^{(t)}) \quad \in (0,1)^{z \times l} , \end{array} $$
(3)

where the patterns W(t) is updated by

$$ \begin{array}{@{}rcl@{}} U^{(t)} &=& A^{(t)} {F^{\prime}}^{\top} \end{array} $$
(4)
$$ \begin{array}{@{}rcl@{}} W^{(t+1)}&=&\text{GRU}{(U^{(t)}, W^{(t)})}. \end{array} $$
(5)

Let SoftmaxR(X) and σ(X) be a softmax function over respective row vectors of matrix X and sigmoid respectively. MTUNet modulates this map by

$$ A^{(t)} = \xi(\bar{A}^{(t)}) = \sigma(\bar A^{(t)}) \odot \text{Softmax}_{\text{R}}(\bar A^{(t)}), $$
(6)

which suppresses weak attention over different patterns at the same spatial location, where ⊙ is the Hadamard product. The function enforces the network to find more specific yet discriminative patterns with less redundancy among them, thus giving more pinpoint attention. This ensures the learned patterns are exclusive. As shown in Fig. 1(a), the attention map responds to a single pattern that rarely includes its peripheral region.

The input feature F is finally described by the overall attention \(A^{\prime }\) corresponding to the extracted patterns, i.e.,

$$ \begin{array}{@{}rcl@{}} A^{\prime} = \frac{1}{z} A^{(T)} \mathbf{1}_{z} \end{array} $$
(7)

where 1z is a row vector with all z elements aggregated being 1. \(A^{\prime }\) is reshaped from l into the same spatial structure as F. Then the features corresponding to the overall attention are extracted and average pooled over the spatial dimensions as

$$ V = \frac{1}{hw} \sum\limits_{ij} A_{ij}^{\prime}F_{ij}, $$
(8)

where \(A_{ij}^{\prime } \in \mathbb {R}\) and \(F_{ij} \in \mathbb {R}^{c}\) are the elements of \(A^{\prime }\) and F at the (i,j)-th spatial location (\(i = 1,2,\dots ,h\) and \(j = 1, 2,\dots ,w\)).

3.4 Pairwise matching

An FSL classification can be solved by finding the membership of a query in one of the given support images. Some FSL methods use metric learning [3, 7] to find matches between a query and the supports, and the cosine similarity or the 2 distance are typical choices [10, 48]. Learnable distances are another popular choice for metric learning-based FSL methods [11,12,13]. We use a learnable distance with an MLP (refer to Section 4.5.2).

Let Vq and {Vkn} be features obtained by applying the PE to query image \(x^{\text {q}} \in \mathcal {Q}\) and support images {xkn} in \(\mathcal {S}\), respectively, where the subscripts \(k = 1, 2, \dots , K\) and \(n = 1, 2, \dots , N\) stand for the n-th image of class k in the K-way N-shot episodic paradigm. For N > 1, the average over the N images are taken to generate representative feature \(\bar {V}_{k}\); otherwise (i.e., N = 1), \(\bar {V}_{k} = V_{k1}\). For computing similarity score s between Vq and \(\bar {V}_{k}\), we use MLP fγ with learnable parameters γ:

$$ s(V^{\text{q}}, \bar{V}_{k}) = \sigma(f_{\gamma}([V^{\text{q}}, \bar{V}_{k}])), $$
(9)

where [⋅,⋅] is concatenation. xq is classified into class k with maximum s over k, i.e.,

$$ k^{*} = \underset{k}{\arg \max} s(V^{\text{q}}, \bar{V}_{k}). $$
(10)

For a K-way task, our pairwise matching runs the similarity computation K times per query image, which is typical computational complexity for for similarity-based methods, such as [7].

3.5 Training

For training, we sample a set \(\mathcal {Q} = \{(x_{km}^{\text {q}}, y^{\text {q}}_{km}) \mid i=1,\dots ,K\times M\}\) of M query images for K classes as well as set \(\mathcal {S}\) of support images from \(\mathcal {D}_{\text {base}}\) for each episode, following the K-way N-shot episodic paradigm. We train the model with the cross-entropy loss:

$$ L = -\sum\limits_{(x^{\text{q}}, y^{\text{q}}) \in \mathcal{Q}} \sum\limits_{k=1}^{K} {y_{k}^{\text{q}}} \log(\bar{s}(V^{\text{q}}, \bar{V}_{k})), $$
(11)

where \(y_{k}^{\text {q}}\) is the k-th element of one-hot vector yq for representing the corresponding label of image xq.

4 Experiments

4.1 Datesets

We evaluate our approach on three commonly-used datasets, mini-ImageNet [3], tiered-ImageNet [22], and CIFAR-FS [49]. Mini-ImageNet consists of 100 classes sampled from ImageNet with 600 images per class. These images are divided into the base \(\mathcal {D}_{\text {base}}\), novel validation \(\mathcal {D}_{\text {val}}\), and novel test \(\mathcal {D}_{\text {test}}\) sets with 64, 16, and 20 classes, respectively, where both \(\mathcal {D}_{\text {val}}\) and \(\mathcal {D}_{\text {test}}\) corresponded to \(\mathcal {D}_{\text {novel}}\) in Section 3.1. The images in miniImageNet are of size 84 × 84. As all recent work, we adopt the same splits of [3] Tiered-ImageNet consists of ImageNet 608 classes divided into 351 base classes, 97 novel validation classes, and 160 novel test classes. There are 779,165 images with size 84 × 84. CIFAR-FS is a dataset with images sampled from CIFAR-100 [50]. This dataset contains 100 classes with 600 images each. We follow the split given in [49], which are 64, 16, and 20 classes for the base, novel validation, and novel test sets.

4.2 Experimental setup

Following most of the literature, we evaluate MTUNet on 10,000 episodes of 5-way classification created by first randomly sampling 5 classes from \(\mathcal {D}_{\text {base}}\) and then sampling support and query images of these classes with N = 1 or 5 and M = 15 per class. We report the average accuracy over K × M = 75 queries in the 10,000 episodes and the 95% confidence interval. We employ three CNN architectures as our backbone f𝜃, which are often used for FSL tasks, namely Conv-4 [7], WRN-28-10 [51] and ResNet-18 [52]. For ResNet-18, we remove the first two down-sampling layers and change the kernel of the first 7 × 7 convolutional layer to 3 × 3. We use the hidden vector of the last convolutional layer after ReLU as feature maps F, where the numbers of feature maps are 512 and 640 for ResNet-18 and WRN-28-10 respectively. There are three steps for training MTUNet.

Pre-training of backbone

The pre-training of the backbone CNNs is important for our PE module. We adopted a distance-based strategy, which is similar to SimpleShot [10]. We train the backbone CNNs with all images in \(\mathcal {D}_{\text {base}}\). The performance of a simple nearest-neighbour-based method is then evaluated over \(\mathcal {D}_{\text {val}}\) with 2,000 episodes of 5-way FSL tasks, and the best model is adopted. The learning rate for training starts at 10− 3 and is divided by 10 every 20 epochs. We train the models for 50 epochs.

Pre-training of PE

As for the PE module pre-training, we set d to 64, and the number T of the update is set to 3. The number z of the patterns is empirically set to 1/10 of the number of classes in the base set, which are 7, 36, and 7 for the mini-ImageNet, tiered-ImageNet, and CIFAR-FS dataset, respectively. Corresponding number of classes’ (subset of \(\mathcal {C}_{\text {base}}\)) images are selected to pre-train the module as a normal classification task similar to [38]. The importance of this choice is discussed in Section 4.5.1. Both gQ and gK have three FC layers with ReLU nonlinearities between them. All the parameters in the backbone f𝜃 are fixed. The learning rate for training starts with 10− 4 and is divided by 10 at the 40th epoch, and the total number of epochs is 60.

Training the whole network

For training the whole MTUNet, the learnable parameters in the backbone CNNs and PE are optimized with a small learning rate of 10− 5. We completely implement 20 training epochs. In a single training epoch, we sample 1,000 episodes of 5-way tasks. Other learnable parts of the model are trained to start with an initial learning rate of 10− 4, which is divided by 10 at the 10th epoch. We save the model with the best performance on 2,000 episodes evaluation sampled from \(\mathcal {D}_{\text {val}}\).

Our model is implemented with PyTorch, and AdaBelief [53] is adapted as an optimizer. Input images are resized to 80 × 80, and we applied data augmentation including random flip and an affine transformations, following [10]. A GPU workstation with two NVIDIA Quadro GV100 (32GB memory) GPUs is used for all experiments. Training 20 epochs on the mini-ImageNet dataset took approximately 19 minutes with a single NVIDIA V100 GPU. This computational cost is not high. We attested that a consumer-grade GPU can easily reproduce our results.

4.3 Few-shot classification results

MTUNet is compared with some popular FSL methods. We exclude methods in semi-supervised and transductive paradigms, which use the statistics of novel set across different FSL episodes. Besides the classification accuracy, we also consider the explainability of the raw image features for the backbone CNNs. Thus, we do not adopt any post-processing methods like 2 normalization in [10]. For testing the model, we report our best model on \(\mathcal {D}_{\text {val}}\) by randomly sampling 10,000 1-shot and 5-shots tasks from \(\mathcal {D}_{\text {test}}\) in Tables 12 and 3 over the three datasets. During testing, taking a 1-shot task for example, our model assigns the query image to one of the classes of support images. It is realized by (i) extracting regions from each of query and support images and extracting features from these regions with PE and (ii) matching the features with PM. The results of MTUNet (w/o PE) means the model trained without the PE module. This model has a structure similar to ProtoNet [7] and is used to evaluate the impact of the PE.

Table 1 Average accuracy of 10000 episodes of 5-way tasks on the mini-ImageNet dataset test set
Table 2 Average accuracy of 10000 episodes of 5-way tasks on the tiered-ImageNet dataset test set
Table 3 Average accuracy of 10000 episodes of 5-way tasks on the CIFAR-FS dataset test set

As seen in the tables, the prediction accuracy of MTUNet outperforms most existing FSL methods in both one-shot and five-shots settings. This proves that our model can achieve high prediction accuracy for FSL tasks. We also find that the different architectures of the backbone CNNs affect the performance. With simple backbone structure, Conv-4 tends to produce a lower performance. The variants with WRN always have a better performance than those with Conv-4 and ResNet-18. Asides from the difference in the network architecture, the size of feature maps may be one of the factors. On the mini-ImageNet dataset, the WRN variants have 20 × 20 feature maps, while the ResNet-18 variants have 10 × 10. Such larger feature maps not only provide more information to the PM module but also give a better basis of patterns as higher resolutions may help find more specific patterns. The results also demonstrate the learning ability of the PE. For all experiment settings, the PE can improve the model accuracy by approximately 2%-4% more than without the PE. This module filters useless features and focus on informative regions as it is designed to be. We will further analyse the importance of pattern number z and PE pre-training categories selection for training MTUNet in Section 4.5.1.

4.4 Explainability

In this section we will qualitatively and quantitatively evaluate the explainability of MTUNet.

4.4.1 Qualitative evaluation

In addition to the classification performance, MTUNet is designed to be explainable in two different aspects. First, pattern-based visual explanation. MTUNet’s decision is based on certain combinations of learned patterns. These patterns are localized in both query and support images through A(T), which can be easily visualized. This visualization offers intuition on the learned patterns and how much these patterns are shared between the query and support images. Second, visualization of pairwise matching scores. Thanks to the one-to-one matching strategy formulated as a binary classification problem in (9), the distributions (or appearances) of learned patterns in query and support images give a strong clue on MTUNet’s matching score s. In this combination, we may find the potential failure reasons by observing the matching matrix.

Pattern-based visual explanation

MTUNet’s decision is based on learned patterns, i.e., it is solely based on how much shared patterns (or features) appear in both query and support images. This design in turn means that, by pinpointing each pattern in the images, we can obtain an intuition behind the decision made by the model. This can be done by merely visualizing A(T).

Figures 4(a) and (b) show a pair of support and query images in the mini-ImageNet dataset for a 5-way task. The pairs (a) and (b) are of classes lock and horizontal bar, respectively. The second column shows the visualization of the aggregated overall attention, given by \(A^{\prime }\). The third to ninth columns are the visualization of the regions corresponding to the learned patterns in A(T) (i.e., the i-th row vector of A(T) represents the appearance of the i-th learned pattern at the respective spatial location).

Fig. 4
figure 4

Visualization of each pattern and the average features for a sampled task in the mini-ImageNet dataset. (a) is the lock class and (b) is the horizontal bar class. Overall is the overall attention among all patterns. The third to ninth columns are the visualization of the regions corresponding to the learned patterns

For (a) with class lock, the support image is a small gold combination lock used for storage cabinets or post boxes. Among all 7 patterns, only pattern 5 shows a strong response, whereas the others are not observed. We can see that pattern 5 pays attention to the discs of the lock in the support image. It also provides a strong response to the words on the left which shows similar morphological characteristics. The query image in (a) is a black combination lock often used for bicycles. The attention maps show almost the same distributions as the support, that is, only pattern 5 has a response on the discs. From these visualizations, we can infer that pattern 5 represents the character of the discs. MTUNet successfully finds a shared pattern although these two locks have a different appearance.

For (b), the support image is a gymnast wearing red. Multiple patterns are observed in the image. We can see that the visualization of pattern 1 identifies part of the human body (head), and pattern 3 appears around the hands grabbing the horizontal bar. The query image is a gymnast in blue. Patterns 1 and 3 respond in a similar way to the support image. Patterns 4 and 5 appear in the background and around other parts of the body, however, their responses are relatively weak compared to patterns 1 and 3. Patterns 1 and 3 may be responsible for human heads and hands grabbing the horizontal bar, leading to the successful classification of the unseen classes.

Visualization of pairwise matching scores

Figure 5 shows the visualized overall attentions \(A^{\prime }\) and corresponding origin support and query images (a 5-way 1-shot task on the mini-ImageNet dataset). Through the pairwise matching module, the FSL task is cast into a binary classification problem. The output for each pair is a value between 0 and 1 due to the sigmoid function, whereas the scores are shown as percentages in the figure. The support images are marked with different colors to represent the classes. The thickness of coloured lines shows higher or lower matching score between each support and query. Only pairs with a score over 0 are shown in the figure.

Fig. 5
figure 5

Matching point of one sampled task in the mini-ImageNet dataset. We only show the connection between pairs with a score over 0, and the scores are shown as percentages

Among all pairwise combinations, the combination of the support and query images of the catamaran obtains a full score (100%). The visualization of the overall attention covers the hulls, especially the masts, in both images, which are the main characteristics of this class. Class goose gets a low matching score. The query is a close-up of a goose on the ground from its front side, which captures the goose’s blackhead or beak. The support image is an overall view of a goose about to fly and the visualization of the overall attention captures the leg. With this combination, finding a shared pattern may not be easy, although these two extracted patterns are both representative parts of a bird. This problem stems from differences in viewing angles, which can be relieved in 5-shot tasks, giving more support from different viewing angles. Surprisingly, the query image for goose obtains 81% for the support image for beetle. This may suggest that one of the patterns responds to black regions and this pattern is solely used as the clue of goose. This is a negative result for the FSL task but clearly demonstrates MTUNet’s explainability on the relationship between visual patterns and the matching scores.

We also provide more visualization samples in Appendix A.

4.4.2 Quantitative evaluation

Our method is designed to interpret FSL tasks, and we think it necessary to compare the explainability of MTUNet with previous XAI methods using existing metrics. We adopt MTUNet without the PE with ResNet-18 as the baseline model and use existing XAI methods for explanations (We consider our PE module as the explainable module. After removing the PE, our model has a similar structure to ProtoNet). We conduct 10000 episodes of 5-way 1-shot tasks, obtain the visual explanations for each task using several types of XAI methods, and compare these explanations to the overall attention map \(A^{\prime }\) generated by our method (MTUNet ResNet-18).

We adopt three evaluation metrics for comparison. (i) Precision: We donate an input image as x and the foreground bounding box by \(\bar {x}\) (provided by ImageNet [61]). Thus, we can compute the area ratio of explanation within the bounding box by the \(\text {Precision} = {\sum }_{p \in \bar {x}} A^{\prime }(p)/{\sum }_{p \in x} A^{\prime }(p)\), where \(A^{\prime }(p)\) is the attention value in \(A^{\prime }\) at pixel p and \(A^{\prime }\) is resized to the same size as the input image. (ii) Insertion area under the curve (IAUC) [62]: This metric calculates the accuracy gain of the model when gradually adding image pixels in the order of importance given by the explanation. (iii) Deletion area under the curve (DAUC) [62]: This metric measures the accuracy drop when gradually removing important pixels from the input image. As shown in Table 4, the explanation of MTUNet outperforms existing XAI methods in all three evaluation metrics, which demonstrates the strong explainability of the proposed method. We think our intrinsic method has the advantage for the interpretation of FSL tasks. Due to the FSL sampling training strategy, both back-prop and perturbation methods may lack the ability to analyze such complex scenarios. While our method can provide an explanation within a simple inference step.

Table 4 Evaluation of MTUNet and existing XAI methods using explainability metrics

4.5 Discussion

4.5.1 Pattern setting

The pattern number z and categories selected for PE pre-training are important elements for training the whole MTUNet. In this section, we will analyse them from these two aspects.

The number z of patterns

The number of patterns can be another crucial factor for MTUNet. Intuitively, a larger z makes the model more discriminative. To show the impact of z, we uniformly sample classes in \(\mathcal {C}_{\text {base}}\) (i.e., defaulting to sampling every I classes from the class list, where I = 10,8,7,5,4,3,2, and 1); thus, I= 1 uses all classes in \(\mathcal {C}_{\text {base}}\).

The test accuracies are shown in Fig. 6 for 5-way 1-shot and 5-way 5-shot tasks on 10,000 sampled episodes over \(\mathcal {D}_{\text {test}}\) of the three datasets. The horizontal axis represents the number of patterns and the vertical axis represents the average accuracy. We would say that the performance has no obvious changes on the CIFAR-FS dataset as the number of z changes, whereas is has slightly decreased results on the mini-ImageNet dataset (approximately 1% for 1-shot and 2% for 5-shots). For the tiered-ImageNet dataset, when setting the pattern number to 51, an obvious performance drop is observed for the WRN backbone (approximately 3.5% for 1-shot), while this does not happen for the 5-shot setting. In general, tuning z may help gain performance, but its impact is not significant. It requires tuning the number z of patterns for each backbone and dataset. Since a small value of z can provide both high classification accuracy and convince the visualization of each pattern (e.g. Fig. 4), we recommend setting z to a small value according to the class number of the dataset. However, it might be an interesting research direction to estimate z, e.g., based on the number of classes in a given FSL task.

Fig. 6
figure 6

Results of pattern number settings for the mini-ImageNet, tiered-ImageNet, and CIFAR-FS dataset. The horizontal axis represents the number of patterns and the vertical axis represents the average accuracy. We report the results with 10,000 sampled 5-way episodes in the novel test set

Selection of classes for PE pre-training

Our PE module is supposed to learn common visual patterns. We use images of a certain subset of classes in \(\mathcal {C}_{\text {base}}\) to learn the initialization of such patterns in our experiments. The selection of this subset thus affects the performance of downstream FSL tasks. To clarify the impact of the choice of the subset, we randomly sample 7 classes 50 times in \(\mathcal {C}_{\text {base}}\) of the mini-ImageNet dataset, and 36 classes 20 times in the tiered-ImageNet dataset, and use the corresponding images for the training PE on top of ResNet-18. The trained PE is used for training MTUNet, which is evaluated over 2,000 episodes of FSL tasks with both the validation and test sets.

Figure 7 left shows a scatter plot of the validation accuracies and corresponding test accuracies. The mean and the 95% confidence interval over the 50 test accuracies for the mini-ImageNet dataset are 56.83% and 0.18%, respectively. This implies that our model benefits from a better choice of classes for PE pre-training. For this choice, we only have access to the validation set; since the validation set and the test set have disjointed classes, the best choice for the validation set is not necessarily the best choice for the test set. While, the plot empirically shows that the validation and test accuracies are highly correlated to each other, with a Pearson’s correlation coefficient of 0.71. We also implemented the experiments on the tiered-ImageNet dataset with 20 random samplings of 36 classes, which shows similar results. The results above lead to the conclusion that MTUNet is sensitive to the PE pre-training, however, we can use the validation set to find the best choice.

Fig. 7
figure 7

Performance of random classes sampling for PE pre-training of patterns. All experiments are implemented on the mini-ImageNet and tiered-ImageNet dataset using ResNet-18 as the backbone

4.5.2 Selection of metric learning methods

In our experiments, we find that a learnable metric by an MLP achieves the best FSL classification performance over commonly used predefined metrics, such as the Euclidean distance and the cosine similarity. As shown in Table 5, we can observe that the MLP performs the best for all backbone settings on the mini-ImageNet dataset. The accuracy difference is small for Conv-4 but noticeable for ResNet-18 and WRN. We can infer that the MLP better deals with features extracted from a larger backbone.

Table 5 Performance of different metric learning methods

4.5.3 Limitations and future work

Our experiments have shown that training MTUNet from scratch (i.e., without pre-training) was infeasible and that it even required two pre-training steps. The number z of patterns to be learned is a hyperparameter to be tuned for the given dataset. Pre-training of PE is sensitive to selection of classes. To address these drawbacks, we will study the relationship between the numbers of classes and patterns. This also requires to investigate the impact of different datasets in the training process. We will keep working on improving the training strategy to make it more agnostic to class selection in pre-training.

The core of MTUNet’s explainability lies in observing the combination of pattern-based visualization and matching scores. However, evaluation of this aspect is not straightforward because the patterns are learned in the course of training without supervision and thus there is no ground-truth. Due to this, we relied on quantitative evaluation to demonstrate the usability of MTUNet’s explainability. This problem may be mitigated by using or building a dataset with fine-grained annotation on possible patterns.

An interesting future direction of MTUNet is to extend it to different types of real-world data, other than images, such as videos and 3D medical images. Research in this direction has been already explored for some tasks, such as shot boundary and key frame detection [65, 66] and lesion localization [67, 68]. MTUNet’s extracted patterns may offer better explanation for tasks in these domains.

5 Conclusion

In this paper, we proposed MTUNet designed for explainable FSL classification tasks. Our model achieved higher classification performance than existing FSL methods on three benchmark datasets. The PE module serves to only include informative regions of image features extracted by CNNs backbone. It can learn better representations and is proved to be a necessary structure for improving prediction accuracy.

Our experiment results also quantitatively and qualitatively demonstrated MTUNet’s strong explainability through patterns in images. Compared to the heatmap-alone explanations provided by existing methods, our explanation can be realized through the combination of pattern-based visual explanation and pairwise matching scores which offer a better proof basis for model decision analysis. With this combination, we can further manually analyse the reason for failure cases, which is important to some high-risk areas (e.g. medical tasks). In addition, the approach taken in our model might be analogous to humans as we usually try to find shared patterns when making a match between images of an object that has never seen before. This can be advantageous since the explanation given by MTUNet can provide an intuitive interpretation (intrinsic) of what the model does.