Skip to content

Commit

Permalink
Print model's frozen status after freezing
Browse files Browse the repository at this point in the history
  • Loading branch information
JimChienTW committed Nov 19, 2024
1 parent d1195a6 commit d31ee18
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
9 changes: 7 additions & 2 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
freeze_LLM_only,
get_policies,
print_model_size,
print_frozen_model_status,
setup,
setup_environ_flags,
train,
Expand Down Expand Up @@ -194,6 +195,8 @@ def main(**kwargs):
)
model.resize_token_embeddings(len(tokenizer))

print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
if (
train_config.enable_fsdp
Expand Down Expand Up @@ -234,12 +237,14 @@ def main(**kwargs):

if not train_config.use_peft and train_config.freeze_layers:
freeze_transformer_layers(model, train_config.num_freeze_layers)
# print model size and frozen layers after freezing layers
print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)

if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
freeze_LLM_only(model)
# print model size and frozen layers after freezing layers
print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)

print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
if is_vision:
Expand Down
46 changes: 45 additions & 1 deletion src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,52 @@ def print_model_size(model, config, rank: int = 0) -> None:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")

def print_frozen_model_status(model, config, rank: int = 0) -> None:
"""
Print the frozen status of the model's and the number of trainable parameters after frozen.

Args:
model: The PyTorch model.
model_name (str): Name of the model.
rank (int, optional): Current process's rank. Defaults to 0.
"""
if rank == 0:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("After freezing the model:")
print(f"--> {config.model_name} has {trainable_params / 1e6} Million trainable params\n")

module_states = {}
# Iterate over all parameters
for name, param in model.named_parameters():
# Extract the top-level module name (e.g., "vision_model", "language_model")
top_module = name.split(".")[0]

# Initialize a record for the top-level module
if top_module not in module_states:
module_states[top_module] = {"frozen": [], "unfrozen": []}

# Group parameters into frozen or unfrozen
if param.requires_grad:
module_states[top_module]["unfrozen"].append(name)
else:
module_states[top_module]["frozen"].append(name)

print("--> Model state after freezing:")
# Analyze and print the results
for module, states in module_states.items():
frozen_params = states["frozen"]
unfrozen_params = states["unfrozen"]

if frozen_params and unfrozen_params:
# Mixed state: both frozen and unfrozen parameters
print(f" {module}: Mixed")
elif frozen_params:
# All parameters are frozen
print(f" {module}: Frozen")
else:
# All parameters are unfrozen
print(f" {module}: Unfrozen")
print("")

def get_policies(cfg, rank):
"""Get the policies for mixed precision and fsdp wrapping"""
Expand Down

0 comments on commit d31ee18

Please sign in to comment.