-
Notifications
You must be signed in to change notification settings - Fork 156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
xm.all_reduce
groups argument is ignored when lowered to HLO
#1107
Comments
Thanks @michaelbenayoun for filing the issue. We will take a look and get back to you. |
I was not able to reproduce the issue. It seems having "torch_neuronx.xla.configure_pjrt_environment()" in the code was causing errors for me so I removed it. I tested the code without "torch_neuronx.xla.configure_pjrt_environment()" and see: PT2.5:
PT2.1:
|
When I revert neuronx-cc to the version you have 2.15.128.0+56dc5a86 and was able to reproduce your bad result ( |
Environment
Trainium 32 instance
Problem statement
While working a bug on the training loss, I noticed an issue with
xm.all_reduce
. I use this function to reduce the losses on the different DP ranks:The loss is way too big, it seems it is scaling with the TP size, which should not happen.
So I have made a small code to reproduce what I think could be a bug:
I then run the code with the following command:
The output log is:
We would expect the output to be
tensor([0., 0.], device='xla:0')
, but we observetensor([28., 28.], device='xla:0')
instead.When analyzing the HLO code associated to the reduction part we get:
It seems that the
groups
argument of thexm.all_reduce
function is completely ignored since we can see thatreplica_groups={}
which explains why we seetensor([28., 28.], device='xla:0')
: reduction happens on every rank.It seems to be a
torch_xla
issue but I am posting here since I am observing that on a Trainium instance.cc @aws-rhsoln @jeffhataws
The text was updated successfully, but these errors were encountered: