Skip to content

Commit

Permalink
Make source check more efficient
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester committed Jan 31, 2025
1 parent 30acc5b commit 21d7352
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def plan_2(...) -> MsgGenerator:
# they are valid plans, unless there is an __all__ defined in the module,
# in which case we only inspect objects listed there, regardless of their
# original source module.
if (
if is_bluesky_plan_generator(obj) and (
hasattr(module, "__all__")
or is_function_sourced_from_module(obj, module)
) and is_bluesky_plan_generator(obj):
):
self.register_plan(obj)

def with_device_module(self, module: ModuleType) -> None:
Expand Down
10 changes: 5 additions & 5 deletions src/blueapi/utils/modules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import importlib
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from types import ModuleType
from typing import Any

Expand Down Expand Up @@ -37,7 +37,9 @@ def get_named_subset(names: list[str]):
yield value


def is_function_sourced_from_module(obj: Any, module: ModuleType) -> bool:
def is_function_sourced_from_module(
obj: Callable[..., Any], module: ModuleType
) -> bool:
"""
Check if an object is originally from a particular module, useful to detect
whether it actually comes from a nested import.
Expand All @@ -46,6 +48,4 @@ def is_function_sourced_from_module(obj: Any, module: ModuleType) -> bool:
obj: Object to check
module: Module to check against object
"""
return (
hasattr(obj, "__module__") and importlib.import_module(obj.__module__) is module
)
return importlib.import_module(obj.__module__) is module
3 changes: 0 additions & 3 deletions tests/unit_tests/utils/functions_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,3 @@ def c(): ...


def d(): ...


e = 1
13 changes: 6 additions & 7 deletions tests/unit_tests/utils/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@ def test_imports_everything_without_all():

def test_source_is_in_module():
module = import_module(".functions_b", package="tests.unit_tests.utils")
assert is_function_sourced_from_module(module.__dict__["c"], module)
c = module.__dict__["c"]
assert callable(c)
assert is_function_sourced_from_module(c, module)


def test_source_is_not_in_module():
module = import_module(".functions_b", package="tests.unit_tests.utils")
assert not is_function_sourced_from_module(module.__dict__["a"], module)


def test_source_check_on_non_function():
module = import_module(".functions_b", package="tests.unit_tests.utils")
assert not is_function_sourced_from_module(module.__dict__["e"], module)
a = module.__dict__["a"]
assert callable(a)
assert not is_function_sourced_from_module(a, module)

0 comments on commit 21d7352

Please sign in to comment.