diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 55c8a2fcf2..f07c663e64 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1586,7 +1586,7 @@ def flash_attn_fwd_softmax_lse_correction( """Merge softmax stats of each step in Attention with context parallelism""" max_scale = torch.max(softmax_lse, softmax_lse_per_step) min_scale = torch.min(softmax_lse, softmax_lse_per_step) - new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale)) softmax_lse.copy_(new_scale)