Skip to content

Commit

Permalink
Eac/nz extraction (#22)
Browse files Browse the repository at this point in the history
* added nz_extraction stuff

* fix nz_data_extractor.py
  • Loading branch information
eacharles authored Feb 5, 2025
1 parent 4db09ac commit 2fe3385
Show file tree
Hide file tree
Showing 16 changed files with 244 additions and 96 deletions.
22 changes: 11 additions & 11 deletions src/rail/cli/rail_plot/plot_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

from . import plot_options

__all__=[
'plot_cli',
'run_command',
'inspect_command',
'extract_datasets_command',
'make_plot_groups',
]
__all__ = [
"plot_cli",
"run_command",
"inspect_command",
"extract_datasets_command",
"make_plot_groups",
]


@click.group()
Expand All @@ -27,7 +27,7 @@ def plot_cli() -> None:
sets of standard plots and html pages to help browse them.
The configuration file can include these yaml_tags
1. `Plots` with type of plots available
2. `Data` with specific datasets we can make those plots with
3. `PlotGroup` with combinations of the two
Expand Down Expand Up @@ -90,7 +90,7 @@ def extract_datasets_command(
extractor_class is able to extract and write the
results to the output_yaml file.
"""

control.clear()
control.extract_datasets(
config_file,
Expand All @@ -112,9 +112,9 @@ def make_plot_groups(output_yaml: str, **kwargs: dict[str, Any]) -> int:
"""Combine plotters with availble datsets
This will read the plotter_yaml and dataset_yaml
files, and combine all the datasets in the list
files, and combine all the datasets in the list
given by dataset_list_name with all the plots given
by plotter_list_name and write the results to the output_yaml
by plotter_list_name and write the results to the output_yaml
file.
"""
control.clear()
Expand Down
2 changes: 0 additions & 2 deletions src/rail/cli/rail_project/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
from .project_commands import *


24 changes: 12 additions & 12 deletions src/rail/cli/rail_project/project_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@


__all__ = [
'project_cli',
'inspect_command',
'build_command',
'subsample_command',
'reduce_command',
'run_group',
"project_cli",
"inspect_command",
"build_command",
"subsample_command",
"reduce_command",
"run_group",
]


@click.group()
@click.version_option(__version__)
Expand All @@ -28,7 +28,7 @@ def project_cli() -> None:
defining the RailProject.
That file can, in turn, include other yaml
configuration files that define a 'library' of
configuration files that define a 'library' of
possible analysis components
"""

Expand Down Expand Up @@ -65,7 +65,7 @@ def build_command(config_file: str, **kwargs: Any) -> int:
This will build all of the pipelines associated to
a particular flavor or flavors, and write them to the
the project pipelines area.
the project pipelines area.
"""
project = RailProject.load_config(config_file)
flavors = project.get_flavor_args(kwargs.pop("flavor"))
Expand Down Expand Up @@ -93,10 +93,10 @@ def subsample_command(
a catalog of input files
This will:
resolve a catalog of input files from the catalog_template,
resolve a catalog of input files from the catalog_template,
flavor, selection and basename parameters,
resolve a single output file from the file_template, flavor and selection
parameters,
parameters,
subsample from the catalog files and write to the output file.
"""

Expand Down Expand Up @@ -157,7 +157,7 @@ def reduce_command(
"""Reduce the roman rubin simulations for analysis
This will:
resolve a catalog of input files from the catalog_template,
resolve a catalog of input files from the catalog_template,
and input_selection parameters,
resolve a catalog of output files from the output_catalog_template
and selection parameters,
Expand Down
6 changes: 2 additions & 4 deletions src/rail/plotting/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import yaml

from rail.projects import RailProject
from rail.projects.factory_mixin import RailFactoryMixin

from .data_extractor import RailProjectDataExtractor
Expand Down Expand Up @@ -276,10 +275,9 @@ def extract_datasets(
Split dataset lists by flavor
"""
extractor_cls = load_extractor_class(extractor_class)
project = RailProject.load_config(config_file)
output_data = {
'Data': extractor_cls.generate_dataset_dict(
project=project,
"Data": extractor_cls.generate_dataset_dict(
project_file=config_file,
**kwargs,
)
}
Expand Down
161 changes: 161 additions & 0 deletions src/rail/plotting/nz_data_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from rail.projects import RailProject

from .data_extraction_funcs import (
get_ceci_nz_output_paths,
get_ceci_true_nz_output_paths,
get_tomo_bins_nz_estimate_data,
get_tomo_bins_true_nz_data,
)
Expand Down Expand Up @@ -41,3 +43,162 @@ def _get_data(self, **kwargs: Any) -> dict[str, Any]:
truth=get_tomo_bins_true_nz_data(**kwcopy),
)
return data

@classmethod
def generate_dataset_dict(
cls,
**kwargs: Any,
) -> list[dict[str, Any]]:
"""
Parameters
----------
**kwargs
Set Notes
Notes
-----
dataset_list_name: str
Name for the resulting DatasetList
dataset_holder_class: str
Class for the dataset holder
project_file: str
Config file for project to inspect
selections: list[str]
Selections to use
flavors: list[str]
Flavors to use
classifiers: list[str]
Flavors to use
summarizers: list[str]
Summarizers to use
Returns
-------
list[dict[str, Any]]
Dictionary of the extracted datasets
"""
dataset_list_name: str | None = kwargs.get("dataset_list_name")
dataset_holder_class: str | None = kwargs.get("dataset_holder_class")
project_file = kwargs["project_file"]
project = RailProject.load_config(project_file)

selections = kwargs.get("selections")
flavors = kwargs.get("flavors")
split_by_flavor = kwargs.get("split_by_flavor", False)

output: list[dict[str, Any]] = []

flavor_dict = project.get_flavors()
if flavors is None or "all" in flavors:
flavors = list(flavor_dict.keys())
if selections is None or "all" in selections:
selections = list(project.get_selections().keys())

project_name = project.name
if not dataset_list_name:
dataset_list_name = f"{project_name}_pz_point"

project_block = dict(
Project=dict(
name=project_name,
yaml_file="dummy",
)
)

output.append(project_block)

dataset_list_dict: dict[str, list[str]] = {}
dataset_key = dataset_list_name
if not split_by_flavor:
dataset_list_dict[dataset_key] = []

for key in flavors:
val = flavor_dict[key]
pipelines = val["pipelines"]
if "all" not in pipelines and "pz" not in pipelines: # pragma: no cover
continue
try:
algos = val["pipeline_overrides"]["default"]["kwargs"]["algorithms"]
except KeyError:
algos = list(project.get_pzalgorithms().keys())

try:
classifiers = val["pipeline_overrides"]["default"]["kwargs"][
"classifiers"
]
except KeyError:
classifiers = list(project.get_classifiers().keys())

try:
summarizers = val["pipeline_overrides"]["default"]["kwargs"][
"summarizers"
]
except KeyError:
summarizers = list(project.get_summarizers().keys())

for selection_ in selections:
if split_by_flavor:
dataset_key = f"{dataset_list_name}_{selection_}_{key}"
dataset_list_dict[dataset_key] = []

for algo_ in algos:
for classifier_ in classifiers:

nz_true_paths = get_ceci_true_nz_output_paths(
project,
selection=selection_,
flavor=key,
algo=algo_,
classifier=classifier_,
)

if not nz_true_paths:
continue

for summarizer_ in summarizers:
nz_paths = get_ceci_nz_output_paths(
project,
selection=selection_,
flavor=key,
algo=algo_,
classifier=classifier_,
summarizer=summarizer_,
)

if not nz_paths:
continue

dataset_name = f"{selection_}_{key}_{algo_}_{classifier_}_{summarizer_}"

dataset_dict = dict(
name=dataset_name,
class_name=dataset_holder_class,
extractor=cls.full_class_name(),
project=project_name,
flavor=key,
algo=algo_,
selection=selection_,
classifier=classifier_,
summarizer=summarizer_,
)

dataset_list_dict[dataset_key].append(dataset_name)
output.append(dict(Dataset=dataset_dict))

for ds_name, ds_list in dataset_list_dict.items():
# Skip empty lists
if not ds_list:
continue
dataset_list = dict(
name=ds_name,
datasets=ds_list,
)
output.append(dict(DatasetList=dataset_list))

return output
4 changes: 2 additions & 2 deletions src/rail/plotting/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,11 @@ def _make_plots(
class RailPlotterList(Configurable):
"""The class collects a set of plotter that can all run on the same data.
E.g., plotters that can all run on a dict that looks like
E.g., plotters that can all run on a dict that looks like
`{truth:np.ndarray, pointEstimates: np.ndarray}` could be put into a PlotterList.
This make it easier to collect similar types of plots.
"""

config_options: dict[str, StageParameter] = dict(
name=StageParameter(str, None, fmt="%s", required=True, msg="PlotterList name"),
plotters=StageParameter(
Expand Down
25 changes: 15 additions & 10 deletions src/rail/plotting/pz_data_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,21 @@ def generate_dataset_dict(
**kwargs: Any,
) -> list[dict[str, Any]]:
"""
Keywords
--------
Parameters
----------
**kwargs
Set Notes
Notes
-----
dataset_list_name: str
Name for the resulting DatasetList
dataset_holder_class: str
Class for the dataset holder
project: RailProject
Project to inspect
project_file: str
Config file for project to inspect
selections: list[str]
Selections to use
Expand All @@ -61,13 +66,13 @@ def generate_dataset_dict(
Returns
-------
output: list[dict[str, Any]]
list[dict[str, Any]]
Dictionary of the extracted datasets
"""
dataset_list_name: str | None = kwargs.get("dataset_list_name")
dataset_holder_class: str | None = kwargs.get("dataset_holder_class")
project = kwargs["project"]
assert isinstance(project, RailProject)
project_file = kwargs["project_file"]
project = RailProject.load_config(project_file)
selections = kwargs.get("selections")
flavors = kwargs.get("flavors")
split_by_flavor = kwargs.get("split_by_flavor", False)
Expand All @@ -82,12 +87,12 @@ def generate_dataset_dict(

project_name = project.name
if not dataset_list_name:
dataset_list_name = f"{project_name}_pz_point"
dataset_list_name = f"{project_name}_nz_tomo"

project_block = dict(
Project=dict(
name=project_name,
yaml_file="dummy",
yaml_file=project_file,
)
)

Expand Down Expand Up @@ -126,7 +131,7 @@ def generate_dataset_dict(
dataset_dict = dict(
name=dataset_name,
class_name=dataset_holder_class,
extractor="rail.plotting.pz_data_extractor.PZPointEstimateDataExtractor",
extractor=cls.full_class_name(),
project=project_name,
flavor=key,
algo=algo_,
Expand Down
1 change: 0 additions & 1 deletion src/rail/projects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@
from . import library, name_utils

from .project import RailFlavor, RailProject

Loading

0 comments on commit 2fe3385

Please sign in to comment.