Skip to content

Commit

Permalink
Update relative_transformer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yhcc authored Jul 6, 2020
1 parent 00a1894 commit d2614d5
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions modules/relative_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, d_model, n_head, dropout, r_w_bias=None, r_r_bias=None, scale
:param rel_pos_embed:
"""
super().__init__()
self.qv_linear = nn.Linear(d_model, d_model * 2, bias=False)
self.qkv_linear = nn.Linear(d_model, d_model * 3, bias=False)
self.n_head = n_head
self.head_dim = d_model // n_head
self.dropout_layer = nn.Dropout(dropout)
Expand Down Expand Up @@ -113,10 +113,10 @@ def forward(self, x, mask):
batch_size, max_len, d_model = x.size()
pos_embed = self.pos_embed(mask) # l x head_dim

qv = self.qv_linear(x) # batch_size x max_len x d_model2
q, v = torch.chunk(qv, chunks=2, dim=-1)
qkv = self.qkv_linear(x) # batch_size x max_len x d_model3
q, k, v = torch.chunk(qkv, chunks=3, dim=-1)
q = q.view(batch_size, max_len, self.n_head, -1).transpose(1, 2)
k = x.view(batch_size, max_len, self.n_head, -1).transpose(1, 2)
k = k.view(batch_size, max_len, self.n_head, -1).transpose(1, 2)
v = v.view(batch_size, max_len, self.n_head, -1).transpose(1, 2) # b x n x l x d

rw_head_q = q + self.r_r_bias[:, None]
Expand Down

0 comments on commit d2614d5

Please sign in to comment.