Training networks

A key design goal of mlGeNN is that once a network topology has been defined using the API described in Building networks, it can be trained using any supported training algorithm. Network and SequentialNetwork objects need to be compiled into GeNN models for training using a training algorithm compiler class. Currently the following compilers for training networks are provided:

e-prop

class ml_genn.compilers.EPropCompiler(example_timesteps, losses, optimiser='adam', tau_reg=500.0, c_reg=0.001, f_target=10.0, train_output_bias=True, dt=1.0, batch_size=1, rng_seed=0, kernel_profiling=False, reset_time_between_batches=True, communicator=None, **genn_kwargs)

Compiler for training models using e-prop [Bellec2020].

The e-prop compiler supports ml_genn.neurons.LeakyIntegrateFire and ml_genn.neurons.AdaptiveLeakyIntegrateFire hidden neuron models; and ml_genn.losses.SparseCategoricalCrossentropy loss functions for classification and ml_genn.losses.MeanSquareError for regression.

e-prop is derived from Real-Time Recurrent Learning (RTRL) so does not require a backward pass meaning that its memory overhead does not scale with sequence length. However, e-prop requires a per-connection eligibility trace meaning that it is incompatible with connectivity like convolutions with shared weights. Furthermore, because each connection has to be updated every timestep, training performance is not improved by sparse activations.

Parameters:
  • example_timesteps (int) – How many timesteps each example will be presented to the network for

  • losses – Either a dictionary mapping loss functions to output populations or a single loss function to apply to all outputs

  • optimiser – Optimiser to use when applying weights

  • tau_reg (float) – Time constant with which hidden neuron spike trains are filtered to obtain the firing rate used for regularisation [ms]

  • c_reg (float) – Regularisation strength

  • f_target (float) – Target hidden neuron firing rate used for regularisation [Hz]

  • train_output_bias (bool) – Should output neuron biases be trained?

  • dt (float) – Simulation timestep [ms]

  • batch_size (int) – What batch size should be used for training? In our experience, e-prop works well with very large batch sizes (512)

  • rng_seed (int) – What value should GeNN’s GPU RNG be seeded with? This is used for all GPU randomness e.g. weight initialisation and Poisson spike train generation

  • kernel_profiling (bool) – Should GeNN record the time spent in each GPU kernel? These values can be extracted directly from the GeNN model which can be accessed via the genn_model property of the compiled model.

  • reset_time_between_batches (bool) – Should time be reset to zero at the start of each example or allowed to run continously?

  • communicator (Communicator) – Communicator used for inter-process communications when training across multiple GPUs.

EventProp

class ml_genn.compilers.EventPropCompiler(example_timesteps, losses, optimiser='adam', reg_lambda_upper=0.0, reg_lambda_lower=0.0, reg_nu_upper=0.0, max_spikes=500, strict_buffer_checking=False, per_timestep_loss=False, dt=1.0, ttfs_alpha=0.01, softmax_temperature=1.0, batch_size=1, rng_seed=0, kernel_profiling=False, communicator=None, delay_optimiser=None, delay_learn_conns=[], **genn_kwargs)

Compiler for training models using EventProp [Wunderlich2021].

The EventProp compiler supports ml_genn.neurons.LeakyIntegrateFire hidden neuron models; and ml_genn.losses.SparseCategoricalCrossentropy loss functions for classification and ml_genn.losses.MeanSquareError for regression.

EventProp implements a fully event-driven backward pass meaning that its memory overhead scales with the number of spikes per-trial rather than sequence length.

In the original paper, [Wunderlich2021] derived EventProp to support loss functions of the form:

\[{\cal L} = l_p(t^{\text{post}}) + \int_0^T l_V(V(t),t) dt\]

such as

\[l_V= -\frac{1}{N_{\text{batch}}} \sum_{m=1}^{N_{\text{batch}}} \log \left( \frac{\exp\left(V_{l(m)}^m(t)\right)}{\sum_{k=1}^{N_{\text{class}}} \exp\left(V_{k}^m(t) \right)} \right)\]

where a function of output neuron membrane voltage is calculated each timestep – in mlGeNN, we refer to these as per-timestep loss functions. However, [Nowotny2024] showed that tasks with more complex temporal structure cannot be learned using these loss functions and extended the framework to support loss functions of the form:

\[{\cal L}_F = F\left(\textstyle \int_0^T l_V(V(t),t) \, dt\right)\]

such as:

\[{\mathcal L_{\text{sum}}} = - \frac{1}{N_{\text{batch}}} \sum_{m=1}^{N_{\text{batch}}} \log \left( \frac{\exp\left(\int_0^T V_{l(m)}^m(t) dt\right)}{\sum_{k=1}^{N_{\text{out}}} \exp\left(\int_0^T V_{k}^m(t) dt\right)} \right)\]

where a function of the integral of voltage is calculated once per-trial.

Parameters:
  • example_timesteps (int) – How many timestamps each example will be presented to the network for

  • losses – Either a dictionary mapping loss functions to output populations or a single loss function to apply to all outputs

  • optimiser – Optimiser to use when applying weights

  • reg_lambda_upper (float) – Regularisation strength, should typically be the same as reg_lambda_lower.

  • reg_lambda_lower (float) – Regularisation strength, should typically be the same as reg_lambda_upper.

  • reg_nu_upper (float) – Target number of hidden neuron spikes used for regularisation

  • max_spikes (int) – What is the maximum number of spikes each neuron (input and hidden) can emit each trial? This is used to allocate memory for the backward pass.

  • strict_buffer_checking (bool) – For performance reasons, if neurons emit more than max_spikes they are normally ignored but, if this flag is set, this will cause an error.

  • per_timestep_loss (bool) – Should we use the per-timestep or per-trial loss functions described above?

  • dt (float) – Simulation timestep [ms]

  • TODO (softmax_temperature)

  • TODO

  • batch_size (int) – What batch size should be used for training? In our experience, EventProp works best with modest batch sizes (32-128)

  • rng_seed (int) – What value should GeNN’s GPU RNG be seeded with? This is used for all GPU randomness e.g. weight initialisation and Poisson spike train generation

  • kernel_profiling (bool) – Should GeNN record the time spent in each GPU kernel? These values can be extracted directly from the GeNN model which can be accessed via the genn_model property of the compiled model.

  • reset_time_between_batches – Should time be reset to zero at the start of each example or allowed to run continously?

  • communicator (Communicator) – Communicator used for inter-process communications when training across multiple GPUs.

  • delay_optimiser – Optimiser to use when applying delays. If None, optimiser will be used for delays

  • delay_learn_conns (Sequence) – Connection for which delays should be learned as well as weight

  • ttfs_alpha (float)

  • softmax_temperature (float)

Once either compiler has been constructed, it can be used to compile a network with:

compiled_net = compiler.compile(network)

and this can be trained on a dataset, prepared as described in Datasets, using a standard training loop with:

with compiled_net:
    # Evaluate model on numpy dataset
    metrics, _  = compiled_net.train({input: spikes},
                                     {output: labels},
                                      num_epochs=50, shuffle=True)

Like in Keras, additional logic such as checkpointing and recording of state variables can be added to the standard training loop using callbacks as described in the Callbacks and recording section.

Augmentation

Like when using Keras, sometimes, merely adding callbacks to the standard training loop is insufficient and you want to perform additional manual processing. One common case of this is augmentation where you want to modify the data being trained on each epoch. This can be implemented by manually looping over epochs and providing new data each time like this:

with compiled_net:
    for e in range(50):
        aug_spikes = augment(spikes)
        metrics, _  = compiled_net.train({input: aug_spikes},
                                         {output: labels},
                                         start_epoch=e, num_epochs=1,
                                         shuffle=True)

where augment is a function that returns an augmented version of a spike dataset (see Datasets).

Default parameters

Sadly the mathematical derivation of the different training algorithms makes different assumptions about the detailed implementation of various neuron models. For example, the e-prop learning rule assumes that neurons will have a ‘relative reset’ where the membrane voltage has a fixed value subtracted from it after a spike whereas EventProp assumes that the membrane voltage will be reset to a fixed value after a spike. To avoid users having to remember these details, mlGeNN compilers provide a dictionary of default parameters which can be passed to the constructors of Network and SequentialNetwork. For example here the e-prop defaults are applied to a sequential network and hence the leaky integrate-and-fire layer within it:

from ml_genn import Layer, SequentialNetwork
from ml_genn.neurons import LeakyIntegrateFire

from ml_genn.compilers.eprop_compiler import default_params

network = SequentialNetwork(default_params)
with network:
    ...
    hidden = Layer(Dense(1.0), LeakyIntegrateFire(v_thresh=0.61, tau_mem=20.0, tau_refrac=5.0), 128)