Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tkurth/extended distributed primitives #273

Merged

Conversation

azrael417
Copy link
Contributor

Modulus Pull Request

Description

This PR enabled gathering of tensors of uneven shapes. This is necessary for integrating modulus into newer versions of makani. Some of the routines can be merged with the V-routines for the graph NN code. I haven't done that yet but I am happy to discuss this

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

No new dependencies necessary

@azrael417
Copy link
Contributor Author

Concerning the changelog, how does that need to be updated? Doesn't that also depend on what was merged in between this PR and other PR which came before this but after I forked the branch?

@azrael417 azrael417 requested a review from akshaysubr December 7, 2023 14:34
Signed-off-by: Thorsten Kurth <[email protected]>
Signed-off-by: Thorsten Kurth <[email protected]>
Signed-off-by: Thorsten Kurth <[email protected]>
Signed-off-by: Thorsten Kurth <[email protected]>
Signed-off-by: Thorsten Kurth <[email protected]>
Signed-off-by: Thorsten Kurth <[email protected]>
@azrael417 azrael417 force-pushed the tkurth/extended-distributed-primitives branch from d7bdb57 to f4fbd15 Compare December 7, 2023 14:45
Copy link
Collaborator

@akshaysubr akshaysubr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Only had a couple of relatively minor comments.

Would be good to also ensure that all the existing distributed tests pass locally by running pytest -m multigpu in the test/ directory since these tests are not covered by CI currently.

modulus/distributed/utils.py Show resolved Hide resolved
modulus/distributed/utils.py Show resolved Hide resolved
modulus/distributed/mappings.py Show resolved Hide resolved
@akshaysubr akshaysubr requested a review from stadlmax December 11, 2023 19:45
@akshaysubr akshaysubr added distributed Distributed and model parallel tools 4 - In Review Currently Under Review labels Dec 11, 2023
Copy link
Collaborator

@stadlmax stadlmax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot to approve this earlier. Thanks for addressing the comment w.r.t. unified utilities. LGTM now.

@NickGeneva
Copy link
Collaborator

/blossom-ci

@azrael417
Copy link
Contributor Author

azrael417 commented Dec 13, 2023

I ran the multrigpu test but ran into some issues. First, there is an assert that num_gpu == 2 (not >=2), so these tests fail on my dgxstation with 4 gpu. Can we relax that criterion a bit?

Working around it with cuda visible devices I can run some of the tests but the meshgraphnet one fails, but this is not related to this MR I think:

`models/meshgraphnet/test_meshgraphnet_snmg.py FFF [100%]

=================================== FAILURES ===================================
____________________ test_distributed_meshgraphnet[dtype0] _____________________

dtype = torch.float32

@pytest.mark.multigpu
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_distributed_meshgraphnet(dtype):
    num_gpus = torch.cuda.device_count()
    assert num_gpus >= 2, "Not enough GPUs available for test"
    world_size = num_gpus
  torch.multiprocessing.spawn(
        run_test_distributed_meshgraphnet,
        args=(world_size, dtype),
        nprocs=world_size,
        start_method="spawn",
    )

models/meshgraphnet/test_meshgraphnet_snmg.py:193:


../../.conda/envs/modulus/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:246: in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
../../.conda/envs/modulus/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:202: in start_processes
while not context.join():


self = <torch.multiprocessing.spawn.ProcessContext object at 0x7f7f41a258a0>
timeout = None

def join(self, timeout=None):
    r"""
    Tries to join one or more processes in this spawn context.
    If one of them exited with a non-zero exit status, this function
    kills the remaining processes and raises an exception with the cause
    of the first process exiting.

    Returns ``True`` if all processes have been joined successfully,
    ``False`` if there are more processes that need to be joined.

    Args:
        timeout (float): Wait this long before giving up on waiting.
    """
    # Ensure this function can be called even when we're done.
    if len(self.sentinels) == 0:
        return True

    # Wait for any process to fail or all of them to succeed.
    ready = multiprocessing.connection.wait(
        self.sentinels.keys(),
        timeout=timeout,
    )

    error_index = None
    for sentinel in ready:
        index = self.sentinels.pop(sentinel)
        process = self.processes[index]
        process.join()
        if process.exitcode != 0:
            error_index = index
            break

    # Return if there was no error.
    if error_index is None:
        # Return whether or not all processes have been joined.
        return len(self.sentinels) == 0

    # Assume failure. Terminate processes that are still alive.
    for process in self.processes:
        if process.is_alive():
            process.terminate()
        process.join()

`

These here look good:

test_multi_gpu_sample.py . [ 8%] distributed/test_autograd.py .... [ 41%] distributed/test_config.py . [ 50%] distributed/test_distributed_fft.py . [ 58%] distributed/test_manager.py .. [ 75%]

@akshaysubr
Copy link
Collaborator

Yeah, the meshgraphnet failure is not related to this MR and an independent issue. Created a separate issue to track that: #278

@akshaysubr
Copy link
Collaborator

/blossom-ci

@stadlmax
Copy link
Collaborator

I ran the multrigpu test but ran into some issues. First, there is an assert that num_gpu == 2 (not >=2), so these tests fail on my dgxstation with 4 gpu. Can we relax that criterion a bit? Working around it with cuda visible devices I can run some of the tests but the meshgraphnet one fails, but this is not related to this MR I think:

I actually started using things like >=2 and setting world_size = num_gpus in a few places where I changed things. @akshaysubr We really should get rid off the assert ... == 2 things eventually.

@azrael417 azrael417 merged commit fd80783 into NVIDIA:main Dec 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
4 - In Review Currently Under Review distributed Distributed and model parallel tools
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants