ignite.metrics#
Metrics provide a way to compute various quantities of interest in an online fashion without having to store the entire output history of a model.
In practice a user needs to attach the metric instance to an engine. The metric value is then computed using the output of the engine’s process_function:
def process_function(engine, batch): # ... return y_pred, y engine = Engine(process_function) metric = Accuracy() metric.attach(engine, "accuracy")
If the engine’s output is not in the format y_pred, y, the user can use the output_transform argument to transform it:
def process_function(engine, batch): # ... return {'y_pred': y_pred, 'y_true': y, ...} engine = Engine(process_function) def output_transform(output): # `output` variable is returned by above `process_function` y_pred = output['y_pred'] y = output['y_true'] return y_pred, y # output format is according to `Accuracy` docs metric = Accuracy(output_transform=output_transform) metric.attach(engine, "accuracy")
Metrics could be combined together to form a new metric through arithmetics, for example:
precision = Precision(average=False) recall = Recall(average=False) F1 = precision * recall * 2 / (precision + recall)Note
This example computes F1 for each class separately, rather than the mean of F1 across class. To combine precision and recall to get F1 or other F metrics, we have to be careful that average=False, i.e. to use the unaveraged precision and recall, otherwise we will not be computing F metrics.
- class ignite.metrics.Accuracy(output_transform=<function _BaseClassification.<lambda>>)[source]#
Calculates the accuracy for binary and multiclass data - update must receive output of the form (y_pred, y). - y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …) - y must be in the following shape (batch_size, …)
In binary case, when y has 0 or 1 values, the elements of y_pred must be between 0 and 1.
- class ignite.metrics.BinaryAccuracy(*args, **kwargs)[source]#
Note: This metric is deprecated in favor of Accuracy.
- class ignite.metrics.CategoricalAccuracy(*args, **kwargs)[source]#
Note: This metric is deprecated in favor of Accuracy.
- class ignite.metrics.Loss(loss_fn, output_transform=<function Loss.<lambda>>, batch_size=<function Loss.<lambda>>)[source]#
Calculates the average loss according to the passed loss_fn.
- Parameters
loss_fn (callable) – a callable taking a prediction tensor, a target tensor, optionally other arguments, and returns the average loss over all observations in the batch.
output_transform (callable) – a callable that is used to transform the
ignite.engine.Engine
’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. The output is is expected to be a tuple (prediction, target) or (prediction, target, kwargs) where kwargs is a dictionary of extra keywords arguments.batch_size (callable) – a callable taking a target tensor that returns the first dimension size (usually the batch size).
- class ignite.metrics.MeanAbsoluteError(output_transform=<function Metric.<lambda>>)[source]#
Calculates the mean absolute error.
update must receive output of the form (y_pred, y).
- class ignite.metrics.MeanPairwiseDistance(p=2, eps=1e-06, output_transform=<function MeanPairwiseDistance.<lambda>>)[source]#
Calculates the mean pairwise distance.
update must receive output of the form (y_pred, y).
- class ignite.metrics.MeanSquaredError(output_transform=<function Metric.<lambda>>)[source]#
Calculates the mean squared error.
update must receive output of the form (y_pred, y).
- class ignite.metrics.Metric(output_transform=<function Metric.<lambda>>)[source]#
Base class for all Metrics.
- Parameters
output_transform (callable, optional) – a callable that is used to transform the
ignite.engine.Engine
’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.
- abstract compute()[source]#
Computes the metric based on it’s accumulated state.
This is called at the end of each epoch.
- Returns
the actual quantity of interest
- Return type
Any
- Raises
NotComputableError – raised when the metric cannot be computed
- class ignite.metrics.Precision(output_transform=<function _BasePrecisionRecall.<lambda>>, average=False)[source]#
Calculates precision for binary and multiclass data - update must receive output of the form (y_pred, y). - y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …) - y must be in the following shape (batch_size, …)
In binary case, when y has 0 or 1 values, the elements of y_pred must be between 0 and 1. Precision is computed over positive class, assumed to be 1.
- Parameters
average (bool, optional) – if True, precision is computed as the unweighted average (across all classes in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case).
- class ignite.metrics.Recall(output_transform=<function _BasePrecisionRecall.<lambda>>, average=False)[source]#
Calculates recall for binary and multiclass data - update must receive output of the form (y_pred, y). - y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …) - y must be in the following shape (batch_size, …)
In binary case, when y has 0 or 1 values, the elements of y_pred must be between 0 and 1. Recall is computed over positive class, assumed to be 1.
- Parameters
average (bool, optional) – if True, recall is computed as the unweighted average (across all classes in multiclass case), otherwise, returns a tensor with the recall (for each class in multiclass case).
- class ignite.metrics.RootMeanSquaredError(output_transform=<function Metric.<lambda>>)[source]#
Calculates the root mean squared error.
update must receive output of the form (y_pred, y).
- class ignite.metrics.TopKCategoricalAccuracy(k=5, output_transform=<function TopKCategoricalAccuracy.<lambda>>)[source]#
Calculates the top-k categorical accuracy.
update must receive output of the form (y_pred, y).
- class ignite.metrics.EpochMetric(compute_fn, output_transform=<function EpochMetric.<lambda>>)[source]#
Class for metrics that should be computed on the entire output history of a model. Model’s output and targets are restricted to be of shape (batch_size, n_classes). Output datatype should be float32. Target datatype should be long.
Warning
Current implementation stores all input data (output and target) in as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM.
update must receive output of the form (y_pred, y).
If target shape is (batch_size, n_classes) and n_classes > 1 than it should be binary: e.g. [[0, 1, 0, 1], ]
- Parameters
compute_fn (callable) – a callable with the signature (torch.tensor, torch.tensor) takes as the input predictions and targets and returns a scalar.
output_transform (callable, optional) – a callable that is used to transform the
ignite.engine.Engine
’s process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.
- class ignite.metrics.RunningAverage(src=None, alpha=0.98, output_transform=None)[source]#
Compute running average of a metric or the output of process function.
- Parameters
src (Metric or None) – input source: an instance of
ignite.metrics.Metric
or None. The latter corresponds to engine.state.output which holds the output of process function.alpha (float, optional) – running average decay factor, default 0.98
output_transform (Callable, optional) – a function to use to transform the output if src is None and corresponds the output of process function. Otherwise it should be None.
Examples:
alpha = 0.98 acc_metric = RunningAverage(CategoricalAccuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha) acc_metric.attach(trainer, 'running_avg_accuracy') avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha) avg_output.attach(trainer, 'running_avg_loss') @trainer.on(Events.ITERATION_COMPLETED) def log_running_avg_metrics(engine): print("running avg accuracy:", engine.state.metrics['running_avg_accuracy']) print("running avg loss:", engine.state.metrics['running_avg_loss'])
- class ignite.metrics.MetricsLambda(f, *args)[source]#
Apply a function to other metrics to obtain a new metric. The result of the new metric is defined to be the result of applying the function to the result of argument metrics.
When update, this metric does not recursively update the metrics it depends on. When reset, all its dependency metrics would be resetted. When attach, all its dependencies would be automatically attached.
- Parameters
f (callable) – the function that defines the computation
args (sequence) – Sequence of other metrics or something else that will be fed to
f
as arguments.
Examples
>>> precision = Precision(average=False) >>> recall = Recall(average=False) >>> def Fbeta(r, p, beta): >>> return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r)).item() >>> F1 = MetricsLambda(Fbeta, recall, precision, 1) >>> F2 = MetricsLambda(Fbeta, recall, precision, 2) >>> F3 = MetricsLambda(Fbeta, recall, precision, 3) >>> F4 = MetricsLambda(Fbeta, recall, precision, 4)