Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xm.all_reduce groups argument is ignored when lowered to HLO #1107

Open
michaelbenayoun opened this issue Feb 5, 2025 · 4 comments
Open

Comments

@michaelbenayoun
Copy link

Environment

Trainium 32 instance

torch==2.1.2
torch-neuronx==2.1.2.2.3.0
torch-xla==2.1.4

Problem statement

While working a bug on the training loss, I noticed an issue with xm.all_reduce. I use this function to reduce the losses on the different DP ranks:

tr_loss_div = tr_loss / dp_size
reduced_tr_loss = xm.all_reduce(xm.REDUCE_SUM, tr_loss_div, groups=get_data_parallel_group(as_list=True))

The loss is way too big, it seems it is scaling with the TP size, which should not happen.

So I have made a small code to reproduce what I think could be a bug:

import os
import torch
import torch_xla.core.xla_model as xm

import torch.distributed as dist
import torch_neuronx

torch_neuronx.xla.configure_pjrt_environment()

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group(backend="xla")

group = dist.new_group(ranks=[rank], backend="xla")

tensor = torch.ones(2).to(xm.xla_device()) * rank
xm.mark_step()

result = xm.all_reduce("sum", tensor, groups=[[i] for i in range(world_size)])
xm.mark_step()
xm.master_print("Result", result)

I then run the code with the following command:

torchrun --nproc_per_node=8 all_reduce_bug.py

The output log is:

2025-02-05 11:24:33.000445:  232433  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_3192192757194213551+e30acd3a/model.neff
2025-02-05 11:24:33.000445:  232436  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_3313400438540464710+e30acd3a/model.neff
2025-02-05 11:24:33.000445:  232437  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_3313400438540464710+e30acd3a/model.neff
2025-02-05 11:24:33.000453:  232433  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_10559991267090666722+e30acd3a/model.neff
2025-02-05 11:24:33.000453:  232436  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_10559991267090666722+e30acd3a/model.neff
2025-02-05 11:24:33.000454:  232437  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_10559991267090666722+e30acd3a/model.neff
2025-02-05 11:24:33.000460:  232439  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_3313400438540464710+e30acd3a/model.neff
2025-02-05 11:24:33.000461:  232434  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_3313400438540464710+e30acd3a/model.neff
2025-02-05 11:24:33.000470:  232439  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_10559991267090666722+e30acd3a/model.neff
2025-02-05 11:24:33.000470:  232434  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_10559991267090666722+e30acd3a/model.neff
2025-Feb-05 11:24:33.0688 232439:234943 [7] nccl_net_ofi_create_plugin:201 CCOM WARN NET/OFI Failed to initialize sendrecv protocol
2025-Feb-05 11:24:33.0695 232439:234943 [7] nccl_net_ofi_create_plugin:316 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
2025-Feb-05 11:24:33.0702 232439:234943 [7] nccl_net_ofi_init:139 CCOM WARN NET/OFI Initializing plugin failed
2025-Feb-05 11:24:33.0709 232439:234943 [7] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?
2025-Feb-05 11:24:33.0763 232437:234617 [5] nccl_net_ofi_create_plugin:201 CCOM WARN NET/OFI Failed to initialize sendrecv protocol
2025-Feb-05 11:24:33.0770 232437:234617 [5] nccl_net_ofi_create_plugin:316 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
2025-Feb-05 11:24:33.0778 232437:234617 [5] nccl_net_ofi_init:139 CCOM WARN NET/OFI Initializing plugin failed
2025-Feb-05 11:24:33.0784 232437:234617 [5] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?
2025-Feb-05 11:24:33.0864 232434:234974 [2] nccl_net_ofi_create_plugin:201 CCOM WARN NET/OFI Failed to initialize sendrecv protocol
2025-Feb-05 11:24:33.0865 232436:234582 [4] nccl_net_ofi_create_plugin:201 CCOM WARN NET/OFI Failed to initialize sendrecv protocol
2025-Feb-05 11:24:33.0865 232433:234559 [1] nccl_net_ofi_create_plugin:201 CCOM WARN NET/OFI Failed to initialize sendrecv protocol
2025-Feb-05 11:24:33.0865 232433:234559 [1] nccl_net_ofi_create_plugin:316 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
2025-Feb-05 11:24:33.0865 232433:234559 [1] nccl_net_ofi_init:139 CCOM WARN NET/OFI Initializing plugin failed
2025-Feb-05 11:24:33.0865 232433:234559 [1] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?
2025-Feb-05 11:24:33.0872 232434:234974 [2] nccl_net_ofi_create_plugin:316 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
2025-Feb-05 11:24:33.0879 232436:234582 [4] nccl_net_ofi_create_plugin:316 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
2025-Feb-05 11:24:33.0886 232434:234974 [2] nccl_net_ofi_init:139 CCOM WARN NET/OFI Initializing plugin failed
2025-Feb-05 11:24:33.0893 232436:234582 [4] nccl_net_ofi_init:139 CCOM WARN NET/OFI Initializing plugin failed
2025-Feb-05 11:24:33.0900 232434:234974 [2] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?
2025-Feb-05 11:24:33.0906 232436:234582 [4] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?
2025-02-05 11:24:41.000636:  232435  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_3313400438540464710+e30acd3a/model.neff
2025-02-05 11:24:41.000638:  232432  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_3027922214684590202+e30acd3a/model.neff
2025-02-05 11:24:41.000645:  232435  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_10559991267090666722+e30acd3a/model.neff
2025-02-05 11:24:41.000646:  232432  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_10559991267090666722+e30acd3a/model.neff
2025-02-05 11:24:41.000656:  232438  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_3313400438540464710+e30acd3a/model.neff
2025-02-05 11:24:41.000665:  232438  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_10559991267090666722+e30acd3a/model.neff
Result 2025-Feb-05 11:24:41.0926 232435:235260 [3] nccl_net_ofi_create_plugin:201 CCOM WARN NET/OFI Failed to initialize sendrecv protocol
2025-Feb-05 11:24:41.0933 232435:235260 [3] nccl_net_ofi_create_plugin:316 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
2025-Feb-05 11:24:41.0939 232432:235305 [0] nccl_net_ofi_create_plugin:201 CCOM WARN NET/OFI Failed to initialize sendrecv protocol
2025-Feb-05 11:24:41.0940 232435:235260 [3] nccl_net_ofi_init:139 CCOM WARN NET/OFI Initializing plugin failed
2025-Feb-05 11:24:41.0947 232432:235305 [0] nccl_net_ofi_create_plugin:316 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
2025-Feb-05 11:24:41.0954 232435:235260 [3] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?
2025-Feb-05 11:24:41.0961 232432:235305 [0] nccl_net_ofi_init:139 CCOM WARN NET/OFI Initializing plugin failed
2025-Feb-05 11:24:41.0974 232432:235305 [0] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?
2025-Feb-05 11:24:42.0036 232438:235516 [6] nccl_net_ofi_create_plugin:201 CCOM WARN NET/OFI Failed to initialize sendrecv protocol
2025-Feb-05 11:24:42.0043 232438:235516 [6] nccl_net_ofi_create_plugin:316 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
2025-Feb-05 11:24:42.0050 232438:235516 [6] nccl_net_ofi_init:139 CCOM WARN NET/OFI Initializing plugin failed
2025-Feb-05 11:24:42.0057 232438:235516 [6] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?
tensor([28., 28.], device='xla:0')

We would expect the output to be tensor([0., 0.], device='xla:0'), but we observe tensor([28., 28.], device='xla:0') instead.

When analyzing the HLO code associated to the reduction part we get:

HloModule SyncTensorsGraph.14, entry_computation_layout={(f32[], f32[2]{0})->(f32[2]{0})}

%AddComputation.6 (x.7: f32[], y.8: f32[]) -> f32[] {
  %x.7 = f32[] parameter(0)
  %y.8 = f32[] parameter(1)
  ROOT %add.9 = f32[] add(f32[] %x.7, f32[] %y.8)
}

ENTRY %SyncTensorsGraph.14 (p0.1: f32[], p1.2: f32[2]) -> (f32[2]) {
  %p1.2 = f32[2]{0} parameter(1), frontend_attributes={neff_input_names="input1"}
  %p0.1 = f32[] parameter(0), frontend_attributes={neff_input_names="input0"}
  %all-reduce.10 = (f32[2]{0}, f32[]) all-reduce(f32[2]{0} %p1.2, f32[] %p0.1), replica_groups={}, to_apply=%AddComputation.6
  %get-tuple-element.11 = f32[2]{0} get-tuple-element((f32[2]{0}, f32[]) %all-reduce.10), index=0
  ROOT %tuple.13 = (f32[2]{0}) tuple(f32[2]{0} %get-tuple-element.11), frontend_attributes={neff_output_names="output0"}
}

It seems that the groups argument of the xm.all_reduce function is completely ignored since we can see that replica_groups={} which explains why we see tensor([28., 28.], device='xla:0'): reduction happens on every rank.

It seems to be a torch_xla issue but I am posting here since I am observing that on a Trainium instance.

cc @aws-rhsoln @jeffhataws

@aws-satyajith
Copy link

Thanks @michaelbenayoun for filing the issue. We will take a look and get back to you.

@jeffhataws
Copy link
Contributor

jeffhataws commented Feb 10, 2025

I was not able to reproduce the issue. It seems having "torch_neuronx.xla.configure_pjrt_environment()" in the code was causing errors for me so I removed it. I tested the code without "torch_neuronx.xla.configure_pjrt_environment()" and see:

PT2.5:

(aws_neuron_venv) ubuntu@ip-10-3-190-82:~/ktest2/pytorch/examples/dp_bert_hf_pretrain$ pip list | grep neuron
libneuronxla                 2.1.335.0
neuronx-cc                   2.0.103529.0a0+aee70f1c
torch-neuronx                2.5.1.2.4.0
(aws_neuron_venv) ubuntu@ip-10-3-190-82:~/ktest2/pytorch/examples/dp_bert_hf_pretrain$ torchrun --nproc_per_node=8 ~/test_allreduce_hf.py 
...
tensor([0., 0.], device='xla:0')

PT2.1:

(aws_neuron_venv_pt21) ubuntu@ip-10-3-190-82:~$ pip list | grep neuron                                                                        
libneuronxla                  2.1.374.0
neuronx-cc                    2.16.345.0+69131dd3
torch-neuronx                 2.1.2.2.4.0
(aws_neuron_venv_pt21) ubuntu@ip-10-3-190-82:~$ torchrun --nproc_per_node=8 ~/test_allreduce_hf.py                                                                                       
...
tensor([0., 0.], device='xla:0')

@michaelbenayoun
Copy link
Author

I just tried again on a clean environment that I have created today on a new instance, without calling torch_neuronx.xla.configure_pjrt_environment(), and I still got my issue.

Environment:

(neuron) ➜  ~ pip list | grep neuron
libneuronxla             2.1.714.0
neuronx-cc               2.15.128.0+56dc5a86
neuronx-distributed      0.9.0
neuronx-hwm              2.11.0.2+e34678757
optimum-neuron           0.0.28.dev1         /home/ubuntu/optimum-neuron
torch-neuronx            2.1.2.2.3.0

Result:

Image

I see we do not have a matching environment so it could be it?

@jeffhataws
Copy link
Contributor

When I revert neuronx-cc to the version you have 2.15.128.0+56dc5a86 and was able to reproduce your bad result (tensor([28., 28.], device='xla:0')). This compiler version was from SDK 2.20.0. I also see the problem with compiler from the patches SDK 2.20.1 and 2.20.2. Please switch to using SDK 2.21 (neuronx-cc==2.16.345.0) where I see good result (tensor([0., 0.], device='xla:0')).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants