ignite.contrib.engines#
Contribution module of engines
- class ignite.contrib.engines.Tbptt_Events(value)[source]#
- Aditional tbptt events. - Additional events for truncated backpropagation throught time dedicated trainer. - TIME_ITERATION_COMPLETED = 'time_iteration_completed'#
 - TIME_ITERATION_STARTED = 'time_iteration_started'#
 
- ignite.contrib.engines.create_supervised_tbptt_trainer(model, optimizer, loss_fn, tbtt_step, dim=0, device=None, non_blocking=False, prepare_batch=<function _prepare_batch>)[source]#
- Create a trainer for truncated backprop through time supervised models. - Training recurrent model on long sequences is computationally intensive as it requires to process the whole sequence before getting a gradient. However, when the training loss is computed over many outputs (X to many), there is an opportunity to compute a gradient over a subsequence. This is known as truncated backpropagation through time. This supervised trainer apply gradient optimization step every tbtt_step time steps of the sequence, while backpropagating through the same tbtt_step time steps. - Parameters
- model (torch.nn.Module) – the model to train 
- optimizer (torch.optim.Optimizer) – the optimizer to use 
- loss_fn (torch.nn loss function) – the loss function to use 
- tbtt_step (int) – the length of time chunks (last one may be smaller) 
- dim (int) – axis representing the time dimension 
- device (str, optional) – device type specification (default: None). Applies to both model and batches. 
- non_blocking (bool, optional) – if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. 
- prepare_batch (Callable, optional) – function that receives batch, device, non_blocking and outputs tuple of tensors (batch_x, batch_y). 
 
- Returns
- a trainer engine with supervised update function 
- Return type