Introduction

Machine learning (ML) in materials informatics (MI) has received significant attention in the academic research world and is gaining widespread adoption [1,2,3,4,5]. More specifically, it has recently been extensively studied for its use in the research and design of novel inorganic materials [6,7,8,9,10]. This is enabled by three major developments: (1) the increasing number of material property datasets as well as the improvement in dataset quality and variety, (2) the rapid pace and development of new ML models tailored to addressing different challenges in materials science (e.g., regression, classification), supplemented by (3) the increase in available computing power and accessibility to ML and deep learning tools. The combination of these developments led to improved capabilities in the exploration and modeling of material properties in the academic world.

Classical ML methods (e.g., linear regression, random forest, support vector machines) have successfully been used for the regression and classification of many material properties [11,12,13,14,15,16,17]. These methods usually rely on the featurization of the input chemical formulae into numerical features that are usable by the models. Typically, this is achieved through the use of a composition-based feature vector (CBFV), which uses descriptive statistics of the properties of constituent atoms in each compound to uniquely represent it [18]. Some common CBFV feature sets are Oliynyk, Magpie, Jarvis and mat2vec [11, 12, 19, 20]. Here, a distinction is made between physically derived CBFVs (with features based on measurable element properties) like Oliynyk, Magpie and Jarvis, and computationally derived CBFVs (with features obtained from computational or deep learning models) like mat2vec. For some properties, additional features such as structural information and processing or measurement conditions are included to further improve model performance [2, 16, 21, 22].

In more recent years, deep learning (DL) models have gained widespread popularity in MI due to numerous advantages compared to classical ML methods. Some examples are ElemNet, CGCNN, MEGNet, DimeNet++, and ALIGNN [23,24,25,26,27]. More recently, graph neural network (GNN) models incorporating attention-based mechanisms such as CrabNet, Roost and H-CLMP have gained increasing popularity [28,29,30]. GNNs have shown improved performance compared to other DL models, particularly in the absence of structural information as model inputs. Another advantage of GNNs is that the inductive biases built into the model and the input data structure are more suited to the learning of material properties, since the interactions between the atoms in the compound can be modeled as weighted interactions between nodes in a graph. In CrabNet, for example, the atom representations are either based on a CBFV feature (mat2vec element vectors) or a non-CBFV feature (onehot element vectors) [28]. For the sake of clarity, the remaining text will use the acronym DL to refer to both deep learning (DL) and graph neural network (GNN) models and methods.

Unfortunately, while DL methods show superb performance in modeling material properties, the element features used by these models typically do not represent any measurable physical property of the elements themselves. Instead, the element representations are learned from the data during the model training process. Therefore, they do not directly provide useful information or insights that can be interpreted by humans. This is different from the CBFV representation typically used in classical ML, where the features represent properties of the elements which are known a priori, such as the atomic mass, first ionization energy, or number of valence electrons.

Despite the high performance of the DL models, there is a disparity between their extensive study in academic research and their limited adoption in the industry for the exploration of materials. We term this disparity the “interpretability-adoption gap”. One significant hurdle to the widespread adoption of the often “black-box” models is the lack of built-in methods for model interpretation. While there are established methods of evaluating model performance in academia [14, 31,32,33], those who are less familiar with DL typically require more intuition into how the models function before they can fully trust the results. Particularly in industry, where there is usually a lower risk tolerance compared to academia, findings based on black-box models and vague model evaluation criteria are not enough to justify making high-stakes decisions such as investing in new research [5, 34,35,36,37,38]. Tangible methods of investigating and understanding model decision-making processes are therefore required to facilitate their adoption in an industrial setting [39].

This led to the development of explainable AI (XAI), which aims to introduce methods for deciphering the internal workings of black-box models and thus enabling users to understand the modeling processes and results [39, 40]. Examples of XAI in research fields outside of MI include: visualizing word embeddings in natural language processing [41,42,43], inspecting decision-making processes in reinforcement learning [44,45,46], visualizing pixel importances [47, 48], or segmenting in computer vision [49, 50]. To date, however, XAI techniques have—with the exception of a few works employing classical ML—largely been underexplored for DL in the MI field [10, 51, 52].

Two common post-hoc model-agnostic methods for obtaining explainable models in classical ML are SHAP and LIME [39, 53,54,55]. Both of these methods are built on top of existing black-box models and use local feature perturbation to estimate the contributions from input features towards the predictions. Other models such as random forest, gradient boosting, and lasso regression inherently provide model interpretability via the use of internal feature importance metrics and (in some models) through bootstrap sampling and feature sampling [39, 51, 56]. Nonetheless, these techniques require that the individual features of the input data are meaningful and represent a measurable feature or physical property. This works in the domain of classical ML and when using a physically derived CBFV to featurize compounds; however, this is not the case for DL methods where the features typically do not reflect measurable values. Thus, these traditional ways of model interpretability fall short in use for the DL models.

Therefore, it is the goal of this work to explore how to increase model interpretability in DL models specifically for applications in MI. Here, we demonstrate how parts of the typically black-box modeling process can be communicated visually and in an interpretable way, using our attention-based model, CrabNet [28]. We have extended CrabNet’s architecture to enable intrinsic interpretability using several methods to be discussed below. In this regard, we lay the first bricks in the bridge spanning the interpretability-adoption gap between academia and industry. This will not only aid researchers in further developing complex models with interpretability in focus, but also promote the adoption of these modeling methods in the materials science industry.

Results and Discussion

The results of this study are described in five subsections. We first compare the element embeddings learned by CrabNet against other CBFV feature sets from the literature, and show how chemical behavior and patterns in element properties can be learned entirely from the training data for each material property. We also show that the learned element representations are comparable to physically derived CBFVs. Secondly, as part of this analysis, we characterize the element prevalence imbalance in the datasets using the Shannon equitability index and relate that to the quality of the learned element embeddings. Third, we further examine how the element representations are successively updated using information about their chemical environment in the compounds, and how they may be used to gain additional insights about element behaviors in different environments. Fourth, we inspect how entire chemical compounds can also be adequately captured using the EDMs and subsequently visualized. We identify interesting trends in the compound representations relating the bond character and number of elements in the compounds to the material property and prediction error, and discuss how such visualizations can lead to additional understanding about the modeling process and the underlying materials chemistry. Lastly, we explore how the self-attention mechanism in CrabNet can be visualized in the form of videos and used to further examine the modeling process, leading to potential new insights about the chemical interactions within a compound. While we use the OQMD_Bandgap dataset to demonstrate the analyses, we note that similar analyses can be also carried out with any of the 28 materials datasets presented in this work.

Learning Meaningful and Per-Property Element Representations

Element representations were obtained as featurized CBFVs, which are fixed-length vectors where each element is uniquely described by the same set of features [12, 18]. For the Oliynyk, Magpie and mat2vec element property feature sets, we use the published vectors to represent the elements [18, 20]. For the CrabNet element representations, we extract the element vectors from the element-derived matrices (EDMs) at the output of the embedding layer (please refer to the CrabNet publication for architecture details [28]). We then examine the similarity between two element vectors x and y by computing the Pearson correlation coefficient r using Equation 1:

$$\begin{aligned} r = \frac{{}\sum _{i=1}^{n} (x_i - {\bar{x}})(y_i - {\bar{y}})}{\sqrt{\sum _{i=1}^{n} (x_i - {\bar{x}})^2(y_i - {\bar{y}})^2}} \end{aligned}$$
(1)

where n is the number of features, \(x_{i}\) and \(y_{i}\) are the values of the ith feature, and \({\bar{x}}\) and \({\bar{y}}\) are the mean values of x and y, respectively.

The correlation r ranges from -1 to 1; the higher or lower the value of r is, the more correlated or anticorrelated are the features that describe the elements, respectively. A value of zero means that there are no correlations between the features of the elements. We compute the pairwise correlation coefficients between the element vectors for all elements and for all element property representations, and show these as heatmaps in Fig. 1. Note that the plots are cropped to the range of elements of the Oliynyk heatmap to aid comparison; please refer to supplementary Fig. S-1 in the supplementary information (SI) for the full heatmaps. In addition, interactive versions of the plots are provided in the SI.

Fig. 1
figure 1

Heatmaps of Pearson correlation matrices between element vectors featurized using a Oliynyk, b Magpie, and c mat2vec element property feature sets. The x- and y-axes are labeled with the atomic numbers. Each cell at coordinate (xy) represents the correlation between the corresponding elements with atomic numbers x and y. Blue represents a high correlation and red represents a high anticorrelation. For the interest of comparison, the heatmaps are truncated to the dimensions of the Oliynyk heatmap. Empty rows indicate that no element vector is available

Here, we can observe that element vectors based on the Oliynyk and Magpie CBFVs contain large regions of similar color in the heatmap. The regions of similar color indicate that the element representations are either highly correlated or highly anticorrelated with each other. Furthermore, these regions are very similar between the two CBFVs. This is expected, since the CBFV features are based on physical properties of the elements. Thus, elements with similar physical properties will be more correlated while dissimilar elements will be more anticorrelated. Accordingly, the large colored regions typically correspond to similarities and dissimilarities between elements from families in the periodic table, such as alkali metals, alkaline earth metals, transition metals, metalloids and reactive nonmetals.

On the other hand, the element vectors from a DL model such as mat2vec do not exhibit such prominent behavior. Overall, the elements show less correlation with each other, and—with the exception of a few areas (to be discussed in later sections)—do not show large continuous regions of similar color. This is due to the fact that the starting element representations in DL models are randomly initialized and are not based on physical properties of the elements. These vector representations of the elements are only updated by the model throughout the training process using the training data. Thus, the correlation patterns that can be observed in this figure represent distinct patterns that the DL model has learned solely from the provided data.

We also note that a different number of element vectors are recorded in the feature sets. For the Oliynyk and Magpie CBFVs, only the elements up to uranium and berkelium are reported, respectively, while vectors up to the element oganesson are provided by mat2vec (please refer to supplementary Fig. S-1 in the SI for the uncropped heatmaps). Particularly for the Oliynyk CBFV, some element vectors are missing, as visible by the empty rows in the heatmap. This disparity in the availability of element vectors between different CBFVs can be caused by reasons such as the instability or rarity of elements, lack of adequate information about the elements, or the inability to measure properties about the elements. The lack of element vectors in some material property feature sets can limit their applicability for certain tasks (such as when studying rare elements) and will be discussed in more detail in later sections.

In addition to learning element representations for a general purpose in materials science, such in the case of mat2vec, DL methods can also learn to relate element characteristics on a material property-specific basis. For example, element embeddings were extracted from the CrabNet and HotCrab models which were reproduced using the supplied model weights and the source code [57, 58]. The CrabNet and HotCrab models use mat2vec and onehot-encoded element features as the starting element representations, respectively. These features are then fine-tuned by the models for each of the 28 reported datasets. We extract one set of element embeddings from each layer of the models. Then, the Pearson correlation between the element vectors are calculated and shown in Fig. 2.

In this work, we use the OQMD_Bandgap dataset to demonstrate our findings. Additional example plots for other properties can be found in the SI. The OQMD datasets are widely used by researchers to evaluate model performance. For detailed information about the OQMD_Bandgap dataset as well as information and discussion about the calculated values, please see the literature [59,60,61].

Fig. 2
figure 2

Heatmaps of Pearson correlation matrices between element vectors extracted from CrabNet and HotCrab. These element representations are learned entirely from data. The x- and y-axes are labeled with the atomic numbers. Each cell at coordinate (xy) represents the correlation between the corresponding elements with atomic numbers x and y. The top row (a and b) shows the correlations between embeddings from CrabNet and the bottom row (c and d) from HotCrab. The left and right columns represent the embeddings extracted from the first and last layer of the models, respectively. Blue represents a high correlation and red represents a high anticorrelation. In d, some regions of interest are annotated

Here, we can observe that both CrabNet and HotCrab are able to learn embeddings for each element of the periodic table, and that the correlations between the elements have a similar pattern, irrespective of the starting element representation (mat2vec or onehot). The observed correlation patterns are also similar to the mat2vec patterns as seen in Fig. 1c. The ability of both CrabNet and HotCrab models to learn similar element embeddings despite having drastically different starting representations is encouraging, and further suggests that domain knowledge is not necessarily required for element featurization if a sufficient quantity and quality of training data is available [18]. This finding is corroborated by the similarly good performance of both models across a wide range of material properties [28]. Interestingly, for deeper layers of the models (Fig. 2b and d), more intense correlation patterns between the elements emerge. This is likely attributed to the self-attention-based learning mechanism of the underlying CrabNet models. At each successive layer within the model, information about additional element-element interactions within the compound (i.e., the chemical environment) are successively taken into account when updating the identity of an element within that compound. As a result, the deeper the layer within the model, the more complex the element interactions—and the element representations—become.

It is also interesting to note the diagonal and horizontal patterns which can be observed in all of the correlation matrices. For example, in Fig. 2d there is a 45-degree diagonal, blue line that can be seen in the correlation matrix starting at the coordinates (13, 31) (corresponding to the element pair (Al, Ga)) and continuing until (40, 58) (corresponding to (Zr, Ce)). This line highlights the well-known periodic law which states that elements with similar chemical properties fall into recurring periodic groups. Please refer to supplementary Fig. S-2 for the enlarged version of the annotated heatmap and for correlation plots for other material properties. Another observation is the triangular region of high correlation between (57, 57) and (71, 71), which indicates that the first-row elements of the f-block are highly similar to each other. A similar triangular region can be observed between (23, 23) and (29, 29), indicating similarities between some first-row elements of the d-block. Lastly, the vertical blue line starting at the coordinates (39, 57) and continuing to (39, 71) indicate the chemical similarities between yttrium and the first-row elements of the f-block. These and other patterns can also be observed in the Oliynyk and Magpie CBFVs in Fig. 1 as well. The ability of the CrabNet and HotCrab models to learn such chemical relationships which are comparable to hand-curated CBFVs based solely on the chemical formulae is exciting, and further reaffirms the finding that hand-engineering of features is not needed when training on big data [18].

Moreover, in Fig. 2c we observe a distinct “border” at the element plutonium (with atomic number 94), where the correlation coefficients between the elements suddenly decrease and the patterns become less pronounced. Additional analysis of the OQMD_Bandgap dataset showed that it does not contain any compounds with elements past plutonium. Due to the fact that the element representations are learned purely by the model from the dataset, their quality depends heavily on the quality of the dataset. Since the model performance depends on the quality of the element representations, by extension, it also then depends on the dataset quality [32].

We define element prevalence as the number of times a certain element has appeared as part of the compounds in a given dataset. When examining the OQMD_Bandgap dataset, we note that there is an imbalance in element prevalence, with oxygen and copper appearing almost 1.5 times to twice as often, and fluorine, chlorine, bromine and iodine appearing only less than 0.1 times as often as the majority of the other elements in the dataset, respectively. This imbalance in element prevalence is even stronger for other datasets such as the aflow__Egap, castelli, CritExam, mp_e_form and phonons datasets (see supplementary Fig. S-3 in the SI for some example element prevalence plots).

Quantifying Dataset Imbalance

The degree to which a dataset is imbalanced (otherwise referred to as its “evenness”) can be measured using the Shannon equitability index, which is a function of the Shannon entropy of the dataset [62,63,64]. Shannon entropy is widely used in information theory and can be used to characterize the degree of imbalance in a dataset [65, 66]. The Shannon entropy H is defined in Equation 2 as:

$$\begin{aligned} H(X) = -\sum _{i=1}^k {\mathrm {P}(x_i) \log \mathrm {P}(x_i)} \end{aligned}$$
(2)

where X is the set of discrete variables \(x_i \in \{x_1,\ \ldots , \ x_n\}\), i is the class, \(\mathrm {P}(x_i)\) is the proportional abundance of \(x_i\) and k is the total number of classes in the dataset.

For a dataset \({\mathcal {D}}\) of n data occurrences and k distinct chemical elements (classes), each with counts \(c_i\), \(\mathrm {P}(x_i) = \frac{c_i}{n}\) and the Shannon entropy can thus also be written as Equation 3:

$$\begin{aligned} H({\mathcal {D}}) = -\sum _{i=1}^k {\frac{c_i}{n} \log \left( \frac{c_i}{n}\right) } \end{aligned}$$
(3)

For continuity, we note that when \(c_i = 0\), it means that no data sample is related to class i in the dataset, and therefore the multiplicand within the summation is defined to be 0. Mathematically, \(\lim _{p \rightarrow 0^+} p\log (p) = 0\). The maximum value of \(H({\mathcal {D}})\) is \(\log (k)\). This value occurs when all element classes in the dataset are observed at the same frequency (i.e., the dataset is completely balanced). Therefore, the Shannon entropy \(H({\mathcal {D}})\) is scaled by \(\log (k)\) to finally obtain the Shannon equitability index \(E({\mathcal {D}})\), which is defined in Equation 4 as:

$$\begin{aligned} E({\mathcal {D}}) = \frac{H({\mathcal {D}})}{\log (k)} \end{aligned}$$
(4)

\(E({\mathcal {D}})\) ranges between 0 for a maximally imbalanced dataset and 1 for a maximally balanced dataset. The Shannon equitability indices are calculated for the 28 datasets examined in this work and are presented in Table 1. A plot showing the same information can be found in the SI (supplementary Figure S-4). For more information about the datasets, please refer to the CrabNet publication [28].

Table 1 Shannon equitability indices calculated from the training data splits of the 28 reported datasets. Datasets were taken from [28]

As can be seen in the table, the datasets studied in this work are not equally balanced in terms of element diversity. The more imbalanced a dataset is in terms of the element prevalence in the chemical compounds, the less likely the models will be able to adequately learn about the elements and their environments. The element embeddings learned for the infrequent elements will therefore be weaker and will not be able to capture as much information about these elements as compared to more frequently occurring elements. This leads to the observed weak correlation patterns between the less frequently seen elements beyond a certain cutoff atomic number in the datasets, as discussed earlier for Fig. 2.

If the weakly learned elements are then encountered during inference time, the model will not be able to make an adequate prediction using the elements’ representations. Additionally, if certain elements or element combinations appear more frequently (majority classes) in the datasets as compared to other elements or combinations (minority classes), the model may be biased to better capture the behavior of majority classes at the expense of sacrificing performance on the minority classes. Such a dataset bias may appear in computational or experimental datasets due to the fact that some elements are more commonly studied for certain material applications. On the other hand, certain elements (e.g., rare or unstable elements) naturally occur less frequently and therefore are also contained in fewer compounds and datasets. Certain elements such as noble gases also rarely form compounds with other elements and are therefore rarely reported in materials datasets.

It is therefore important to implement data processing and modeling techniques to address biases as a result of dataset imbalance. Some example techniques include dataset re-sampling, generating synthetic data for imbalanced classes, implementing weighted loss functions that penalize errors for minority classes more, or using alternative loss functions and metrics to evaluate model performance [64, 67, 68]. Additionally, the model architecture can also be tailored to address dataset bias, and certain types of models (such as those based on self-attention or guided attention architectures) have an increased robustness against dataset bias [69, 70].

Lastly, it is worthy to note that while most DL models learn element representations from structured materials datasets, methods such as word2vec and mat2vec use text mining and other natural language processing (NLP) techniques to learn the element embeddings from academic publications [20, 71, 72]. The data present in publications covers a much longer time period and contains a higher diversity in terms of types of compounds, material properties and applications studied. These data are in unstructured form and therefore cannot be used as training data for DL methods such as CrabNet; however, they can easily be used for word2vec and mat2vec. Therefore, these text mining methods are able to learn from a much larger corpus of materials data and are not restricted by the availability of structured datasets. Accordingly, DL models such as CrabNet can benefit by using the pre-trained mat2vec element embeddings and fine-tuning them to new tasks, thereby minimizing the impact of missing elements in the training dataset.

Capturing the Influence of Chemical Environments on Element Representations

In addition to learning the representations of each element, CrabNet and HotCrab can also capture the behavior of the elements when they are present in different chemical environments. Figure 3 shows the two-dimensional projections of the element vectors corresponding to the silicon atom from 2374 different silicon-containing compounds within the OQMD_Bandgap test dataset. The silicon vectors are extracted from the transformed EDM tensors from HotCrab (a onehot-featurized version of CrabNet) and show the transformation of the silicon representations after they are passed through the three successive self-attention layers. For visualization, the vectors are projected down to two dimensions using the uniform manifold approximation and projection (UMAP) method [73]. The resulting points are plotted and colored by three parameters: (1) the fractional abundance of the element silicon in the compound, (2) the predicted property value of the compound (in this case, band gap), and (3) the oxidation state of silicon as predicted by Pymatgen [74]. For more information, please see the Methods.

Fig. 3
figure 3

Vector representations of the silicon element in 2374 different chemical environments and at different layers of the HotCrab model. Each point shows the model-internal representation of the silicon atom, after the information regarding the other atoms in the chemical environment have been introduced via HotCrab through the three attention layers (top row to bottom row). The points are colored by: (left column) the fractional abundance of silicon, (center column) the predicted value of the compound, and (right column) the predicted oxidation state of silicon, where gray points indicate that the oxidation state was unable to be predicted. Four clusters are outlined in the bottom-left plot

As can be seen in the plots from the first layer (first row), there is a large number of distinct point clusters, with one major cluster near the center, two medium clusters above and below the center cluster, and many smaller clusters consisting of a few points. The larger clusters are formed because the initial representations of the silicon atoms are very similar to another (due to the learned element embedding of silicon). The similar silicon vectors are thus projected through UMAP into coordinates that lie close together, even though the silicon atoms are present in different chemical environments.

We can observe as well that the clustering in layer one is mostly attributable to the fractional amount, since each cluster consists primarily of points with the same fractional silicon amount. After the second layer, we observe that the points start to become separated into different and recognizable clusters. The clusters are no longer identifiable entirely based on the fractional amount of silicon, and clusters based on the predicted band gap value of the compound and oxidation state of silicon start to emerge. By the end of the third and last layer, we can observe four clusters that are distinguishable by the fractional amount of silicon, the predicted band gap, and the oxidation state of silicon (the clusters are outlined in Fig. 3, bottom left).

More specifically, we observe that the cluster at the bottom-right side of the plot consists mainly of silicon with a fractional amount of around 0.15 to 0.3 (with a few points reaching 0.5), whereas the cluster near the bottom-left contains almost exclusively of silicon with fractional amounts of 0.5 plus a few points above 0.5. The cluster near the top contains regions of silicon with fractional amounts between 0.3 to 0.4 near the left and right, and around 0.2 to 0.3 in the middle. Near the top of this cluster, a smaller cluster is highlighted which consists mainly of silicon instances with low abundance, between 0.2 and 0. Please note that interactive versions of these plots can be found in the SI together with another example visualization plotted for the element chromium (supplementary Fig. S-5).

In the predicted value plot of the last layer, we observe that only the small cluster near the top contains the silicon element in compounds with a nonzero band gap. Similarly, when examining the oxidation state plot, we note that while most clusters contain a mixture of silicon atoms in several oxidation states, the same cluster near the top consists almost exclusively of silicon atoms in the \(+4\) oxidation state and very few atoms in other oxidation states. Closer examination reveals that this cluster consists primarily of silicate materials such as Ca2SiO4, CaMgSiO4, MgMnSiO4, Li4SiO4, Sr3MgSi2O8, Li2MgSiO4, and others. Interestingly, while some compounds with silicon in the \(+4\) state are visible in other clusters, these compounds have a zero band gap. This suggests that additional interactions between the elements were captured by HotCrab which lead to these compounds being correctly clustered together with other compounds with zero band gap.

These element behavior plots suggest that for silicon-containing compounds in the OQMD_Bandgap dataset, the fractional amount and the oxidation state of the silicon atoms are important factors that together determine the band gap of the compounds. By cross-referencing the three plots, we can identify trends between the fractional amount and oxidation state of silicon and relate this information to the predicted band gap of the compounds. On the other hand, the clustering also suggests that there are other interactions between the elements in a compound which are currently not highlighted by the selected properties in Fig. 3. It is our expectation that by examining these interactions, additional insight about the modeling process and element representations can be gained. Moreover, the findings from examining internal representations of elements in this way may suggest additional studies to further improve the understanding of the underlying phenomena governing materials behaviors. Note that while these visualizations were generated using HotCrab, similar results can be obtained using the CrabNet model.

Capturing Globally Unique Representations of Chemical Compounds

In addition to examining the behavior of individual elements in different chemical environments, we can also visualize all of the compounds in a given dataset to uncover additional insights. We extract the internal vector representation of all of the 51242 compounds in the OQMD_Bandgap test dataset from the last self-attention layer of HotCrab, perform dimensionality reduction using UMAP and finally visualize the compounds as shown in Fig. 4. In addition to coloring the plots by the predicted value, prediction error, and number of distinct elements for the compounds, we also highlight the chemical trend between ionic to covalent bonding character within the compounds. This trend is revealed by calculating and visualizing the standard deviation of the Pauling electronegativities of the constituent atoms \(\sigma _{\chi }\) in a given compound [75] according to Equation 5:

$$\begin{aligned} \sigma _{\chi } = \sqrt{ \frac{\sum _{i=1}^n (\chi _{i} - {\bar{\chi }})^2}{n - 1} } \end{aligned}$$
(5)

where \(\chi _{i}\) is the Pauling electronegativity of each element i in the compound (totaling n elements), and \({\bar{\chi }}\) is the average electronegativity of all elements in the compound. A higher \(\sigma _{\chi }\) signifies a more ionic bonding character, and a lower value signifies a more covalent bonding character.

Fig. 4
figure 4

Global representations of the 51242 compounds in the OQMD_Bandgap test dataset, extracted from layer three of HotCrab, embedded down to two dimensions using UMAP and colored by the parameters: a the predicted value of the compound (band gap); b the prediction error (\({\hat{y}} - y\)); c the bond character of the compounds ranging from more covalent (blue) to more ionic (red) as measured by the standard deviations in the Pauling electronegativities of the constituent elements; and d the number of distinct elements in the compound. A cluster of interest is outlined in the plot at the top-right

Many clusters with varying sizes are visible in the figure. Some clusters are placed further apart, while some clusters are closer to, or are overlapping other clusters. In particular, the outlined cluster near the right of the figure is of particular interest. This is the only cluster where the compounds with a nonzero band gap are located, as is visible from Fig. 4a. Additionally, it is also within this cluster that HotCrab makes the largest errors when predicting the band gap value, as seen in Fig. 4b. For the other compounds, the prediction errors of HotCrab are close to zero. Even through a small proportion of model predictions have larger errors, the overall model performance is very good and is comparable with, or better than, other state-of-the-art models [28]. This superior performance of CrabNet and HotCrab models when predicting properties with a defined cutoff (such as the cutoff of 0 eV in this case for band gap) is likely attributed to the prediction of element-logits in the modeling process. These element-logits are used to weight the final model predictions in CrabNet and HotCrab to improve the model accuracy [28].

Notably, we also observe from Fig. 4a, c and d that the band gap only partially depends on the bond nature of the compound and on the number of unique elements in the compound. While most of the compounds in the cluster of interest exhibit more ionic bond characters, there are also other clusters with similar bond character that do not have a nonzero band gap. Similarly, it appears that the compounds with a nonzero band gap mainly contain four or five unique elements; however, there are also other compounds with these numbers of unique elements which have a zero band gap.

Here, we do note that while UMAP can reveal structures and patterns within high-dimensional data, it generally emphasizes local structure at the expense of global structure. Therefore, for the UMAP visualizations shown in this work, it is more appropriate to interpret the local structure (e.g., the elements or compounds present within individual clusters in Fig. 3 and 4) than the global structure. While the number of local neighbors considered can be specified as a hyperparameter in UMAP, a trade-off is made between preserving local versus global structure. Therefore, the distances between elements and compounds within a single cluster are more meaningful than inter-cluster distances in the UMAP visualizations. Lastly, we note that while these visualizations were generated based on the test dataset using HotCrab, similar results can be obtained using CrabNet or the training dataset.

Visualizing the Training Progress

Beyond visualizing the element and compound representations from CrabNet after training, it is also possible to access the self-attention matrices of the CrabNet encoding layers to observe the model learning process during training. The attention matrices (commonly referred to as the attention maps) contain information regarding how each element (rows) is influenced by all other elements in the compound as well as itself (columns). The values in the attention maps are the attention scores and are used in the encoder to update the element representations. An attention score of zero means that the element in the column is completely ignored when updating the element’s representation in that row. Conversely, a score of one means that the entire update is based solely on that column’s element.

In the CrabNet publication [28], example attention maps were shown for compounds after the model has finished training. Here, we extend this approach by visualizing the CrabNet attention maps during the model training process in the form of attention video clips (see SI files for examples). This is achieved by saving the attention matrices from the model encoder layers after every mini-step in the training process and generating a video to show the learning progress. Fig. 5 shows a snapshot of two example attention videos obtained at the end of model training. The attention maps from the first encoding layer of CrabNet are plotted as heatmaps in the left column, while the right column shows the predicted values from the model against the target value at every mini-step. This process is performed at every mini-step in the training process, and the resulting plots are merged into a video clip which shows the learning progress of the model throughout training.

Fig. 5
figure 5

Snapshots of attention videos for observing the training progress of CrabNet using two example compounds a \({\text{Gd}_{1}\text{Mn}_{1}\text{Si}_{1}}\) and b \({\text{C}_{5}\text{Ca}_{1}\text{Fe}_{1}\text{H}_{8}\text{N}_{6}\text{O}_{5}}\) from the validation data split of the aflow__Egap dataset. The left plots show the attention maps of the four attention heads at the first attention layer, where the x axis of each heatmap is labeled with the fractional amount of the elements and the other axes are labeled with the element symbol. The right plots show the model predictions (blue) for the compounds, evaluated after each training mini-step throughout the whole training process. The true property value (target) is represented with the red “X” and the dotted line

From the attention maps, we can observe that some elements are considered less relevant in the determination of the material property, whereas some elements are considered very relevant. Also we can note that individual attention heads pay attention to different element-element interactions in the compound, as is visible by the significantly different attention patterns in the plots. Throughout the training process, the attention pattern for each head remains relatively fixed after a few mini-steps, indicating that the model discovers a pattern for recognizing inter-element interactions early on in the training process, which it then continues to refine as more training steps are taken.

For the top compound, we can observe that while the model initially over- and underestimates the property value early on in the training, it learns to correct the error and finally achieves a low prediction error towards the end of training. Conversely, for the bottom compound, we observe that while the model initially correctly estimates the property value of the compound, the predicted value decreases and the estimation error increases throughout training, with the error finally plateauing towards the end of the training. By examining the attention heatmaps for this compound, we notice that attention head 1 shows a significantly different behavior as compared to the other attention heads. It dedicates almost all of its attention to the element iron, while the other attention heads capture many more inter-element interactions. It may be interesting to investigate further to find out if CrabNet is misrepresenting the interactions from the iron element with the other elements and thus making the prediction error, or if another phenomenon is contributing to the prediction error on this compound.

By observing the element groups and inter-elemental interactions that CrabNet pays attention to for each material property throughout the training process, we may be able to gain additional insight about which relevant elements and interactions contribute significantly to the material property. Similarly, in the case where the model does not make a good property prediction or fails to learn a specific material property, these attention videos can be informative in showing when, where, and how the model fails. Additionally, since the element representations in a compound are updated according to the attention scores, it would be interesting to train CrabNet on material properties where the property has a high sensitivity to changes in elemental prevalence. An example of this is in the case of dopants, where a small change in the dopant amount can significantly influence a material’s electrical [15, 76, 77], mechanical [17, 78,79,80], and thermal properties [81,82,83,84]. Finally, it may be interesting to expand the studied materials to include co-doped materials and use the attention videos to visualize the complex inter-elemental interactions between the co-dopants and the host elements.

Conclusion

In this work, we examined the CrabNet model through the use of several built-in model interpretability methods in order to visualize the data featurization and modeling process. We demonstrated that CrabNet can adequately capture the chemical behavior of compounds in a dataset by using the vector representations of their constituent elements. The element representations can be learned entirely from the training data on a per-property basis, and contain rich information about the elements and their chemical trends. Additionally, we examined dataset imbalance, its relation to the quality of learned representations, and the limitations that imbalanced datasets may ultimately impose on the modeling processes.

The element and compound vectors can be projected using UMAP into distinguishable clusters which can then be visualized and characterized by the element stoichiometry, local chemical environment and oxidation state of the elements, or by the bond behavior of the compounds. Lastly, the examination of the self-attention matrices during model training in the form of attention videos can be used to further understand the modeling process, debug potential model or dataset errors, or gain additional insights about chemical interactions within a given compound.

The model interpretability techniques presented in this work will enable materials science practitioners to not only visualize a specific element’s behavior within different chemical environments, but also to obtain a global view of the chemical compounds, behaviors and trends within a larger dataset. The ability of CrabNet to adequately model and express the complex chemical behaviors and interactions of elements and compounds based solely on learning from data is encouraging. With the addition of model interpretability methods to CrabNet, the findings and intuitions presented in this work may lead to further insightful and interesting research. Specifically, we believe that follow-up works may fall into one of these three general directions:

  1. 1.

    Learning and representing elements and compounds. Our work has shown that it is possible to visualize CrabNet’s internal representations of elements and compounds via techniques such as UMAP. However, it would be interesting to further investigate why CrabNet’s representations of some of these elements or compounds lead to them being placed into the same cluster or not, despite the fact that these elements and compounds are similar to each other in terms of identity and/or chemical environment. This may also be combined with a more detailed examination of the attention videos and how the attention mechanism in CrabNet leads to the updating of the element representations for each compound.

  2. 2.

    Examination of individual attention head behaviors. This work used the EDM (element-derived matrix) data from CrabNet to examine the element and compound representations within CrabNet. CrabNet utilizes four self-attention heads to model element-element interactions, the results of which are then concatenated and transformed back to an updated EDM matrix. As such, the EDM is a pooled representation of the compounds. It would be interesting to further examine the per-head modeling of the compounds, as it has been shown that each head can capture different types of inter-element interactions and thus may give additional insight to the modeling process within CrabNet.

  3. 3.

    Discovery of additional inter-element interactions. From the analyses presented in this study, it is clear that while some changes in the material property (e.g., band gap) can be explained by certain properties of the compounds (such as element stoichiometry, number of unique elements, and/or bond character), there are additional behaviors that govern the material property. These additional interactions are also adequately modeled by CrabNet, since it can predict a wide range of material properties with low errors. Examining the modeling process of these behaviors within CrabNet may lead to an improved understanding of the complex phenomena underlying material properties.

Further research to answer these and subsequent questions may allow us to gain additional insights about the behaviors and properties of elements and materials, improve our understanding of models such as CrabNet, increase our confidence in the use of data-driven methods, and ultimately, accelerate the adoption of deep learning and machine learning in materials science.

Methods

Adaptation of CrabNet Model

The CrabNet model and material property datasets as originally reported were used as the basis for this study [28]. Fully trained model weights for both CrabNet and HotCrab were obtained from [57]. In order to obtain the EDMs containing the elements and compounds data used in this study, custom function hooks were implemented in PyTorch. These hooks were attached to the CrabNet model architecture to allow access to the model-internal data during training and inference.

The source code as well as the data that were used and generated in this study can be found on the updated CrabNet GitHub repository [58]. In addition, we provide detailed instructions for the use and reproduction of our reported results. Please note that due to the prohibitively large size of the stored attention matrices used in the attention videos, it is not possible to provide these for download. However, instructions and scripts are provided for generating these matrices and videos.

All experiments, unless otherwise noted, were performed on a workstation equipped with an Intel i7-8700K CPU, 32 GB of DDR4 RAM, and one Nvidia RTX 2080 GPU.

Element Embeddings

Element embeddings for pure elements were generated on a per-property basis. To do this, an EDM consisting of all of the elements from hydrogen to oganesson was generated (with each row representing one element). Then, for each material property, the corresponding CrabNet or HotCrab model was loaded and the model hooks attached. The EDM was then passed through the network and the modified EDM at the output of the element embedding layer was obtained and detached from the model graph. This resulting EDM contains the property-specific element embeddings of all of the elements. Thus, each element was represented by a vector with the shape \((1, d_\mathrm {model})\), where \(d_\mathrm {model}\) is the size of the embedding. Element embeddings for Oliynyk, Magpie, and mat2vec were obtained from the original publications [18].

Compound Embeddings

Compound embeddings were obtained in a similar fashion to element embeddings. Instead of generating an EDM from pure elements, the EDMs were generated from the actual chemical formulae from the datasets and collated in batches using the model data loader. Model hooks were then attached to the CrabNet and HotCrab models and enabled during model inference. The transformed EDMs after each of the three self-attention layers of the CrabNet models were then collected.

The obtained compound EDMs have the shape of \((n_\mathrm {compounds}, n_\mathrm {elements}, d_\mathrm {model})\), where \(n_\mathrm {compounds}\) is the total number of compounds in the dataset, \(n_\mathrm {elements}\) is the maximum number of elements per compound, and \(d_\mathrm {model}\) is the size of the embedding. Thus, each compound in the EDM is represented by one tensor slice with the dimensions \((1, n_\mathrm {elements}, d_\mathrm {model})\). Due to the fact that different compounds within the same dataset may contain a different number of elements, the extra rows of the EDMs were zero-filled to indicate no elements present. In order to ensure that the compound embeddings are comparable with each other using UMAP, the three-dimensional compound EDMs were collapsed to two dimensions \((n_\mathrm {compounds}, 1, d_\mathrm {model})\) by calculating summary statistics (such as sum, range, variance) of the EDM columns across the elements dimension.

Dimensionality Reduction

CrabNet uses vectors with a \(d_\mathrm {model}\) dimension of 512 to represent chemical elements and compounds in the input data. It would be infeasible to try to visualize all 512 dimensions. Therefore, dimensionality reduction was applied to the vector representations to transform the vectors into two-dimensional space for visualization.

Three common methods for dimensionality reduction were tested: principal component analysis (PCA), t-distributed stochastic neighbor embedding (t-SNE), and uniform manifold approximation and projection (UMAP) [73, 85, 86]. Compared to t-SNE and PCA, UMAP revealed more visually distinct clusters for the data presented in this work. Therefore, UMAP was chosen as the dimensionality reduction method. The random seed was fixed so that each initialization of the UMAP method produces the same results. For element embeddings, the rows of the EDMs with dimensions \((1, d_\mathrm {model})\) are transformed using UMAP. For the compound embeddings, the matrices corresponding to each compound were first collapsed as described above, and the resulting representations with dimensions \((1, d_\mathrm {model})\) for each compound were transformed using UMAP.

Oxidation State Estimation

Oxidation states for elements in the compounds were estimated using the Pymatgen package (version 2022.0.8) using the chemical formulae of the compounds. The built-in functions for assigning oxidation states were used, which are based on charge-balancing heuristics and use the most probable oxidation states as determined based on the compounds in the Inorganic Crystal Structure Database [74].

Attention Video Generation

Custom function hooks were programmed and attached to a newly-initialized CrabNet model. During training of CrabNet, the attention matrices of every CrabNet encoder layer was extracted from the model and saved into a compressed Zarr array on disk. The model predictions for the properties were also generated and saved. This procedure is performed after every mini-step during the training process (corresponding to each mini-batch of data). The plots were then generated for each mini-step and merged together using the software FFMPEG to create the attention videos. Due to the large amount of storage and computing power required to store and process the attention matrices, these tasks were performed on a high-performance computing cluster.