Tutorial 4

In this tutorial, we are going to train a small Convolutional Neural Network using TensorFlow and convert it to an SNN using the few-spike encoding scheme.

The ANN and converted SNN both achieve around 99% on the MNIST test set.

Install

Download wheel file

[1]:
if "google.colab" in str(get_ipython()):
    !gdown 1wUeynMCgEOl2oK2LAd4E0s0iT_OiNOfl
    !pip install pygenn-5.1.0-cp311-cp311-linux_x86_64.whl
    %env CUDA_PATH=/usr/local/cuda
    !pip uninstall -y tf-keras
    !rm -rf /content/ml_genn-ml_genn_2_3_0
    !wget https://github.com/genn-team/ml_genn/archive/refs/tags/ml_genn_2_3_0.zip
    !unzip -q ml_genn_2_3_0.zip
    !pip install ./ml_genn-ml_genn_2_3_0/ml_genn
    !pip install ./ml_genn-ml_genn_2_3_0/ml_genn_tf
Downloading...
From: https://drive.google.com/uc?id=1wUeynMCgEOl2oK2LAd4E0s0iT_OiNOfl
To: /content/pygenn-5.1.0-cp311-cp311-linux_x86_64.whl
100% 8.49M/8.49M [00:00<00:00, 157MB/s]
Processing ./pygenn-5.1.0-cp311-cp311-linux_x86_64.whl
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from pygenn==5.1.0) (1.26.4)
Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from pygenn==5.1.0) (5.9.5)
Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from pygenn==5.1.0) (75.1.0)
pygenn is already installed with the same version as the provided wheel. Use --force-reinstall to force an installation of the wheel.
env: CUDA_PATH=/usr/local/cuda
WARNING: Skipping tf-keras as it is not installed.
--2025-01-21 11:04:08--  https://github.com/genn-team/ml_genn/archive/refs/tags/ml_genn_2_3_0.zip
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://codeload.github.com/genn-team/ml_genn/zip/refs/tags/ml_genn_2_3_0 [following]
--2025-01-21 11:04:08--  https://codeload.github.com/genn-team/ml_genn/zip/refs/tags/ml_genn_2_3_0
Resolving codeload.github.com (codeload.github.com)... 20.205.243.165
Connecting to codeload.github.com (codeload.github.com)|20.205.243.165|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/zip]
Saving to: ‘ml_genn_2_3_0.zip.2’

ml_genn_2_3_0.zip.2     [ <=>                ] 681.24K  --.-KB/s    in 0.05s

2025-01-21 11:04:08 (12.1 MB/s) - ‘ml_genn_2_3_0.zip.2’ saved [697592]

Processing ./ml_genn-ml_genn_2_3_0/ml_genn
  Preparing metadata (setup.py) ... done
Requirement already satisfied: pygenn<6.0.0,>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ml_genn==2.3.0) (5.1.0)
Collecting enum-compat (from ml_genn==2.3.0)
  Downloading enum_compat-0.0.3-py3-none-any.whl.metadata (954 bytes)
Requirement already satisfied: tqdm>=4.27.0 in /usr/local/lib/python3.11/dist-packages (from ml_genn==2.3.0) (4.67.1)
Requirement already satisfied: deprecated in /usr/local/lib/python3.11/dist-packages (from ml_genn==2.3.0) (1.2.15)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from pygenn<6.0.0,>=5.1.0->ml_genn==2.3.0) (1.26.4)
Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from pygenn<6.0.0,>=5.1.0->ml_genn==2.3.0) (5.9.5)
Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from pygenn<6.0.0,>=5.1.0->ml_genn==2.3.0) (75.1.0)
Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.11/dist-packages (from deprecated->ml_genn==2.3.0) (1.17.0)
Downloading enum_compat-0.0.3-py3-none-any.whl (1.3 kB)
Building wheels for collected packages: ml_genn
  Building wheel for ml_genn (setup.py) ... done
  Created wheel for ml_genn: filename=ml_genn-2.3.0-py3-none-any.whl size=131136 sha256=8e31526a7365e104aafce2a395478691a5b2ecbd34e157e788705995845c421d
  Stored in directory: /tmp/pip-ephem-wheel-cache-kx41bhpg/wheels/e6/30/c3/d2812036f97eda07dd49782a8c8707b279525e4d30ab961677
Successfully built ml_genn
Installing collected packages: enum-compat, ml_genn
Successfully installed enum-compat-0.0.3 ml_genn-2.3.0
Processing ./ml_genn-ml_genn_2_3_0/ml_genn_tf
  Preparing metadata (setup.py) ... done
Collecting tensorflow<2.15.0 (from ml_genn_tf==2.3.0)
  Using cached tensorflow-2.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Requirement already satisfied: ml_genn<3.0.0,>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from ml_genn_tf==2.3.0) (2.3.0)
Requirement already satisfied: pygenn<6.0.0,>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ml_genn<3.0.0,>=2.3.0->ml_genn_tf==2.3.0) (5.1.0)
Requirement already satisfied: enum-compat in /usr/local/lib/python3.11/dist-packages (from ml_genn<3.0.0,>=2.3.0->ml_genn_tf==2.3.0) (0.0.3)
Requirement already satisfied: tqdm>=4.27.0 in /usr/local/lib/python3.11/dist-packages (from ml_genn<3.0.0,>=2.3.0->ml_genn_tf==2.3.0) (4.67.1)
Requirement already satisfied: deprecated in /usr/local/lib/python3.11/dist-packages (from ml_genn<3.0.0,>=2.3.0->ml_genn_tf==2.3.0) (1.2.15)
Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (1.4.0)
Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (1.6.3)
Requirement already satisfied: flatbuffers>=23.5.26 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (24.12.23)
Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (0.6.0)
Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (0.2.0)
Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (3.12.1)
Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (18.1.1)
Collecting ml-dtypes==0.2.0 (from tensorflow<2.15.0->ml_genn_tf==2.3.0)
  Downloading ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (1.26.4)
Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (3.4.0)
Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (24.2)
Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (4.25.5)
Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (75.1.0)
Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (1.17.0)
Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (2.5.0)
Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (4.12.2)
Collecting wrapt<1.15,>=1.11.0 (from tensorflow<2.15.0->ml_genn_tf==2.3.0)
  Downloading wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (0.37.1)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.11/dist-packages (from tensorflow<2.15.0->ml_genn_tf==2.3.0) (1.69.0)
Collecting tensorboard<2.15,>=2.14 (from tensorflow<2.15.0->ml_genn_tf==2.3.0)
  Downloading tensorboard-2.14.1-py3-none-any.whl.metadata (1.7 kB)
Collecting tensorflow-estimator<2.15,>=2.14.0 (from tensorflow<2.15.0->ml_genn_tf==2.3.0)
  Downloading tensorflow_estimator-2.14.0-py2.py3-none-any.whl.metadata (1.3 kB)
Collecting keras<2.15,>=2.14.0 (from tensorflow<2.15.0->ml_genn_tf==2.3.0)
  Downloading keras-2.14.0-py3-none-any.whl.metadata (2.4 kB)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.11/dist-packages (from astunparse>=1.6.0->tensorflow<2.15.0->ml_genn_tf==2.3.0) (0.45.1)
Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from pygenn<6.0.0,>=5.1.0->ml_genn<3.0.0,>=2.3.0->ml_genn_tf==2.3.0) (5.9.5)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (2.27.0)
Collecting google-auth-oauthlib<1.1,>=0.5 (from tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0)
  Downloading google_auth_oauthlib-1.0.0-py2.py3-none-any.whl.metadata (2.7 kB)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (3.7)
Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (2.32.3)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (0.7.2)
Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (3.1.3)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (5.5.0)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.11/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (0.4.1)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.11/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (4.9)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (1.3.1)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (3.4.1)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (2.3.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (2024.12.14)
Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.11/dist-packages (from werkzeug>=1.0.1->tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (3.0.2)
Requirement already satisfied: pyasn1<0.7.0,>=0.4.6 in /usr/local/lib/python3.11/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (0.6.1)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.11/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.15,>=2.14->tensorflow<2.15.0->ml_genn_tf==2.3.0) (3.2.2)
Downloading tensorflow-2.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (489.9 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 489.9/489.9 MB 3.5 MB/s eta 0:00:00
Downloading ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 55.1 MB/s eta 0:00:00
Downloading keras-2.14.0-py3-none-any.whl (1.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 77.5 MB/s eta 0:00:00
Downloading tensorboard-2.14.1-py3-none-any.whl (5.5 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.5/5.5 MB 53.7 MB/s eta 0:00:00
Downloading tensorflow_estimator-2.14.0-py2.py3-none-any.whl (440 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 440.7/440.7 kB 36.1 MB/s eta 0:00:00
Downloading wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (78 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.4/78.4 kB 8.2 MB/s eta 0:00:00
Downloading google_auth_oauthlib-1.0.0-py2.py3-none-any.whl (18 kB)
Building wheels for collected packages: ml_genn_tf
  Building wheel for ml_genn_tf (setup.py) ... done
  Created wheel for ml_genn_tf: filename=ml_genn_tf-2.3.0-py3-none-any.whl size=12502 sha256=293723eca08d2a04d05986a2558e7e77cf6769e086b2c2200d5367f05664a88c
  Stored in directory: /tmp/pip-ephem-wheel-cache-r5ssun9d/wheels/40/23/ee/e262cb32552545b280d095d471df44b53d13f0d28c7b0bdeb6
Successfully built ml_genn_tf
Installing collected packages: wrapt, tensorflow-estimator, ml-dtypes, keras, google-auth-oauthlib, tensorboard, tensorflow, ml_genn_tf
  Attempting uninstall: wrapt
    Found existing installation: wrapt 1.17.0
    Uninstalling wrapt-1.17.0:
      Successfully uninstalled wrapt-1.17.0
  Attempting uninstall: ml-dtypes
    Found existing installation: ml-dtypes 0.4.1
    Uninstalling ml-dtypes-0.4.1:
      Successfully uninstalled ml-dtypes-0.4.1
  Attempting uninstall: keras
    Found existing installation: keras 3.5.0
    Uninstalling keras-3.5.0:
      Successfully uninstalled keras-3.5.0
  Attempting uninstall: google-auth-oauthlib
    Found existing installation: google-auth-oauthlib 1.2.1
    Uninstalling google-auth-oauthlib-1.2.1:
      Successfully uninstalled google-auth-oauthlib-1.2.1
  Attempting uninstall: tensorboard
    Found existing installation: tensorboard 2.17.1
    Uninstalling tensorboard-2.17.1:
      Successfully uninstalled tensorboard-2.17.1
  Attempting uninstall: tensorflow
    Found existing installation: tensorflow 2.17.1
    Uninstalling tensorflow-2.17.1:
      Successfully uninstalled tensorflow-2.17.1
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorstore 0.1.71 requires ml_dtypes>=0.3.1, but you have ml-dtypes 0.2.0 which is incompatible.
Successfully installed google-auth-oauthlib-1.0.0 keras-2.14.0 ml-dtypes-0.2.0 ml_genn_tf-2.3.0 tensorboard-2.14.1 tensorflow-2.14.1 tensorflow-estimator-2.14.0 wrapt-1.14.1

Train ANN

Firstly we define a simple ANN in Keras with two convolutional layers followed by two dense layers and train it:

[2]:
from tensorflow.keras import models, layers, datasets
from tensorflow.config import experimental

# Irritatingly, TF's default GPU memory allocator  allocates
# all available GPU memory - this can't be freed and would leave
# none for mlGeNN so we turn off this behaviour
for gpu in experimental.list_physical_devices("GPU"):
    experimental.set_memory_growth(gpu, True)

# Load MNIST data and normalise to [0,1]
(train_x, train_y), (test_x, test_y) = datasets.mnist.load_data()
train_x = train_x.reshape((-1, 28, 28, 1)) / 255.0
test_x = test_x.reshape((-1, 28, 28, 1)) / 255.0

# Create and compile TF model
tf_model = models.Sequential([
    layers.Conv2D(16, 5, padding="valid", activation="relu", use_bias=False, input_shape=train_x.shape[1:]),
    layers.AveragePooling2D(2),
    layers.Conv2D(8, 5, padding="valid", activation="relu", use_bias=False),
    layers.AveragePooling2D(2),
    layers.Flatten(),
    layers.Dense(128, activation="relu", use_bias=False),
    layers.Dense(64, activation="relu", use_bias=False),
    layers.Dense(train_y.max() + 1, activation="softmax", use_bias=False),
], name="simple_cnn")
tf_model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

# Fit TF model
tf_model.fit(train_x, train_y, epochs=10)
Epoch 1/10
1875/1875 [==============================] - 27s 14ms/step - loss: 0.2373 - accuracy: 0.9262
Epoch 2/10
1875/1875 [==============================] - 26s 14ms/step - loss: 0.0759 - accuracy: 0.9759
Epoch 3/10
1875/1875 [==============================] - 23s 13ms/step - loss: 0.0554 - accuracy: 0.9828
Epoch 4/10
1875/1875 [==============================] - 25s 13ms/step - loss: 0.0436 - accuracy: 0.9863
Epoch 5/10
1875/1875 [==============================] - 25s 13ms/step - loss: 0.0357 - accuracy: 0.9887
Epoch 6/10
1875/1875 [==============================] - 25s 14ms/step - loss: 0.0312 - accuracy: 0.9900
Epoch 7/10
1875/1875 [==============================] - 25s 13ms/step - loss: 0.0255 - accuracy: 0.9914
Epoch 8/10
1875/1875 [==============================] - 24s 13ms/step - loss: 0.0225 - accuracy: 0.9929
Epoch 9/10
1875/1875 [==============================] - 24s 13ms/step - loss: 0.0205 - accuracy: 0.9930
Epoch 10/10
1875/1875 [==============================] - 25s 13ms/step - loss: 0.0177 - accuracy: 0.9942
[2]:
<keras.src.callbacks.History at 0x7e4f80696850>

Evaluate ANN model

Now we evaluate the ANN on the MNIST test set:

[3]:
tf_model.evaluate(test_x, test_y)
313/313 [==============================] - 2s 7ms/step - loss: 0.0404 - accuracy: 0.9876
[3]:
[0.04040272533893585, 0.9876000285148621]

Build normalization dataset

To correctly configure the conversion algorithm, the range of activations in each layer is required. We determine this from a single, randomly selected batch of training data. Slightly awkwardly, mlGeNN takes these as an iterator so we turn them into a TF dataset:

[4]:
import numpy as np
from tensorflow.data import Dataset, AUTOTUNE

# ML GeNN norm dataset
norm_i = np.random.choice(train_x.shape[0], 128, replace=False)

norm_ds = Dataset.from_tensor_slices((train_x[norm_i], train_y[norm_i]))
norm_ds = norm_ds.batch(128)
norm_ds = norm_ds.prefetch(AUTOTUNE)

Convert model

We are going to use the few-spike conversion scheme to convert the ANN to an SNN with \(k=8\) timesteps per examples:

Stöckl, Christoph, and Wolfgang Maass. 2021. “Optimized Spiking Neurons Can Classify Images with High Accuracy through Temporal Coding with Two Spikes.” Nature Machine Intelligence 3(3): 230–38 (doi)

[5]:
from ml_genn_tf.converters import FewSpike

# Build few-spike converter
converter = FewSpike(k=8, norm_data=[norm_ds])

# Convert and compile ML GeNN model
net, net_inputs, net_outputs, tf_layer_pops = converter.convert(tf_model)

Compilation

In mlGeNN, in order to turn an abstract network description into something that can actually be used for training or inference you use a compiler class. Here, we ask the converter to build us a suitable compiler and specify batch size and that we don’t want connectvity expanded into sparse connectivity.

[6]:
compiler = converter.create_compiler(prefer_in_memory_connect=False, batch_size=128)
compiled_net = compiler.compile(net, inputs=net_inputs, outputs=net_outputs)

Evaluate SNN models

Finally, we evaluate the SNN model on the MNIST test set:

[7]:
with compiled_net:
    compiled_net.evaluate({net_inputs[0]: test_x},
                          {net_outputs[0]: test_y})