diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 41ae3a5ba..e00e58de1 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -1245,8 +1245,10 @@ def forward( cfg = self.config if positions is not None and max_seq_len is not None: if max_seq_len != positions.shape[-1]: - raise ValueError("Both `positions` and `max_seq_len` are provided and they " - "do not match. You only need to provide one of them.") + raise ValueError( + "Both `positions` and `max_seq_len` are provided and they " + "do not match. You only need to provide one of them." + ) if positions is None: if max_seq_len is None: raise ValueError(