Skip to content

Commit

Permalink
#0: Fix dummy weight issue
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Jan 31, 2025
1 parent c22eb6a commit 62f1830
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions models/demos/llama3/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ def __init__(
self._set_model_params(self.DEFAULT_CKPT_DIR)
else: # With Dummy weights, set the params from the local copy inside the model folder. This is required for CI pipeline that doesn't mount the external folders.
self.checkpoint_type = CheckpointType.Meta
if "3.2-1B" in self.DEFAULT_CKPT_DIR:
local_params = "LLAMA3_2_1B_PARAMS"
elif "3.2-3B" in self.DEFAULT_CKPT_DIR:
local_params = "LLAMA3_2_3B_PARAMS"
elif "3.1-8B" in self.DEFAULT_CKPT_DIR:
local_params = "LLAMA3_1_8B_PARAMS"
elif "3.2-11B" in self.DEFAULT_CKPT_DIR:
local_params = "LLAMA3_2_11B_PARAMS"
elif "3.1-70B" in self.DEFAULT_CKPT_DIR:
local_params = "LLAMA3_1_70B_PARAMS"
else:
raise ValueError(
f"No local params found for {self.DEFAULT_CKPT_DIR}, dummy weights are not supported for this model"
)
self._set_model_params(self.LOCAL_LLAMA_PARAMS[local_params])

# Set the max number of tokens for each prefill chunk based on the model and device
Expand Down Expand Up @@ -1069,31 +1083,25 @@ def _set_llama_params(self, checkpoint_dir):
# Set the model name based on the checkpoint directory being loaded
# FIXME: add a llama prefix to all llama-specific models and names
if "3.2-1B" in checkpoint_dir:
local_params = "LLAMA3_2_1B_PARAMS"
self.model_name = "Llama3.2-1B" + "-Instruct" if self.instruct else ""
self.rope_scaling_factor = 32
elif "3.2-3B" in checkpoint_dir:
local_params = "LLAMA3_2_3B_PARAMS"
self.model_name = "Llama3.2-3B" + "-Instruct" if self.instruct else ""
self.rope_scaling_factor = 32
elif "3.1-8B" in checkpoint_dir:
local_params = "LLAMA3_1_8B_PARAMS"
self.model_name = "Llama3.1-8B" + "-Instruct" if self.instruct else ""
self.rope_scaling_factor = 8
elif "3.2-11B" in checkpoint_dir:
local_params = "LLAMA3_2_11B_PARAMS"
self.model_name = "Llama3.2-11B" + "-Instruct" if self.instruct else ""
self.rope_scaling_factor = 8 # shared with 3.1-8B
elif "3.1-70B" in checkpoint_dir:
local_params = "LLAMA3_1_70B_PARAMS"
self.model_name = "Llama3.1-70B" + "-Instruct" if self.instruct else ""
self.rope_scaling_factor = 8
self.is_70b = True # self.dim == 8192 and self.n_layers == 80
else:
local_params = "UNKNOWN"
self.model_name = "Unknown"
self.rope_scaling_factor = 4
logger.warning(f"Unknown model: {LLAMA_DIR}")
logger.warning(f"Unknown Meta-style model: {checkpoint_dir}")
self.orig_context_len = 8192

def _set_hf_params(self, checkpoint_dir):
Expand Down

0 comments on commit 62f1830

Please sign in to comment.