Skip to content

Commit

Permalink
Fix how the denominator is computed in fast autoregressive transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejwolczyk committed Jul 4, 2024
1 parent d3ed90f commit a70ceeb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion popgym/baselines/models/linear_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def forward(
# numerator = Q^T S
numerator = torch.einsum("bti, btil -> btl", Q, S)
# denominator = Q^T Z
denominator = torch.einsum("bti, btl -> bt", Q, Z).reshape(B, T, 1) + 1e-5
denominator = torch.einsum("bti, bti -> bt", Q, Z).reshape(B, T, 1) + 1e-5
# output = (Q^T S) / (Q^T Z)
output = numerator / denominator

Expand Down

0 comments on commit a70ceeb

Please sign in to comment.