ignite.metrics#
- class ignite.metrics.BinaryAccuracy(output_transform=<function Metric.<lambda>>)[source]#
Calculates the binary accuracy.
update must receive output of the form (y_pred, y).
y_pred must be in the following shape (batch_size, …) and it’s elements must be between 0 and 1.
y must be in the following shape (batch_size, …)
- class ignite.metrics.CategoricalAccuracy(output_transform=<function Metric.<lambda>>)[source]#
Calculates the categorical accuracy.
update must receive output of the form (y_pred, y).
y_pred must be in the following shape (batch_size, num_categories, …)
y must be in the following shape (batch_size, …)
- class ignite.metrics.Loss(loss_fn, output_transform=<function Loss.<lambda>>)[source]#
Calculates the average loss according to the passed loss_fn.
loss_fn must return the average loss over all observations in the batch.
update must receive output of the form (y_pred, y).
- class ignite.metrics.MeanAbsoluteError(output_transform=<function Metric.<lambda>>)[source]#
Calculates the mean absolute error.
update must receive output of the form (y_pred, y).
- class ignite.metrics.MeanPairwiseDistance(p=2, eps=1e-06, output_transform=<function MeanPairwiseDistance.<lambda>>)[source]#
Calculates the mean pairwise distance.
update must receive output of the form (y_pred, y).
- class ignite.metrics.MeanSquaredError(output_transform=<function Metric.<lambda>>)[source]#
Calculates the mean squared error.
update must receive output of the form (y_pred, y).
- class ignite.metrics.Metric(output_transform=<function Metric.<lambda>>)[source]#
- abstract compute()[source]#
Computes the metric based on it’s accumulated state.
This is called at the end of each epoch.
- Returns
the actual quantity of interest
- Return type
Any
- Raises
NotComputableError – raised when the metric cannot be computed
- class ignite.metrics.Precision(average=False, output_transform=<function Precision.<lambda>>)[source]#
Calculates precision.
update must receive output of the form (y_pred, y).
If average is True, returns the unweighted average across all classes. Otherwise, returns a tensor with the precision for each class.
- class ignite.metrics.Recall(average=False, output_transform=<function Recall.<lambda>>)[source]#
Calculates recall.
update must receive output of the form (y_pred, y).
If average is True, returns the unweighted average across all classes. Otherwise, returns a tensor with the recall for each class.