Training networks
A key design goal of mlGeNN is that once a network topology has been defined using the API described in
Building networks, it can be trained using any supported training algorithm.
Network
and SequentialNetwork
objects need to be compiled into
GeNN models for training using a training algorithm compiler class. Currently the
following compilers for training networks are provided:
e-prop
- class ml_genn.compilers.EPropCompiler(example_timesteps, losses, optimiser='adam', tau_reg=500.0, c_reg=0.001, f_target=10.0, train_output_bias=True, dt=1.0, batch_size=1, rng_seed=0, kernel_profiling=False, reset_time_between_batches=True, communicator=None, **genn_kwargs)
Compiler for training models using e-prop [Bellec2020].
The e-prop compiler supports
ml_genn.neurons.LeakyIntegrateFire
andml_genn.neurons.AdaptiveLeakyIntegrateFire
hidden neuron models; andml_genn.losses.SparseCategoricalCrossentropy
loss functions for classification andml_genn.losses.MeanSquareError
for regression.e-prop is derived from Real-Time Recurrent Learning (RTRL) so does not require a backward pass meaning that its memory overhead does not scale with sequence length. However, e-prop requires a per-connection eligibility trace meaning that it is incompatible with connectivity like convolutions with shared weights. Furthermore, because each connection has to be updated every timestep, training performance is not improved by sparse activations.
- Parameters:
example_timesteps (int) – How many timesteps each example will be presented to the network for
losses – Either a dictionary mapping loss functions to output populations or a single loss function to apply to all outputs
optimiser – Optimiser to use when applying weights
tau_reg (float) – Time constant with which hidden neuron spike trains are filtered to obtain the firing rate used for regularisation [ms]
c_reg (float) – Regularisation strength
f_target (float) – Target hidden neuron firing rate used for regularisation [Hz]
train_output_bias (bool) – Should output neuron biases be trained?
dt (float) – Simulation timestep [ms]
batch_size (int) – What batch size should be used for training? In our experience, e-prop works well with very large batch sizes (512)
rng_seed (int) – What value should GeNN’s GPU RNG be seeded with? This is used for all GPU randomness e.g. weight initialisation and Poisson spike train generation
kernel_profiling (bool) – Should GeNN record the time spent in each GPU kernel? These values can be extracted directly from the GeNN model which can be accessed via the
genn_model
property of the compiled model.reset_time_between_batches (bool) – Should time be reset to zero at the start of each example or allowed to run continously?
communicator (Communicator) – Communicator used for inter-process communications when training across multiple GPUs.
EventProp
- class ml_genn.compilers.EventPropCompiler(example_timesteps, losses, optimiser='adam', reg_lambda_upper=0.0, reg_lambda_lower=0.0, reg_nu_upper=0.0, max_spikes=500, strict_buffer_checking=False, per_timestep_loss=False, dt=1.0, ttfs_alpha=0.01, softmax_temperature=1.0, batch_size=1, rng_seed=0, kernel_profiling=False, communicator=None, delay_optimiser=None, delay_learn_conns=[], **genn_kwargs)
Compiler for training models using EventProp [Wunderlich2021].
The EventProp compiler supports
ml_genn.neurons.LeakyIntegrateFire
hidden neuron models; andml_genn.losses.SparseCategoricalCrossentropy
loss functions for classification andml_genn.losses.MeanSquareError
for regression.EventProp implements a fully event-driven backward pass meaning that its memory overhead scales with the number of spikes per-trial rather than sequence length.
In the original paper, [Wunderlich2021] derived EventProp to support loss functions of the form:
\[{\cal L} = l_p(t^{\text{post}}) + \int_0^T l_V(V(t),t) dt\]such as
\[l_V= -\frac{1}{N_{\text{batch}}} \sum_{m=1}^{N_{\text{batch}}} \log \left( \frac{\exp\left(V_{l(m)}^m(t)\right)}{\sum_{k=1}^{N_{\text{class}}} \exp\left(V_{k}^m(t) \right)} \right)\]where a function of output neuron membrane voltage is calculated each timestep – in mlGeNN, we refer to these as per-timestep loss functions. However, [Nowotny2024] showed that tasks with more complex temporal structure cannot be learned using these loss functions and extended the framework to support loss functions of the form:
\[{\cal L}_F = F\left(\textstyle \int_0^T l_V(V(t),t) \, dt\right)\]such as:
\[{\mathcal L_{\text{sum}}} = - \frac{1}{N_{\text{batch}}} \sum_{m=1}^{N_{\text{batch}}} \log \left( \frac{\exp\left(\int_0^T V_{l(m)}^m(t) dt\right)}{\sum_{k=1}^{N_{\text{out}}} \exp\left(\int_0^T V_{k}^m(t) dt\right)} \right)\]where a function of the integral of voltage is calculated once per-trial.
- Parameters:
example_timesteps (int) – How many timestamps each example will be presented to the network for
losses – Either a dictionary mapping loss functions to output populations or a single loss function to apply to all outputs
optimiser – Optimiser to use when applying weights
reg_lambda_upper (float) – Regularisation strength, should typically be the same as
reg_lambda_lower
.reg_lambda_lower (float) – Regularisation strength, should typically be the same as
reg_lambda_upper
.reg_nu_upper (float) – Target number of hidden neuron spikes used for regularisation
max_spikes (int) – What is the maximum number of spikes each neuron (input and hidden) can emit each trial? This is used to allocate memory for the backward pass.
strict_buffer_checking (bool) – For performance reasons, if neurons emit more than
max_spikes
they are normally ignored but, if this flag is set, this will cause an error.per_timestep_loss (bool) – Should we use the per-timestep or per-trial loss functions described above?
dt (float) – Simulation timestep [ms]
TODO (softmax_temperature)
TODO
batch_size (int) – What batch size should be used for training? In our experience, EventProp works best with modest batch sizes (32-128)
rng_seed (int) – What value should GeNN’s GPU RNG be seeded with? This is used for all GPU randomness e.g. weight initialisation and Poisson spike train generation
kernel_profiling (bool) – Should GeNN record the time spent in each GPU kernel? These values can be extracted directly from the GeNN model which can be accessed via the
genn_model
property of the compiled model.reset_time_between_batches – Should time be reset to zero at the start of each example or allowed to run continously?
communicator (Communicator) – Communicator used for inter-process communications when training across multiple GPUs.
delay_optimiser – Optimiser to use when applying delays. If None,
optimiser
will be used for delaysdelay_learn_conns (Sequence) – Connection for which delays should be learned as well as weight
ttfs_alpha (float)
softmax_temperature (float)
Once either compiler has been constructed, it can be used to compile a network with:
compiled_net = compiler.compile(network)
and this can be trained on a dataset, prepared as described in Datasets, using a standard training loop with:
with compiled_net:
# Evaluate model on numpy dataset
metrics, _ = compiled_net.train({input: spikes},
{output: labels},
num_epochs=50, shuffle=True)
Like in Keras, additional logic such as checkpointing and recording of state variables can be added to the standard training loop using callbacks as described in the Callbacks and recording section.
Augmentation
Like when using Keras, sometimes, merely adding callbacks to the standard training loop is insufficient and you want to perform additional manual processing. One common case of this is augmentation where you want to modify the data being trained on each epoch. This can be implemented by manually looping over epochs and providing new data each time like this:
with compiled_net:
for e in range(50):
aug_spikes = augment(spikes)
metrics, _ = compiled_net.train({input: aug_spikes},
{output: labels},
start_epoch=e, num_epochs=1,
shuffle=True)
where augment
is a function that returns an augmented version of a spike dataset (see Datasets).
Default parameters
Sadly the mathematical derivation of the different training algorithms makes different
assumptions about the detailed implementation of various neuron models. For example,
the e-prop learning rule assumes that neurons will have a ‘relative reset’
where the membrane voltage has a fixed value subtracted from it after a spike whereas
EventProp assumes that the membrane voltage will be reset to a fixed value after a spike.
To avoid users having to remember these details, mlGeNN compilers provide a dictionary of
default parameters which can be passed to the constructors of Network
and SequentialNetwork
. For example here the e-prop defaults are applied to
a sequential network and hence the leaky integrate-and-fire layer within it:
from ml_genn import Layer, SequentialNetwork
from ml_genn.neurons import LeakyIntegrateFire
from ml_genn.compilers.eprop_compiler import default_params
network = SequentialNetwork(default_params)
with network:
...
hidden = Layer(Dense(1.0), LeakyIntegrateFire(v_thresh=0.61, tau_mem=20.0, tau_refrac=5.0), 128)