Skip to main content
Log in

Counterfactual based reinforcement learning for graph neural networks

  • Original Research
  • Published:
Annals of Operations Research Aims and scope Submit manuscript

Abstract

There have been many models made to achieve optimal results on classification tasks. We present a novel framework that is able to augment these models to achieve even higher levels of classification accuracy. Our framework is used in addition to and flexibly on top of other models and uses a reinforcement learning approach to learn and generate new difficult training data samples in order to further refine the classification model. By making new, harder, and more meaningful data samples our framework helps the model learn meaningful relationships in the data for its classification task. This allows our framework to augment models during training rather than working on pre-trained classifiers. Through our experimentation we show that our framework improves models’ classification accuracy. We also show the effectiveness of tuning our components through our ablation studies. Lastly, we discuss possible improvements to our framework and directions for future works.

This is a preview of subscription content, log in via an institution to check access.

Access this article

Price excludes VAT (USA)
Tax calculation will be finalised during checkout.

Instant access to the full article PDF.

Fig. 1
Fig. 2
Fig. 3
Fig. 4

Similar content being viewed by others

References

  • Atwood, J., & Towsley, D. (2016). Diffusion-convolutional neural networks. In Advances in neural information processing systems (pp. 1993–2001).

  • Bai, G., Liu, L., Sun, B., & Fang, J. (2015). A survey of user classification in social networks. In 6th IEEE International Conference on Software Engineering and Service Science (ICSESS), 2015 (pp. 1038–1041). https://doi.org/10.1109/ICSESS.2015.7339230.

  • Bhagat, S., Cormode, G., & Muthukrishnan, S. (2011). Node classification in social networks. In Social Network Data Analytics (pp. 115–148). Springer.

  • Bjerrum, E. J. (2017). SMILES enumeration as data augmentation for neural network modeling of molecules. arXiv preprint arXiv:1703.07076.

  • Cao, N. D., & Kipf, T. (2018). Molgan: An implicit generative model for small molecular graphs. arXiv preprint arXiv:1805.11973.

  • Chen, H., Huang, Z., Xu, Y., Deng, Z., Huang, F., He, P., & Li, Z. (2022). Neighbor enhanced graph convolutional networks for node classification and recommendation. Knowledge-Based Systems, 246, 108594.

    Article  Google Scholar 

  • Curtarolo, S., Hart, G., Nardelli, M., Mingo, N., Sanvito, S., & Levy, O. (2013). The high-throughput highway to computational materials design. Nature Materials, 12(3), 191–201.

    Article  Google Scholar 

  • De Lamare, R. C., & Sampaio-Neto, R. (2011). Adaptive reduced-rank equalization algorithms based on alternating optimization design techniques for mimo systems. IEEE Transactions on Vehicular Technology, 60(6), 2482–2494.

    Article  Google Scholar 

  • Dong, Z., Zhu, H., Cheng, P., Feng, X., Cai, G., He, X., Xu, J., & Wen, J. (2020). Counterfactual learning for recommender system. In Fourteenth ACM Conference on Recommender Systems (pp. 568–569).

  • Fonteneau, R., Murphy, S. A., Wehenkel, L., & Ernst, D. (2013). Batch mode reinforcement learning based on the synthesis of artificial trajectories. Annals of Operations Research, 208(1), 383–416.

    Article  Google Scholar 

  • Feng, F., Zhang, J., He, X., Zhang, H., & Chua, T. S. (2021). Empowering language understanding with counterfactual reasoning. In ACL-IJCNLP Findings.

  • Gan, M., & Gao, L. (2019). Discovering memory-based preferences for poi recommendation in location-based social networks. ISPRS International Journal of Geo-Information, 8(6), 279.

    Article  Google Scholar 

  • Goyal, Y., Wu, Z., Ernst, J., Batra, D., Parikh, D., & Lee, S. (2019). Counterfactual visual explanations. In International Conference on Machine Learning (pp. 2376–2384). PMLR.

  • Hamilton, W., Ying, Z., & Leskovec, J. (2017). Inductive representation learning on large graphs. Advances in Neural Information Processing Systems, 30.

  • He, X., Deng, K., Wang, X., Li, Y., Zhang, Y., & Wang, M. (2020). Lightgcn: Simplifying and powering graph convolution network for recommendation. In Proceedings of the 43rd International ACM SIGIR Conference on Research and Development in Information Retrieval (pp. 639–648).

  • Horie, N., Matsui, T., Moriyama, K., Mutoh, A., & Inuzuka, N. (2019). Multi-objective safe reinforcement learning: The relationship between multi-objective reinforcement learning and safe reincforcement learning. Artificial Life and Robotics, 24(3), 352–359.

    Article  Google Scholar 

  • Hong, T., Bai, H., Li, S., & Zhu, Z. (2016). An efficient algorithm for designing projection matrix in compressive sensing based on alternating optimization. Signal Processing, 125, 9–20.

    Article  Google Scholar 

  • Huang, R., Xia, M., Nguyen, D., Zhao, T., Sakamuru, S., Zhao, J., Shahane, S., Rossoshek, A., & Simeonov, A. (2016). Tox21challenge to build predictive models of nuclear receptor and stress response pathways as mediated by exposure to environmental chemicals and drugs. Frontiers in Environmental Science, 3, 85.

    Article  Google Scholar 

  • Jia, J., Wang, B., & Gong, N. (2017). Random walk based fake account detection in online social networks. In 2017 47th Annual IEEE/IFIP International Conference on Dependable Systems and Networks (DSN) (pp. 273–284). IEEE.

  • Jin, W., Barzilay, R., & Jaakkola, T. (2018). Junction tree variational autoencoder for molecular graph generation. In International Conference on Machine Learning (pp. 2323–2332). PMLR.

  • Khalilpourazari, S., & Doulabi, H. H. (2022). Designing a hybrid reinforcement learning based algorithm with application in prediction of the COVID-19 pandemic in Quebec. Annals of Operations Research, 312, 1261–1305.

    Article  Google Scholar 

  • Kwak, H., Lee, C., Park, H., & Moon, S. (2010). What is twitter, a social network or a news media?. In Proceedings of the 19th International Conference on World wide web. ACM.

  • Lee, J.B., Rossi, R., & Kong, X. (2018). Graph classification using structural attention. In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (pp. 1666–1674).

  • Li, Y., Tarlow, D., Brockschmidt, M., & Zemel, R. (2016). Gated graph sequence neural networks. In International Conference on Learning Representations (ICLR).

  • Liu, Q., & Dong, Y. (2022). Deep Feature Extraction Based on Dynamic Graph Convolutional Networks For Accelerated Hyperspectral Image Classification. ISPRS Annals of the Photogrammetry, Remote Sensing and Spatial Information Sciences, 3, 139–146.

    Article  Google Scholar 

  • Liu, S., Kailkhura, B., Loveland, D., & Han, Y. (2019).Generative counterfactual introspection for explainable deep learning. In 2019 IEEE Global Conference on Signal and Information Processing (GlobalSIP) (pp. 1–5). IEEE.

  • Mayr, M., Klambauer, G., Unterthiner, T., & Hochreiter, S. (2016). Deeptox: Toxicity prediction using deep learning. Frontiers in Environmental Science, 3, 80.

    Article  Google Scholar 

  • Maziarz, K., Jackson-Flux, H., Cameron, P., Sirockin, F., Schneider, N., & Stiefl, N., & Brockschmidt, M. (2021). Learning to extend molecular scaffolds with structural motifs. arXiv preprint arXiv:2103.03864.

  • McAuley, J., & Leskovec, J. (2012). Learning to discover social circles in ego networks. NIPS.

  • Numeroso, D., & Bacciu, D. (2020). Explaining deep graph networks with molecular counterfactuals. NeurIPS Workshop on Machine Learning for Molecules.

  • Pearl, J. (2009). Causality. Cambridge university Press.

  • Pope, P., Kolouri, S., Rostrami, M., Martin, C., & Hoffmann, H. (2018). Discovering molecular functional groups using graph convolutional neural networks. arXiv preprint arXiv:1812.00265.

  • Schneider, G., & Fechner, U. (2005). Computer-based de novo design of drug-like molecules. Nature Reviews Drug Discovery, 4(8), 649–663.

    Article  Google Scholar 

  • Simm, G., Pinsler, R., & Hernández-Lobato, J. M. (2020). Reinforcement learning for molecular design guided by quantum mechanics. In International Conference on Machine Learning (pp. 8959–8969). PMLR.

  • Sutton, R. S., & Barto, A. G. (2018). Reinforcement learning: An introduction. MIT Press.

  • Takahashi, T., Konishi, K., & Furukawa, T. (2012). Rank minimization approach to image inpainting using null space based alternating optimization. In 2012 19th IEEE International Conference on Image Processing (pp. 1717–1720). IEEE.

  • Veličković, P., Cucurull, G., Casanova, A., Romero, A., Lio, P., & Bengio, Y. (2017). Graph attention networks. arXiv preprint arXiv:1710.10903.

  • Wang, M., Zheng, D., Ye, Z., Gan, Q., Li, M., Song, X., Zhou, J., Ma, C., Yu, L., Gai, Y., & Xiao, T. (2019). Deep graph library: A graph-centric, highly-performant package for graph neural networks. arXiv preprint arXiv:1909.01315.

  • Weininger, D. (1988). Smiles, a chemical language and information system. 1. Introduction to methodology and encoding rules. Journal of Chemical Information and Computer Sciences, 28(1), 31–36.

    Article  Google Scholar 

  • Xinyi, Z., & Chen, L. (2018). Capsule graph neural network. In International Conference on Learning Representations.

  • Xu, S., Li, Y., Liu, S., Fu, Z., & Zhang, Y. (2020). Learning post-hoc causal explanations for recommendation. arXiv preprint arXiv:2006.16977.

  • Yan, C., Chen, Y., & Zhou, L. (2019). Differentiated fashion recommendation using knowledge graph and data augmentation. IEEE Access, 7, 102239–102248.

    Article  Google Scholar 

  • Ying, R., Bourgeois, D., You, J., Zitnik, M., & Leskovec, J. (2019). Gnn explainer: A tool for post-hoc explanation of graph neural networks. arXiv preprint arXiv:1903.03894.

  • Zhang, M., Cui, Z., Neumann, M., & Chen, Y. (2018). An end-to-end deep learning architecture for graph classification. In Thirty-Second AAAI Conference on Artificial Intelligence.

  • Zhou, A., Kearnes, S., Li, L., Zare, R., & Riley, P. (2019). Optimization of molecules via deep reinforcement learning. Scientific Reports, 9(1), 1–10.

    Google Scholar 

Download references

Author information

Authors and Affiliations

Authors

Corresponding author

Correspondence to David Pham.

Additional information

Publisher's Note

Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Appendix

Appendix

For ease of reference we list all key notations and symbols used in this paper.

G

Graph: A graph is made up of vertices and edges

V

Vertex: A node of the graph represented as a vector

E

Edge: An edge of the graph which can be directed or undirected

\(\phi \)

Deep Graph Network (DGN): A deep graph neural network that is used for classification. This is not model specific. Any DGN can be used in the CRL Framework

\(y=\phi (G)\)

DGN Result: The classification result of the DGN in the form of a vector

\(\psi )\)

Generator: The generator is used to create counterfactuals to augment the data and aid in the training of the DGN

\(G'\)

Counterfactual Graph: The output of the generator is a counterfactual G’ that is similar to the original graph G but has a different classification result

\({\mathcal {D}}\)

Domain Knowledge: Input constraints for the RL agent to constrict exploration and allow only realistic domain-compliant output

\({\mathcal {L}}\)

Predictive Disagreement: The measure of predicted disagreement in the classification of two graphs

\({\mathcal {K}}\)

Similarity Measure: The measure of similarity between two graphs

\(\mathcal {W(A)}\)

Weight Measure: The weight and corresponding probablity that the RL agent will take action \({\mathcal {A}}\)

\(\alpha \)

Hyperparameter Exploration: Tuning parameter that corresponds to the likelihood that the RL agent will explore more unknown and different counterfactuals or create similar counterfactuals that more closely relate to those in the dataset

e

Hyperparameter Epochs: Tuning parameter that controls the number of epochs that the DGN will train versus the number of epochs that the RL agent will train

Rights and permissions

Springer Nature or its licensor holds exclusive rights to this article under a publishing agreement with the author(s) or other rightsholder(s); author self-archiving of the accepted manuscript version of this article is solely governed by the terms of such publishing agreement and applicable law.

Reprints and permissions

About this article

Check for updates. Verify currency and authenticity via CrossMark

Cite this article

Pham, D., Zhang, Y. Counterfactual based reinforcement learning for graph neural networks. Ann Oper Res (2022). https://doi.org/10.1007/s10479-022-04978-9

Download citation

  • Accepted:

  • Published:

  • DOI: https://doi.org/10.1007/s10479-022-04978-9

Keywords

Navigation