diff --git a/docs/how-to/write-plans.md b/docs/how-to/write-plans.md index afd616c16..582c42f8c 100644 --- a/docs/how-to/write-plans.md +++ b/docs/how-to/write-plans.md @@ -28,7 +28,9 @@ def my_plan( ... ``` -The type annotations (e.g. `: str`, `: int`, `-> MsgGenerator`) are required as blueapi uses them to detect that this function is intended to be a plan and generate its runtime API. +## Detection + +The type annotations in the example above (e.g. `: str`, `: int`, `-> MsgGenerator`) are required as blueapi uses them to detect that this function is intended to be a plan and generate its runtime API. If there is an [`__all__` dunder](https://docs.python.org/3/tutorial/modules.html#importing-from-a-package) present in the module, blueapi will read that and import anything within that qualifies as a plan, per its type annotations. If not it will read everything in the module that hasn't been imported, for example it will ignore a plan imported from another module. **Input annotations should be as broad as possible**, the least specific implementation that is sufficient to accomplish the requirements of the plan. For example, if a plan is written to drive a specific motor (`MyMotor`), but only uses the general methods on the [`Movable` protocol](https://blueskyproject.io/bluesky/main/hardware.html#bluesky.protocols.Movable), it should take `Movable` as a parameter annotation rather than `MyMotor`. diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index fcc63227d..fd24f3ab4 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -16,7 +16,11 @@ from blueapi import utils from blueapi.config import EnvironmentConfig, SourceKind -from blueapi.utils import BlueapiPlanModelConfig, load_module_all +from blueapi.utils import ( + BlueapiPlanModelConfig, + is_function_sourced_from_module, + load_module_all, +) from .bluesky_types import ( BLUESKY_PROTOCOLS, @@ -99,7 +103,15 @@ def plan_2(...) -> MsgGenerator: """ for obj in load_module_all(module): - if is_bluesky_plan_generator(obj): + # The rule here is that we only inspect objects defined in the module + # (as opposed to objects imported from other modules) to determine if + # they are valid plans, unless there is an __all__ defined in the module, + # in which case we only inspect objects listed there, regardless of their + # original source module. + if is_bluesky_plan_generator(obj) and ( + hasattr(module, "__all__") + or is_function_sourced_from_module(obj, module) + ): self.register_plan(obj) def with_device_module(self, module: ModuleType) -> None: diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index 602cea9bb..38288853f 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -2,7 +2,7 @@ from .connect_devices import connect_devices from .file_permissions import get_owner_gid, is_sgid_set from .invalid_config_error import InvalidConfigError -from .modules import load_module_all +from .modules import is_function_sourced_from_module, load_module_all from .serialization import serialize from .thread_exception import handle_all_exceptions @@ -17,4 +17,5 @@ "connect_devices", "is_sgid_set", "get_owner_gid", + "is_function_sourced_from_module", ] diff --git a/src/blueapi/utils/modules.py b/src/blueapi/utils/modules.py index e7c28e6f9..1ee58d505 100644 --- a/src/blueapi/utils/modules.py +++ b/src/blueapi/utils/modules.py @@ -1,4 +1,5 @@ -from collections.abc import Iterable +import importlib +from collections.abc import Callable, Iterable from types import ModuleType from typing import Any @@ -34,3 +35,17 @@ def get_named_subset(names: list[str]): for name, value in mod.__dict__.items(): if not name.startswith("_"): yield value + + +def is_function_sourced_from_module( + func: Callable[..., Any], module: ModuleType +) -> bool: + """ + Check if a function is originally from a particular module, useful to detect + whether it actually comes from a nested import. + + Args: + func: Object to check + module: Module to check against object + """ + return importlib.import_module(func.__module__) is module diff --git a/tests/unit_tests/core/fake_plan_module_with_all.py b/tests/unit_tests/core/fake_plan_module_with_all.py new file mode 100644 index 000000000..84bafe0ef --- /dev/null +++ b/tests/unit_tests/core/fake_plan_module_with_all.py @@ -0,0 +1,9 @@ +from bluesky.utils import MsgGenerator +from tests.unit_tests.core.fake_plan_module import plan_a, plan_b # noqa: F401 + + +def plan_c(c: bool) -> MsgGenerator[None]: ... +def plan_d(d: int) -> MsgGenerator[int]: ... + + +__all__ = ["plan_a", "plan_d"] diff --git a/tests/unit_tests/core/fake_plan_module_with_imports.py b/tests/unit_tests/core/fake_plan_module_with_imports.py new file mode 100644 index 000000000..f2e99cbe4 --- /dev/null +++ b/tests/unit_tests/core/fake_plan_module_with_imports.py @@ -0,0 +1,6 @@ +from bluesky.utils import MsgGenerator +from tests.unit_tests.core.fake_plan_module import plan_a, plan_b # noqa: F401 + + +def plan_c(c: bool) -> MsgGenerator[None]: ... +def plan_d(d: int) -> MsgGenerator[int]: ... diff --git a/tests/unit_tests/core/test_context.py b/tests/unit_tests/core/test_context.py index b28088544..72598d551 100644 --- a/tests/unit_tests/core/test_context.py +++ b/tests/unit_tests/core/test_context.py @@ -170,6 +170,20 @@ def test_add_plan_from_module(empty_context: BlueskyContext) -> None: assert EXPECTED_PLANS == empty_context.plans.keys() +def test_only_plans_from_source_module_detected(empty_context: BlueskyContext) -> None: + import tests.unit_tests.core.fake_plan_module_with_imports as plan_module + + empty_context.with_plan_module(plan_module) + assert {"plan_c", "plan_d"} == empty_context.plans.keys() + + +def test_only_plans_from_all_in_module_detected(empty_context: BlueskyContext) -> None: + import tests.unit_tests.core.fake_plan_module_with_all as plan_module + + empty_context.with_plan_module(plan_module) + assert {"plan_a", "plan_d"} == empty_context.plans.keys() + + def test_add_named_device(empty_context: BlueskyContext, sim_motor: SynAxis) -> None: empty_context.register_device(sim_motor) assert empty_context.devices[SIM_MOTOR_NAME] is sim_motor diff --git a/tests/unit_tests/utils/functions_a.py b/tests/unit_tests/utils/functions_a.py new file mode 100644 index 000000000..fad3c27e6 --- /dev/null +++ b/tests/unit_tests/utils/functions_a.py @@ -0,0 +1,4 @@ +def a(): ... + + +def b(): ... diff --git a/tests/unit_tests/utils/functions_b.py b/tests/unit_tests/utils/functions_b.py new file mode 100644 index 000000000..78e72d560 --- /dev/null +++ b/tests/unit_tests/utils/functions_b.py @@ -0,0 +1,7 @@ +from .functions_a import a, b # noqa: F401 + + +def c(): ... + + +def d(): ... diff --git a/tests/unit_tests/utils/test_modules.py b/tests/unit_tests/utils/test_modules.py index c48c13cec..5a0cf3bcf 100644 --- a/tests/unit_tests/utils/test_modules.py +++ b/tests/unit_tests/utils/test_modules.py @@ -1,6 +1,6 @@ from importlib import import_module -from blueapi.utils import load_module_all +from blueapi.utils import is_function_sourced_from_module, load_module_all def test_imports_all(): @@ -11,3 +11,17 @@ def test_imports_all(): def test_imports_everything_without_all(): module = import_module(".lacksall", package="tests.unit_tests.utils") assert list(load_module_all(module)) == [3, "hello", 9] + + +def test_source_is_in_module(): + module = import_module(".functions_b", package="tests.unit_tests.utils") + c = module.__dict__["c"] + assert callable(c) + assert is_function_sourced_from_module(c, module) + + +def test_source_is_not_in_module(): + module = import_module(".functions_b", package="tests.unit_tests.utils") + a = module.__dict__["a"] + assert callable(a) + assert not is_function_sourced_from_module(a, module)