Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

795 Care about where a plan is sourced #807

Merged
merged 8 commits into from
Jan 31, 2025
4 changes: 3 additions & 1 deletion docs/how-to/write-plans.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def my_plan(
...
```

The type annotations (e.g. `: str`, `: int`, `-> MsgGenerator`) are required as blueapi uses them to detect that this function is intended to be a plan and generate its runtime API.
## Detection

The type annotations in the example above (e.g. `: str`, `: int`, `-> MsgGenerator`) are required as blueapi uses them to detect that this function is intended to be a plan and generate its runtime API. If there is an [`__all__` dunder](https://docs.python.org/3/tutorial/modules.html#importing-from-a-package) present in the module, blueapi will read that and import anything within that qualifies as a plan, per its type annotations. If not it will read everything in the module that hasn't been imported, for example it will ignore a plan imported from another module.

**Input annotations should be as broad as possible**, the least specific implementation that is sufficient to accomplish the requirements of the plan. For example, if a plan is written to drive a specific motor (`MyMotor`), but only uses the general methods on the [`Movable` protocol](https://blueskyproject.io/bluesky/main/hardware.html#bluesky.protocols.Movable), it should take `Movable` as a parameter annotation rather than `MyMotor`.

Expand Down
16 changes: 14 additions & 2 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

from blueapi import utils
from blueapi.config import EnvironmentConfig, SourceKind
from blueapi.utils import BlueapiPlanModelConfig, load_module_all
from blueapi.utils import (
BlueapiPlanModelConfig,
is_function_sourced_from_module,
load_module_all,
)

from .bluesky_types import (
BLUESKY_PROTOCOLS,
Expand Down Expand Up @@ -99,7 +103,15 @@ def plan_2(...) -> MsgGenerator:
"""

for obj in load_module_all(module):
if is_bluesky_plan_generator(obj):
# The rule here is that we only inspect objects defined in the module
# (as opposed to objects imported from other modules) to determine if
# they are valid plans, unless there is an __all__ defined in the module,
# in which case we only inspect objects listed there, regardless of their
# original source module.
if is_bluesky_plan_generator(obj) and (
hasattr(module, "__all__")
or is_function_sourced_from_module(obj, module)
):
self.register_plan(obj)

def with_device_module(self, module: ModuleType) -> None:
Expand Down
3 changes: 2 additions & 1 deletion src/blueapi/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .connect_devices import connect_devices
from .file_permissions import get_owner_gid, is_sgid_set
from .invalid_config_error import InvalidConfigError
from .modules import load_module_all
from .modules import is_function_sourced_from_module, load_module_all
from .serialization import serialize
from .thread_exception import handle_all_exceptions

Expand All @@ -17,4 +17,5 @@
"connect_devices",
"is_sgid_set",
"get_owner_gid",
"is_function_sourced_from_module",
]
17 changes: 16 additions & 1 deletion src/blueapi/utils/modules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Iterable
import importlib
from collections.abc import Callable, Iterable
from types import ModuleType
from typing import Any

Expand Down Expand Up @@ -34,3 +35,17 @@ def get_named_subset(names: list[str]):
for name, value in mod.__dict__.items():
if not name.startswith("_"):
yield value


def is_function_sourced_from_module(
func: Callable[..., Any], module: ModuleType
) -> bool:
"""
Check if a function is originally from a particular module, useful to detect
whether it actually comes from a nested import.

Args:
func: Object to check
module: Module to check against object
"""
return importlib.import_module(func.__module__) is module
9 changes: 9 additions & 0 deletions tests/unit_tests/core/fake_plan_module_with_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from bluesky.utils import MsgGenerator
from tests.unit_tests.core.fake_plan_module import plan_a, plan_b # noqa: F401


def plan_c(c: bool) -> MsgGenerator[None]: ...
def plan_d(d: int) -> MsgGenerator[int]: ...


__all__ = ["plan_a", "plan_d"]
6 changes: 6 additions & 0 deletions tests/unit_tests/core/fake_plan_module_with_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from bluesky.utils import MsgGenerator
from tests.unit_tests.core.fake_plan_module import plan_a, plan_b # noqa: F401


def plan_c(c: bool) -> MsgGenerator[None]: ...
def plan_d(d: int) -> MsgGenerator[int]: ...
14 changes: 14 additions & 0 deletions tests/unit_tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,20 @@ def test_add_plan_from_module(empty_context: BlueskyContext) -> None:
assert EXPECTED_PLANS == empty_context.plans.keys()


def test_only_plans_from_source_module_detected(empty_context: BlueskyContext) -> None:
import tests.unit_tests.core.fake_plan_module_with_imports as plan_module

empty_context.with_plan_module(plan_module)
assert {"plan_c", "plan_d"} == empty_context.plans.keys()


def test_only_plans_from_all_in_module_detected(empty_context: BlueskyContext) -> None:
import tests.unit_tests.core.fake_plan_module_with_all as plan_module

empty_context.with_plan_module(plan_module)
assert {"plan_a", "plan_d"} == empty_context.plans.keys()


def test_add_named_device(empty_context: BlueskyContext, sim_motor: SynAxis) -> None:
empty_context.register_device(sim_motor)
assert empty_context.devices[SIM_MOTOR_NAME] is sim_motor
Expand Down
4 changes: 4 additions & 0 deletions tests/unit_tests/utils/functions_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def a(): ...


def b(): ...
7 changes: 7 additions & 0 deletions tests/unit_tests/utils/functions_b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .functions_a import a, b # noqa: F401


def c(): ...


def d(): ...
16 changes: 15 additions & 1 deletion tests/unit_tests/utils/test_modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from importlib import import_module

from blueapi.utils import load_module_all
from blueapi.utils import is_function_sourced_from_module, load_module_all


def test_imports_all():
Expand All @@ -11,3 +11,17 @@ def test_imports_all():
def test_imports_everything_without_all():
module = import_module(".lacksall", package="tests.unit_tests.utils")
assert list(load_module_all(module)) == [3, "hello", 9]


def test_source_is_in_module():
module = import_module(".functions_b", package="tests.unit_tests.utils")
c = module.__dict__["c"]
assert callable(c)
assert is_function_sourced_from_module(c, module)


def test_source_is_not_in_module():
module = import_module(".functions_b", package="tests.unit_tests.utils")
a = module.__dict__["a"]
assert callable(a)
assert not is_function_sourced_from_module(a, module)
Loading