CasReg is a deep learning model based on cascaded networks, that produce small amounts of displacement to warp progressively the moving image towards the fixed image.
The trained registration model can then be used to perform Multi-Atlas Segmentation (MAS) : multiple annotated images and their labels are registered with the image to segment, the resulting warped labels are then combined to form a refined segmentation.
This repository includes:
- A preprocessing script, that convert Nifti images (and optionally labels) to .npz format. The images are then cropped, resized and normalized.
- Training and testing script for cascaded registration.
- A multi-atlas segmentation script using the trained weights of the cascaded registration.
For more information about CasReg, please read the following paper:
Unsupervised fetal brain MR segmentation using multi-atlas deep learning registration
Valentin Comte1, Mireia Alenya1, Andrea Urru1, Ayako Nakaki2, Francesca Crovetto2, Oscar Camara1, Elisenda Eixarch2, Fàtima Crispi2, Gemma Piella Fenoy1, Mario Ceresa1, and Miguel A. González Ballester1,3
1 BCN MedTech, Department of Information and Communication Technologies, Universitat Pompeu
Fabra, Barcelona, Spain
2 Maternal Fetal Medicine, BCNatal, Center for Maternal Fetal and Neonatal Medicine (Hospital Clínic
and Hospital Sant Joan de Déu), Barcelona, Spain
3 ICREA, Barcelona, Spain
CasReg requires a GPU for training and inference (it should have at least 10GB of memory).
CasReg uses th following packages:
- pytorch-gpu 1.9.0
- torchvision 0.2.2
- torchsummary 1.5.1
- cudatoolkit 10.2
- numpy 1.21
- nibabel 4.0.1
- simpleitk 2.1.1
- scipy 1.7.3
We highly recommend to use a Conda environment to install the package dependencies. To install all the required packages, run:
conda env create -f casreg_env.yml
To run the preprocessing, use the following command:
python preprocessing.py --img_path /path/to/nifti/images/folder/ --label_path /path/to/nifti/labels/folder/ --prep_dir /path/to/preprocessed/folder/ --img_size 128 128 128
img_size: defines the new shape of the input images (and labels), default is (128,128,128).
label_path (optional): if given, the images will be cropped around the labelled regions, unless --crop_zero is used.
crop_zero (optional): maximal value of the zero regions, the image will be cropped to remove them.
To run the training, use the following command:
python train.py --save_dir /path/to/the/weights/folder/ --npz_dir /path/to/preprocessed/folder/ --nb_labels 8 --nb_cascades 5 --contracted 1
save_dir: the weights will be saved there.
nb_labels: is the number of labels of the segmentation used for validation (optional).
img_size: defines the new shape of the input images (and labels), default is (128,128,128).
contracted: 0 for original architecture (uses more memory) / 1 for the contracted architecture.
