Skip to content

Commit

Permalink
Eac/colors (#23)
Browse files Browse the repository at this point in the history
* added better colors to tomo plots

* fix typo & arguement order

* coverage
  • Loading branch information
eacharles authored Feb 5, 2025
1 parent 2fe3385 commit b5c282c
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 11 deletions.
6 changes: 3 additions & 3 deletions src/rail/plotting/nz_data_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def generate_dataset_dict(
project_block = dict(
Project=dict(
name=project_name,
yaml_file="dummy",
yaml_file=project_file,
)
)

Expand Down Expand Up @@ -158,7 +158,7 @@ def generate_dataset_dict(
classifier=classifier_,
)

if not nz_true_paths:
if not nz_true_paths: # pragma: no cover
continue

for summarizer_ in summarizers:
Expand All @@ -171,7 +171,7 @@ def generate_dataset_dict(
summarizer=summarizer_,
)

if not nz_paths:
if not nz_paths: # pragma: no cover
continue

dataset_name = f"{selection_}_{key}_{algo_}_{classifier_}_{summarizer_}"
Expand Down
9 changes: 7 additions & 2 deletions src/rail/plotting/nz_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import qp
from ceci.config import StageParameter
from matplotlib import pyplot as plt
import matplotlib as mpl

from .dataset_holder import RailDatasetHolder
from .plot_holder import RailPlotHolder
Expand Down Expand Up @@ -43,9 +44,13 @@ def _make_plot(
nz_vals = nz_estimates.pdf(bin_edges)
n_pdf = truth.npdf

cmap = mpl.colormaps['plasma']
colors = cmap(np.linspace(0, 1, n_pdf))

for i in range(n_pdf):
axes.plot(bin_edges, truth_vals[i], "-")
axes.plot(bin_edges, nz_vals[i])
color=colors[i]
axes.plot(bin_edges, truth_vals[i], "-", color=color)
axes.plot(bin_edges, nz_vals[i], "--", color=color)
plt.xlabel("z")
plt.ylabel("n(z)")
plot_name = self._make_full_plot_name(prefix, "")
Expand Down
2 changes: 1 addition & 1 deletion src/rail/projects/pipeline_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def build(

pipeline_kwargs = self.config.kwargs.copy()

if self.config.pipeline_overrides:
if self.config.pipeline_overrides: # pragma: no cover
copy_overrides = self.config.pipeline_overrides.copy()

stages_config = os.path.join(
Expand Down
2 changes: 1 addition & 1 deletion src/rail/projects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def add_flavor(self, name: str, **kwargs: Any) -> RailFlavor:
"""Add a new flavor to the Project"""
if self._flavors is None: # pragma: no cover
self.get_flavors()

assert self._flavors is not None
if name in self._flavors:
raise KeyError(f"Flavor {name} already in RailProject {self.name}")
flavor_params = self.config.Baseline.copy()
Expand Down
49 changes: 46 additions & 3 deletions tests/cli/plot/test_plot_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,35 @@ def test_cli_inspect(setup_project_area: int) -> None:
check_result(result)


def test_cli_extract_datasets(setup_project_area: int) -> None:
def test_cli_extract_nz_datasets(setup_project_area: int) -> None:
assert setup_project_area == 0
runner = CliRunner()

# run with split by flavor
result = runner.invoke(
plot_cli,
"extract-datasets "
"--extractor_class rail.plotting.nz_data_extractor.NZTomoBinDataExtractor "
"--flavor all "
"--split_by_flavor "
"--output_yaml tests/temp_data/dataset_nz_out.yaml "
"tests/ci_project.yaml",
)
check_result(result)

# run without split by flavor
result = runner.invoke(
plot_cli,
"extract-datasets "
"--extractor_class rail.plotting.nz_data_extractor.NZTomoBinDataExtractor "
"--flavor all "
"--output_yaml tests/temp_data/dataset_nz_out.yaml "
"tests/ci_project.yaml",
)
check_result(result)


def test_cli_extract_pz_datasets(setup_project_area: int) -> None:
assert setup_project_area == 0
runner = CliRunner()

Expand All @@ -45,7 +73,6 @@ def test_cli_extract_datasets(setup_project_area: int) -> None:
"extract-datasets "
"--extractor_class rail.plotting.pz_data_extractor.PZPointEstimateDataExtractor "
"--flavor all "
"--selection all "
"--split_by_flavor "
"--output_yaml tests/temp_data/dataset_out.yaml "
"tests/ci_project.yaml",
Expand All @@ -58,13 +85,13 @@ def test_cli_extract_datasets(setup_project_area: int) -> None:
"extract-datasets "
"--extractor_class rail.plotting.pz_data_extractor.PZPointEstimateDataExtractor "
"--flavor all "
"--selection all "
"--output_yaml tests/temp_data/dataset_out.yaml "
"tests/ci_project.yaml",
)
check_result(result)



def test_cli_make_plot_groups(setup_project_area: int) -> None:
assert setup_project_area == 0
runner = CliRunner()
Expand All @@ -79,3 +106,19 @@ def test_cli_make_plot_groups(setup_project_area: int) -> None:
"--dataset_list_name blend_baseline_all",
)
check_result(result)


def test_cli_make_nz_plot_groups(setup_project_area: int) -> None:
assert setup_project_area == 0
runner = CliRunner()

result = runner.invoke(
plot_cli,
"make-plot-groups "
"--output_yaml tests/temp_data/check_nz_plot_group.yaml "
"--plotter_yaml_path tests/ci_plots.yaml "
"--dataset_yaml_path tests/ci_datasets.yaml "
"--plotter_list_name tomo_bins "
"--dataset_list_name blend_baseline_tomo_knn",
)
check_result(result)
22 changes: 22 additions & 0 deletions tests/plotting/test_plot_group_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,25 @@ def test_load_yaml(setup_project_area: int) -> None:

check_plot_group = RailPlotGroupFactory.get_plot_group("accuracy_v_ztrue")
assert isinstance(check_plot_group, RailPlotGroup)



def test_make_instance_yaml(setup_project_area: int):
assert setup_project_area == 0

RailPlotGroupFactory.make_yaml(
output_yaml='tests/temp_data/check_pz_plot_groups.yaml',
plotter_yaml_path='tests/ci_plots.yaml',
dataset_yaml_path='tests/ci_datasets.yaml',
plotter_list_name='zestimate_v_ztrue',
output_prefix="",
dataset_list_name='baseline_test',
)
RailPlotGroupFactory.make_yaml(
output_yaml='tests/temp_data/check_nz_plot_groups.yaml',
plotter_yaml_path='tests/ci_plots.yaml',
dataset_yaml_path='tests/ci_datasets.yaml',
plotter_list_name='tomo_bins',
output_prefix="",
dataset_list_name='blend_baseline_tomo_knn',
)
8 changes: 7 additions & 1 deletion tests/projects/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from rail.projects.project import RailFlavor, RailProject
from rail.projects.project import RailProject


def check_get_func(func: Callable, check_dict: dict[str, Any]) -> None:
Expand All @@ -16,6 +16,12 @@ def check_get_func(func: Callable, check_dict: dict[str, Any]) -> None:
func("does_not_exist")


def test_project_doc() -> None:

RailProject.functionality_help()
RailProject.configuration_help()


def test_project_class(setup_project_area: int) -> None:
assert setup_project_area == 0

Expand Down

0 comments on commit b5c282c

Please sign in to comment.