import socket
from functools import wraps
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
import torch
from ignite.distributed.comp_models import (
    _SerialModel,
    has_hvd_support,
    has_native_dist_support,
    has_xla_support,
    registered_computation_models,
)
from ignite.utils import setup_logger
__all__ = [
    "backend",
    "broadcast",
    "device",
    "available_backends",
    "model_name",
    "get_world_size",
    "get_rank",
    "get_local_rank",
    "get_nproc_per_node",
    "get_node_rank",
    "get_nnodes",
    "spawn",
    "initialize",
    "finalize",
    "show_config",
    "set_local_rank",
    "all_reduce",
    "all_gather",
    "barrier",
    "hostname",
    "has_xla_support",
    "has_native_dist_support",
    "has_hvd_support",
    "sync",
    "registered_computation_models",
    "one_rank_only",
]
_model = _SerialModel()
_need_to_sync = True
[docs]def sync(temporary: bool = False) -> None:
    """Helper method to force this module to synchronize with current distributed context.
    This method should be used when distributed context is manually created or destroyed.
    Args:
        temporary: If True, distributed model synchronization is done every call of ``idist.get_*`` methods.
            This may have a negative performance impact.
    """
    global _model
    for comp_model_cls in registered_computation_models:
        if comp_model_cls == _SerialModel:
            continue
        model = comp_model_cls.create_from_context()
        if model is not None:
            _set_model(model, temporary=temporary)
            return
    _model = _SerialModel() 
[docs]def device() -> torch.device:
    """Returns current device according to current distributed configuration.
    - `torch.device("cpu")` if no distributed configuration or torch native gloo distributed configuration
    - `torch.device("cuda:local_rank")` if torch native nccl or horovod distributed configuration
    - `torch.device("xla:index")` if XLA distributed configuration
    Returns:
        torch.device
    .. versionchanged:: 0.4.2
        Added Horovod distributed framework.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.device() 
[docs]def backend() -> Optional[str]:
    """Returns computation model's backend.
    - `None` for no distributed configuration
    - "nccl" or "gloo" or "mpi" for native torch distributed configuration
    - "xla-tpu" for XLA distributed configuration
    - "horovod" for Horovod distributed framework
    Returns:
        str or None
    .. versionchanged:: 0.4.2
        Added Horovod distributed framework.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.backend() 
[docs]def available_backends() -> Tuple[str, ...]:
    """Returns available backends."""
    out = ()  # type: Tuple[str, ...]
    for m in registered_computation_models:
        out += m.available_backends
    return out 
[docs]def model_name() -> str:
    """Returns distributed configuration name (given by ignite)
    - `serial` for no distributed configuration
    - `native-dist` for native torch distributed configuration
    - `xla-dist` for XLA distributed configuration
    - `horovod-dist` for Horovod distributed framework
    .. versionchanged:: 0.4.2
        `horovod-dist` will be returned for Horovod distributed framework.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.name 
[docs]def get_world_size() -> int:
    """Returns world size of current distributed configuration. Returns 1 if no distributed configuration."""
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.get_world_size() 
[docs]def get_rank() -> int:
    """Returns process rank within current distributed configuration. Returns 0 if no distributed configuration."""
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.get_rank() 
[docs]def get_local_rank() -> int:
    """Returns local process rank within current distributed configuration.
    Returns 0 if no distributed configuration."""
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.get_local_rank() 
[docs]def get_nproc_per_node() -> int:
    """Returns number of processes (or tasks) per node within current distributed configuration.
    Returns 1 if no distributed configuration.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.get_nproc_per_node() 
[docs]def get_nnodes() -> int:
    """Returns number of nodes within current distributed configuration.
    Returns 1 if no distributed configuration.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.get_nnodes() 
[docs]def get_node_rank() -> int:
    """Returns node rank within current distributed configuration.
    Returns 0 if no distributed configuration.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.get_node_rank() 
[docs]def hostname() -> str:
    """Returns host name for current process within current distributed configuration."""
    return socket.gethostname() 
[docs]def spawn(
    backend: str,
    fn: Callable,
    args: Tuple,
    kwargs_dict: Optional[Mapping] = None,
    nproc_per_node: int = 1,
    **kwargs: Any,
) -> None:
    """Spawns ``nproc_per_node`` processes that run ``fn`` with ``args``/``kwargs_dict`` and initialize
    distributed configuration defined by ``backend``.
    Args:
        backend: backend to use: `nccl`, `gloo`, `xla-tpu`, `horovod`
        fn: function to called as the entrypoint of the spawned process.
            This function must be defined at the top level of a module so it can be pickled and spawned.
            This is a requirement imposed by multiprocessing. The function is called as ``fn(i, *args, **kwargs_dict)``,
            where `i` is the process index and args is the passed through tuple of arguments.
        args: arguments passed to `fn`.
        kwargs_dict: kwargs passed to `fn`.
        nproc_per_node: number of processes to spawn on a single node. Default, 1.
        kwargs: acceptable kwargs according to provided backend:
            - | "nccl" or "gloo" : ``nnodes`` (default, 1), ``node_rank`` (default, 0), ``master_addr``
              | (default, "127.0.0.1"), ``master_port`` (default, 2222), ``init_method`` (default, "env://"),
              | `timeout` to `dist.init_process_group`_ function
              | and kwargs for `mp.start_processes`_ function.
            - | "xla-tpu" : ``nnodes`` (default, 1), ``node_rank`` (default, 0) and kwargs to `xmp.spawn`_ function.
            - | "horovod": ``hosts`` (default, None) and other kwargs to `hvd_run`_ function. Arguments ``nnodes=1``
              | and ``node_rank=0`` are tolerated and ignored, otherwise an exception is raised.
    Examples:
        1) Launch single node multi-GPU training using torch native distributed framework
        .. code-block:: python
            # >>> python main.py
            # main.py
            import ignite.distributed as idist
            def train_fn(local_rank, a, b, c, d=12):
                import torch.distributed as dist
                assert dist.is_available() and dist.is_initialized()
                assert dist.get_world_size() == 4
                device = idist.device()
                assert device == torch.device(f"cuda:{local_rank}")
            idist.spawn("nccl", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=4)
        2) Launch multi-node multi-GPU training using torch native distributed framework
        .. code-block:: python
            # >>> (node 0): python main.py --node_rank=0 --nnodes=8 --master_addr=master --master_port=2222
            # >>> (node 1): python main.py --node_rank=1 --nnodes=8 --master_addr=master --master_port=2222
            # >>> ...
            # >>> (node 7): python main.py --node_rank=7 --nnodes=8 --master_addr=master --master_port=2222
            # main.py
            import torch
            import ignite.distributed as idist
            def train_fn(local_rank, nnodes, nproc_per_node):
                import torch.distributed as dist
                assert dist.is_available() and dist.is_initialized()
                assert dist.get_world_size() == nnodes * nproc_per_node
                device = idist.device()
                assert device == torch.device(f"cuda:{local_rank}")
            idist.spawn(
                "nccl",
                train_fn,
                args=(nnodes, nproc_per_node),
                nproc_per_node=nproc_per_node,
                nnodes=nnodes,
                node_rank=node_rank,
                master_addr=master_addr,
                master_port=master_port
            )
        3) Launch single node multi-TPU training (for example on Google Colab) using PyTorch/XLA
        .. code-block:: python
            # >>> python main.py
            # main.py
            import ignite.distributed as idist
            def train_fn(local_rank, a, b, c, d=12):
                import torch_xla.core.xla_model as xm
                assert xm.get_world_size() == 8
                device = idist.device()
                assert "xla" in device.type
            idist.spawn("xla-tpu", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=8)
    .. _dist.init_process_group: https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
    .. _mp.start_processes: https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn
    .. _xmp.spawn: http://pytorch.org/xla/release/1.6/index.html#torch_xla.distributed.xla_multiprocessing.spawn
    .. _hvd_run: https://horovod.readthedocs.io/en/latest/api.html#module-horovod.run
    .. versionchanged:: 0.4.2
        ``backend`` now accepts `horovod` distributed framework.
    """
    _assert_backend(backend)
    if kwargs_dict is None:
        kwargs_dict = {}
    for comp_model_cls in registered_computation_models:
        if backend not in comp_model_cls.available_backends:
            continue
        comp_model_cls.spawn(
            fn, args=args, kwargs_dict=kwargs_dict, nproc_per_node=nproc_per_node, backend=backend, **kwargs
        ) 
[docs]def all_reduce(tensor: Union[torch.Tensor, float], op: str = "SUM") -> Union[torch.Tensor, float]:
    """Helper method to perform all reduce operation.
    Args:
        tensor: tensor or number to collect across participating processes.
        op: reduction operation, "SUM" by default. Possible values: "SUM", "PRODUCT", "MIN", "MAX", "AND", "OR".
            Horovod backend supports only "SUM", "AVERAGE", "ADASUM", "MIN", "MAX", "PRODUCT".
    Returns:
        torch.Tensor or number
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.all_reduce(tensor, op) 
[docs]def all_gather(tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
    """Helper method to perform all gather operation.
    Args:
        tensor: tensor or number or str to collect across participating processes.
    Returns:
        torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` if input is a tensor or
        torch.Tensor of shape ``(world_size, )`` if input is a number or
        List of strings if input is a string
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.all_gather(tensor) 
[docs]def broadcast(
    tensor: Union[torch.Tensor, float, str, None], src: int = 0, safe_mode: bool = False
) -> Union[torch.Tensor, float, str]:
    """Helper method to perform broadcast operation.
    Args:
        tensor: tensor or number or str to broadcast to participating processes.
            Make sure to respect data type of torch tensor input for all processes, otherwise execution will crash.
            Can use None for non-source data with ``safe_mode=True``.
        src: source rank. Default, 0.
        safe_mode: if True, non source input data can be ``None`` or anything (will be discarded), otherwise data
            type of the input ``tensor`` should be respected for all processes. Please, keep in mind, this mode is
            working only for dense tensors as source input if a tensor is provided. There are additional collective
            ops are performed before doing the broadcast and, thus, can be slower than without using this mode.
            Default, False.
    Returns:
        torch.Tensor or string or number
    Examples:
        .. code-block:: python
            y = None
            if idist.get_rank() == 0:
                t1 = torch.rand(4, 5, 6, device=idist.device())
                s1 = "abc"
                x = 12.3456
                y = torch.rand(1, 2, 3, device=idist.device())
            else:
                t1 = torch.empty(4, 5, 6, device=idist.device())
                s1 = ""
                x = 0.0
            # Broadcast tensor t1 from rank 0 to all processes
            t1 = idist.broadcast(t1, src=0)
            assert isinstance(t1, torch.Tensor)
            # Broadcast string s1 from rank 0 to all processes
            s1 = idist.broadcast(s1, src=0)
            # >>> s1 = "abc"
            # Broadcast float number x from rank 0 to all processes
            x = idist.broadcast(x, src=0)
            # >>> x = 12.3456
            # Broadcast any of those types from rank 0,
            # but other ranks do not define the placeholder
            y = idist.broadcast(y, src=0, safe_mode=True)
            assert isinstance(y, torch.Tensor)
    .. versionadded:: 0.4.2
    .. versionchanged:: 0.4.5
        added ``safe_mode``
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.broadcast(tensor, src=src, safe_mode=safe_mode) 
[docs]def barrier() -> None:
    """Helper method to synchronize all processes."""
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    _model.barrier() 
[docs]def set_local_rank(index: int) -> None:
    """Method to hint the local rank in case if torch native distributed context is created by user
    without using :meth:`~ignite.distributed.utils.initialize` or :meth:`~ignite.distributed.utils.spawn`.
    Args:
        index: local rank or current process index
    Examples:
        User set up torch native distributed process group
        .. code-block:: python
            import ignite.distributed as idist
            def run(local_rank, *args, **kwargs):
                idist.set_local_rank(local_rank)
                # ...
                dist.init_process_group(**dist_info)
                # ...
    """
    from ignite.distributed.comp_models.base import ComputationModel
    ComputationModel._ext_local_rank = index 
def _set_model(model: Any, temporary: bool = False) -> None:
    global _model, _need_to_sync
    _model = model
    _need_to_sync = True
    if not isinstance(_model, _SerialModel) and not temporary:
        _need_to_sync = False
def _assert_backend(backend: str) -> None:
    backends = available_backends()
    if backend not in backends:
        raise ValueError(f"Backend should be one of '{backends}'")
[docs]def initialize(backend: str, **kwargs: Any) -> None:
    """Initializes distributed configuration according to provided ``backend``
    Args:
        backend: backend: `nccl`, `gloo`, `xla-tpu`, `horovod`.
        kwargs: acceptable kwargs according to provided backend:
            - | "nccl" or "gloo" : ``timeout(=timedelta(minutes=30))``, ``init_method(=None)``,
              | ``rank(=None)``, ``world_size(=None)``.
              | By default, ``init_method`` will be "env://". See more info about parameters: `torch_init`_.
            - | "horovod" : comm(=None), more info: `hvd_init`_.
    Examples:
        Launch single node multi-GPU training with ``torchrun`` utility.
        .. code-block:: python
            # >>> torchrun --nproc_per_node=4 main.py
            # main.py
            import ignite.distributed as idist
            def train_fn(local_rank, a, b, c):
                import torch.distributed as dist
                assert dist.is_available() and dist.is_initialized()
                assert dist.get_world_size() == 4
                device = idist.device()
                assert device == torch.device(f"cuda:{local_rank}")
            backend = "nccl"  # or "gloo" or "horovod" or "xla-tpu"
            idist.initialize(backend)
            # or for torch native distributed on Windows:
            # idist.initialize("nccl", init_method="file://tmp/shared")
            local_rank = idist.get_local_rank()
            train_fn(local_rank, a, b, c)
            idist.finalize()
    .. _torch_init: https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
    .. _hvd_init: https://horovod.readthedocs.io/en/latest/api.html#module-horovod.torch
    .. versionchanged:: 0.4.2
        ``backend`` now accepts `horovod` distributed framework.
    .. versionchanged:: 0.4.5
        ``kwargs`` now accepts ``init_method``, ``rank``, ``world_size`` for PyTorch native distributed backend.
    """
    if not (has_xla_support or has_native_dist_support or has_hvd_support):
        # nothing to do => serial model
        # maybe warn about this
        return
    _assert_backend(backend)
    for comp_model_cls in registered_computation_models:
        if backend not in comp_model_cls.available_backends:
            continue
        _set_model(comp_model_cls(backend, **kwargs)) 
[docs]def finalize() -> None:
    """Finalizes distributed configuration. For example, in case of native pytorch distributed configuration,
    it calls ``dist.destroy_process_group()``.
    """
    _model.finalize()
    _set_model(_SerialModel()) 
[docs]def show_config() -> None:
    """Helper method to display distributed configuration via ``logging``."""
    # setup parallel logger
    logger = setup_logger(__name__)
    logger.info(f"distributed configuration: {model_name()}")
    logger.info(f"backend: {backend()}")
    logger.info(f"device: {device().type}")
    logger.info(f"hostname: {hostname()}")
    logger.info(f"world size: {get_world_size()}")
    logger.info(f"rank: {get_rank()}")
    logger.info(f"local rank: {get_local_rank()}")
    logger.info(f"num processes per_node: {get_nproc_per_node()}")
    logger.info(f"num nodes: {get_nnodes()}")
    logger.info(f"node rank: {get_node_rank()}") 
[docs]def one_rank_only(rank: int = 0, with_barrier: bool = False) -> Callable:
    """Decorator to filter handlers wrt a rank number
    Args:
        rank: rank number of the handler (default: 0).
        with_barrier: synchronisation with a barrier (default: False).
    Examples:
        .. code-block:: python
            engine = ...
            @engine.on(...)
            @one_rank_only() # means @one_rank_only(rank=0)
            def some_handler(_):
                ...
            @engine.on(...)
            @one_rank_only(rank=1)
            def some_handler(_):
                ...
    """
    def _one_rank_only(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Optional[Any]:
            ret = None
            if get_rank() == rank:
                ret = func(*args, **kwargs)
            if with_barrier:
                barrier()
            return ret
        return wrapper
    return _one_rank_only