From 2278eea4da005b18e210e00c13015f96d6391deb Mon Sep 17 00:00:00 2001 From: Paul Pinchuk Date: Mon, 19 Aug 2024 11:43:44 -0600 Subject: [PATCH] Minor refactor --- gaps/cli/config.py | 45 ++++++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/gaps/cli/config.py b/gaps/cli/config.py index 782cd6b..69d0674 100644 --- a/gaps/cli/config.py +++ b/gaps/cli/config.py @@ -244,22 +244,12 @@ def kickoff_jobs(self): keys_to_run, lists_to_run = self._keys_and_lists_to_run() jobs = sorted(product(*lists_to_run)) - num_jobs_submit = len(jobs) - self._warn_about_excessive_au_usage(num_jobs_submit) + self._warn_about_excessive_au_usage(len(jobs)) extra_exec_args = self._extract_extra_exec_args_for_command() - exec_kwargs = deepcopy(self.exec_kwargs) - num_test_nodes = exec_kwargs.pop("num_test_nodes", None) - if num_test_nodes is None: - num_test_nodes = float("inf") + for tag, values, exec_kwargs in self._with_tagged_context(jobs): - for node_index, values in enumerate(jobs): - if node_index >= num_test_nodes: - return self - - tag = _tag(node_index, num_jobs_submit) - self.ctx.obj["NAME"] = job_name = f"{self.job_name}{tag}" - node_specific_config = self._compile_node_config(tag, job_name) + node_specific_config = self._compile_node_config(tag) node_specific_config.update(extra_exec_args) for key, val in zip(keys_to_run, values): @@ -268,13 +258,31 @@ def kickoff_jobs(self): else: node_specific_config.update(dict(zip(key, val))) - cmd = self._compile_run_command(job_name, node_specific_config) + cmd = self._compile_run_command(node_specific_config) kickoff_job(self.ctx, cmd, exec_kwargs) return self - def _compile_node_config(self, tag, job_name): + def _with_tagged_context(self, jobs): + """Iterate over jobs and populate context with job name. """ + num_jobs_submit = len(jobs) + + exec_kwargs = deepcopy(self.exec_kwargs) + num_test_nodes = exec_kwargs.pop("num_test_nodes", None) + if num_test_nodes is None: + num_test_nodes = float("inf") + + for node_index, values in enumerate(jobs): + if node_index >= num_test_nodes: + return + + tag = _tag(node_index, num_jobs_submit) + self.ctx.obj["NAME"] = f"{self.job_name}{tag}" + yield tag, values, exec_kwargs + + def _compile_node_config(self, tag): """Compile initial node-specific config. """ + job_name = self.ctx.obj["NAME"] node_specific_config = deepcopy(self.config) node_specific_config.pop("execution_control", None) node_specific_config.update( @@ -287,15 +295,14 @@ def _compile_node_config(self, tag, job_name): "job_name": job_name, "out_dir": self.project_dir.as_posix(), "out_fpath": self._suggested_stem(job_name).as_posix(), - "run_method": getattr( - self.command_config, "run_method", None - ), + "run_method": getattr(self.command_config, "run_method", None), } ) return node_specific_config - def _compile_run_command(self, job_name, node_specific_config): + def _compile_run_command(self, node_specific_config): """Create run command from config and job name. """ + job_name = self.ctx.obj["NAME"] cmd = "; ".join(_CMD_LIST).format( run_func_module=self.command_config.runner.__module__, run_func_name=self.command_config.runner.__name__,