diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 5eb6d8ca98..6fa35c4980 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -619,8 +619,7 @@ def tensor_hook( hooks = [] for _, module in state_dict_model.named_modules(): - hooks.append( - module._register_state_dict_hook(tensor_hook),) + hooks.append(module._register_state_dict_hook(tensor_hook),) state_dict = get_model_state_dict( state_dict_model,