Skip to content

Commit

Permalink
Update Flamingo Builder to use Llama3ScaledRoPE instead of RotaryPosi…
Browse files Browse the repository at this point in the history
…tionalEmbeddings (pytorch#1202)

* update rope class

* remove old rope class
  • Loading branch information
Gasoonjia authored Sep 25, 2024
1 parent 021fd32 commit 0b8ca05
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from torchchat.model import Model, ModelArgs, ModelType

from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE

from torchchat.model_config.model_config import resolve_model_config
from torchchat.utils.build_utils import (
Expand Down Expand Up @@ -402,7 +402,7 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
max_seq_len = decoder_config['max_seq_len']
rope_base = decoder_config['rope_base']
for submodule in model.modules():
if isinstance(submodule, RotaryPositionalEmbeddings):
if isinstance(submodule, Llama3ScaledRoPE):
submodule.__init__(head_dim, max_seq_len, rope_base)
state_dict = flamingo_meta_to_tune(checkpoint)
model.model.load_state_dict(state_dict, assign=True, strict=False)
Expand Down

0 comments on commit 0b8ca05

Please sign in to comment.