ml_genn.metrics package
Metrics are used for calculating the performance of models
based on some labels and the prediction obtained from a model
using a ml_genn.readouts.Readout
- class ml_genn.metrics.MeanSquareError
Bases:
Metric
Computes the mean squared error between labels and prediction
- reset()
Resets metric
- property result
Quantity calculated by metric
- update(y_true, y_pred, communicator)
Update metric based on a batch of true and predicted values.
- Parameters:
y_true (ndarray) – ‘true’ values provided to compiled network evaluate/train method
y_pred (ndarray) – predicted values provided by model readout
communicator (Communicator | None) – communicator to use to synchronise metrics across GPUs when doing multi-GPU training.
- class ml_genn.metrics.Metric
Bases:
ABC
Base class for all metrics
- abstract reset()
Resets metric
- abstract property result: ndarray | None
Quantity calculated by metric
- abstract update(y_true, y_pred, communicator)
Update metric based on a batch of true and predicted values.
- Parameters:
y_true (ndarray) – ‘true’ values provided to compiled network evaluate/train method
y_pred (ndarray) – predicted values provided by model readout
communicator (Communicator | None) – communicator to use to synchronise metrics across GPUs when doing multi-GPU training.
- class ml_genn.metrics.SparseCategoricalAccuracy
Bases:
Metric
Computes the crossentropy between labels and prediction when there are two or more label classes, specified as integers.
- reset()
Resets metric
- property result
Quantity calculated by metric
- update(y_true, y_pred, communicator)
Update metric based on a batch of true and predicted values.
- Parameters:
y_true (ndarray) – ‘true’ values provided to compiled network evaluate/train method
y_pred (ndarray) – predicted values provided by model readout
communicator (Communicator | None) – communicator to use to synchronise metrics across GPUs when doing multi-GPU training.