- This repository contains the source code and documentation for the paper "Stochastic Weight Sharing for Bayesian Neural Networks" submitted to the AISTATS conference.
.
├── bnn/ # Bayesian Neural Network related code and files
├── checkpoint/ # Directory for storing checkpoints during training
├── checkpoint_eval/ # Directory for storing evaluation checkpoints
├── conf/ # Configuration files for experiments
├── data/ # Directory for datasets used in experiments
├── models/ # Model definitions and architecture files
├── networks/ # Contains different neural network architecture files
├── Visualization/ # Scripts and tools for data visualization
├── .gitignore # Git ignore file to specify untracked files
├── 2DGBNNs_train.py # Training script for 2D Gaussian Bayesian Neural Networks
├── exp_stochastic_nn.py # Experiment script for stochastic neural networks
├── gmm_train.py # Script to train Gaussian Mixture Models
├── init_gmm.py # Script to initialize Gaussian Mixture Models
├── kmeans_init.py # Script to initialize K-means clustering
├── predict_acc_nll_ece.py # Script to predict accuracy, negative log-likelihood, and expected calibration error
├── predict_acc_nll_ece_resnet18_cifar10.py # Prediction script for ResNet-18 on CIFAR-10
├── predict_acc_nll_ece_resnet18_cifar100.py # Prediction script for ResNet-18 on CIFAR-100
├── predict_acc_nll_ece_resnet18_imagenet.py # Prediction script for ResNet-18 on ImageNet
├── predict_acc_nll_ece_wrn_cifar100.py # Prediction script for Wide Residual Network on CIFAR-100
├── README.md # Markdown file with project overview and instructions
├── requirements.txt # Required libraries and dependencies for the project
├── run.bat # Batch file for running scripts on Windows
├── train_deterministic_network.py # Script for training deterministic neural networks
├── utils.py # Utility functions used across the project
└── visualize_stochastic_weights_scatter.py # Script to visualize weight scatter in stochastic models
To get started, clone the repository and install the required dependencies:
git clone [email protected]:gfhvbjk/Anonymous.git
cd Anonymouse
pip install -r requirements.txt
Note: we use pyro as a VI tool, therefore, pls install pyro as follows:
pip3 install pyro-ppl- CIFAR-10, CIFAR-100, and MINST will be downloaded automatically
- ImageNet1k needs to be put into the
datafolder- put the ImageNet1k as "data/train" and "data/val" (recommendation)
- or you can specify the specific the path of ImageNet by adding
-dataset_path = your path(not recommendation)
We provide the following Evaluation scripts in this repository.
- ImageNet1k by ResNet-18 script: predict_acc_nll_ece_resnet18_imagenet.py
- Download our pre-trained model here and put it under checkpoint/resnet18/imagenet/2DGBNNs/
-
python predict_acc_nll_ece_resnet18_imagenet.py
- Or you can put the pre-trained model anywhere by running
-
python predict_acc_nll_ece_resnet18_imagenet.py -weights_path your_path
- The outputs include
Outliers,Ellipses,Gaussians,Accuracy,NLL, andECE
| Metric | Value |
|---|---|
| Outliers | 23013 |
| Ellipses | 10885 |
| Gaussians | 2217 |
| Accuracy | 68.11 ± 0.03 |
| NLL | 1.250 ± 0.005 |
| ECE | 0.019 ± 0.003 |
- CIFAR-100 by ResNet-18 script: predict_acc_nll_ece_resnet18_cifar100.py
- Download our pre-trained model here and put it under checkpoint/resnet18/cifar100/2DGBNNs/
-
python predict_acc_nll_ece_resnet18_cifar100.py
- Or you can put the pre-trained model anywhere by running
-
python predict_acc_nll_ece_resnet18_cifar100.py -weights_path your_path
- The outputs include
Outliers,Ellipses,Gaussians,Accuracy,NLL, andECE
| Metric | Value |
|---|---|
| Outliers | 14624 |
| Ellipses | 260 |
| Gaussians | 2387 |
| Accuracy | 74.7 ± 0.1 |
| NLL | 1.049 ± 0.003 |
| ECE | 0.042± 0.003 |
- CIFAR-10 by ResNet-18 script: predict_acc_nll_ece_resnet18_cifar10.py
- Download our pre-trained model here and put it under checkpoint/resnet18/cifar10/2DGBNNs/
-
python predict_acc_nll_ece_resnet18_cifar10.py
- Or you can put the pre-trained model anywhere by running
-
python predict_acc_nll_ece_resnet18_cifar10.py -weights_path your_path
- The outputs include
Outliers,Ellipses,Gaussians,Accuracy,NLL, andECE
| Metric | Value |
|---|---|
| Outliers | 123310 |
| Ellipses | 57 |
| Gaussians | 1569 |
| Accuracy | 91.8 ± 0.1 |
| NLL | 0.313 ± 0.003 |
| ECE | 0.034 ± 0.003 |
During training, three stages in total, including weights init, stochastic network training, and finally the
2DGBNNstraining.
- We provide all stage scripts and instructions for training of WRN-28-10 in CIFAR-100 step by step. Also,
checkpoint(pre-trained weights)is provided if it exists and can extremely save you time by just putting them into the corresponding folder. - All pre-trained can be downloaded in this folder. here
- But we highly recommend you download them one by one following the steps.
Input:
- NN architecture
$f^{\mathbf{w}}$ - Training data
$\mathcal{D} = {(\mathbf{X}, \mathbf{y})}$ - Algorithm thresholds:
$\tau_w$ ,$\tau_d$ ,$\tau_g$ ,$\tau_v$ - BNN prior
$p(\mathbf{w})$
Output:
- Stochastic weight-sharing trained BNN
- Train deterministic neural network
Run the following script:
python train_deterministic_network.py -net wrn -dataset cifar100You can download our pre-trained model here And put it under "checkpoint/wrn/cifar100/origin"
-
Initialize GMM Parameters
Initialize$\mu$ ,$\sigma$ according to$p(\mathbf{w})$ . -
Pre-training Loop
For each epoch:- Sample weights:
$$\mathbf{w} = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, \mathbf{I})$$ - Update
$\mu$ ,$\sigma$ by training on$\mathcal{D}$ .
- Sample weights:
-
Identify Outliers
For each weight$w_i$ :- If
$|w_i| > \tau_w$ or$|\nabla_{w_i}|$ is in the top$1%$ , then$w_i$ is an outlier. - Else,
$w_i$ is an inlier.
- If
-
Learn GMM on Inliers
Learn GMM on inlier parameters$\Theta_{in}$ .
Pick a deterministic pre-trained model and get its path
Run the following script:
python exp_stochastic_nn.py -net wrn -dataset cifar100 -weights_path "checkpoint/wrn/origin/wrn.pth" (pls it replace by your path)This pre-trained model can downloaded in here
Now, we need to initialize the GMM, pls pick a stochastic neural network and get its path
Run the following script:
Note that: we need to install cuml and cudf for GMM weights initiation (must installation) Install them by following:
pip install cudf-cu11 cuml-cu11Then run the following bash command:
python init_gmm.py -weights_path "checkpoint/wrn/cifar100/stochastic/wrn_stochastic.pth" (pls replace it by your path)This pre-trained model can download here
The initiation gmm configuration will be stored under "checkpoint/wrn/cifar100/init_gmm"
-
Mahalanobis Distance Check
For each inlier weight$w_i$ :- Perform Mahalanobis distance check.
- If
$w_i$ is outside the 95th percentile:- Assign
$w_i$ to multiple clusters.
- Assign
- Else:
- Assign
$w_i$ to the closest Gaussian.
- Assign
-
Alpha-Blending for Ellipse Points
Apply alpha-blending for ellipse points. -
Merge Gaussians
Repeat until no more Gaussians can be merged:- For each pair
$(\mathcal{N}_1, \mathcal{N}_2)$ in GMM:- If
$W(\mathcal{N}_1, \mathcal{N}_2) < \tau_d$ ,$\Delta_g < \tau_g$ , and$\Delta_v < \tau_v$ :- Merge
$\mathcal{N}_1$ and$\mathcal{N}_2$ .
- Merge
- If
- For each pair
-
Final Training Loop
For each epoch:- For each weight
$w_i$ :- If
$w_i$ is an inlier:- Sample
$w_i \sim \sum_{k=1}^{K} \pi_k \mathcal{N}(\mu_k, \Sigma_k)$ .
- Sample
- Else:
- Use
$w_i \sim \mathcal{N}(\mu_{w_i}, \sigma^2_{w_i})$ .
- Use
- If
- Perform minimizing step for
$\hat{\mathcal{L}}(\mathcal{D}, q)$ .
- For each weight
- To train the model, run the following command:
python 2DGBNNs_train.py -weights_stochastic_path "checkpoint/wrn/cifar100/stochastic/stochastic_wrn.pth" (pls replace it by your path) \
-weights_path_origin "checkpoint/wrn/cifar100/origin/wrn.pth" (pls replace it by your path) \
-init_gmm "checkpoint/wrn/cifar100/init_gmm/clusters_gpu.json" (pls replace it by your path) \
-output "checkpoint/wrn/cifar100/2DGBNNs"This pre-trained model can download here
This project is licensed under the MIT License. See the `LICENSE` file for details.