diff --git a/torax/time_step_calculator/config.py b/torax/time_step_calculator/config.py new file mode 100644 index 00000000..d361a8ab --- /dev/null +++ b/torax/time_step_calculator/config.py @@ -0,0 +1,43 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pydantic config for time step calculators.""" + +import enum +from torax.time_step_calculator import chi_time_step_calculator +from torax.time_step_calculator import fixed_time_step_calculator +from torax.time_step_calculator import time_step_calculator +from torax.torax_pydantic import torax_pydantic + + +@enum.unique +class TimeStepCalculatorType(enum.Enum): + """Types of time step calculators.""" + + CHI = 'chi' + FIXED = 'fixed' + + +class TimeStepCalculator(torax_pydantic.BaseModelMutable): + """Config for a time step calculator.""" + + calculator_type: TimeStepCalculatorType = TimeStepCalculatorType.CHI + + @property + def time_step_calculator(self) -> time_step_calculator.TimeStepCalculator: + match self.calculator_type: + case TimeStepCalculatorType.CHI: + return chi_time_step_calculator.ChiTimeStepCalculator() + case TimeStepCalculatorType.FIXED: + return fixed_time_step_calculator.FixedTimeStepCalculator() diff --git a/torax/torax_pydantic/model_config.py b/torax/torax_pydantic/model_config.py new file mode 100644 index 00000000..48f70cb1 --- /dev/null +++ b/torax/torax_pydantic/model_config.py @@ -0,0 +1,28 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pydantic config for Torax.""" + +from torax.time_step_calculator import config as time_step_calculator_config +from torax.torax_pydantic import model_base + + +class ToraxConfig(model_base.BaseModelMutable): + """Base config class for Torax. + + Attributes: + time_step_calculator: Config for the time step calculator. + """ + + time_step_calculator: time_step_calculator_config.TimeStepCalculator diff --git a/torax/torax_pydantic/tests/model_config_test.py b/torax/torax_pydantic/tests/model_config_test.py new file mode 100644 index 00000000..32a94c27 --- /dev/null +++ b/torax/torax_pydantic/tests/model_config_test.py @@ -0,0 +1,56 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the `torax.config` module.""" + +from absl.testing import absltest +from absl.testing import parameterized +from torax.config import config_loader +from torax.torax_pydantic import model_config + + +class ConfigTest(parameterized.TestCase): + """Unit tests for the `torax.config` module.""" + + def test_full_config_construction(self): + """Test for basic config construction.""" + + module = config_loader.import_module( + ".tests.test_data.test_iterhybrid_newton", + config_package="torax", + ) + + # Test only the subset of config fields that are currently supported. + module_config = { + key: module.CONFIG[key] + for key in model_config.ToraxConfig.model_fields.keys() + } + config_pydantic = model_config.ToraxConfig.from_dict(module_config) + + self.assertEqual( + config_pydantic.time_step_calculator.calculator_type.value, + module_config["time_step_calculator"]["calculator_type"], + ) + + # The full model should always be serializable. + with self.subTest("json_serialization"): + config_json = config_pydantic.model_dump_json() + config_pydantic_roundtrip = model_config.ToraxConfig.model_validate_json( + config_json + ) + self.assertEqual(config_pydantic, config_pydantic_roundtrip) + + +if __name__ == "__main__": + absltest.main()