From 4f9ed099db24c1b1f7a96f60c5255011e6852bdf Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Fri, 10 Jan 2025 17:11:39 -0800 Subject: [PATCH] Use log1p(x) instead of log(1+x) This function is more accurate than torch.log() for small values of input - https://pytorch.org/docs/stable/generated/torch.log1p.html Found with TorchFix https://github.com/pytorch-labs/torchfix/ Signed-off-by: Sergii Dymchenko --- transformer_engine/pytorch/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 55c8a2fcf2..f07c663e64 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1586,7 +1586,7 @@ def flash_attn_fwd_softmax_lse_correction( """Merge softmax stats of each step in Attention with context parallelism""" max_scale = torch.max(softmax_lse, softmax_lse_per_step) min_scale = torch.min(softmax_lse, softmax_lse_per_step) - new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale)) softmax_lse.copy_(new_scale)