diff --git a/dmlcloud/util/distributed.py b/dmlcloud/util/distributed.py index 4db2e65..9a10450 100644 --- a/dmlcloud/util/distributed.py +++ b/dmlcloud/util/distributed.py @@ -1,6 +1,7 @@ import os from contextlib import contextmanager +import torch import torch.distributed as dist from .tcp import find_free_port, get_local_ips @@ -87,7 +88,7 @@ def init_process_group_dummy(**kwargs): """ backend = kwargs.get('backend', None) if backend is None: - backend = 'cpu:gloo,cuda:nccl' if dist.is_nccl_available() else 'gloo' + backend = 'cpu:gloo,cuda:nccl' if dist.is_nccl_available() and torch.cuda.is_available() else 'gloo' store = dist.HashStore() dist.init_process_group(store=store, rank=0, world_size=1, backend=backend, **kwargs)