Inference

The inference functionality in mlGeNN allows you to run trained networks efficiently. The compilers.InferenceCompiler provides the core functionality for running inference on trained networks:

class ml_genn.compilers.InferenceCompiler(evaluate_timesteps, dt=1.0, batch_size=1, rng_seed=0, kernel_profiling=False, prefer_in_memory_connect=True, reset_time_between_batches=True, reset_vars_between_batches=True, reset_in_syn_between_batches=False, communicator=None, **genn_kwargs)

Compiler for performing inference on trained networks

Parameters:
  • evaluate_timesteps (int) – How many timestamps each example will be presented to the network for

  • dt (float) – Simulation timestep [ms]

  • batch_size (int) – What batch size should be used for inference?

  • rng_seed (int) – What value should GeNN’s GPU RNG be seeded with? This is used for all GPU randomness e.g. weight initialisation and Poisson spike train generation

  • kernel_profiling (bool) – Should GeNN record the time spent in each GPU kernel? These values can be extracted directly from the GeNN model which can be accessed via the genn_model property of the compiled model.

  • prefer_in_memory_connect – Should in-memory connectivity strategies such as TOEPLITZ be used rather than converting all connectivity into matrices.

  • reset_time_between_batches – Should time be reset to zero at the start of each example or allowed to run continously?

  • reset_vars_between_batches – Should neuron variables be reset to their initial values at the start of each example or allowed to run continously?

  • reset_in_syn_between_batches – Should synaptic input variables be reset to their initial values at the start of each example or allowed to run continously?

  • communicator (Communicator) – Communicator used for inter-process communications when training across multiple GPUs.

Here’s a basic example:

from ml_genn.compilers import InferenceCompiler

# Create compiler with metric
compiler = InferenceCompiler(
    example_timesteps=500,
    metrics="mean_square_error")

# Compile network
compiled_net = compiler.compile(network)

# Run evaluation
with compiled_net:
    metrics, _ = compiled_net.evaluate(
        {input_layer: x_test},
        {output_layer: y_test})

    print(f"Error: {metrics['mean_square_error']}")

The compilers.CompiledInferenceNetwork objects produced by the compilers.InferenceCompiler provide methods for evaluation on datasets specified as sequences (typically numpy arrays or lists of utils.data.PreprocessedSpikes: objects):

CompiledInferenceNetwork.evaluate(x, y, metrics='sparse_categorical_accuracy', callbacks=[<ml_genn.callbacks.progress_bar.BatchProgressBar object>])

Evaluate metrics on a numpy dataset

Parameters:
  • x (dict) – Dictionary of testing inputs

  • y (dict) – Dictionary of testing labels to compare predictions against

  • metrics (dict | Metric | str) – Metrics to calculate.

  • callbacks – List of callbacks to run during inference.

Alternatively, you can evaluate a model on a dataset iterator (such as a :

CompiledInferenceNetwork.evaluate_batch_iter(inputs, outputs, data, num_batches=None, metrics='sparse_categorical_accuracy', callbacks=[<ml_genn.callbacks.progress_bar.BatchProgressBar object>])

Evaluate metrics on an iterator that provides batches of a dataset

Parameters:
  • inputs – Input population(s)

  • outputs – Output population(s)

  • data (Iterator) – Iterator which produces batches of inputs and labels

  • num_batches (int | None) – Number of batches iterator will produce

  • metrics (dict | Metric | str) – Metrics to calculate.

  • callbacks – List of callbacks to run during inference.

Finally, raw predictions (i.e. the output of your model’s readouts) can be obtained on a dataset:

CompiledInferenceNetwork.predict(x, outputs, callbacks=[<ml_genn.callbacks.progress_bar.BatchProgressBar object>])

Generate predictions from a numpy dataset

Parameters:
  • x (dict) – Dictionary of testing inputs

  • outputs (Sequence | InputLayer | Layer | Population) – Output population(s) to extract predictions from

  • callbacks – List of callbacks to run during inference.

Like in Keras, additional logic such as checkpointing and recording of state variables can be added to any of these standard inference loops using callbacks as described in the Callbacks and recording section.