Source code for ignite.contrib.handlers.tqdm_logger
try:
from tqdm import tqdm
except ImportError:
raise RuntimeError("This contrib module requires tqdm to be installed")
from ignite.engine import Events
[docs]class ProgressBar:
"""
TQDM progress bar handler to log training progress and computed metrics.
Args:
persist (bool, optional): set to ``True`` to persist the progress bar after completion (default = ``False``)
bar_format (str, optional): Specify a custom bar string formatting. May impact performance.
[default: '{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]'].
Set to ``None`` to use ``tqdm`` default bar formatting: '{l_bar}{bar}{r_bar}', where
l_bar='{desc}: {percentage:3.0f}%|' and
r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, '
'{rate_fmt}{postfix}]'
Possible vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt,
percentage, rate, rate_fmt, rate_noinv, rate_noinv_fmt,
rate_inv, rate_inv_fmt, elapsed, remaining, desc, postfix.
Note that a trailing ": " is automatically removed after {desc}
if the latter is empty.
**tqdm_kwargs: kwargs passed to tqdm progress bar
Examples:
Simple progress bar
.. code-block:: python
trainer = create_supervised_trainer(model, optimizer, loss)
pbar = ProgressBar()
pbar.attach(trainer)
Attach metrics that already have been computed at `ITERATION_COMPLETED` (such as `RunningAverage`)
.. code-block:: python
trainer = create_supervised_trainer(model, optimizer, loss)
RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')
pbar = ProgressBar()
pbar.attach(trainer, ['loss'])
Directly attach the engine's output
.. code-block:: python
trainer = create_supervised_trainer(model, optimizer, loss)
pbar = ProgressBar()
pbar.attach(trainer, output_transform=lambda x: {'loss': x})
Note:
When adding attaching the progress bar to an engine, it is recommend that you replace
every print operation in the engine's handlers triggered every iteration with
``pbar.log_message`` to guarantee the correct format of the stdout.
"""
def __init__(self, persist=False,
bar_format='{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]',
**tqdm_kwargs):
self.pbar = None
self.persist = persist
self.bar_format = bar_format
self.tqdm_kwargs = tqdm_kwargs
def _reset(self, engine):
self.pbar = tqdm(
total=len(engine.state.dataloader),
leave=self.persist,
bar_format=self.bar_format,
**self.tqdm_kwargs
)
def _close(self, engine):
self.pbar.close()
self.pbar = None
def _update(self, engine, metric_names=None, output_transform=None):
if self.pbar is None:
self._reset(engine)
if 'desc' not in self.tqdm_kwargs:
self.pbar.set_description('Epoch [{}/{}]'.format(engine.state.epoch, engine.state.max_epochs))
metrics = {}
if metric_names is not None:
if not all(metric in engine.state.metrics for metric in metric_names):
self._close(engine)
raise KeyError("metrics not found in engine.state.metrics")
metrics.update({name: '{:.2e}'.format(engine.state.metrics[name]) for name in metric_names})
if output_transform is not None:
output_dict = output_transform(engine.state.output)
if not isinstance(output_dict, dict):
output_dict = {"output": output_dict}
metrics.update({name: '{:.2e}'.format(value) for name, value in output_dict.items()})
if metrics:
self.pbar.set_postfix(**metrics)
self.pbar.update()
[docs] @staticmethod
def log_message(message):
"""
Logs a message, preserving the progress bar correct output format
Args:
message (str): string you wish to log
"""
tqdm.write(message)
[docs] def attach(self, engine, metric_names=None, output_transform=None):
"""
Attaches the progress bar to an engine object
Args:
engine (Engine): engine object
metric_names (list, optional): list of the metrics names to log as the bar progresses
output_transform (Callable, optional): a function to select what you want to print from the engine's
output. This function may return either a dictionary with entries in the format of ``{name: value}``,
or a single scalar, which will be displayed with the default name `output`.
"""
if metric_names is not None and not isinstance(metric_names, list):
raise TypeError("metric_names should be a list, got {} instead".format(type(metric_names)))
if output_transform is not None and not callable(output_transform):
raise TypeError("output_transform should be a function, got {} instead"
.format(type(output_transform)))
engine.add_event_handler(Events.ITERATION_COMPLETED, self._update, metric_names, output_transform)
engine.add_event_handler(Events.EPOCH_COMPLETED, self._close)