Skip to content

Commit

Permalink
Fix encoder and encoderdecoder layernorm in PyTorch WMT model
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 12, 2023
1 parent d229b98 commit 76d0f44
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def __init__(self,
layer_norm_eps=layer_norm_eps,
attention_temp=attention_temp,
norm_first=norm_first)
encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
encoder_norm = (
nn.LayerNorm(d_model, eps=layer_norm_eps) if norm_first else None)
self.encoder = TransformerEncoder(encoder_layer, nlayers, encoder_norm)

def forward(self,
Expand Down Expand Up @@ -577,7 +578,8 @@ def __init__(self,
norm_first=norm_first) for _ in range(num_layers)
])
self.num_layers = num_layers
self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm = (
nn.LayerNorm(d_model, eps=layer_norm_eps) if norm_first else None)

def forward(self,
tgt: Tensor,
Expand Down

0 comments on commit 76d0f44

Please sign in to comment.