Shortcuts

ignite.metrics#

class ignite.metrics.BinaryAccuracy(output_transform=<function Metric.<lambda>>)[source]#

Calculates the binary accuracy.

  • update must receive output of the form (y_pred, y).

  • y_pred must be in the following shape (batch_size, …) and it’s elements must be between 0 and 1.

  • y must be in the following shape (batch_size, …)

class ignite.metrics.CategoricalAccuracy(output_transform=<function Metric.<lambda>>)[source]#

Calculates the categorical accuracy.

  • update must receive output of the form (y_pred, y).

  • y_pred must be in the following shape (batch_size, num_categories, …)

  • y must be in the following shape (batch_size, …)

class ignite.metrics.Loss(loss_fn, output_transform=<function Loss.<lambda>>)[source]#

Calculates the average loss according to the passed loss_fn.

  • loss_fn must return the average loss over all observations in the batch.

  • update must receive output of the form (y_pred, y).

class ignite.metrics.MeanAbsoluteError(output_transform=<function Metric.<lambda>>)[source]#

Calculates the mean absolute error.

  • update must receive output of the form (y_pred, y).

class ignite.metrics.MeanPairwiseDistance(p=2, eps=1e-06, output_transform=<function MeanPairwiseDistance.<lambda>>)[source]#

Calculates the mean pairwise distance.

  • update must receive output of the form (y_pred, y).

class ignite.metrics.MeanSquaredError(output_transform=<function Metric.<lambda>>)[source]#

Calculates the mean squared error.

  • update must receive output of the form (y_pred, y).

class ignite.metrics.Metric(output_transform=<function Metric.<lambda>>)[source]#
abstract compute()[source]#

Computes the metric based on it’s accumulated state.

This is called at the end of each epoch.

Returns

the actual quantity of interest

Return type

Any

Raises

NotComputableError – raised when the metric cannot be computed

abstract reset()[source]#

Resets the metric to to it’s initial state.

This is called at the start of each epoch.

abstract update(output)[source]#

Updates the metric’s state using the passed batch output.

This is called once for each batch.

Parameters

output – the is the output from the engine’s process function

class ignite.metrics.Precision(average=False, output_transform=<function Precision.<lambda>>)[source]#

Calculates precision.

  • update must receive output of the form (y_pred, y).

If average is True, returns the unweighted average across all classes. Otherwise, returns a tensor with the precision for each class.

class ignite.metrics.Recall(average=False, output_transform=<function Recall.<lambda>>)[source]#

Calculates recall.

  • update must receive output of the form (y_pred, y).

If average is True, returns the unweighted average across all classes. Otherwise, returns a tensor with the recall for each class.

class ignite.metrics.RootMeanSquaredError(output_transform=<function Metric.<lambda>>)[source]#

Calculates the root mean squared error.

  • update must receive output of the form (y_pred, y).

class ignite.metrics.TopKCategoricalAccuracy(k=5, output_transform=<function TopKCategoricalAccuracy.<lambda>>)[source]#

Calculates the top-k categorical accuracy.

  • update must receive output of the form (y_pred, y).