Skip to content

Commit

Permalink
implement custom_style_config plotting parameter (columnflow#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
mafrahm authored Dec 8, 2023
1 parent 250c6ec commit f901c9e
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions columnflow/tasks/framework/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import luigi

from columnflow.types import Any, Callable
from columnflow.tasks.framework.base import ConfigTask
from columnflow.tasks.framework.base import ConfigTask, RESOLVE_DEFAULT
from columnflow.tasks.framework.mixins import DatasetsProcessesMixin, VariablesMixin
from columnflow.tasks.framework.parameters import SettingsParameter, MultiSettingsParameter
from columnflow.util import DotDict, dict_add_strict
Expand Down Expand Up @@ -47,6 +47,13 @@ class PlotBase(ConfigTask):
description="Parameter to set a list of custom plotting parameters. Format: "
"'option1=val1,option2=val2,...'",
)
custom_style_config = luigi.Parameter(
default=RESOLVE_DEFAULT,
significant=False,
description="Parameter to overwrite the *style_config* that is passed to the plot function"
"via a dictionary in the `custom_style_config_groups` auxiliary in the config;"
"defaults to the `default_custom_style_config` aux",
)
skip_legend = law.OptionalBoolParameter(
default=None,
significant=False,
Expand All @@ -68,7 +75,7 @@ def resolve_param_values(cls, params):
return params
config_inst = params["config_inst"]

# resolve variable_settings
# resolve general_settings
if "general_settings" in params:
settings = params["general_settings"]
# when empty and default general_settings are defined, use them instead
Expand All @@ -94,6 +101,7 @@ def get_plot_parameters(self) -> DotDict:
dict_add_strict(params, "skip_legend", self.skip_legend)
dict_add_strict(params, "cms_label", self.cms_label)
dict_add_strict(params, "general_settings", self.general_settings)
dict_add_strict(params, "custom_style_config", self.custom_style_config)
return params

def get_plot_names(self, name: str) -> list[str]:
Expand Down Expand Up @@ -155,6 +163,21 @@ def update_plot_kwargs(self, kwargs: dict) -> dict:
for key, value in general_settings.items():
kwargs.setdefault(key, value)

# resolve custom_style_config
custom_style_config = kwargs.get("custom_style_config", None)
if custom_style_config == RESOLVE_DEFAULT:
custom_style_config = self.config_inst.x("default_custom_style_config", RESOLVE_DEFAULT)

groups = self.config_inst.x("custom_style_config_groups", {})
if isinstance(custom_style_config, str) and custom_style_config in groups.keys():
custom_style_config = groups[custom_style_config]

# update style_config
style_config = kwargs.get("style_config", {})
if isinstance(custom_style_config, dict) and isinstance(style_config, dict):
style_config = law.util.merge_dicts(custom_style_config, style_config)
kwargs["style_config"] = style_config

return kwargs


Expand Down

0 comments on commit f901c9e

Please sign in to comment.