Skip to content
/ GDML Public

Graph Machine Learning based Doubly Robust Estimator for Network Causal Effects

License

Notifications You must be signed in to change notification settings

BaharanKh/GDML

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GDML: Graph Machine Learning based Doubly Robust Estimator for Network Causal Effects

This repository contains code for paper Graph Machine Learning based Doubly Robust Estimator for Network Causal Effects, published at AISTATS'25.

Overview

We propose a novel methodology for estimating causal effects in social networks under interference and network-induced confounding. Our framework combines double machine learning framework with graph machine learning techniques, enabling accurate and efficient estimation of both direct and peer effects within a single observational social network. The proposed estimator achieves semiparametric efficiency under mild regularity conditions, ensuring consistent uncertainty quantification.

Project Structure

network-causal-inference/
├── README.md                 # This file
├── requirements.txt          # Package dependencies
├── src/                      # Source code
│   ├── __init__.py           # Package initialization
│   ├── data_generators.py    # Data generating processes (DGPs)
│   ├── data_processors.py    # Dataset loading/processing utilities
│   ├── dataset.py            # Custom PyTorch Geometric dataset
│   ├── models.py             # GNN model implementations
│   ├── experiment.py         # Experiment runner
│   ├── analysis.py           # Result analysis functions
│   └── main.py               # Main script
├── data/                     # Dataset directory (create this directory and place your data here)
├── results/                  # Experimental results (CSV)
└── plots/                    # Generated visualizations

Installation

  1. Clone this repository:
git clone https://github.com/username/network-causal-inference.git
cd network-causal-inference
  1. Create a virtual environment (optional but recommended):
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate
  1. Install dependencies:
pip install -r requirements.txt

Usage

First, place your data — including the adjacency matrix and node features — in the data folder, following the project structure outlined above. If you plan to use the data generative processes in our framework, you can find them in the data_generative_process.py file or add your own there.

To run an experiment with default settings:

python -m src.main

This will:

  1. Load or generate the specified dataset
  2. Run the specified number of trials
  3. Train GNN models and estimate causal effects
  4. Save results and generate plots

Command-line Arguments

The following arguments can be used to customize the experiment:

--trials INT         Number of trials (default: 100)
--dataset STRING     Dataset name: 'cora', 'pubmed', 'SBM' (default: 'cora')
--model STRING       GNN model: 'GCN', 'GraphSage', 'GIN' (default: 'GIN')
--gamma FLOAT        Network effect strength (default: 0.25)
--theta INT             True direct treatment effect (default: 10)
--alpha INT           True peer effect (default: 5)
--K INT              Number of folds for cross-validation (default: 3)
--device INT         GPU device ID (default: 0)
--seed INT           Random seed (default: 0)
--save_plots         Save plots to disk

Example:

python -m src.main --dataset cora --model GIN --trials 50 --gamma 0.5 --save_plots

Data Generating Processes

The code offers two data generative processes for simulating individual features within a network: one with non-linearity and one without. You can use either of these DGPs or add your own in the data_generative_process.py file.

Models

The implementation supports three Graph Neural Network architectures:

  1. GIN: Graph Isomorphism Network
  2. GCN: Graph Convolutional Network
  3. GraphSAGE: Graph Sample and Aggregate

These models can be used for both treatment and outcome predictions.

Datasets

The framework supports several standard network datasets:

  • Cora: Citation network of computer science papers
  • Pubmed: Citation network of medical papers
  • Flickr: Network of images shared on Flickr
  • SBM: Synthetic network from a Stochastic Block Model
  • Indian Village: Survey data from villages in Karnataka, India used for investigating the impact of Self-Help Group participation on financial risk tolerance through outstanding loan

Contact

Seyedeh Baharan Khatami skhatami@ucsd.edu

Bibtex

If you use this work in your research, please cite the following:

@inproceedings{khatamigraph,
  title={Graph Machine Learning based Doubly Robust Estimator for Network Causal Effects},
  author={Khatami, Seyedeh Baharan and Parikh, Harsh and Chen, Haowei and Roy, Sudeepa and Salimi, Babak},
  booktitle={The 28th International Conference on Artificial Intelligence and Statistics}
}

License

MIT License

About

Graph Machine Learning based Doubly Robust Estimator for Network Causal Effects

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published