diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index cd3a0ede6..fd24f3ab4 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -108,10 +108,10 @@ def plan_2(...) -> MsgGenerator: # they are valid plans, unless there is an __all__ defined in the module, # in which case we only inspect objects listed there, regardless of their # original source module. - if ( + if is_bluesky_plan_generator(obj) and ( hasattr(module, "__all__") or is_function_sourced_from_module(obj, module) - ) and is_bluesky_plan_generator(obj): + ): self.register_plan(obj) def with_device_module(self, module: ModuleType) -> None: diff --git a/src/blueapi/utils/modules.py b/src/blueapi/utils/modules.py index ca04237d3..932f86985 100644 --- a/src/blueapi/utils/modules.py +++ b/src/blueapi/utils/modules.py @@ -1,5 +1,5 @@ import importlib -from collections.abc import Iterable +from collections.abc import Callable, Iterable from types import ModuleType from typing import Any @@ -37,7 +37,9 @@ def get_named_subset(names: list[str]): yield value -def is_function_sourced_from_module(obj: Any, module: ModuleType) -> bool: +def is_function_sourced_from_module( + obj: Callable[..., Any], module: ModuleType +) -> bool: """ Check if an object is originally from a particular module, useful to detect whether it actually comes from a nested import. @@ -46,6 +48,4 @@ def is_function_sourced_from_module(obj: Any, module: ModuleType) -> bool: obj: Object to check module: Module to check against object """ - return ( - hasattr(obj, "__module__") and importlib.import_module(obj.__module__) is module - ) + return importlib.import_module(obj.__module__) is module diff --git a/tests/unit_tests/utils/functions_b.py b/tests/unit_tests/utils/functions_b.py index 90c6714d6..78e72d560 100644 --- a/tests/unit_tests/utils/functions_b.py +++ b/tests/unit_tests/utils/functions_b.py @@ -5,6 +5,3 @@ def c(): ... def d(): ... - - -e = 1 diff --git a/tests/unit_tests/utils/test_modules.py b/tests/unit_tests/utils/test_modules.py index d9b47434b..5a0cf3bcf 100644 --- a/tests/unit_tests/utils/test_modules.py +++ b/tests/unit_tests/utils/test_modules.py @@ -15,14 +15,13 @@ def test_imports_everything_without_all(): def test_source_is_in_module(): module = import_module(".functions_b", package="tests.unit_tests.utils") - assert is_function_sourced_from_module(module.__dict__["c"], module) + c = module.__dict__["c"] + assert callable(c) + assert is_function_sourced_from_module(c, module) def test_source_is_not_in_module(): module = import_module(".functions_b", package="tests.unit_tests.utils") - assert not is_function_sourced_from_module(module.__dict__["a"], module) - - -def test_source_check_on_non_function(): - module = import_module(".functions_b", package="tests.unit_tests.utils") - assert not is_function_sourced_from_module(module.__dict__["e"], module) + a = module.__dict__["a"] + assert callable(a) + assert not is_function_sourced_from_module(a, module)