Tutorial 2
In this tutorial, we are going to directly train a simple SNN with a single hidden layer using e-prop on the MNIST dataset, converted to a latency spike code.
Clearly, this is far from a state of the art architecture, but it still achieves 94% accuracy on MNIST.
Install
Download wheel file
[1]:
if "google.colab" in str(get_ipython()):
!gdown 1qlNCa_WT7sOmifYjPgGyggqUGsPiHn5y
!pip install pygenn-5.3.0-cp312-cp312-linux_x86_64.whl
%env CUDA_PATH=/usr/local/cuda
!rm -rf /content/ml_genn-ml_genn_2_4_0
!wget https://github.com/genn-team/ml_genn/archive/refs/tags/ml_genn_2_4_0.zip
!unzip -q ml_genn_2_4_0.zip
!pip install ./ml_genn-ml_genn_2_4_0/ml_genn
Downloading...
From: https://drive.google.com/uc?id=1QF6eMWoqmOehbzXNSUbrImyBo0dTbv6J
To: /content/pygenn-5.2.0-cp311-cp311-linux_x86_64.whl
100% 8.60M/8.60M [00:00<00:00, 46.5MB/s]
Processing ./pygenn-5.2.0-cp311-cp311-linux_x86_64.whl
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from pygenn==5.2.0) (2.0.2)
Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from pygenn==5.2.0) (5.9.5)
Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from pygenn==5.2.0) (75.2.0)
pygenn is already installed with the same version as the provided wheel. Use --force-reinstall to force an installation of the wheel.
env: CUDA_PATH=/usr/local/cuda
--2025-04-25 14:58:50-- https://github.com/genn-team/ml_genn/archive/refs/tags/ml_genn_2_3_1.zip
Resolving github.com (github.com)... 140.82.114.3
Connecting to github.com (github.com)|140.82.114.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://codeload.github.com/genn-team/ml_genn/zip/refs/tags/ml_genn_2_3_1 [following]
--2025-04-25 14:58:51-- https://codeload.github.com/genn-team/ml_genn/zip/refs/tags/ml_genn_2_3_1
Resolving codeload.github.com (codeload.github.com)... 140.82.114.10
Connecting to codeload.github.com (codeload.github.com)|140.82.114.10|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 706004 (689K) [application/zip]
Saving to: ‘ml_genn_2_3_1.zip.1’
ml_genn_2_3_1.zip.1 100%[===================>] 689.46K 1.77MB/s in 0.4s
2025-04-25 14:58:51 (1.77 MB/s) - ‘ml_genn_2_3_1.zip.1’ saved [706004/706004]
Processing ./ml_genn-ml_genn_2_3_1/ml_genn
Preparing metadata (setup.py) ... done
Requirement already satisfied: pygenn<6.0.0,>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ml_genn==2.3.1) (5.2.0)
Collecting enum-compat (from ml_genn==2.3.1)
Downloading enum_compat-0.0.3-py3-none-any.whl.metadata (954 bytes)
Requirement already satisfied: tqdm>=4.27.0 in /usr/local/lib/python3.11/dist-packages (from ml_genn==2.3.1) (4.67.1)
Requirement already satisfied: deprecated in /usr/local/lib/python3.11/dist-packages (from ml_genn==2.3.1) (1.2.18)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from pygenn<6.0.0,>=5.1.0->ml_genn==2.3.1) (2.0.2)
Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from pygenn<6.0.0,>=5.1.0->ml_genn==2.3.1) (5.9.5)
Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from pygenn<6.0.0,>=5.1.0->ml_genn==2.3.1) (75.2.0)
Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.11/dist-packages (from deprecated->ml_genn==2.3.1) (1.17.2)
Downloading enum_compat-0.0.3-py3-none-any.whl (1.3 kB)
Building wheels for collected packages: ml_genn
Building wheel for ml_genn (setup.py) ... done
Created wheel for ml_genn: filename=ml_genn-2.3.1-py3-none-any.whl size=132274 sha256=e8044f9167084a0d3837e7cf8a7a95819e2b26af666038835137c046e2a351d9
Stored in directory: /tmp/pip-ephem-wheel-cache-bo79lql7/wheels/ad/47/4c/9f1426577cf209699f869217d188ff31354668a5b50067c04b
Successfully built ml_genn
Installing collected packages: enum-compat, ml_genn
Successfully installed enum-compat-0.0.3 ml_genn-2.3.1
Install MNIST package
[2]:
!pip install mnist
Collecting mnist
Downloading mnist-0.2.2-py2.py3-none-any.whl.metadata (1.6 kB)
Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from mnist) (2.0.2)
Downloading mnist-0.2.2-py2.py3-none-any.whl (3.5 kB)
Installing collected packages: mnist
Successfully installed mnist-0.2.2
Build model
Import standard modules and required mlGeNN classes
[3]:
import mnist
import numpy as np
import matplotlib.pyplot as plt
from ml_genn import InputLayer, Layer, SequentialNetwork
from ml_genn.callbacks import Checkpoint
from ml_genn.compilers import EPropCompiler, InferenceCompiler
from ml_genn.connectivity import Dense,FixedProbability
from ml_genn.initializers import Normal
from ml_genn.neurons import LeakyIntegrate, LeakyIntegrateFire, SpikeInput
from ml_genn.serialisers import Numpy
from ml_genn.utils.data import (calc_latest_spike_time, log_latency_encode_data)
from ml_genn.compilers.eprop_compiler import default_params
##Parameters
Define some model parameters
[4]:
NUM_INPUT = 28 * 28
NUM_HIDDEN = 128
NUM_OUTPUT = 10
BATCH_SIZE = 128
Latency encoding
- There are numerous ways to encode images using spikes but here we are going to emit a single spike for each neuron at a time calculated as follows from the pixel grayscale \(x\): :nbsphinx-math:`begin{align}
- T(x) = begin{cases}
tau_text{eff} logleft(frac{x}{x-theta} right) & x > theta\ infty & otherwise\
end{cases}
end{align}` where \(\tau_\text{eff}=20\text{ms}\) and \(\theta=51\).
[5]:
mnist.datasets_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
train_spikes = log_latency_encode_data(mnist.train_images(), 20.0, 51)
Network definition
Because our network is entirely feedforward, we can define it as a SequentialNetwork where each layer is automatically connected to the previous layer. As we have converted the MNIST dataset to spikes, we will use a SpikeInput to inject these directly into the network. For our hidden layer we are going to use standard Leaky integrate-and-fire neurons as this task does not require more computationally expensive adaptive LIF neurons. Finally, we are going to use a non-spiking output layer
and read classifications out of this by determining the maximum of the summed membrane voltages of the output neurons.
[6]:
# Create sequential model
serialiser = Numpy("latency_mnist_checkpoints")
network = SequentialNetwork(default_params)
with network:
# Populations
input = InputLayer(SpikeInput(max_spikes=BATCH_SIZE * NUM_INPUT),
NUM_INPUT)
hidden = Layer(Dense(Normal(sd=1.0 / np.sqrt(NUM_INPUT))),
LeakyIntegrateFire(v_thresh=0.61, tau_mem=20.0,
tau_refrac=5.0),
NUM_HIDDEN)
output = Layer(Dense(Normal(sd=1.0 / np.sqrt(NUM_HIDDEN))),
LeakyIntegrate(tau_mem=20.0, readout="sum_var"),
NUM_OUTPUT)
Compilation
In mlGeNN, in order to turn an abstract network description into something that can actually be used for training or inference you use a compiler class. Here, we use the EPropCompiler to train with e-prop and specify batch size and how many timesteps to evaluate each example for as well as choosing our optimiser and loss function. Because this is a classification task, we want to use cross-entropy loss and, because our labels are specified in this way (rather than e.g. one-hot encoded), we
use the sparse catgorical variant.
[7]:
max_example_timesteps = int(np.ceil(calc_latest_spike_time(train_spikes)))
compiler = EPropCompiler(example_timesteps=max_example_timesteps,
losses="sparse_categorical_crossentropy",
optimiser="adam", batch_size=BATCH_SIZE)
compiled_net = compiler.compile(network)
Training
Now we will train the model for 10 epochs using our compiled network. To verify it’s performance we take 10% of the training data as a validation split and add an additional callback to checkpoint weights every epoch.
[8]:
with compiled_net:
# Evaluate model on numpy dataset
callbacks = ["batch_progress_bar", Checkpoint(serialiser)]
compiled_net.train({input: train_spikes},
{output: mnist.train_labels()},
num_epochs=10, shuffle=True,
validation_split=0.1,
callbacks=callbacks)
Evaluate
Load weights checkpointed from last epoch:
[9]:
network.load((9,), serialiser)
Create an InferenceCompiler and compile network for inference:
[10]:
compiler = InferenceCompiler(evaluate_timesteps=max_example_timesteps,
batch_size=BATCH_SIZE)
compiled_net = compiler.compile(network)
Encode test set using the same log-latency encoding and evaluate it:
[11]:
test_spikes = log_latency_encode_data(mnist.test_images(), 20.0, 51)
with compiled_net:
compiled_net.evaluate({input: test_spikes},
{output: mnist.test_labels()})