ignite.engine#
- class ignite.engine.Engine(process_function)[source]#
Runs a given process_function over each batch of a dataset, emitting events as it goes.
- Parameters
process_function (Callable) – A function receiving a handle to the engine and the current batch in each iteration, and returns data to be stored in the engine’s state
Example usage:
def train_and_store_loss(engine, batch): inputs, targets = batch optimizer.zero_grad() outputs = model(inputs) loss = loss_fn(outputs, targets) loss.backward() optimizer.step() return loss.item() engine = Engine(train_and_store_loss) engine.run(data_loader) # Loss value is now stored in `engine.state.output`.
- add_event_handler(event_name, handler, *args, **kwargs)[source]#
Add an event handler to be executed when the specified event is fired
- Parameters
event_name (Events) – event from ignite.engine.Events to attach the handler to
handler (Callable) – the callable event handler that should be invoked
*args – optional args to be passed to handler
**kwargs – optional keyword args to be passed to handler
Notes
The handler function’s first argument will be self, the Engine object it was bound to.
Note that other arguments can be passed to the handler in addition to the *args and **kwargs passed here, for example during Events.EXCEPTION_RAISED.
Example usage:
engine = Engine(process_function) def print_epoch(engine): print("Epoch: {}".format(engine.state.epoch)) engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch)
- on(event_name, *args, **kwargs)[source]#
Decorator shortcut for add_event_handler
- Parameters
event_name (Events) – event to attach the handler to
*args – optional args to be passed to handler
**kwargs – optional keyword args to be passed to handler
- ignite.engine.create_supervised_evaluator(model, metrics={}, device=None)[source]#
Factory function for creating an evaluator for supervised models
- Parameters
model (torch.nn.Module) – the model to train
metrics (dict of str -
ignite.metrics.Metric
) – a map of metric names to Metricsdevice (str, optional) – device type specification (default: None). Applies to both model and batches.
- Returns
an evaluator engine with supervised inference function
- Return type
- ignite.engine.create_supervised_trainer(model, optimizer, loss_fn, device=None)[source]#
Factory function for creating a trainer for supervised models
- Parameters
model (torch.nn.Module) – the model to train
optimizer (torch.optim.Optimizer) – the optimizer to use
loss_fn (torch.nn loss function) – the loss function to use
device (str, optional) – device type specification (default: None). Applies to both model and batches.
- Returns
a trainer engine with supervised update function
- Return type
- class ignite.engine.Events(value)[source]#
Events that are fired by the
ignite.engine.Engine
during execution- COMPLETED = 'completed'#
- EPOCH_COMPLETED = 'epoch_completed'#
- EPOCH_STARTED = 'epoch_started'#
- EXCEPTION_RAISED = 'exception_raised'#
- ITERATION_COMPLETED = 'iteration_completed'#
- ITERATION_STARTED = 'iteration_started'#
- STARTED = 'started'#