ml_genn.callbacks package
Callbacks are used to run custom logic mid-simulation including for recording state.
- class ml_genn.callbacks.BatchProgressBar
Bases:
Callback
Callback to display a tqdm progress bar which gets updated every batch.
- on_batch_end(batch, metrics)
- on_epoch_begin(epoch)
- on_test_begin()
- on_test_end(metrics)
- on_train_begin()
- on_train_end(metrics)
- set_params(num_batches, **kwargs)
- class ml_genn.callbacks.Callback
Bases:
object
Base class for all callbacks
- class ml_genn.callbacks.Checkpoint(serialiser='numpy', epoch_interval=1)
Bases:
Callback
Callback which serialises network state after a specified number of epochs.
- Parameters:
serialiser (Serialiser | str) – Serialiser to use
epoch_interval (int) – After how many epochs should checkpoints be saved?
- on_epoch_end(epoch, metrics)
- set_params(compiled_network, **kwargs)
- class ml_genn.callbacks.ConnVarRecorder(conn, genn_var, key=None, example_filter=None, src_neuron_filter=None, trg_neuron_filter=None)
Bases:
Callback
Callback used for recording connection state variables during simulation. Variables are specified using GeNN state variable name. By convention,
g
is always the weight andd
is the per-synapse delay if it is used. Other variables are compiler-specific e.g.Gradient
accomulates gradients when usingml_genn.compilers.EventPropCompiler
.- Parameters:
conn (Connection) – Synapse population to record from
genn_var (str) – Internal name of variable to record
key – Key to assign recording data produced by this callback in dictionary returned by evaluation/training methods of compiled network
example_filter (Sequence | ndarray | integer | int | None) – Filter used to select which examples to record from (see Callbacks and recording for more information).
src_neuron_filter (Sequence | slice | ndarray | integer | int | None) – Filter used to select which synapses to record from (see Callbacks and recording for more information).
trg_neuron_filter (Sequence | slice | ndarray | integer | int | None) – Filter used to select which synapses to record from (see Callbacks and recording for more information).
- get_data()
- on_batch_begin(batch)
- Parameters:
batch (int)
- on_timestep_end(timestep)
- Parameters:
timestep (int)
- set_params(data, compiled_network, **kwargs)
- class ml_genn.callbacks.CustomUpdateOnBatchBegin(name)
Bases:
CustomUpdate
Callback that triggers a GeNN custom update at the beginning of every batch.
- Parameters:
name (str)
- on_batch_begin(batch)
- class ml_genn.callbacks.CustomUpdateOnBatchEnd(name)
Bases:
CustomUpdate
Callback that triggers a GeNN custom update at the end of every batch.
- Parameters:
name (str)
- on_batch_end(batch, metrics)
- class ml_genn.callbacks.CustomUpdateOnTimestepBegin(name)
Bases:
CustomUpdate
Callback that triggers a GeNN custom update at the beginning of every timestep.
- Parameters:
name (str)
- on_timestep_begin(timestep)
- class ml_genn.callbacks.CustomUpdateOnTimestepEnd(name)
Bases:
CustomUpdate
Callback that triggers a GeNN custom update at the end of every timestep.
- Parameters:
name (str)
- on_timestep_end(timestep)
- class ml_genn.callbacks.OptimiserParamSchedule(param_name, func)
Bases:
Callback
Callback which updates an parameter on an
optimisers.Optimiser
every epoch based on a callable.- Parameters:
param_name (str) – Name of parameter to update. Not all optimiser parameters can be changed at runtime
func (Callable[[int, Number], Number]) – Callable called every epoch to determine new parameter value
- on_epoch_begin(epoch)
- set_params(compiled_network, **kwargs)
- class ml_genn.callbacks.SpikeRecorder(pop, key=None, example_filter=None, neuron_filter=None, record_counts=False)
Bases:
Callback
Callback used for recording spikes during simulation.
- Parameters:
pop (InputLayer | Layer | Population) – Population to record from
example_filter (Sequence | ndarray | integer | int | None) – Filter used to select which examples to record from (see Callbacks and recording for more information).
neuron_filter (Sequence | slice | ndarray | integer | int | None) – Filter used to select which neurons to record from (see Callbacks and recording for more information).
record_counts (bool) – Should only the (per-neuron) spike count be recorded rather than all the spikes?
- get_data()
- on_batch_begin(batch)
- on_timestep_end(timestep)
- set_first()
- set_params(data, compiled_network, **kwargs)
- class ml_genn.callbacks.VarRecorder(pop, var=None, key=None, example_filter=None, neuron_filter=None, genn_var=None)
Bases:
Callback
Callback used for recording state variables during simulation. Variables can specified either by the name of the mlGeNN
ml_genn.utils.value.ValueDescriptor
class attribute corresponding to the variable e.g.v
for the membrane voltage of aml_genn.neurons.LeakyIntegrateFire
neuron or by the internal name of a GeNN state variable e.g.LambdaV
which is a state variable added to track gradients byml_genn.compilers.EventPropCompiler
.- Parameters:
pop (InputLayer | Layer | Population) – Population to record from
var (str | None) – Name of variable to record
key – Key to assign recording data produced by this callback in dictionary returned by evaluation/training methods of compiled network
example_filter (Sequence | ndarray | integer | int | None) – Filter used to select which examples to record from (see Callbacks and recording for more information).
neuron_filter (Sequence | slice | ndarray | integer | int | None) – Filter used to select which neurons to record from (see Callbacks and recording for more information).
genn_var (str | None) – Internal name of variable to record
- get_data()
- on_batch_begin(batch)
- Parameters:
batch (int)
- on_timestep_end(timestep)
- Parameters:
timestep (int)
- set_params(data, compiled_network, **kwargs)