Implementation of the paper Efficient Trajectory Inference in Wasserstein Space Using Consecutive Averaging (AISTATS 2025).
Authors: Amartya Banerjee, Harlin Lee, Nir Sharon, Caroline Moosmüller
Capturing data from dynamic processes through cross-sectional measurements is seen in many fields such as computational biology. Trajectory inference deals with the challenge of reconstructing continuous processes from such observations. In this work, we propose methods for B-spline approximation and interpolation of point clouds through consecutive averaging that is instrinsic to the Wasserstein space. Combining subdivision schemes with optimal transport-based geodesic, our methods carry out trajectory inference at a chosen level of precision and smoothness, and can automatically handle scenarios where particles undergo division over time. The implementation supports both B-spline and 4-point subdivision schemes for trajectory inference.
You can install the package using either conda or pip:
# Create and activate conda environment
conda create -n wti python=3.10
conda activate wti
# Clone the repository
git clone https://github.com/yourusername/Wasserstein-Trajectory-Inference.git
cd Wasserstein-Trajectory-Inference
# Install basic dependencies
conda install numpy scipy matplotlib pandas seaborn jupyter ipykernel scikit-learn
pip install pot>=0.9.3
# Set up Jupyter kernel
python -m ipykernel install --user --name=wti# Create and activate virtual environment
python -m venv wti-env
source wti-env/bin/activate # On Windows, use: wti-env\Scripts\activate
# Clone the repository
git clone https://github.com/yourusername/Wasserstein-Trajectory-Inference.git
cd Wasserstein-Trajectory-Inference
# Install dependencies
pip install -r requirements.txt
# Set up Jupyter kernel
python -m ipykernel install --user --name=wti-envThe repository is organized as follows:
Wasserstein-Trajectory-Inference/
├── Data/ # Dataset directory
├── gifs/ # demo gif(s)
├── data_gen.py # Dataset generation utilities
├── pc_traj_inference.py # Main trajectory inference implementation
├── utils.py # Plotting and helper functions
├── wlr-demo.ipynb # Demo notebook with examples
├── requirements.txt # Python dependencies
├── MIT-LICENSE.txt # License file
└── README.md # This file
import numpy as np
from data_gen import get_dataset
from pc_traj_inference import PointCloudTrajInference
from utils import plot_figure
# Generate example data
data = get_dataset('diverging_gaussians')
# Initialize trajectory inference
pci = PointCloudTrajInference(
initial_data=data,
degree=2,
refinement_levels=8, # Using 8 levels for better results
method='OT',
balanced=True
)
# Perform trajectory inference
pci.traj_inference(traj_inference_type='Bspline')
# Visualize results
plot_figure(pci.pc_out, data, plot_initial_data=True, title='Trajectory Inference')- Point cloud trajectory inference using optimal transport
- Multiple interpolation schemes:
- Lane-Riesenfeld (B-spline) subdivision
- 4-point subdivision scheme
- Support for point cloud data with both uniform and non-uniform mass
Main class for performing trajectory inference:
pci = PointCloudTrajInference(
initial_data, # List of point clouds (positions, weights)
degree=2, # Polynomial degree for B-spline
refinement_levels=8, # Number of subdivision steps (typically 5-8)
method='OT', # Transport method ('OT')
balanced=True # Whether to reorder points across time
)The get_dataset function provides several predefined datasets:
# Available dataset options
datasets = [
'diverging_gaussians', # Multiple Gaussian clusters that diverge
'weighted_gaussian', # Gaussian with non-uniform weights
'converging_gaussian', # Converging Gaussian clusters
'diamonds', # Diamond-shaped (Petal) pattern
'dyngen_tree', # Tree-like branching trajectories
'dyngen_cycle', # Cyclic trajectory data
'citeseq', # Processed CITE-seq single-cell data
'supercells', # Processed supercell data
'circular_gaussian' # Circular arrangement of Gaussians
]
data = get_dataset('diverging_gaussians', seed=0) # Optional seed for reproducibilityThe plot_figure function in utils.py provides visualization capabilities:
plot_figure(
PC, # List of point clouds
initial_data, # Original data for comparison
plot_initial_data=True,
figsize=(15, 10),
alpha=0.5,
title='',
tick_threshold=100,
weighted=True # Scale marker sizes by weights
)The repository includes a Jupyter notebook wti-demo.ipynb that demonstrates:
- Loading and generating different datasets
- Applying trajectory inference with various parameters
- Visualizing results and comparing different methods
If you use this code in your research, please cite:
@inproceedings{banerjee2025efficient,
title={Efficient Trajectory Inference in Wasserstein Space Using Consecutive Averaging},
author={Banerjee, Amartya and Lee, Harlin and Sharon, Nir and Moosm{\"u}ller, Caroline},
booktitle={International Conference on Artificial Intelligence and Statistics (AISTATS)},
year={2025},
url={https://arxiv.org/abs/2405.19679}
}This project is licensed under the MIT License - see the MIT-LICENSE.txt file for details.
