diff --git a/pyproject.toml b/pyproject.toml index d154432..1b79d06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "aind-data-transformation" -description = "Generated from aind-library-template" +description = "Generic Etl Job template that can be imported" license = {text = "MIT"} requires-python = ">=3.8" authors = [ diff --git a/src/aind_data_transformation/core.py b/src/aind_data_transformation/core.py index 868b4fc..fbba0d9 100644 --- a/src/aind_data_transformation/core.py +++ b/src/aind_data_transformation/core.py @@ -9,6 +9,8 @@ from pydantic import BaseModel, ConfigDict, Field from pydantic_settings import BaseSettings, SettingsConfigDict +PathLike = TypeVar("PathLike", str, Path) + def get_parser() -> argparse.ArgumentParser: """ @@ -49,8 +51,8 @@ class BasicJobSettings(BaseSettings): """Model to define Transformation Job Configs""" model_config = SettingsConfigDict(env_prefix="TRANSFORMATION_JOB_") - input_source: Path - output_directory: Path + input_source: PathLike + output_directory: PathLike @classmethod def from_config_file(cls, config_file_location: Path): @@ -92,7 +94,16 @@ def __init__(self, job_settings: _T): job_settings : _T Generic type that is bound by the BaseSettings class. """ - self.job_settings = job_settings + self.job_settings = job_settings.model_copy(deep=True) + # Parse str into Paths + if isinstance(self.job_settings.input_source, str): + self.job_settings.input_source = Path( + self.job_settings.input_source + ) + if isinstance(self.job_settings.output_directory, str): + self.job_settings.output_directory = Path( + self.job_settings.output_directory + ) @abstractmethod def run_job(self) -> JobResponse: diff --git a/tests/test_core.py b/tests/test_core.py index 8893bdc..5d8b3de 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -48,12 +48,22 @@ def setUpClass(cls) -> None: """Set up tests with basic job settings and etl job""" basic_settings = ExampleJobSettings( param=2, - input_source=Path("some_input_dir"), - output_directory=Path("some_output_dir"), + input_source="some_input_dir", + output_directory="some_output_dir", ) cls.basic_settings = basic_settings cls.basic_job = ExampleJob(job_settings=basic_settings) + def test_settings_with_paths(self): + """Tests JobSettings can be set with Path types if desired.""" + basic_settings = ExampleJobSettings( + param=2, + input_source=Path("some_input_dir"), + output_directory=Path("some_out_dir"), + ) + self.assertEqual(Path("some_input_dir"), basic_settings.input_source) + self.assertEqual(Path("some_out_dir"), basic_settings.output_directory) + def test_load_cli_args_json_str(self): """Tests loading json string defined in command line args""" job_settings_json = self.basic_settings.model_dump_json()