Inference
The inference functionality in mlGeNN allows you to run trained networks efficiently.
The compilers.InferenceCompiler
provides the core functionality for running inference on
trained networks:
- class ml_genn.compilers.InferenceCompiler(evaluate_timesteps, dt=1.0, batch_size=1, rng_seed=0, kernel_profiling=False, prefer_in_memory_connect=True, reset_time_between_batches=True, reset_vars_between_batches=True, reset_in_syn_between_batches=False, communicator=None, **genn_kwargs)
Compiler for performing inference on trained networks
- Parameters:
evaluate_timesteps (int) – How many timestamps each example will be presented to the network for
dt (float) – Simulation timestep [ms]
batch_size (int) – What batch size should be used for inference?
rng_seed (int) – What value should GeNN’s GPU RNG be seeded with? This is used for all GPU randomness e.g. weight initialisation and Poisson spike train generation
kernel_profiling (bool) – Should GeNN record the time spent in each GPU kernel? These values can be extracted directly from the GeNN model which can be accessed via the
genn_model
property of the compiled model.prefer_in_memory_connect – Should in-memory connectivity strategies such as TOEPLITZ be used rather than converting all connectivity into matrices.
reset_time_between_batches – Should time be reset to zero at the start of each example or allowed to run continously?
reset_vars_between_batches – Should neuron variables be reset to their initial values at the start of each example or allowed to run continously?
reset_in_syn_between_batches – Should synaptic input variables be reset to their initial values at the start of each example or allowed to run continously?
communicator (Communicator) – Communicator used for inter-process communications when training across multiple GPUs.
Here’s a basic example:
from ml_genn.compilers import InferenceCompiler
# Create compiler with metric
compiler = InferenceCompiler(
example_timesteps=500,
metrics="mean_square_error")
# Compile network
compiled_net = compiler.compile(network)
# Run evaluation
with compiled_net:
metrics, _ = compiled_net.evaluate(
{input_layer: x_test},
{output_layer: y_test})
print(f"Error: {metrics['mean_square_error']}")
The compilers.CompiledInferenceNetwork
objects produced by the
compilers.InferenceCompiler
provide methods for evaluation on
datasets specified as sequences (typically numpy arrays or lists of
utils.data.PreprocessedSpikes
: objects):
- CompiledInferenceNetwork.evaluate(x, y, metrics='sparse_categorical_accuracy', callbacks=[<ml_genn.callbacks.progress_bar.BatchProgressBar object>])
Evaluate metrics on a numpy dataset
- Parameters:
x (dict) – Dictionary of testing inputs
y (dict) – Dictionary of testing labels to compare predictions against
metrics (dict | Metric | str) – Metrics to calculate.
callbacks – List of callbacks to run during inference.
Alternatively, you can evaluate a model on a dataset iterator (such as a :
- CompiledInferenceNetwork.evaluate_batch_iter(inputs, outputs, data, num_batches=None, metrics='sparse_categorical_accuracy', callbacks=[<ml_genn.callbacks.progress_bar.BatchProgressBar object>])
Evaluate metrics on an iterator that provides batches of a dataset
- Parameters:
inputs – Input population(s)
outputs – Output population(s)
data (Iterator) – Iterator which produces batches of inputs and labels
num_batches (int | None) – Number of batches iterator will produce
metrics (dict | Metric | str) – Metrics to calculate.
callbacks – List of callbacks to run during inference.
Finally, raw predictions (i.e. the output of your model’s readouts) can be obtained on a dataset:
- CompiledInferenceNetwork.predict(x, outputs, callbacks=[<ml_genn.callbacks.progress_bar.BatchProgressBar object>])
Generate predictions from a numpy dataset
- Parameters:
x (dict) – Dictionary of testing inputs
outputs (Sequence | InputLayer | Layer | Population) – Output population(s) to extract predictions from
callbacks – List of callbacks to run during inference.
Like in Keras, additional logic such as checkpointing and recording of state variables can be added to any of these standard inference loops using callbacks as described in the Callbacks and recording section.