Skip to content

Commit

Permalink
Eac/com cam (#14)
Browse files Browse the repository at this point in the history
* fix up extract_datasets

* fix up plot_group_factory

* improve name_utils.update_include_dict

* remove output_catalog_basename

* fix up flavor pipeline override code
  • Loading branch information
eacharles authored Feb 4, 2025
1 parent f601b26 commit 64313a6
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 35 deletions.
10 changes: 6 additions & 4 deletions src/rail/plotting/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,12 @@ def extract_datasets(
"""
extractor_cls = load_extractor_class(extractor_class)
project = RailProject.load_config(config_file)
output_data = extractor_cls.generate_dataset_dict(
project=project,
**kwargs,
)
output_data = {
'Data': extractor_cls.generate_dataset_dict(
project=project,
**kwargs,
)
}
with open(output_yaml, "w", encoding="utf-8") as fout:
yaml.dump(output_data, fout)

Expand Down
12 changes: 8 additions & 4 deletions src/rail/plotting/plot_group_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,22 @@ def make_instance_yaml(
dataset_path = re.sub(
".*rail_project_config", "${RAIL_PROJECT_CONFIG_DIR}", dataset_yaml_path
)
output: list[dict[str, Any]] = []
output.append(dict(Includes=[plotter_path, dataset_path]))
plot_groups: list[dict] = []
for ds_name in dataset_list_name:
group_name = f"{output_prefix}{ds_name}_{plotter_list_name}"
output.append(
plot_groups.append(
dict(
PlotGroup=dict(
name=group_name,
plotter_list_name=plotter_list_name,
dataset_dict_name=ds_name,
dataset_list_name=ds_name,
)
)
)

output: dict[str, Any] = dict(
Includes=[plotter_path, dataset_path],
PlotGroups=plot_groups,
)
with open(output_yaml, "w", encoding="utf-8") as fout:
yaml.dump(output, fout)
7 changes: 6 additions & 1 deletion src/rail/projects/name_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ def update_include_dict(
"""
for key, val in include_dict.items():
if isinstance(val, dict) and key in orig_dict:
update_include_dict(orig_dict[key], val)
sub_dict = orig_dict[key]
if isinstance(sub_dict, dict):
update_include_dict(orig_dict[key], val)
continue
if isinstance(val, (dict, list)):
orig_dict[key] = val.copy()
else:
orig_dict[key] = val

Expand Down
45 changes: 24 additions & 21 deletions src/rail/projects/pipeline_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,13 @@ class RailPipelineTemplate(Configurable):
),
input_catalog_template=StageParameter(
str,
None,
"",
fmt="%s",
msg="Template to use for input catalog",
),
output_catalog_template=StageParameter(
str,
None,
"",
fmt="%s",
msg="Template to use for output catalog",
),
Expand All @@ -390,12 +390,6 @@ class RailPipelineTemplate(Configurable):
fmt="%s",
msg="Basename to use for input catalog",
),
output_catalog_basename=StageParameter(
str,
"",
fmt="%s",
msg="Basename to use for output catalog",
),
input_file_templates=StageParameter(
dict,
{},
Expand Down Expand Up @@ -509,18 +503,6 @@ def build(
pass

pipeline_kwargs = self.config.kwargs.copy()
for key, val in pipeline_kwargs.items():
if isinstance(val, list) and "all" in val:
if key == "selectors":
pipeline_kwargs[key] = project.get_spec_selections()
elif key == "algorithms":
pipeline_kwargs[key] = project.get_pzalgorithms()
elif key == "classifiers":
pipeline_kwargs[key] = project.get_classifiers()
elif key == "summarizers":
pipeline_kwargs[key] = project.get_summarizers()
elif key == "error_models":
pipeline_kwargs[key] = project.get_error_models()

if self.config.pipeline_overrides:
copy_overrides = self.config.pipeline_overrides.copy()
Expand All @@ -529,11 +511,32 @@ def build(
pipe_out_dir, f"{self.config.name}_overrides.yml"
)

kwarg_overrides = copy_overrides.pop('kwargs', {})
pipeline_kwargs.update(**kwarg_overrides)

with open(stages_config, "w", encoding="utf-8") as fout:
yaml.dump(copy_overrides, fout)
else:
stages_config = None

for key, val in pipeline_kwargs.items():
if key == "selectors":
temp_dict = project.get_spec_selections()
elif key == "algorithms":
temp_dict= project.get_pzalgorithms()
elif key == "classifiers":
temp_dict = project.get_classifiers()
elif key == "summarizers":
temp_dict = project.get_summarizers()
elif key == "error_models":
temp_dict = project.get_error_models()
else:
continue
if 'all' in val:
pipeline_kwargs[key] = temp_dict
else:
pipeline_kwargs[key] = {algo_name_: temp_dict[algo_name_] for algo_name_ in val}

catalog_tag = project.get_flavor(self.config.flavor).get("catalog_tag", None)
if catalog_tag:
catalog_utils.apply_defaults(catalog_tag)
Expand Down Expand Up @@ -636,7 +639,7 @@ def make_pipeline_catalog_commands(
)
sink_catalog_files = project.get_catalog_files(
pipeline_info.config.output_catalog_template,
basename=pipeline_info.config.output_catalog_basename,
basename='output.hdf5',
flavor=self.config.flavor,
**kwargs,
)
Expand Down
10 changes: 5 additions & 5 deletions src/rail/projects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def build_pipelines(
"""
flavor_dict = self.get_flavor(flavor)
pipelines_to_build = flavor_dict["pipelines"]
pipeline_overrides = flavor_dict.get("pipeline_overrides", {})
all_flavor_overrides = flavor_dict.get("pipeline_overrides", {}).copy()
do_all = "all" in pipelines_to_build

ok = 0
Expand All @@ -496,11 +496,11 @@ def build_pipelines(
print(f"Skipping existing pipeline {output_yaml}")
continue

overrides = pipeline_overrides.get("default", {}).copy()
overrides.update(**pipeline_overrides.get(pipeline_name, {}))

overrides = all_flavor_overrides.get("default", {}).copy()
pipeline_overrides = all_flavor_overrides.get(pipeline_name, {}).copy()
overrides.update(**pipeline_overrides)
pipeline_instance = pipeline_info.make_instance(
self, flavor, pipeline_overrides
self, flavor, overrides
)
ok |= pipeline_instance.build(self)

Expand Down

0 comments on commit 64313a6

Please sign in to comment.