Shortcuts

ignite.engine#

class ignite.engine.Engine(process_function)[source]#

Runs a given process_function over each batch of a dataset, emitting events as it goes.

Parameters

process_function (callable) – A function receiving a handle to the engine and the current batch in each iteration, and returns data to be stored in the engine’s state.

Example usage:

def train_and_store_loss(engine, batch):
    inputs, targets = batch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss.item()

engine = Engine(train_and_store_loss)
engine.run(data_loader)

# Loss value is now stored in `engine.state.output`.
add_event_handler(event_name, handler, *args, **kwargs)[source]#

Add an event handler to be executed when the specified event is fired.

Parameters
  • event_name – An event to attach the handler to. Valid events are from Events or any event_name added by register_events().

  • handler (callable) – the callable event handler that should be invoked

  • *args – optional args to be passed to handler.

  • **kwargs – optional keyword args to be passed to handler.

Note

The handler function’s first argument will be self, the Engine object it was bound to.

Note that other arguments can be passed to the handler in addition to the *args and **kwargs passed here, for example during EXCEPTION_RAISED.

Returns

RemovableEventHandler, which can be used to remove the handler.

Example usage:

engine = Engine(process_function)

def print_epoch(engine):
    print("Epoch: {}".format(engine.state.epoch))

engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch)
fire_event(event_name)[source]#

Execute all the handlers associated with given event.

This method executes all handlers associated with the event event_name. This is the method used in run() to call the core events found in Events.

Custom events can be fired if they have been registered before with register_events(). The engine state attribute should be used to exchange “dynamic” data among process_function and handlers.

This method is called automatically for core events. If no custom events are used in the engine, there is no need for the user to call the method.

Parameters

event_name – event for which the handlers should be executed. Valid events are from Events or any event_name added by register_events().

has_event_handler(handler, event_name=None)[source]#

Check if the specified event has the specified handler.

Parameters
  • handler (callable) – the callable event handler.

  • event_name – The event the handler attached to. Set this to None to search all events.

on(event_name, *args, **kwargs)[source]#

Decorator shortcut for add_event_handler.

Parameters
  • event_name – An event to attach the handler to. Valid events are from Events or any event_name added by register_events().

  • *args – optional args to be passed to handler.

  • **kwargs – optional keyword args to be passed to handler.

register_events(*event_names, **kwargs)[source]#

Add events that can be fired.

Registering an event will let the user fire these events at any point. This opens the door to make the run() loop even more configurable.

By default, the events from Events are registered.

Parameters
  • *event_names – An object (ideally a string or int) to define the name of the event being supported.

  • event_to_attr (dict) – A dictionary to map an event to a state attribute.

Example usage:

from enum import Enum

class Custom_Events(Enum):
    FOO_EVENT = "foo_event"
    BAR_EVENT = "bar_event"

engine = Engine(process_function)
engine.register_events(*Custom_Events)

Example with State Attribute:

from enum import Enum

class TBPTT_Events(Enum):
    TIME_ITERATION_STARTED = "time_iteration_started"
    TIME_ITERATION_COMPLETED = "time_iteration_completed"

TBPTT_event_to_attr = {TBPTT_Events.TIME_ITERATION_STARTED: 'time_iteration',
                       TBPTT_Events.TIME_ITERATION_COMPLETED: 'time_iteration'}

engine = Engine(process_function)
engine.register_events(*TBPTT_Events, event_to_attr=TBPTT_event_to_attr)
engine.run(data)
# engine.state contains an attribute time_iteration, which can be accessed using engine.state.time_iteration
remove_event_handler(handler, event_name)[source]#

Remove event handler handler from registered handlers of the engine

Parameters
  • handler (callable) – the callable event handler that should be removed

  • event_name – The event the handler attached to.

run(data, max_epochs=1)[source]#

Runs the process_function over the passed data.

Parameters
  • data (Iterable) – Collection of batches allowing repeated iteration (e.g., list or DataLoader).

  • max_epochs (int, optional) – max epochs to run for (default: 1).

Returns

output state.

Return type

State

terminate()[source]#

Sends terminate signal to the engine, so that it terminates completely the run after the current iteration.

terminate_epoch()[source]#

Sends terminate signal to the engine, so that it terminates the current epoch after the current iteration.

ignite.engine.create_supervised_evaluator(model, metrics=None, device=None, non_blocking=False, prepare_batch=<function _prepare_batch>, output_transform=<function <lambda>>)[source]#

Factory function for creating an evaluator for supervised models.

Parameters
  • model (torch.nn.Module) – the model to train.

  • metrics (dict of str - Metric) – a map of metric names to Metrics.

  • device (str, optional) – device type specification (default: None). Applies to both model and batches.

  • non_blocking (bool, optional) – if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect.

  • prepare_batch (callable, optional) – function that receives batch, device, non_blocking and outputs tuple of tensors (batch_x, batch_y).

  • output_transform (callable, optional) – function that receives ‘x’, ‘y’, ‘y_pred’ and returns value to be assigned to engine’s state.output after each iteration. Default is returning (y_pred, y,) which fits output expected by metrics. If you change it you should use output_transform in metrics.

Note: engine.state.output for this engine is defind by output_transform parameter and is

a tuple of (batch_pred, batch_y) by default.

Returns

an evaluator engine with supervised inference function.

Return type

Engine

ignite.engine.create_supervised_trainer(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=<function _prepare_batch>, output_transform=<function <lambda>>)[source]#

Factory function for creating a trainer for supervised models.

Parameters
  • model (torch.nn.Module) – the model to train.

  • optimizer (torch.optim.Optimizer) – the optimizer to use.

  • loss_fn (torch.nn loss function) – the loss function to use.

  • device (str, optional) – device type specification (default: None). Applies to both model and batches.

  • non_blocking (bool, optional) – if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect.

  • prepare_batch (callable, optional) – function that receives batch, device, non_blocking and outputs tuple of tensors (batch_x, batch_y).

  • output_transform (callable, optional) – function that receives ‘x’, ‘y’, ‘y_pred’, ‘loss’ and returns value to be assigned to engine’s state.output after each iteration. Default is returning loss.item().

Note: engine.state.output for this engine is defind by output_transform parameter and is the loss

of the processed batch by default.

Returns

a trainer engine with supervised update function.

Return type

Engine

class ignite.engine.Events(value)[source]#

Events that are fired by the Engine during execution.

COMPLETED = 'completed'#
EPOCH_COMPLETED = 'epoch_completed'#
EPOCH_STARTED = 'epoch_started'#
EXCEPTION_RAISED = 'exception_raised'#
ITERATION_COMPLETED = 'iteration_completed'#
ITERATION_STARTED = 'iteration_started'#
STARTED = 'started'#
class ignite.engine.State(**kwargs)[source]#

An object that is used to pass internal and user-defined state between event handlers.

class ignite.engine.engine.RemovableEventHandle(event_name, handler, engine)[source]#

A weakref handle to remove a registered event.

A handle that may be used to remove a registered event handler via the remove method, with-statement, or context manager protocol. Returned from add_event_handler().

Parameters
  • event_name – Registered event name.

  • handler – Registered event handler, stored as weakref.

  • engine – Target engine, stored as weakref.

Example usage:

engine = Engine()

def print_epoch(engine):
    print("Epoch: {}".format(engine.state.epoch))

with engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch):
    # print_epoch handler registered for a single run
    engine.run(data)

# print_epoch handler is now unregistered
remove()[source]#

Remove handler from engine.