Skip to main content
Log in

Two-Stage Training of Graph Neural Networks for Graph Classification

  • Published:
Neural Processing Letters Aims and scope Submit manuscript

Abstract

Graph neural networks (GNNs) have received massive attention in the field of machine learning on graphs. Inspired by the success of neural networks, a line of research has been conducted to train GNNs to deal with various tasks, such as node classification, graph classification, and link prediction. In this work, our task of interest is graph classification. Several GNN models have been proposed and shown great accuracy in this task. However, the question is whether usual training methods fully realize the capacity of the GNN models. In this work, we propose a two-stage training framework based on triplet loss. In the first stage, GNN is trained to map each graph to a Euclidean-space vector so that graphs of the same class are close while those of different classes are mapped far apart. Once graphs are well-separated based on labels, a classifier is trained to distinguish between different classes. This method is generic in the sense that it is compatible with any GNN model. By adapting five GNN models to our method, we demonstrate the consistent improvement in accuracy and utilization of each GNN’s allocated capacity over the original training method of each model up to \(5.4\%\) points in 12 datasets.

This is a preview of subscription content, log in via an institution to check access.

Access this article

Price excludes VAT (USA)
Tax calculation will be finalised during checkout.

Instant access to the full article PDF.

Fig. 1
Fig. 2
Fig. 3
Fig. 4
Fig. 5
Fig. 6

Similar content being viewed by others

Data Availibility

The datasets used in the study are available at http://github.com/manhtuando97/two-stage-gnn.

Code Availability

The source code used in the study is released as open source under the MIT license at http://github.com/manhtuando97/two-stage-gnn.

Notes

  1. https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page.

  2. For example, when the accuracy of the original setting is \(63\%\) and that of our method is \(67\%\), our method improves the accuracy by \(4\%\) points. The accuracy improvement is \(67\% - 63\% = 4\%\), so, with respect to the original setting’s accuracy (which is \(63\%\)), this is an improvement of: \(4/63 \approx 6.3\%\). Since our focus is on the absolute accuracy of classification, we claim “\(4\%\) points” instead of “\(6.3\%\)”.

References

  1. Chechik G, Sharma V, Shalit U, Bengio S (2010) Large Scale Online Learning of Image Similarity Through Ranking. Journal of Machine Learning Research 11–3:1109–1135

    MathSciNet  MATH  Google Scholar 

  2. Chen D, Yankai L, Wei L, Peng L, Jie Z, Xu S (2020) Measuring and Relieving the Over-Smoothing Problem for Graph Neural Networks from the Topological View. Proceedings of the AAAI Conference on Artificial Intelligence

  3. Ching T, Himmelstein DS, Beaulieu-Jones BK, Kalinin AA, Do BT, Way GP, Ferrero E, Agapow PM, Zietz M, Hoffman MM et al (2018) Opportunities and obstacles for deep learning in biology and medicine. Journal of The Royal Society Interface 15(141):20170387 (The Royal Society)

    Article  Google Scholar 

  4. Dai H, Dai B, Song L (2016) Discriminative embeddings of latent variable models for structured data. Proceedings of the International Conference on Machine Learning

  5. Defferrard M, Bresson X, Vandergheynst P (2016) Convolutional neural networks on graphs with fast localized spectral filtering. Advances in Neural Information Processing Systems 29:3844–3852

    Google Scholar 

  6. Duvenaud DK, Maclaurin D, Iparraguirre J, Bombarell R, Hirzel T, Aspuru-Guzik A, Adams RP (2015) Convolutional networks on graphs for learning molecular fingerprints. Advances in Neural Information Processing Systems 28:2224–2232

    Google Scholar 

  7. Gao H, Ji S (2019) Graph U-Nets. Proceedings of the International Conference on Machine Learning

  8. Hajebi K, Abbasi-Yadkori Y, Shahbazi H, Zhang H (2011) Fast approximate nearest-neighbor search with k-nearest neighbor graph. Proceedings of the International Joint Conference on Artificial Intelligence

  9. Hamilton W, Ying Z, Leskovec J (2017) Inductive representation learning on large graphs. Advances in Neural Information Processing Systems 30:1025–1035

    Google Scholar 

  10. Hu W, Liu B, Gomes J, Zitnik M, Liang P, Pande V, Leskovec J (2020) Strategies for Pre-training Graph Neural Networks. Proceedings of the International Conference on Learning Representations

  11. Hwang H, Lee S, Park C, Shin K (2022) AHP: Learning to Negative Sample for Hyperedge Prediction. Proceedings of the International ACM SIGIR Conference on Research and Development in Information Retrieval

  12. Kipf TN, Welling M (2017) Semi-supervised classification with graph convolutional networks. Proceedings of the International Conference on Learning Represetations

  13. Ko J, Lee K, Shin K, Park N (2020) MONSTOR: An Inductive Approach for Estimating and Maximizing Influence over Unseen Networks. Proceedings of the IEEE/ACM International Conference on Advances in Social Networks Analysis and Mining

  14. Ktena SI, Parisot S, Ferrante E, Rajchl M, Lee M, Glocker B, Rueckert D (2018) Metric learning with spectral graph convolutions on brain connectivity networks. NeuroImage 169:432–442 (Elsevier)

    Article  Google Scholar 

  15. Lee J, Lee I, Kang J (2019) Self-Attention Graph Pooling. Proceedings of the International Conference on Machine Learning

  16. Ling X, Wu L, Wang S, Ma T, Xu F, Liu AX, Wu C, Ji S (2020) Hierarchical graph matching networks for deep graph similarity learning. arXiv preprint arXiv:2007.04395

  17. Liu J, Ma G, Jiang F, Lu CT, Philip SY, Ragin AB (2019) Community-preserving graph convolutions for structural and functional joint embedding of brain networks. Proceedings of the IEEE International Conference on Big Data

  18. Lu Y, Jiang X, Fang Y, Shi C (2021) Learning to pre-train graph neural networks. Proceedings of the AAAI Conference on Artificial Intelligence

  19. Ma Y, Wang S, Aggarwal CC, Tang J (2019) Graph Convolutional Networks with EigenPooling. Proceedings of the International Conference on Knowledge Discovery & Data Mining

  20. Morris C, Kriege NM, Bause F, Kersting K, Mutzel P, Neumann M (2020) TUdataset: A collection of benchmark datasets for learning with graphs. arXiv preprint arXiv:2007.08663

  21. New York City Taxi and Limousine Corporation - Trip Record Data (2019). http://www.nyc.gov/html/tlc/html/about/trip_record_data.shtml

  22. Niepert M, Ahmed M, Kutzkov K (2016) Learning convolutional neural networks for graphs. Proceedings of the International Conference on Machine Learning

  23. Park Y, Hwang H, Lee SG (2015) A Fast k-Nearest Neighbor Search Using Query-Specific Signature Selection. Proceedings of the International on Conference on Information and Knowledge Management

  24. Qi S, Wang W, Jia B, Shen J, Zhu SC (2018) Learning Human-Object Interactions by Graph Parsing Neural Networks. Proceedings of the European Conference on Computer Vision

  25. Schlichtkrull M, Kipf TN, Bloem P, Van Den Berg R, Titov I, Welling M (2018) Modeling relational data with graph convolutional networks. Proceedings of the European Semantic Web Conference

  26. Schroff F, Kalenichenko D, Philbin J (2015) Facenet: A unified embedding for face recognition and clustering. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition

  27. Shi W, Rajkumar R (2020) Point-GNN: Graph neural network for 3d object detection in a point cloud. Proceedings of the IEEE/CVF conference on Computer Vision and Pattern Recognition

  28. Veličković P, Cucurull G, Casanova A, Romero A, Liò P, Bengio Y (2018) Graph Attention Networks. Proceedings of the International Conference on Learning Representations

  29. Wang J, Agarwal D, Huang M, Hu G, Zhou Z, Ye C, Zhang NR (2019) Data denoising with transfer learning in single-cell transcriptomics. Nature methods 16(6):875–878 (Nature Publishing Group)

    Article  Google Scholar 

  30. Xie G, Liu J, Xiong H, Shao L (2021) Scale-Aware Graph Neural Network for Few-Shot Semantic Segmentation. Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition

  31. Ying Z, You J, Morris C, Ren X, Hamilton W, Leskovec J (2018) Hierarchical Graph Representation Learning with Differentiable Pooling. Advances in Neural Information Processing Systems 31:4805–4815

  32. Zhang M, Cui Z, Neumann M, Chen Y (2018) An end-to-end deep learning architecture for graph classification. Proceedings of the AAAI Conference on Artificial Intelligence

Download references

Funding

This work was supported by Agency for Defense Development (ADD) (No. UI2100072D, Technique Analysis and Model Prototyping for the Elements Identification of Enemy Behavior and Threat) as a part of AI - Command Decision Support for Future Ground Operations (AICDS), Institute of Information & Communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) (No. 2019-0-00075, Artificial Intelligence Graduate School Program (KAIST) and No. 2020-0-01361, Artificial Intelligence Graduate School Program (Yonsei University)), and the Yonsei University Research Fund of 2021.

Author information

Authors and Affiliations

Authors

Contributions

MD: Methodology, Validation, Investigation, Software, Writing - Original Draft, Visualization. NP: Conceptualization, Writing - Review & Editing. KS: Supervision, Conceptualization, Funding acquisition, Writing - Review & Editing.

Corresponding author

Correspondence to Kijung Shin.

Ethics declarations

Conflicts of interest

None.

Additional information

Publisher's Note

Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Appendix: Effect of Classification Layers

Appendix: Effect of Classification Layers

Table 12 Classification accuracy when the classifier consists of one layer and when the classifier is tuned using up to three layers on the benchmark datasets. 2STG \(=1\) and 2STG \(\le 3\) represent 2STG with one classification layer and up to three classification layers (as detailed in Sect. 4.1.4), respectively.
Table 13 Classification accuracy when the classifier consists of one layer and when the classifier is tuned using up to three layers on the New York City Taxi datasets. 2STG \(=1\) and 2STG \(\le 3\) represent 2STG with one classification layer and up to three classification layers (as detailed in Sect. 4.1.4), respectively.

We report in Tables 12 and 13 the classification accuracy when the classifier only consists of one layer and when the classifier is tuned using up to three layers (as described in Sect. 4.1.4). In some cases, the 1-layer classifier could achieve an accuracy that is close or slightly higher than that of the classifier using up to three layers. However, tuning the model using up to three classification layers allows us to achieve better accuracy than using only one layer for the classifier in most cases. The few exceptions are highlighted in bold italic in Tables 12 and 13. These results indicate that using a strong classifier is generally helpful in enhancing the classification accuracy. We also visualize in Fig. 7 the cases in which the differences between using one classification layer and using up to three classification layers are the highest. In particular, we highlight the t-SNE visualization of the embeddings generated by GraphSage 2STG+ for the datasets PTC-FM and JAN. G. when using one classification layer and up to three classification layers, respectively.

Fig. 7
figure 7

Graph embeddings generated by GraphSAGE for the datasets PTC-FM and JAN. G. in the two cases of 2STG+ with one classification layer (left) and 2STG+ with up to three classification layers (right), respectively. These are the graphs embeddings (i.e., \({\hat{f}}(g_i)\) of each graph \(g_i\)), which are the input of the classification part in Fig. 2. T-SNE is applied to reduce the embedding dimension from 64 to 2 for visualization. Due to the capacity difference of the classifiers (one layer vs. up to three layers), there is a gap in the final classification accuracies of the two cases reported in Tables 12 and 13

Rights and permissions

Springer Nature or its licensor holds exclusive rights to this article under a publishing agreement with the author(s) or other rightsholder(s); author self-archiving of the accepted manuscript version of this article is solely governed by the terms of such publishing agreement and applicable law.

Reprints and permissions

About this article

Check for updates. Verify currency and authenticity via CrossMark

Cite this article

Do, M.T., Park, N. & Shin, K. Two-Stage Training of Graph Neural Networks for Graph Classification. Neural Process Lett 55, 2799–2823 (2023). https://doi.org/10.1007/s11063-022-10985-5

Download citation

  • Accepted:

  • Published:

  • Issue Date:

  • DOI: https://doi.org/10.1007/s11063-022-10985-5

Keywords

Navigation