Skip to content

moulelin/2DGBNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Stochastic Weight Sharing for Bayesian Neural Networks


Overview

  • This repository contains the source code and documentation for the paper "Stochastic Weight Sharing for Bayesian Neural Networks" submitted to the AISTATS conference.

Repository Structure

.
├── bnn/                                  # Bayesian Neural Network related code and files
├── checkpoint/                           # Directory for storing checkpoints during training
├── checkpoint_eval/                      # Directory for storing evaluation checkpoints
├── conf/                                 # Configuration files for experiments
├── data/                                 # Directory for datasets used in experiments
├── models/                               # Model definitions and architecture files
├── networks/                             # Contains different neural network architecture files
├── Visualization/                        # Scripts and tools for data visualization
├── .gitignore                            # Git ignore file to specify untracked files
├── 2DGBNNs_train.py                      # Training script for 2D Gaussian Bayesian Neural Networks
├── exp_stochastic_nn.py                  # Experiment script for stochastic neural networks
├── gmm_train.py                          # Script to train Gaussian Mixture Models
├── init_gmm.py                           # Script to initialize Gaussian Mixture Models
├── kmeans_init.py                        # Script to initialize K-means clustering
├── predict_acc_nll_ece.py                # Script to predict accuracy, negative log-likelihood, and expected calibration error
├── predict_acc_nll_ece_resnet18_cifar10.py  # Prediction script for ResNet-18 on CIFAR-10
├── predict_acc_nll_ece_resnet18_cifar100.py # Prediction script for ResNet-18 on CIFAR-100
├── predict_acc_nll_ece_resnet18_imagenet.py # Prediction script for ResNet-18 on ImageNet
├── predict_acc_nll_ece_wrn_cifar100.py   # Prediction script for Wide Residual Network on CIFAR-100
├── README.md                             # Markdown file with project overview and instructions
├── requirements.txt                      # Required libraries and dependencies for the project
├── run.bat                               # Batch file for running scripts on Windows
├── train_deterministic_network.py        # Script for training deterministic neural networks
├── utils.py                              # Utility functions used across the project
└── visualize_stochastic_weights_scatter.py # Script to visualize weight scatter in stochastic models

Installation

To get started, clone the repository and install the required dependencies:

git clone [email protected]:gfhvbjk/Anonymous.git
cd Anonymouse
pip install -r requirements.txt

Note: we use pyro as a VI tool, therefore, pls install pyro as follows:

pip3 install pyro-ppl

Usage

Data Preparation

  • CIFAR-10, CIFAR-100, and MINST will be downloaded automatically
  • ImageNet1k needs to be put into the data folder
    • put the ImageNet1k as "data/train" and "data/val" (recommendation)
    • or you can specify the specific the path of ImageNet by adding -dataset_path = your path (not recommendation)

Evaluation and Checkpoints

We provide the following Evaluation scripts in this repository.

ImageNet1k by ResNet-18

  • ImageNet1k by ResNet-18 script: predict_acc_nll_ece_resnet18_imagenet.py
    • Download our pre-trained model here and put it under checkpoint/resnet18/imagenet/2DGBNNs/
    • python predict_acc_nll_ece_resnet18_imagenet.py
    • Or you can put the pre-trained model anywhere by running
    • python predict_acc_nll_ece_resnet18_imagenet.py -weights_path your_path
  • The outputs include Outliers, Ellipses, Gaussians, Accuracy, NLL, and ECE
Metric Value
Outliers 23013
Ellipses 10885
Gaussians 2217
Accuracy 68.11 ± 0.03
NLL 1.250 ± 0.005
ECE 0.019 ± 0.003

CIFAR-100 by ResNet-18

  • CIFAR-100 by ResNet-18 script: predict_acc_nll_ece_resnet18_cifar100.py
    • Download our pre-trained model here and put it under checkpoint/resnet18/cifar100/2DGBNNs/
    • python predict_acc_nll_ece_resnet18_cifar100.py
    • Or you can put the pre-trained model anywhere by running
    • python predict_acc_nll_ece_resnet18_cifar100.py -weights_path your_path
  • The outputs include Outliers, Ellipses, Gaussians, Accuracy, NLL, and ECE
Metric Value
Outliers 14624
Ellipses 260
Gaussians 2387
Accuracy 74.7 ± 0.1
NLL 1.049 ± 0.003
ECE 0.042± 0.003

CIFAR-10 by ResNet-18

  • CIFAR-10 by ResNet-18 script: predict_acc_nll_ece_resnet18_cifar10.py
    • Download our pre-trained model here and put it under checkpoint/resnet18/cifar10/2DGBNNs/
    • python predict_acc_nll_ece_resnet18_cifar10.py
    • Or you can put the pre-trained model anywhere by running
    • python predict_acc_nll_ece_resnet18_cifar10.py -weights_path your_path
  • The outputs include Outliers, Ellipses, Gaussians, Accuracy, NLL, and ECE
Metric Value
Outliers 123310
Ellipses 57
Gaussians 1569
Accuracy 91.8 ± 0.1
NLL 0.313 ± 0.003
ECE 0.034 ± 0.003

Train

During training, three stages in total, including weights init, stochastic network training, and finally the 2DGBNNs training.

  • We provide all stage scripts and instructions for training of WRN-28-10 in CIFAR-100 step by step. Also, checkpoint(pre-trained weights) is provided if it exists and can extremely save you time by just putting them into the corresponding folder.
  • All pre-trained can be downloaded in this folder. here
  • But we highly recommend you download them one by one following the steps.

Algorithm: Scaling BNNs to Large Models and Datasets

2DGBNN

Input:

  • NN architecture $f^{\mathbf{w}}$
  • Training data $\mathcal{D} = {(\mathbf{X}, \mathbf{y})}$
  • Algorithm thresholds: $\tau_w$, $\tau_d$, $\tau_g$, $\tau_v$
  • BNN prior $p(\mathbf{w})$

Output:

  • Stochastic weight-sharing trained BNN

Stage 0: Pre-trained model

  1. Train deterministic neural network

Run the following script:

python train_deterministic_network.py -net wrn -dataset cifar100

You can download our pre-trained model here And put it under "checkpoint/wrn/cifar100/origin"


Stage 1: Initialise GMM

  1. Initialize GMM Parameters
    Initialize $\mu$, $\sigma$ according to $p(\mathbf{w})$.

  2. Pre-training Loop
    For each epoch:

    • Sample weights:
      $$\mathbf{w} = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, \mathbf{I})$$
    • Update $\mu$, $\sigma$ by training on $\mathcal{D}$.
  3. Identify Outliers
    For each weight $w_i$:

    • If $|w_i| > \tau_w$ or $|\nabla_{w_i}|$ is in the top $1%$, then $w_i$ is an outlier.
    • Else, $w_i$ is an inlier.
  4. Learn GMM on Inliers
    Learn GMM on inlier parameters $\Theta_{in}$.

Pick a deterministic pre-trained model and get its path

Run the following script:

python exp_stochastic_nn.py -net wrn -dataset cifar100 -weights_path "checkpoint/wrn/origin/wrn.pth" (pls it replace by your path)

This pre-trained model can downloaded in here

Now, we need to initialize the GMM, pls pick a stochastic neural network and get its path

Run the following script:

Note that: we need to install cuml and cudf for GMM weights initiation (must installation) Install them by following:

pip install cudf-cu11 cuml-cu11

Then run the following bash command:

python init_gmm.py -weights_path "checkpoint/wrn/cifar100/stochastic/wrn_stochastic.pth" (pls replace it by your path)

This pre-trained model can download here

The initiation gmm configuration will be stored under "checkpoint/wrn/cifar100/init_gmm"


Stage 2: Refine GMM

  1. Mahalanobis Distance Check
    For each inlier weight $w_i$:

    • Perform Mahalanobis distance check.
    • If $w_i$ is outside the 95th percentile:
      • Assign $w_i$ to multiple clusters.
    • Else:
      • Assign $w_i$ to the closest Gaussian.
  2. Alpha-Blending for Ellipse Points
    Apply alpha-blending for ellipse points.

  3. Merge Gaussians
    Repeat until no more Gaussians can be merged:

    • For each pair $(\mathcal{N}_1, \mathcal{N}_2)$ in GMM:
      • If $W(\mathcal{N}_1, \mathcal{N}_2) < \tau_d$, $\Delta_g < \tau_g$, and $\Delta_v < \tau_v$:
        • Merge $\mathcal{N}_1$ and $\mathcal{N}_2$.
  4. Final Training Loop
    For each epoch:

    • For each weight $w_i$:
      • If $w_i$ is an inlier:
        • Sample $w_i \sim \sum_{k=1}^{K} \pi_k \mathcal{N}(\mu_k, \Sigma_k)$.
      • Else:
        • Use $w_i \sim \mathcal{N}(\mu_{w_i}, \sigma^2_{w_i})$.
    • Perform minimizing step for $\hat{\mathcal{L}}(\mathcal{D}, q)$.
  • To train the model, run the following command:
python 2DGBNNs_train.py -weights_stochastic_path "checkpoint/wrn/cifar100/stochastic/stochastic_wrn.pth" (pls replace it by your path) \
-weights_path_origin "checkpoint/wrn/cifar100/origin/wrn.pth" (pls replace it by your path) \
-init_gmm "checkpoint/wrn/cifar100/init_gmm/clusters_gpu.json" (pls replace it by your path) \
-output "checkpoint/wrn/cifar100/2DGBNNs"

This pre-trained model can download here

License

This project is licensed under the MIT License. See the `LICENSE` file for details.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages