This repository contains code for paper Graph Machine Learning based Doubly Robust Estimator for Network Causal Effects, published at AISTATS'25.
We propose a novel methodology for estimating causal effects in social networks under interference and network-induced confounding. Our framework combines double machine learning framework with graph machine learning techniques, enabling accurate and efficient estimation of both direct and peer effects within a single observational social network. The proposed estimator achieves semiparametric efficiency under mild regularity conditions, ensuring consistent uncertainty quantification.
network-causal-inference/
├── README.md # This file
├── requirements.txt # Package dependencies
├── src/ # Source code
│ ├── __init__.py # Package initialization
│ ├── data_generators.py # Data generating processes (DGPs)
│ ├── data_processors.py # Dataset loading/processing utilities
│ ├── dataset.py # Custom PyTorch Geometric dataset
│ ├── models.py # GNN model implementations
│ ├── experiment.py # Experiment runner
│ ├── analysis.py # Result analysis functions
│ └── main.py # Main script
├── data/ # Dataset directory (create this directory and place your data here)
├── results/ # Experimental results (CSV)
└── plots/ # Generated visualizations
- Clone this repository:
git clone https://github.com/username/network-causal-inference.git
cd network-causal-inference- Create a virtual environment (optional but recommended):
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate- Install dependencies:
pip install -r requirements.txtFirst, place your data — including the adjacency matrix and node features — in the data folder, following the project structure outlined above. If you plan to use the data generative processes in our framework, you can find them in the data_generative_process.py file or add your own there.
To run an experiment with default settings:
python -m src.mainThis will:
- Load or generate the specified dataset
- Run the specified number of trials
- Train GNN models and estimate causal effects
- Save results and generate plots
The following arguments can be used to customize the experiment:
--trials INT Number of trials (default: 100)
--dataset STRING Dataset name: 'cora', 'pubmed', 'SBM' (default: 'cora')
--model STRING GNN model: 'GCN', 'GraphSage', 'GIN' (default: 'GIN')
--gamma FLOAT Network effect strength (default: 0.25)
--theta INT True direct treatment effect (default: 10)
--alpha INT True peer effect (default: 5)
--K INT Number of folds for cross-validation (default: 3)
--device INT GPU device ID (default: 0)
--seed INT Random seed (default: 0)
--save_plots Save plots to disk
Example:
python -m src.main --dataset cora --model GIN --trials 50 --gamma 0.5 --save_plotsThe code offers two data generative processes for simulating individual features within a network: one with non-linearity and one without. You can use either of these DGPs or add your own in the data_generative_process.py file.
The implementation supports three Graph Neural Network architectures:
- GIN: Graph Isomorphism Network
- GCN: Graph Convolutional Network
- GraphSAGE: Graph Sample and Aggregate
These models can be used for both treatment and outcome predictions.
The framework supports several standard network datasets:
- Cora: Citation network of computer science papers
- Pubmed: Citation network of medical papers
- Flickr: Network of images shared on Flickr
- SBM: Synthetic network from a Stochastic Block Model
- Indian Village: Survey data from villages in Karnataka, India used for investigating the impact of Self-Help Group participation on financial risk tolerance through outstanding loan
Seyedeh Baharan Khatami skhatami@ucsd.edu
If you use this work in your research, please cite the following:
@inproceedings{khatamigraph,
title={Graph Machine Learning based Doubly Robust Estimator for Network Causal Effects},
author={Khatami, Seyedeh Baharan and Parikh, Harsh and Chen, Haowei and Roy, Sudeepa and Salimi, Babak},
booktitle={The 28th International Conference on Artificial Intelligence and Statistics}
}