RISE: Radius of Influence based Subgraph Extraction for 3D Molecular Graph Explanation
Jingxiang Qu† , Wenhan Gao† , Jiaxing Zhang, Xufeng Liu, Hua Wei, Haibin Ling, Yi Liu*
(* Corresponding Author)
(† Equal Contribution)
ICML 2025
Introduction: 3D Geometric Graph Neural Networks (GNNs) have emerged as transformative tools for modeling molecular data. Despite their predictive power, these models often suffer from limited interpretability, raising concerns for scientific applications that require reliable and transparent insights. While existing methods have primarily focused on explaining molecular substructures in 2D GNNs, the transition to 3D GNNs introduces unique challenges, such as handling the implicit dense edge structures created by a cutoff radius. To tackle this, we introduce a novel explanation method specifically designed for 3D GNNs, which localizes the explanation to the immediate neighborhood of each node within the 3D space. Each node is assigned an radius of influence, defining the localized region within which message passing captures spatial and structural interactions crucial for the model's predictions. This method leverages the spatial and geometric characteristics inherent in 3D graphs. By constraining the subgraph to a localized radius of influence, the approach not only enhances interpretability but also aligns with the physical and structural dependencies typical of 3D graph applications, such as molecular learning.
Before running the experiment, please remove/comment the line 250 in torch_geometric.explain.explainer.
Because RISE doesn't need to set any threshold to the edge_mask, we don't need to validate the size of mask.
To run the QM9 experiments, adapt explained_model_name, target_attr, epoch, budget, and checkpoint.
(If running on SchNet or DimeNet, the 'checkpoint' can be ignored.)
If you want to test the explainer on SEGNN, you need to train the SEGNN and saved the model_stat_dict firstly.
The official version of SEGNN can be find here.
It is noted that the chemical properties in QM9 dataset are encoded follows 'target_attr' dict:
{0: 'mu', 1: 'alpha', 2: 'homo', 3: 'lumo', 4: 'gap',
5: 'electronic_spatial_extent', 6: 'zpve', 7: 'energy_U0',
8: 'energy_U', 9: 'enthalpy_H', 10: 'free_energy', 11: 'heat_capacity'}.
python main_qm9.py --explained_model_name=SchNet --target_attr=0 --epoch=200 --budget=0.5
To run the GEOM experiments, adapt Explained_model_name, epoch, budget, and checkpoint. (Please train the backbone model on GEOM Dataset First.)
python main_geom.py --explained_model_name=SchNet --epoch=200 --budget=0.5 --checkpoint='your_checkpoint_path.pt'