This project combines a novel training method called Spatio-Temporal Backpropagation (STBP) with a Batch Normalization technique known as threshold-dependent Batch Normalization (tdBN) to make the training of Deep Spiking Neural Networks possible. The applied architecture is inspired by an 18-layer ResNet, sometimes referred to as a Deep Spiking Residual Network (DSRN). It uses tdBN and a customized Basic Block instead of the standard Batch Normalization and Basic Block found in ResNet. A new spiking activation function with gradient approximation is also designed to replicate brain-like behavior and make intermediate derivatives differentiable during backpropagation. tdBN helps address inappropriate firing rates from each neuron and vanishing or exploding gradients, as well as reducing internal covariate shift during the training process, thus balancing the threshold and pre-synaptic inputs. An impressive accuracy of 98.67% was achieved in the MNIST dataset after just 11 epochs.
- Create an environment using
conda create --name <env> --file requirements.txt
- Main packages:
torch
,tensorboard
,pandas
,numpy
- OS: macOS Ventura
$13.0.1$ - Python
$3.10.8$ , Anaconda$22.9.0$ - Device: GPU Tesla T4 (Google Colab)
- Run this command
python main.py --train_batch_size 64 --val_batch_size 200 --epochs 11 --lr 0.001
- There are many arguments that can be changed when running in the terminal (Read more in the code, file
main.py
)
- Building a customized Basic Block
- Acording to the idea from the paper
$3$ , we're going to use this customized Basic Block (on the right in the image below) when training Deep Spiking Neural Networks.
- Acording to the idea from the paper
The customized Basic Block for SNNs.
- tdBN normalizes the pre-activations to
$N(0, V^{2}_{th})$ , see more in paper$3$ - Approximation derivative is a rectangular function
$$h(u)=\frac{1}{2a}sign(|u|< a)$$ - Deep Spiking Neural Network Structure
19 layers | |
---|---|
conv1 |
|
block 1 | |
block 2 | |
block 3 | |
average pool stride=2, 256-d FC | |
10-d FC, softmax |
The average pool layer takes an average over channels axis, see the image below:
The Global Average Pooling Layer.
- The size of each image in MNIST dataset is
$(28, 28)$ and I'm going to use these transforms:-
RandomHorizontalFlip
: Horizontally flip the given image randomly with a given probability (default:$0.5$ ). -
RandomCrop
: Crop the given image at a random location. -
Normalize
: Normalize a tensor image with mean and standard deviation.
-
- The customized MNIST has a transformation
ToTensor
already.
Hyperparameters | Values |
---|---|
Timesteps steps
|
|
Peak width aa
|
|
Threshold Vth
|
|
Decay factor tau
|
|
Alpha alpha
|
|
#epochs | |
Learning rate | |
Adam betas | |
Adam epsilon |
- Early Stopping is used to avoid overfitting. That is whenever the validation loss increases, the stopper counter increments by
$1$ until the counter is equal to parameterpatience
, the training process is stopped.
- In this project, I just show an example experiment of training the customized
$19$ -layer ResNet (shown via tensorboard). With the hyperparameters set up in filehelper_layers.py
and by settingepochs=11
and using Learning Rate Adjustment strategy (i.e. reducing the learning rate after an amount of epochs, in this case I picked$10$ for that amount). The training time is around$2$ hours. - If you haven't installed
tensorboard
, type this commandpip install -U tensorboard
to install it. To run tensorboard, type this command in the terminal:tensorboard --logdir runs/<folder including events file>
. By default, TensorBoard will automatically write event files to a directory namedruns
in the current working directory.
TensorBoard GUI.
- The image below is the test loss and accuracy after finishing
$11$ epochs:
Test loss and test accuracy at the end of the training process.
- Here is the firing rates when inferring a test sample. Run command
python compute_firing_rates.py
to see the plot:
Average firing rates over neuron of each layer.
- Some comments:
- Note that there're
$2$ types of neurons, which are convolution neurons and fully connected neurons. In this case, I'll consider both kinds as neurons and compute the average firing rates over them. - In layer
$17$ (a fully connected layer), the average firing rate is the highest.
- Note that there're
- The model performance is extremely sentitive to
$V_{th}$ andaa
(the width parameter of the rectangular curve) so you need to make a careful choice of these$2$ hyperparameters.
- https://www.frontiersin.org/articles/10.3389/fnins.2018.00331/full#B41
- https://arxiv.org/abs/1809.05793
- https://arxiv.org/abs/2011.05280
- http://d2l.ai/chapter_convolutional-modern/batch-norm.html#equation-eq-batchnorm
- https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html
- https://arxiv.org/abs/1502.03167