Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to shared function to check torch version #5368

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 4 additions & 24 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .backend import *
from .comm import *
from ..runtime import compiler
from ..runtime.utils import required_torch_version
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Evaluate moving this to utils to resolve the circular import issues.

import os

DS_COMM_ALL_GATHER_OFF = False
Expand All @@ -18,40 +19,19 @@
DS_COMM_REDUCE_OFF = False


def is_torch_ver_eq_2_0():
TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[:2])
if TORCH_MAJOR == 2 and TORCH_MINOR == 0:
return True
return False


def is_torch_ver_ge_2_1():
TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[:2])
if TORCH_MAJOR >= 2 and TORCH_MINOR >= 1:
return True
return False


def torch_ver_ge_1_13():
TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[:2])
if TORCH_MAJOR >= 1 and TORCH_MINOR >= 13:
return True
return False


def has_coalescing_manager():
has_c10d = hasattr(torch.distributed, 'distributed_c10d')
return has_c10d and hasattr(torch.distributed.distributed_c10d, '_coalescing_manager')


def has_all_reduce_coalesced():
return hasattr(torch.distributed, "all_reduce_coalesced") and torch_ver_ge_1_13()
return hasattr(torch.distributed, "all_reduce_coalesced") and required_torch_version(min_version=1.13)


def get_coalescing_manager(group, device, reqs, async_op):
if is_torch_ver_eq_2_0():
if required_torch_version(min_version=2.0, max_version=2.0):
return torch.distributed.distributed_c10d._coalescing_manager(group, device=device, reqs=reqs)
elif is_torch_ver_ge_2_1():
elif required_torch_version(min_version=2.1):
return torch.distributed.distributed_c10d._coalescing_manager(group, device=device, async_ops=async_op)
else:
return torch.distributed.distributed_c10d._coalescing_manager(group, reqs)
Expand Down
Loading