From 701c8ef97ba91d347f321a00c9136a60b62c534b Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Wed, 22 Jan 2025 12:15:48 -0500 Subject: [PATCH 1/4] Nail in edge case of torch dtype --- src/transformers/modeling_utils.py | 24 +++++++++++++++++ src/transformers/utils/__init__.py | 7 +++++ tests/utils/test_modeling_utils.py | 43 ++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a1e2db6c0883..4bf5bdf5b19c 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -98,6 +98,7 @@ logging, replace_return_docstrings, strtobool, + test_injection, ) from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files from .utils.import_utils import ( @@ -245,6 +246,25 @@ def set_zero3_state(): _is_ds_init_called = False +def restore_default_torch_dtype(func): + """ + Decorator to restore the default torch dtype + at the end of the function. Serves + as a backup in case loading a model in raises + an error. + """ + + @wraps(func) + def _wrapper(*args, **kwargs): + old_dtype = torch.get_default_dtype() + try: + return func(*args, **kwargs) + finally: + torch.set_default_dtype(old_dtype) + + return _wrapper + + def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]): try: return next(parameter.parameters()).device @@ -1401,6 +1421,7 @@ def add_model_tags(self, tags: Union[List[str], str]) -> None: self.model_tags.append(tag) @classmethod + @restore_default_torch_dtype def _from_config(cls, config, **kwargs): """ All context managers that the model should be initialized under go here. @@ -1422,6 +1443,7 @@ def _from_config(cls, config, **kwargs): dtype_orig = None if torch_dtype is not None: dtype_orig = cls._set_default_torch_dtype(torch_dtype) + test_injection() config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. @@ -3138,6 +3160,7 @@ def float(self, *args): return super().float(*args) @classmethod + @restore_default_torch_dtype def from_pretrained( cls: Type[SpecificPreTrainedModelType], pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], @@ -4059,6 +4082,7 @@ def from_pretrained( for key in config.sub_configs.keys(): value = getattr(config, key) value.torch_dtype = default_dtype + test_injection() # Check if `_keep_in_fp32_modules` is not None use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index e5aedf5916fa..4f07971456ac 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -320,3 +320,10 @@ def get_available_devices() -> FrozenSet[str]: devices.add("musa") return frozenset(devices) + + +def test_injection(): + """ + An injection point for testing hard-to-mock functions + """ + return True diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 84b5ebbb24ce..33a224be2baf 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -316,6 +316,10 @@ def check_models_equal(model1, model2): return models_are_equal +def fake_from_config(*args, **kwargs): + raise RuntimeError() + + @require_torch class ModelUtilsTest(TestCasePlus): @slow @@ -1818,6 +1822,45 @@ def test_cache_when_needed_at_train_time(self): self.assertIsNone(model_outputs.past_key_values) self.assertTrue(model.training) + def test_restore_default_torch_dtype_from_pretrained(self): + """ + Tests that the default torch dtype is restored + when an error happens during the loading of a model. + """ + old_dtype = torch.get_default_dtype() + # set default type to float32 + torch.set_default_dtype(torch.float32) + # Mock injection point which is right after the call to `_set_default_torch_dtype` + with mock.patch("transformers.modeling_utils.test_injection", side_effect=RuntimeError()): + with self.assertRaises(RuntimeError): + _ = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto", torch_dtype=torch.float16) + # default should still be float32 + assert torch.get_default_dtype() == torch.float32 + torch.set_default_dtype(old_dtype) + + def test_restore_default_torch_dtype_from_config(self): + """ + Tests that the default torch dtype is restored + when an error happens during the loading of a model. + """ + old_dtype = torch.get_default_dtype() + # set default type to float32 + torch.set_default_dtype(torch.float32) + + config = AutoConfig.from_pretrained( + TINY_MISTRAL, + ) + # Mock injection point which is right after the call to `_set_default_torch_dtype` + with mock.patch("transformers.modeling_utils.test_injection", side_effect=RuntimeError()): + with self.assertRaises(RuntimeError): + config.torch_dtype = torch.float16 + _ = AutoModelForCausalLM.from_config( + config, + ) + # default should still be float32 + assert torch.get_default_dtype() == torch.float32 + torch.set_default_dtype(old_dtype) + @slow @require_torch From ee9900f222c52877326db4fa63414ca121fba67d Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Wed, 22 Jan 2025 12:42:22 -0500 Subject: [PATCH 2/4] Rm unused func --- tests/utils/test_modeling_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 33a224be2baf..18c9a6f32aa0 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -316,10 +316,6 @@ def check_models_equal(model1, model2): return models_are_equal -def fake_from_config(*args, **kwargs): - raise RuntimeError() - - @require_torch class ModelUtilsTest(TestCasePlus): @slow From 3c1a799b83f120d2c6ec1f6ff481acc932f5fde2 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Thu, 23 Jan 2025 15:03:40 -0500 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Benjamin Bossan --- src/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4bf5bdf5b19c..7f58a87b3d9d 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -250,8 +250,8 @@ def restore_default_torch_dtype(func): """ Decorator to restore the default torch dtype at the end of the function. Serves - as a backup in case loading a model in raises - an error. + as a backup in case calling the function raises + an error after the function has changed the default dtype but before it could restore it. """ @wraps(func) From 86eac8f1c81a1af6216de7f27aa72365e20503a4 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 23 Jan 2025 15:12:31 -0500 Subject: [PATCH 4/4] Refactor tests to only mock what we need, don't introduce injection functions --- src/transformers/modeling_utils.py | 3 --- src/transformers/utils/__init__.py | 7 ------- tests/utils/test_modeling_utils.py | 31 ++++++++++++++++++++++++++++-- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7f58a87b3d9d..c7c6d0520c26 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -98,7 +98,6 @@ logging, replace_return_docstrings, strtobool, - test_injection, ) from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files from .utils.import_utils import ( @@ -1443,7 +1442,6 @@ def _from_config(cls, config, **kwargs): dtype_orig = None if torch_dtype is not None: dtype_orig = cls._set_default_torch_dtype(torch_dtype) - test_injection() config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. @@ -4082,7 +4080,6 @@ def from_pretrained( for key in config.sub_configs.keys(): value = getattr(config, key) value.torch_dtype = default_dtype - test_injection() # Check if `_keep_in_fp32_modules` is not None use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 4f07971456ac..e5aedf5916fa 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -320,10 +320,3 @@ def get_available_devices() -> FrozenSet[str]: devices.add("musa") return frozenset(devices) - - -def test_injection(): - """ - An injection point for testing hard-to-mock functions - """ - return True diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 18c9a6f32aa0..f833724ebff7 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -39,6 +39,7 @@ AutoModelForSequenceClassification, DynamicCache, LlavaForConditionalGeneration, + MistralForCausalLM, OwlViTForObjectDetection, PretrainedConfig, is_torch_available, @@ -1826,8 +1827,21 @@ def test_restore_default_torch_dtype_from_pretrained(self): old_dtype = torch.get_default_dtype() # set default type to float32 torch.set_default_dtype(torch.float32) + # Mock injection point which is right after the call to `_set_default_torch_dtype` - with mock.patch("transformers.modeling_utils.test_injection", side_effect=RuntimeError()): + original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype + + def debug(*args, **kwargs): + # call the method as usual, than raise a RuntimeError + original_set_default_torch_dtype(*args, **kwargs) + raise RuntimeError + + AutoModelForCausalLM._set_default_torch_dtype = debug + + with mock.patch( + "transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype", + side_effect=debug, + ): with self.assertRaises(RuntimeError): _ = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto", torch_dtype=torch.float16) # default should still be float32 @@ -1846,8 +1860,21 @@ def test_restore_default_torch_dtype_from_config(self): config = AutoConfig.from_pretrained( TINY_MISTRAL, ) + # Mock injection point which is right after the call to `_set_default_torch_dtype` - with mock.patch("transformers.modeling_utils.test_injection", side_effect=RuntimeError()): + original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype + + def debug(*args, **kwargs): + # call the method as usual, than raise a RuntimeError + original_set_default_torch_dtype(*args, **kwargs) + raise RuntimeError + + AutoModelForCausalLM._set_default_torch_dtype = debug + + with mock.patch( + "transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype", + side_effect=debug, + ): with self.assertRaises(RuntimeError): config.torch_dtype = torch.float16 _ = AutoModelForCausalLM.from_config(