-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Nail in edge case of torch dtype being overriden permantly in the case of an error #35845
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,7 @@ | |
AutoModelForSequenceClassification, | ||
DynamicCache, | ||
LlavaForConditionalGeneration, | ||
MistralForCausalLM, | ||
OwlViTForObjectDetection, | ||
PretrainedConfig, | ||
is_torch_available, | ||
|
@@ -1818,6 +1819,71 @@ def test_cache_when_needed_at_train_time(self): | |
self.assertIsNone(model_outputs.past_key_values) | ||
self.assertTrue(model.training) | ||
|
||
def test_restore_default_torch_dtype_from_pretrained(self): | ||
muellerzr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Tests that the default torch dtype is restored | ||
when an error happens during the loading of a model. | ||
""" | ||
old_dtype = torch.get_default_dtype() | ||
# set default type to float32 | ||
torch.set_default_dtype(torch.float32) | ||
|
||
# Mock injection point which is right after the call to `_set_default_torch_dtype` | ||
original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype | ||
|
||
def debug(*args, **kwargs): | ||
# call the method as usual, than raise a RuntimeError | ||
original_set_default_torch_dtype(*args, **kwargs) | ||
raise RuntimeError | ||
|
||
AutoModelForCausalLM._set_default_torch_dtype = debug | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this line needed, given the patch in the next line? |
||
|
||
with mock.patch( | ||
"transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype", | ||
side_effect=debug, | ||
): | ||
with self.assertRaises(RuntimeError): | ||
_ = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto", torch_dtype=torch.float16) | ||
# default should still be float32 | ||
assert torch.get_default_dtype() == torch.float32 | ||
torch.set_default_dtype(old_dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if the test itself has the same issue as the original bug: If the test fails, this line is never called, which could result in subsequent tests failing because the default dtype is permanently changed. |
||
|
||
def test_restore_default_torch_dtype_from_config(self): | ||
""" | ||
Tests that the default torch dtype is restored | ||
when an error happens during the loading of a model. | ||
""" | ||
old_dtype = torch.get_default_dtype() | ||
# set default type to float32 | ||
torch.set_default_dtype(torch.float32) | ||
|
||
config = AutoConfig.from_pretrained( | ||
TINY_MISTRAL, | ||
) | ||
|
||
# Mock injection point which is right after the call to `_set_default_torch_dtype` | ||
original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype | ||
|
||
def debug(*args, **kwargs): | ||
# call the method as usual, than raise a RuntimeError | ||
original_set_default_torch_dtype(*args, **kwargs) | ||
raise RuntimeError | ||
|
||
AutoModelForCausalLM._set_default_torch_dtype = debug | ||
|
||
with mock.patch( | ||
"transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype", | ||
side_effect=debug, | ||
): | ||
with self.assertRaises(RuntimeError): | ||
config.torch_dtype = torch.float16 | ||
_ = AutoModelForCausalLM.from_config( | ||
config, | ||
) | ||
# default should still be float32 | ||
assert torch.get_default_dtype() == torch.float32 | ||
torch.set_default_dtype(old_dtype) | ||
|
||
|
||
@slow | ||
@require_torch | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this approach should generally work. It would fail if there was ever a case where calling
from_pretrained
should purposefully change the default dtype for the rest of the process. I'm not sure if such a use case exist, just wanted to highlight it.