From 47890016ca6e525b90c4ff557b203c03ec602ce8 Mon Sep 17 00:00:00 2001 From: lgarg26 <154364140+lgarg26@users.noreply.github.com> Date: Tue, 21 Jan 2025 16:22:26 -0500 Subject: [PATCH] Fixup Optional runopt cfg values handling during cfg_from_json_repr deserialization Differential Revision: D68445341 Pull Request resolved: https://github.com/pytorch/torchx/pull/1000 --- torchx/specs/api.py | 7 ++++++- torchx/specs/test/api_test.py | 5 ++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/torchx/specs/api.py b/torchx/specs/api.py index 835dea53d..377ef7e55 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -948,7 +948,12 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]: for key, val in cfg_dict.items(): runopt_ = self.get(key) if runopt_: - if runopt_.opt_type == List[str]: + # Optional runopt cfg values default their value to None, + # but use `_type` to specify their type when provided. + # Make sure not to treat None's as lists/dictionaries + if val is None: + cfg[key] = val + elif runopt_.opt_type == List[str]: cfg[key] = [str(v) for v in val] elif runopt_.opt_type == Dict[str, str]: cfg[key] = {str(k): str(v) for k, v in val.items()} diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index 389102547..60dfc0dc6 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -555,6 +555,7 @@ def test_config_from_json_repr(self) -> None: opts.add("disable", type_=bool, default=True, help="") opts.add("complex_list", type_=List[str], default=[], help="") opts.add("complex_dict", type_=Dict[str, str], default={}, help="") + opts.add("default_none", type_=List[str], help="") self.assertDictEqual( { @@ -565,6 +566,7 @@ def test_config_from_json_repr(self) -> None: "disable": False, "complex_list": ["v1", "v2", "v3"], "complex_dict": {"k1": "v1", "k2": "v2"}, + "default_none": None, }, opts.resolve( opts.cfg_from_json_repr( @@ -575,7 +577,8 @@ def test_config_from_json_repr(self) -> None: "enable": true, "disable": false, "complex_list": ["v1", "v2", "v3"], - "complex_dict": {"k1": "v1", "k2": "v2"} + "complex_dict": {"k1": "v1", "k2": "v2"}, + "default_none": null }""" ) ),