Skip to content

Commit

Permalink
Add the option to tie_word_embeddings (pytorch#1260)
Browse files Browse the repository at this point in the history
Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart authored Oct 9, 2024
1 parent 286527c commit 438ebb1
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ class TransformerArgs:
# Optional biases
attention_bias: bool = False
feed_forward_bias: bool = False
# Whether or not to tie the input word embeddings to the output
tie_word_embeddings: bool = False

def __post_init__(self):
if self.n_local_heads == -1:
Expand Down Expand Up @@ -632,12 +634,20 @@ def __init__(self, config: TransformerArgs) -> None:
if config.stage_idx == config.n_stages - 1:
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.output.weight = self.tok_embeddings.weight
else:
self.norm = None
self.output = None

self.max_batch_size = -1
self.max_seq_length = -1
self._register_load_state_dict_pre_hook(self.load_hook)

def load_hook(self, state_dict, prefix, *args):
"""Handle tied embeddings at load time"""
if self.config.tie_word_embeddings:
state_dict.setdefault("model.output.weight", state_dict["model.tok_embeddings.weight"])

def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
if (
Expand Down

0 comments on commit 438ebb1

Please sign in to comment.