Tutorial 3

In this tutorial, we are going to directly train a simple SNN with a single hidden layer using EventProp on the MNIST dataset, converted to a latency spike code.

Clearly, this is far from a state of the art architecture, but it still achieves 96% accuracy on MNIST.

Install

Download wheel file

[1]:
if "google.colab" in str(get_ipython()):
    !gdown 1wUeynMCgEOl2oK2LAd4E0s0iT_OiNOfl
    !pip install pygenn-5.1.0-cp311-cp311-linux_x86_64.whl
    %env CUDA_PATH=/usr/local/cuda

    !rm -rf /content/ml_genn-ml_genn_2_3_0
    !wget https://github.com/genn-team/ml_genn/archive/refs/tags/ml_genn_2_3_0.zip
    !unzip -q ml_genn_2_3_0.zip
    !pip install ./ml_genn-ml_genn_2_3_0/ml_genn
Downloading...
From: https://drive.google.com/uc?id=1wUeynMCgEOl2oK2LAd4E0s0iT_OiNOfl
To: /content/pygenn-5.1.0-cp311-cp311-linux_x86_64.whl
100% 8.49M/8.49M [00:00<00:00, 29.4MB/s]
Processing ./pygenn-5.1.0-cp311-cp311-linux_x86_64.whl
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from pygenn==5.1.0) (1.26.4)
Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from pygenn==5.1.0) (5.9.5)
Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from pygenn==5.1.0) (75.1.0)
pygenn is already installed with the same version as the provided wheel. Use --force-reinstall to force an installation of the wheel.
env: CUDA_PATH=/usr/local/cuda
--2025-01-21 10:48:31--  https://github.com/genn-team/ml_genn/archive/refs/tags/ml_genn_2_3_0.zip
Resolving github.com (github.com)... 140.82.121.4
Connecting to github.com (github.com)|140.82.121.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://codeload.github.com/genn-team/ml_genn/zip/refs/tags/ml_genn_2_3_0 [following]
--2025-01-21 10:48:31--  https://codeload.github.com/genn-team/ml_genn/zip/refs/tags/ml_genn_2_3_0
Resolving codeload.github.com (codeload.github.com)... 140.82.121.10
Connecting to codeload.github.com (codeload.github.com)|140.82.121.10|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 697592 (681K) [application/zip]
Saving to: ‘ml_genn_2_3_0.zip.1’

ml_genn_2_3_0.zip.1 100%[===================>] 681.24K  --.-KB/s    in 0.04s

2025-01-21 10:48:32 (14.9 MB/s) - ‘ml_genn_2_3_0.zip.1’ saved [697592/697592]

Processing ./ml_genn-ml_genn_2_3_0/ml_genn
  Preparing metadata (setup.py) ... done
Requirement already satisfied: pygenn<6.0.0,>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ml_genn==2.3.0) (5.1.0)
Requirement already satisfied: enum-compat in /usr/local/lib/python3.11/dist-packages (from ml_genn==2.3.0) (0.0.3)
Requirement already satisfied: tqdm>=4.27.0 in /usr/local/lib/python3.11/dist-packages (from ml_genn==2.3.0) (4.67.1)
Requirement already satisfied: deprecated in /usr/local/lib/python3.11/dist-packages (from ml_genn==2.3.0) (1.2.15)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from pygenn<6.0.0,>=5.1.0->ml_genn==2.3.0) (1.26.4)
Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from pygenn<6.0.0,>=5.1.0->ml_genn==2.3.0) (5.9.5)
Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from pygenn<6.0.0,>=5.1.0->ml_genn==2.3.0) (75.1.0)
Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.11/dist-packages (from deprecated->ml_genn==2.3.0) (1.17.0)
Building wheels for collected packages: ml_genn
  Building wheel for ml_genn (setup.py) ... done
  Created wheel for ml_genn: filename=ml_genn-2.3.0-py3-none-any.whl size=131136 sha256=0225072f9b0c642ca94d3885fe0b406690bcb240049d2224971ef410fe384476
  Stored in directory: /tmp/pip-ephem-wheel-cache-s23l4ixv/wheels/e6/30/c3/d2812036f97eda07dd49782a8c8707b279525e4d30ab961677
Successfully built ml_genn
Installing collected packages: ml_genn
  Attempting uninstall: ml_genn
    Found existing installation: ml_genn 2.3.0
    Uninstalling ml_genn-2.3.0:
      Successfully uninstalled ml_genn-2.3.0
Successfully installed ml_genn-2.3.0

Install MNIST package

[2]:
!pip install mnist
Collecting mnist
  Using cached mnist-0.2.2-py2.py3-none-any.whl.metadata (1.6 kB)
Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from mnist) (1.26.4)
Using cached mnist-0.2.2-py2.py3-none-any.whl (3.5 kB)
Installing collected packages: mnist
Successfully installed mnist-0.2.2

Build model

Import standard modules and required mlGeNN classes

[3]:
import mnist
import numpy as np
import matplotlib.pyplot as plt

from ml_genn import InputLayer, Layer, SequentialNetwork
from ml_genn.callbacks import Checkpoint
from ml_genn.compilers import EventPropCompiler, InferenceCompiler
from ml_genn.connectivity import Dense
from ml_genn.initializers import Normal
from ml_genn.neurons import LeakyIntegrate, LeakyIntegrateFire, SpikeInput
from ml_genn.optimisers import Adam
from ml_genn.serialisers import Numpy
from ml_genn.synapses import Exponential

from ml_genn.utils.data import (calc_latest_spike_time, linear_latency_encode_data)

from ml_genn.compilers.event_prop_compiler import default_params

##Parameters

Define some model parameters

[4]:
NUM_INPUT = 28 * 28
NUM_HIDDEN = 128
NUM_OUTPUT = 10
BATCH_SIZE = 128

Latency encoding

There are numerous ways to encode images using spikes but here we are going to emit a single spike for each neuron at a time proportional the each pixel’s grayscale.

[5]:
mnist.datasets_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
train_spikes = linear_latency_encode_data(mnist.train_images(), 20.0)

Network definition

Because our network is entirely feedforward, we can define it as a SequentialNetwork where each layer is automatically connected to the previous layer. As we have converted the MNIST dataset to spikes, we will use a SpikeInput to inject these directly into the network. For our hidden layer we are going to use standard Leaky integrate-and-fire neurons. Finally, we are going to use a non-spiking output layer and read classifications out of this by determening the maximum of the output neurons’ averaged membrane voltages.

[6]:
# Create sequential model
serialiser = Numpy("latency_mnist_checkpoints")
network = SequentialNetwork(default_params)
with network:
    # Populations
    input = InputLayer(SpikeInput(max_spikes=BATCH_SIZE * NUM_INPUT),
                                  NUM_INPUT)
    hidden = Layer(Dense(Normal(mean=0.078, sd=0.045)),
                   LeakyIntegrateFire(v_thresh=1.0, tau_mem=20.0,
                                      tau_refrac=None),
                   NUM_HIDDEN, Exponential(5.0))
    output = Layer(Dense(Normal(mean=0.2, sd=0.37)),
                   LeakyIntegrate(tau_mem=20.0, readout="avg_var"),
                   NUM_OUTPUT, Exponential(5.0))

Compilation

In mlGeNN, in order to turn an abstract network description into something that can actually be used for training or inference you use a compiler class. Here, we use the EventPropCompiler to train with EventProp and specify batch size and how many timesteps to evaluate each example for as well as choosing our optimiser and loss function. Because this is a classification task, we want to use cross-entropy loss and, because our labels are specified in this way (rather than e.g. one-hot encoded), we use the sparse catgorical variant.

[7]:
compiler = EventPropCompiler(example_timesteps=20,
                         losses="sparse_categorical_crossentropy",
                         optimiser=Adam(1e-2), batch_size=BATCH_SIZE)
compiled_net = compiler.compile(network)

Training

Now we will train the model for 10 epochs using our compiled network. To verify its performance we take 10% of the training data as a validation split and add an additional callback to checkpoint weights every epoch.

[8]:
with compiled_net:
    # Evaluate model on numpy dataset
    callbacks = ["batch_progress_bar", Checkpoint(serialiser)]
    compiled_net.train({input: train_spikes},
                       {output: mnist.train_labels()},
                       num_epochs=15, shuffle=True,
                       validation_split=0.1,
                       callbacks=callbacks)

Evaluate

Load weights checkpointed from last epoch:

[9]:
network.load((14,), serialiser)

Create an InferenceCompiler and compile network for inference:

[10]:
compiler = InferenceCompiler(evaluate_timesteps=20,
                             reset_in_syn_between_batches=True,
                             batch_size=BATCH_SIZE)
compiled_net = compiler.compile(network)

Encode test set using the same log-latency encoding and evaluate it:

[11]:
test_spikes = linear_latency_encode_data(mnist.test_images(), 20.0)
with compiled_net:
    compiled_net.evaluate({input: test_spikes},
                          {output: mnist.test_labels()})