From 668437aff7bce0a475e55879bbe29e314dfcd8f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Manuel=20Dom=C3=ADnguez?= Date: Thu, 14 Nov 2024 17:57:09 +0100 Subject: [PATCH 1/2] Support listing file sources asynchronously The method `list()` from `galaxy.files.sources.BaseFilesSource` lists the directories and files within a file source. An optional keyword argument `recursive` (`False` by default) lets it recursively retrieve directories and files within a specific directory. This operation is very cheap in terms of CPU and expensive in IO terms, be it network or filesystem IO. Depending on how the underlying system is built, it may support retrieving directories and files recursively or not. If it does not, then every time a directory is listed, it is necessary to make another request to list each subdirectory. This may end up involving hundreds of requests. Done sequentially, this can be extremely slow, especially if each one involves network access. This commit makes the `list()` method asynchronous, which enables Galaxy to wait for the underlying system to complete the requests concurrently, resulting in a massive speedup. The price to pay is the extra complexity of using the async primitives. Since this change implies that all functions in the chain up to the API endpoints and the test functions must also be made asynchronous, this commit also takes care of it. --- lib/galaxy/files/sources/__init__.py | 8 +- lib/galaxy/files/sources/_pyfilesystem2.py | 2 +- lib/galaxy/files/sources/invenio.py | 2 +- lib/galaxy/files/sources/posix.py | 2 +- lib/galaxy/files/sources/s3fs.py | 2 +- lib/galaxy/managers/file_source_instances.py | 34 ++++----- lib/galaxy/managers/remote_files.py | 4 +- lib/galaxy/util/unittest.py | 23 ++++++ lib/galaxy/webapps/galaxy/api/file_sources.py | 12 +-- lib/galaxy/webapps/galaxy/api/remote_files.py | 4 +- packages/app/test-requirements.txt | 1 + packages/files/test-requirements.txt | 1 + test/unit/app/managers/base.py | 15 +++- .../app/managers/test_user_file_sources.py | 44 +++++------ test/unit/files/_util.py | 12 +-- test/unit/files/test_dropbox.py | 5 +- test/unit/files/test_gcsfs.py | 5 +- test/unit/files/test_googledrive.py | 5 +- test/unit/files/test_onedata.py | 5 +- test/unit/files/test_posix.py | 75 +++++++++++-------- test/unit/files/test_s3.py | 5 +- test/unit/files/test_temp.py | 66 +++++++++------- test/unit/files/test_webdav.py | 19 +++-- 23 files changed, 207 insertions(+), 144 deletions(-) diff --git a/lib/galaxy/files/sources/__init__.py b/lib/galaxy/files/sources/__init__.py index fd1e84670173..ed3fcdd45e06 100644 --- a/lib/galaxy/files/sources/__init__.py +++ b/lib/galaxy/files/sources/__init__.py @@ -296,7 +296,7 @@ def get_uri_root(self) -> str: """Return a prefix for the root (e.g. gxfiles://prefix/).""" @abc.abstractmethod - def list( + async def list( self, path="/", recursive=False, @@ -443,7 +443,7 @@ def _serialization_props(self, user_context: "OptionalUserContext" = None) -> Fi Used in to_dict method if for_serialization is True. """ - def list( + async def list( self, path="/", recursive=False, @@ -467,9 +467,9 @@ def list( if offset is not None and offset < 0: raise RequestParameterInvalidException("Offset must be greater than or equal to 0.") - return self._list(path, recursive, user_context, opts, limit, offset, query) + return await self._list(path, recursive, user_context, opts, limit, offset, query) - def _list( + async def _list( self, path="/", recursive=False, diff --git a/lib/galaxy/files/sources/_pyfilesystem2.py b/lib/galaxy/files/sources/_pyfilesystem2.py index 5b3add6e6119..76993ededbe9 100644 --- a/lib/galaxy/files/sources/_pyfilesystem2.py +++ b/lib/galaxy/files/sources/_pyfilesystem2.py @@ -50,7 +50,7 @@ def __init__(self, **kwd: Unpack[FilesSourceProperties]): def _open_fs(self, user_context: OptionalUserContext = None, opts: Optional[FilesSourceOptions] = None) -> FS: """Subclasses must instantiate a PyFilesystem2 handle for this file system.""" - def _list( + async def _list( self, path="/", recursive=False, diff --git a/lib/galaxy/files/sources/invenio.py b/lib/galaxy/files/sources/invenio.py index 146d63d0b641..7ba366a9b440 100644 --- a/lib/galaxy/files/sources/invenio.py +++ b/lib/galaxy/files/sources/invenio.py @@ -146,7 +146,7 @@ def to_relative_path(self, url: str) -> str: def get_repository_interactor(self, repository_url: str) -> RDMRepositoryInteractor: return InvenioRepositoryInteractor(repository_url, self) - def _list( + async def _list( self, path="/", recursive=True, diff --git a/lib/galaxy/files/sources/posix.py b/lib/galaxy/files/sources/posix.py index 258d5ef20a69..cac3a8fbee47 100644 --- a/lib/galaxy/files/sources/posix.py +++ b/lib/galaxy/files/sources/posix.py @@ -60,7 +60,7 @@ def __init__(self, **kwd: Unpack[PosixFilesSourceProperties]): def prefer_links(self) -> bool: return self._prefer_links - def _list( + async def _list( self, path="/", recursive=True, diff --git a/lib/galaxy/files/sources/s3fs.py b/lib/galaxy/files/sources/s3fs.py index c6a7dedc96d3..d1453965bf5e 100644 --- a/lib/galaxy/files/sources/s3fs.py +++ b/lib/galaxy/files/sources/s3fs.py @@ -68,7 +68,7 @@ def __init__(self, **kwd: Unpack[S3FsFilesSourceProperties]): if self._endpoint_url: self._props.update({"client_kwargs": {"endpoint_url": self._endpoint_url}}) - def _list( + async def _list( self, path="/", recursive=True, diff --git a/lib/galaxy/managers/file_source_instances.py b/lib/galaxy/managers/file_source_instances.py index 4572eafc8c66..5936cecd7acd 100644 --- a/lib/galaxy/managers/file_source_instances.py +++ b/lib/galaxy/managers/file_source_instances.py @@ -334,45 +334,45 @@ def create_instance(self, trans: ProvidesUserContext, payload: CreateInstancePay self._save(persisted_file_source) return self._to_model(trans, persisted_file_source) - def test_modify_instance( + async def test_modify_instance( self, trans: ProvidesUserContext, id: UUID4, payload: TestModifyInstancePayload ) -> PluginStatus: persisted_file_source = self._get(trans, id) if isinstance(payload, TestUpgradeInstancePayload): - return self._plugin_status_for_upgrade(trans, payload, persisted_file_source) + return await self._plugin_status_for_upgrade(trans, payload, persisted_file_source) else: assert isinstance(payload, TestUpdateInstancePayload) - return self._plugin_status_for_update(trans, payload, persisted_file_source) + return await self._plugin_status_for_update(trans, payload, persisted_file_source) - def _plugin_status_for_update( + async def _plugin_status_for_update( self, trans: ProvidesUserContext, payload: TestUpdateInstancePayload, persisted_file_source: UserFileSource ) -> PluginStatus: template = self._get_template(persisted_file_source) target = UpdateTestTarget(persisted_file_source, payload) - return self._plugin_status_for_template(trans, target, template) + return await self._plugin_status_for_template(trans, target, template) - def _plugin_status_for_upgrade( + async def _plugin_status_for_upgrade( self, trans: ProvidesUserContext, payload: TestUpgradeInstancePayload, persisted_file_source: UserFileSource ) -> PluginStatus: template = self._get_and_validate_target_upgrade_template(persisted_file_source, payload) target = UpgradeTestTarget(persisted_file_source, payload) - return self._plugin_status_for_template(trans, target, template) + return await self._plugin_status_for_template(trans, target, template) - def plugin_status_for_instance(self, trans: ProvidesUserContext, id: UUID4): + async def plugin_status_for_instance(self, trans: ProvidesUserContext, id: UUID4): persisted_file_source = self._get(trans, id) - return self._plugin_status(trans, persisted_file_source, to_template_reference(persisted_file_source)) + return await self._plugin_status(trans, persisted_file_source, to_template_reference(persisted_file_source)) - def plugin_status(self, trans: ProvidesUserContext, payload: CreateInstancePayload) -> PluginStatus: + async def plugin_status(self, trans: ProvidesUserContext, payload: CreateInstancePayload) -> PluginStatus: target = CreateTestTarget(payload, UserFileSource) - return self._plugin_status(trans, target, payload) + return await self._plugin_status(trans, target, payload) - def _plugin_status( + async def _plugin_status( self, trans: ProvidesUserContext, target: CanTestPluginStatus, template_reference: TemplateReference ): template = self._catalog.find_template(template_reference) - return self._plugin_status_for_template(trans, target, template) + return await self._plugin_status_for_template(trans, target, template) - def _plugin_status_for_template( + async def _plugin_status_for_template( self, trans: ProvidesUserContext, payload: CanTestPluginStatus, template: FileSourceTemplate ): template_definition_status = status_template_definition(template) @@ -396,7 +396,7 @@ def _plugin_status_for_template( if template_settings_status.is_not_ok: return PluginStatus(**status_kwds) assert configuration - file_source, connection_status = self._connection_status(trans, payload, configuration) + file_source, connection_status = await self._connection_status(trans, payload, configuration) status_kwds["connection"] = connection_status if connection_status.is_not_ok: return PluginStatus(**status_kwds) @@ -443,7 +443,7 @@ def _template_settings_status( exception = e return configuration, settings_exception_to_status(exception) - def _connection_status( + async def _connection_status( self, trans: ProvidesUserContext, target: CanTestPluginStatus, configuration: FileSourceConfiguration ) -> Tuple[Optional[BaseFilesSource], PluginAspectStatus]: file_source = None @@ -471,7 +471,7 @@ def _connection_status( # a connection problem if we cannot browsable_file_source = cast(SupportsBrowsing, file_source) user_context = ProvidesFileSourcesUserContext(trans) - browsable_file_source.list("/", recursive=False, user_context=user_context) + await browsable_file_source.list("/", recursive=False, user_context=user_context) except Exception as e: exception = e return file_source, connection_exception_to_status("file source", exception) diff --git a/lib/galaxy/managers/remote_files.py b/lib/galaxy/managers/remote_files.py index 165efe3c95b2..2f5dd0779d0e 100644 --- a/lib/galaxy/managers/remote_files.py +++ b/lib/galaxy/managers/remote_files.py @@ -43,7 +43,7 @@ class RemoteFilesManager: def __init__(self, app: MinimalManagerApp): self._app = app - def index( + async def index( self, user_ctx: ProvidesUserContext, target: str, @@ -93,7 +93,7 @@ def index( opts = FilesSourceOptions() opts.writeable = writeable or False try: - index, count = file_source.list( + index, count = await file_source.list( file_source_path.path, recursive=recursive, user_context=user_file_source_context, diff --git a/lib/galaxy/util/unittest.py b/lib/galaxy/util/unittest.py index 8794f34a6daf..9c2252281db0 100644 --- a/lib/galaxy/util/unittest.py +++ b/lib/galaxy/util/unittest.py @@ -1,3 +1,5 @@ +import inspect + import pytest @@ -43,3 +45,24 @@ def assertRaises(self, exception): def assertRaisesRegex(self, exception, regex): return pytest.raises(exception, match=regex) + + +class MarkAsyncMeta(type): + """ + Metaclass that marks all asynchronous methods of a class as async tests. + + Methods that are not recognized by pytest as tests will simply be ignored, despite having been marked as async + tests. + """ + + def __new__(cls, name, bases, dict_): + for attribute_name, attribute_value in dict_.items(): + if inspect.iscoroutinefunction(attribute_value): + dict_[attribute_name] = pytest.mark.asyncio(attribute_value) + return super().__new__(cls, name, bases, dict_) + + +class IsolatedAsyncioTestCase(TestCase, metaclass=MarkAsyncMeta): + """ + Partial re-implementation of standard library `unittest.IsolatedAsyncioTestCase` using pytest methods. + """ diff --git a/lib/galaxy/webapps/galaxy/api/file_sources.py b/lib/galaxy/webapps/galaxy/api/file_sources.py index 4546e7b88ebc..62f7d63a5165 100644 --- a/lib/galaxy/webapps/galaxy/api/file_sources.py +++ b/lib/galaxy/webapps/galaxy/api/file_sources.py @@ -96,12 +96,12 @@ def create( summary="Test payload for creating user-bound file source.", operation_id="file_sources__test_new_instance_configuration", ) - def test_instance_configuration( + async def test_instance_configuration( self, trans: ProvidesUserContext = DependsOnTrans, payload: CreateInstancePayload = Body(...), ) -> PluginStatus: - return self.file_source_instances_manager.plugin_status(trans, payload) + return await self.file_source_instances_manager.plugin_status(trans, payload) @router.get( "/api/file_source_instances", @@ -131,12 +131,12 @@ def instances_show( summary="Test a file source instance and return status.", operation_id="file_sources__instances_test_instance", ) - def instance_test( + async def instance_test( self, trans: ProvidesUserContext = DependsOnTrans, uuid: UUID4 = UserFileSourceIdPathParam, ) -> PluginStatus: - return self.file_source_instances_manager.plugin_status_for_instance(trans, uuid) + return await self.file_source_instances_manager.plugin_status_for_instance(trans, uuid) @router.put( "/api/file_source_instances/{uuid}", @@ -156,13 +156,13 @@ def update_instance( summary="Test updating or upgrading user file source instance.", operation_id="file_sources__test_instances_update", ) - def test_update_instance( + async def test_update_instance( self, trans: ProvidesUserContext = DependsOnTrans, uuid: UUID4 = UserFileSourceIdPathParam, payload: TestModifyInstancePayload = Body(...), ) -> PluginStatus: - return self.file_source_instances_manager.test_modify_instance(trans, uuid, payload) + return await self.file_source_instances_manager.test_modify_instance(trans, uuid, payload) @router.delete( "/api/file_source_instances/{uuid}", diff --git a/lib/galaxy/webapps/galaxy/api/remote_files.py b/lib/galaxy/webapps/galaxy/api/remote_files.py index c5a708d748bf..ef0c9e5a2238 100644 --- a/lib/galaxy/webapps/galaxy/api/remote_files.py +++ b/lib/galaxy/webapps/galaxy/api/remote_files.py @@ -128,7 +128,7 @@ class FastAPIRemoteFiles: deprecated=True, summary="Displays remote files available to the user. Please use /api/remote_files instead.", ) - def index( + async def index( self, response: Response, user_ctx: ProvidesUserContext = DependsOnTrans, @@ -146,7 +146,7 @@ def index( The total count of files and directories is returned in the 'total_matches' header. """ - result, count = self.manager.index( + result, count = await self.manager.index( user_ctx, target, format, recursive, disable, writeable, limit, offset, query, sort_by ) response.headers["total_matches"] = str(count) diff --git a/packages/app/test-requirements.txt b/packages/app/test-requirements.txt index 7075b49702b9..da161030b1c0 100644 --- a/packages/app/test-requirements.txt +++ b/packages/app/test-requirements.txt @@ -2,4 +2,5 @@ mock-ssh-server pkce pykwalify pytest +pytest-asyncio testfixtures diff --git a/packages/files/test-requirements.txt b/packages/files/test-requirements.txt index ad572b378cec..e5abfc9b7a80 100644 --- a/packages/files/test-requirements.txt +++ b/packages/files/test-requirements.txt @@ -1,3 +1,4 @@ pytest +pytest-asyncio fs-gcsfs s3fs>=2023.1.0,<2024 diff --git a/test/unit/app/managers/base.py b/test/unit/app/managers/base.py index 5eee85e4e022..57aa1a41e7c7 100644 --- a/test/unit/app/managers/base.py +++ b/test/unit/app/managers/base.py @@ -8,9 +8,14 @@ from galaxy.app_unittest_utils import galaxy_mock from galaxy.managers.users import UserManager -from galaxy.util.unittest import TestCase +from galaxy.util.unittest import ( + IsolatedAsyncioTestCase, + TestCase, +) from galaxy.work.context import SessionRequestContext +__all__ = ("BaseIsolatedAsyncioTestCase", "BaseTestCase", "CreatesCollectionsMixin") + # ============================================================================= admin_email = "admin@admin.admin" admin_users = admin_email @@ -104,6 +109,14 @@ def assertIsJsonifyable(self, item): assert isinstance(json.dumps(item), str) +class BaseIsolatedAsyncioTestCase(BaseTestCase, IsolatedAsyncioTestCase): + """ + Asynchronous version of `BaseTestCase`. + + Can run sync tests too. + """ + + class CreatesCollectionsMixin: trans: SessionRequestContext diff --git a/test/unit/app/managers/test_user_file_sources.py b/test/unit/app/managers/test_user_file_sources.py index bb30feed8184..429b7b0627b6 100644 --- a/test/unit/app/managers/test_user_file_sources.py +++ b/test/unit/app/managers/test_user_file_sources.py @@ -46,7 +46,7 @@ from galaxy.schema.schema import OAuth2State from galaxy.util import config_templates from galaxy.util.config_templates import RawTemplateConfig -from .base import BaseTestCase +from .base import BaseIsolatedAsyncioTestCase SIMPLE_FILE_SOURCE_NAME = "myfilesource" SIMPLE_FILE_SOURCE_DESCRIPTION = "a description of my file source" @@ -205,7 +205,7 @@ def simple_vault_template(tmp_path): ) -class TestFileSourcesTestCase(BaseTestCase): +class TestFileSourcesTestCase(BaseIsolatedAsyncioTestCase): manager: FileSourceInstancesManager file_sources: UserDefinedFileSourcesImpl @@ -282,7 +282,7 @@ def mock_get_token_from_code_raw( user_object = self._create_instance(create_payload) assert get_uuid(user_object.uuid) == get_uuid(uuid) - def test_oauth2_access_token_injection_during_verify(self, tmp_path, monkeypatch): + async def test_oauth2_access_token_injection_during_verify(self, tmp_path, monkeypatch): if DropboxFS is None: raise SkipTest("Optional dropbpox dependency not available") self._init_dropbox_env(tmp_path, monkeypatch) @@ -320,14 +320,14 @@ def __init__(self, **kwd): monkeypatch.setattr(config_templates, "get_token_from_refresh_raw", mock_get_token_from_refresh_raw) monkeypatch.setattr(dropbox, "DropboxFS", MockDropboxFS) - status = self.manager.plugin_status(self.trans, create_payload) + status = await self.manager.plugin_status(self.trans, create_payload) assert status.oauth2_access_token_generation assert not status.oauth2_access_token_generation.is_not_ok assert status.connection assert not status.connection.is_not_ok assert pyfilesystem_fs_init_kwd["access_token"] == "my_test_access_token" - def test_report_oauth2_access_token_generation_failure(self, tmp_path, monkeypatch): + async def test_report_oauth2_access_token_generation_failure(self, tmp_path, monkeypatch): self._init_dropbox_env(tmp_path, monkeypatch) uuid = uuid4().hex @@ -350,7 +350,7 @@ def mock_get_token_from_refresh_raw(refresh_token, client_pair, config): return MockExceptionResponse(excepton_message) monkeypatch.setattr(config_templates, "get_token_from_refresh_raw", mock_get_token_from_refresh_raw) - status = self.manager.plugin_status(self.trans, create_payload) + status = await self.manager.plugin_status(self.trans, create_payload) assert status.oauth2_access_token_generation assert status.oauth2_access_token_generation.is_not_ok assert excepton_message in status.oauth2_access_token_generation.message @@ -602,7 +602,7 @@ def test_upgrade_fails_if_new_secrets_absent(self, tmp_path): self._assert_modify_throws_exception(user_object_store, upgrade_to_1, RequestParameterMissingException) - def test_status_valid(self, tmp_path): + async def test_status_valid(self, tmp_path): self.init_user_in_database() self._init_managers(tmp_path) (tmp_path / self.trans.user.username).mkdir() @@ -614,14 +614,14 @@ def test_status_valid(self, tmp_path): variables={}, secrets={}, ) - status = self.manager.plugin_status(self.trans, create_payload) + status = await self.manager.plugin_status(self.trans, create_payload) assert status.connection assert not status.connection.is_not_ok assert not status.template_definition.is_not_ok assert status.template_settings assert not status.template_settings.is_not_ok - def test_status_invalid_connection(self, tmp_path): + async def test_status_invalid_connection(self, tmp_path): self.init_user_in_database() self._init_managers(tmp_path) # We don't make the directory like above so it doesn't exist @@ -634,7 +634,7 @@ def test_status_invalid_connection(self, tmp_path): variables={}, secrets={}, ) - status = self.manager.plugin_status(self.trans, create_payload) + status = await self.manager.plugin_status(self.trans, create_payload) assert not status.template_definition.is_not_ok assert status.template_settings assert not status.template_settings.is_not_ok @@ -644,7 +644,7 @@ def test_status_invalid_connection(self, tmp_path): assert status.connection assert status.connection.is_not_ok - def test_status_invalid_settings_undefined_variable(self, tmp_path): + async def test_status_invalid_settings_undefined_variable(self, tmp_path): self.init_user_in_database() self._init_managers(tmp_path, config_dict=invalid_home_directory_template(tmp_path)) create_payload = CreateInstancePayload( @@ -655,7 +655,7 @@ def test_status_invalid_settings_undefined_variable(self, tmp_path): variables={}, secrets={}, ) - status = self.manager.plugin_status(self.trans, create_payload) + status = await self.manager.plugin_status(self.trans, create_payload) assert not status.template_definition.is_not_ok assert status.template_settings assert status.template_settings.is_not_ok @@ -664,7 +664,7 @@ def test_status_invalid_settings_undefined_variable(self, tmp_path): ) assert status.connection is None - def test_status_invalid_settings_configuration_validation(self, tmp_path): + async def test_status_invalid_settings_configuration_validation(self, tmp_path): self.init_user_in_database() self._init_managers(tmp_path, config_dict=invalid_home_directory_template_type_error(tmp_path)) create_payload = CreateInstancePayload( @@ -675,14 +675,14 @@ def test_status_invalid_settings_configuration_validation(self, tmp_path): variables={}, secrets={}, ) - status = self.manager.plugin_status(self.trans, create_payload) + status = await self.manager.plugin_status(self.trans, create_payload) assert not status.template_definition.is_not_ok assert status.template_settings assert status.template_settings.is_not_ok assert "Input should be a valid boolean" in status.template_settings.message assert status.connection is None - def test_status_existing_valid(self, tmp_path): + async def test_status_existing_valid(self, tmp_path): self.init_user_in_database() self._init_managers(tmp_path) (tmp_path / self.trans.user.username).mkdir() @@ -695,14 +695,14 @@ def test_status_existing_valid(self, tmp_path): secrets={}, ) user_file_source = self._create_instance(create_payload) - status = self.manager.plugin_status_for_instance(self.trans, user_file_source.uuid) + status = await self.manager.plugin_status_for_instance(self.trans, user_file_source.uuid) assert status.connection assert not status.connection.is_not_ok assert not status.template_definition.is_not_ok assert status.template_settings assert not status.template_settings.is_not_ok - def test_status_update_valid(self, tmp_path): + async def test_status_update_valid(self, tmp_path): self._init_managers(tmp_path, config_dict=simple_variable_template(tmp_path)) create_payload = CreateInstancePayload( name=SIMPLE_FILE_SOURCE_NAME, @@ -720,14 +720,14 @@ def test_status_update_valid(self, tmp_path): } ) - status = self.manager.test_modify_instance(self.trans, user_file_source.uuid, update) + status = await self.manager.test_modify_instance(self.trans, user_file_source.uuid, update) assert not status.template_definition.is_not_ok assert status.template_settings assert not status.template_settings.is_not_ok assert status.connection assert status.connection.is_not_ok - def test_status_upgrade_valid(self, tmp_path): + async def test_status_upgrade_valid(self, tmp_path): user_file_source = self._init_upgrade_test_case(tmp_path) assert "sec1" in user_file_source.secrets assert "sec2" not in user_file_source.secrets @@ -742,21 +742,21 @@ def test_status_upgrade_valid(self, tmp_path): }, variables={}, ) - status = self.manager.test_modify_instance(self.trans, user_file_source.uuid, upgrade_to_1) + status = await self.manager.test_modify_instance(self.trans, user_file_source.uuid, upgrade_to_1) assert not status.template_definition.is_not_ok assert status.template_settings assert not status.template_settings.is_not_ok assert status.connection assert status.connection.is_not_ok - def test_status_upgrade_invalid(self, tmp_path): + async def test_status_upgrade_invalid(self, tmp_path): user_file_source = self._init_invalid_upgrade_test_case(tmp_path) upgrade_to_1 = TestUpgradeInstancePayload( template_version=1, secrets={}, variables={}, ) - status = self.manager.test_modify_instance(self.trans, user_file_source.uuid, upgrade_to_1) + status = await self.manager.test_modify_instance(self.trans, user_file_source.uuid, upgrade_to_1) assert not status.template_definition.is_not_ok assert status.template_settings assert status.template_settings.is_not_ok diff --git a/test/unit/files/_util.py b/test/unit/files/_util.py index 2bf1274fcee5..4e57e6a50efc 100644 --- a/test/unit/files/_util.py +++ b/test/unit/files/_util.py @@ -36,7 +36,7 @@ def find(dir_list, class_=None, name=None): return None -def list_root( +async def list_root( file_sources: ConfiguredFileSources, uri: str, recursive: bool, @@ -44,11 +44,11 @@ def list_root( ): file_source_pair = file_sources.get_file_source_path(uri) file_source = file_source_pair.file_source - res, _ = file_source.list("/", recursive=recursive, user_context=user_context) + res, _ = await file_source.list("/", recursive=recursive, user_context=user_context) return res -def list_dir( +async def list_dir( file_sources: ConfiguredFileSources, uri: str, recursive: bool, @@ -58,7 +58,7 @@ def list_dir( file_source = file_source_pair.file_source print(file_source_pair.path) print(uri) - res, _ = file_source.list(file_source_pair.path, recursive=recursive, user_context=user_context) + res, _ = await file_source.list(file_source_pair.path, recursive=recursive, user_context=user_context) return res @@ -186,14 +186,14 @@ def assert_can_write_and_read_to_conf(conf: dict): ) -def assert_simple_file_realize(conf_file, recursive=False, filename="a", contents="a\n", contains=False): +async def assert_simple_file_realize(conf_file, recursive=False, filename="a", contents="a\n", contains=False): user_context = user_context_fixture() file_sources = configured_file_sources(conf_file) file_source_pair = file_sources.get_file_source_path("gxfiles://test1") assert file_source_pair.path == "/" file_source = file_source_pair.file_source - res, _ = file_source.list("/", recursive=recursive, user_context=user_context) + res, _ = await file_source.list("/", recursive=recursive, user_context=user_context) a_file = find(res, class_="File", name=filename) assert a_file diff --git a/test/unit/files/test_dropbox.py b/test/unit/files/test_dropbox.py index 111612ea4956..8605f4c12530 100644 --- a/test/unit/files/test_dropbox.py +++ b/test/unit/files/test_dropbox.py @@ -13,5 +13,6 @@ @skip_if_no_dropbox_access_token -def test_file_source(): - assert_simple_file_realize(FILE_SOURCES_CONF, recursive=True) +@pytest.mark.asyncio +async def test_file_source(): + await assert_simple_file_realize(FILE_SOURCES_CONF, recursive=True) diff --git a/test/unit/files/test_gcsfs.py b/test/unit/files/test_gcsfs.py index 94a113717835..215c79dd8044 100644 --- a/test/unit/files/test_gcsfs.py +++ b/test/unit/files/test_gcsfs.py @@ -19,7 +19,8 @@ @skip_if_no_gcsfs_libs -def test_file_source(): - assert_simple_file_realize( +@pytest.mark.asyncio +async def test_file_source(): + await assert_simple_file_realize( FILE_SOURCES_CONF, recursive=False, filename="README", contents="1000genomes", contains=True ) diff --git a/test/unit/files/test_googledrive.py b/test/unit/files/test_googledrive.py index 2661713ee629..c177d11fbcda 100644 --- a/test/unit/files/test_googledrive.py +++ b/test/unit/files/test_googledrive.py @@ -15,5 +15,6 @@ @skip_if_no_google_drive_access_token -def test_file_source(): - assert_simple_file_realize(FILE_SOURCES_CONF) +@pytest.mark.asyncio +async def test_file_source(): + await assert_simple_file_realize(FILE_SOURCES_CONF) diff --git a/test/unit/files/test_onedata.py b/test/unit/files/test_onedata.py index b4a5813699fb..524027a244f0 100644 --- a/test/unit/files/test_onedata.py +++ b/test/unit/files/test_onedata.py @@ -14,5 +14,6 @@ @skip_if_no_onedata_access_token -def test_file_source(): - assert_simple_file_realize(FILE_SOURCES_CONF) +@pytest.mark.asyncio +async def test_file_source(): + await assert_simple_file_realize(FILE_SOURCES_CONF) diff --git a/test/unit/files/test_posix.py b/test/unit/files/test_posix.py index a5749500c597..153a2415555e 100644 --- a/test/unit/files/test_posix.py +++ b/test/unit/files/test_posix.py @@ -37,7 +37,8 @@ EMAIL = "alice@galaxyproject.org" -def test_posix(): +@pytest.mark.asyncio +async def test_posix(): file_sources = _configured_file_sources() as_dict = file_sources.to_dict() assert len(as_dict["file_sources"]) == 1 @@ -46,7 +47,7 @@ def test_posix(): _download_and_check_file(file_sources) - res = list_root(file_sources, "gxfiles://test1", recursive=False) + res = await list_root(file_sources, "gxfiles://test1", recursive=False) file_a = find_file_a(res) assert file_a assert file_a["uri"] == "gxfiles://test1/a" @@ -56,7 +57,7 @@ def test_posix(): assert subdir1["class"] == "Directory" assert subdir1["uri"] == "gxfiles://test1/subdir1" - res = list_dir(file_sources, "gxfiles://test1/subdir1", recursive=False) + res = await list_dir(file_sources, "gxfiles://test1/subdir1", recursive=False) subdir2 = find(res, name="subdir2") assert subdir2, res assert subdir2["uri"] == "gxfiles://test1/subdir1/subdir2" @@ -65,7 +66,7 @@ def test_posix(): assert file_c, res assert file_c["uri"] == "gxfiles://test1/subdir1/c" - res = list_root(file_sources, "gxfiles://test1", recursive=True) + res = await list_root(file_sources, "gxfiles://test1", recursive=True) subdir1 = find(res, name="subdir1") subdir2 = find(res, name="subdir2") assert subdir1["class"] == "Directory" @@ -120,26 +121,28 @@ def test_posix_nonexistent_parent_write(): assert "Parent" in str(e) -def test_posix_per_user(): +@pytest.mark.asyncio +async def test_posix_per_user(): file_sources = _configured_file_sources(per_user=True) user_context = user_context_fixture() assert_realizes_as(file_sources, "gxfiles://test1/a", "a\n", user_context=user_context) - res = list_root(file_sources, "gxfiles://test1", recursive=False, user_context=user_context) + res = await list_root(file_sources, "gxfiles://test1", recursive=False, user_context=user_context) assert find_file_a(res) -def test_posix_per_user_writable(): +@pytest.mark.asyncio +async def test_posix_per_user_writable(): file_sources = _configured_file_sources(per_user=True, writable=True) user_context = user_context_fixture() - res = list_root(file_sources, "gxfiles://test1", recursive=False, user_context=user_context) + res = await list_root(file_sources, "gxfiles://test1", recursive=False, user_context=user_context) b = find(res, name="b") assert b is None write_from(file_sources, "gxfiles://test1/b", "my test content", user_context=user_context) - res = list_root(file_sources, "gxfiles://test1", recursive=False, user_context=user_context) + res = await list_root(file_sources, "gxfiles://test1", recursive=False, user_context=user_context) b = find(res, name="b") assert b is not None, b @@ -269,7 +272,8 @@ def test_user_import_dir_implicit_config(): assert_realizes_as(file_sources, "gxuserimport://a", "a\n", user_context=user_context) -def test_posix_user_access_requires_role(): +@pytest.mark.asyncio +async def test_posix_user_access_requires_role(): allowed_role_name = "role1" plugin_extra_config = { "requires_roles": allowed_role_name, @@ -277,13 +281,13 @@ def test_posix_user_access_requires_role(): file_sources = _configured_file_sources(writable=True, plugin_extra_config=plugin_extra_config) user_context = user_context_fixture() - _assert_user_access_prohibited(file_sources, user_context) + await _assert_user_access_prohibited(file_sources, user_context) user_context = user_context_fixture(role_names={allowed_role_name}) - _assert_user_access_granted(file_sources, user_context) + await _assert_user_access_granted(file_sources, user_context) -def test_posix_user_access_requires_group(): +async def test_posix_user_access_requires_group(): allowed_group_name = "group1" plugin_extra_config = { "requires_groups": allowed_group_name, @@ -291,13 +295,14 @@ def test_posix_user_access_requires_group(): file_sources = _configured_file_sources(writable=True, plugin_extra_config=plugin_extra_config) user_context = user_context_fixture() - _assert_user_access_prohibited(file_sources, user_context) + await _assert_user_access_prohibited(file_sources, user_context) user_context = user_context_fixture(group_names={allowed_group_name}) - _assert_user_access_granted(file_sources, user_context) + await _assert_user_access_granted(file_sources, user_context) -def test_posix_admin_user_has_access(): +@pytest.mark.asyncio +async def test_posix_admin_user_has_access(): plugin_extra_config = { "requires_roles": "role1", "requires_groups": "group1", @@ -305,13 +310,14 @@ def test_posix_admin_user_has_access(): file_sources = _configured_file_sources(writable=True, plugin_extra_config=plugin_extra_config) user_context = user_context_fixture() - _assert_user_access_prohibited(file_sources, user_context) + await _assert_user_access_prohibited(file_sources, user_context) user_context = user_context_fixture(is_admin=True) - _assert_user_access_granted(file_sources, user_context) + await _assert_user_access_granted(file_sources, user_context) -def test_posix_user_access_requires_role_and_group(): +@pytest.mark.asyncio +async def test_posix_user_access_requires_role_and_group(): allowed_group_name = "group1" allowed_role_name = "role1" plugin_extra_config = { @@ -321,16 +327,17 @@ def test_posix_user_access_requires_role_and_group(): file_sources = _configured_file_sources(writable=True, plugin_extra_config=plugin_extra_config) user_context = user_context_fixture(group_names={allowed_group_name}) - _assert_user_access_prohibited(file_sources, user_context) + await _assert_user_access_prohibited(file_sources, user_context) user_context = user_context_fixture(role_names={allowed_role_name}) - _assert_user_access_prohibited(file_sources, user_context) + await _assert_user_access_prohibited(file_sources, user_context) user_context = user_context_fixture(role_names={allowed_role_name}, group_names={allowed_group_name}) - _assert_user_access_granted(file_sources, user_context) + await _assert_user_access_granted(file_sources, user_context) -def test_posix_user_access_using_boolean_rules(): +@pytest.mark.asyncio +async def test_posix_user_access_using_boolean_rules(): plugin_extra_config = { "requires_roles": "role1 and (role2 or role3)", "requires_groups": "group1 and group2 and not group3", @@ -338,19 +345,19 @@ def test_posix_user_access_using_boolean_rules(): file_sources = _configured_file_sources(writable=True, plugin_extra_config=plugin_extra_config) user_context = user_context_fixture(role_names=set(), group_names=set()) - _assert_user_access_prohibited(file_sources, user_context) + await _assert_user_access_prohibited(file_sources, user_context) user_context = user_context_fixture(role_names={"role1"}, group_names={"group1", "group2"}) - _assert_user_access_prohibited(file_sources, user_context) + await _assert_user_access_prohibited(file_sources, user_context) user_context = user_context_fixture(role_names={"role1", "role3"}, group_names={"group1", "group2", "group3"}) - _assert_user_access_prohibited(file_sources, user_context) + await _assert_user_access_prohibited(file_sources, user_context) user_context = user_context_fixture(role_names={"role1", "role2"}, group_names={"group3", "group5"}) - _assert_user_access_prohibited(file_sources, user_context) + await _assert_user_access_prohibited(file_sources, user_context) user_context = user_context_fixture(role_names={"role1", "role3"}, group_names={"group1", "group2"}) - _assert_user_access_granted(file_sources, user_context) + await _assert_user_access_granted(file_sources, user_context) def test_posix_file_url_only_mode_non_admin_cannot_retrieve(): @@ -447,9 +454,10 @@ def test_posix_file_url_disallowed_root(): assert_realizes_as(file_sources, test_url, "some content\n") -def _assert_user_access_prohibited(file_sources, user_context): +@pytest.mark.asyncio +async def _assert_user_access_prohibited(file_sources, user_context): with pytest.raises(ItemAccessibilityException): - list_root(file_sources, "gxfiles://test1", recursive=False, user_context=user_context) + await list_root(file_sources, "gxfiles://test1", recursive=False, user_context=user_context) with pytest.raises(ItemAccessibilityException): write_from(file_sources, "gxfiles://test1/b", "my test content", user_context=user_context) @@ -458,13 +466,14 @@ def _assert_user_access_prohibited(file_sources, user_context): assert_realizes_as(file_sources, "gxfiles://test1/a", "a\n", user_context=user_context) -def _assert_user_access_granted(file_sources, user_context): - res = list_root(file_sources, "gxfiles://test1", recursive=False, user_context=user_context) +@pytest.mark.asyncio +async def _assert_user_access_granted(file_sources, user_context): + res = await list_root(file_sources, "gxfiles://test1", recursive=False, user_context=user_context) assert res write_from(file_sources, "gxfiles://test1/b", "my test content", user_context=user_context) - res = list_root(file_sources, "gxfiles://test1", recursive=False, user_context=user_context) + res = await list_root(file_sources, "gxfiles://test1", recursive=False, user_context=user_context) b = find(res, name="b") assert b is not None, b diff --git a/test/unit/files/test_s3.py b/test/unit/files/test_s3.py index 7529959a5488..c1b12ebf8d8f 100644 --- a/test/unit/files/test_s3.py +++ b/test/unit/files/test_s3.py @@ -17,8 +17,9 @@ FILE_SOURCES_CONF = os.path.join(SCRIPT_DIRECTORY, "s3_file_sources_conf.yml") -def test_file_source(): - assert_simple_file_realize( +@pytest.mark.asyncio +async def test_file_source(): + await assert_simple_file_realize( FILE_SOURCES_CONF, recursive=False, filename="data_use_policies.txt", diff --git a/test/unit/files/test_temp.py b/test/unit/files/test_temp.py index 2d51a359916b..eae44e90c57b 100644 --- a/test/unit/files/test_temp.py +++ b/test/unit/files/test_temp.py @@ -56,34 +56,35 @@ def test_list_recursive(temp_file_source: TempFilesSource): assert_list_names(temp_file_source, "/", recursive=True, expected_names=expected_names) -def test_pagination(temp_file_source: TempFilesSource): +@pytest.mark.asyncio +async def test_pagination(temp_file_source: TempFilesSource): # Pagination is only supported for non-recursive listings. recursive = False - root_lvl_entries, count = temp_file_source.list("/", recursive=recursive) + root_lvl_entries, count = await temp_file_source.list("/", recursive=recursive) assert count == 4 assert len(root_lvl_entries) == 4 # Get first entry - result, count = temp_file_source.list("/", recursive=recursive, limit=1, offset=0) + result, count = await temp_file_source.list("/", recursive=recursive, limit=1, offset=0) assert count == 4 assert len(result) == 1 assert result[0] == root_lvl_entries[0] # Get second entry - result, count = temp_file_source.list("/", recursive=recursive, limit=1, offset=1) + result, count = await temp_file_source.list("/", recursive=recursive, limit=1, offset=1) assert count == 4 assert len(result) == 1 assert result[0] == root_lvl_entries[1] # Get second and third entry - result, count = temp_file_source.list("/", recursive=recursive, limit=2, offset=1) + result, count = await temp_file_source.list("/", recursive=recursive, limit=2, offset=1) assert count == 4 assert len(result) == 2 assert result[0] == root_lvl_entries[1] assert result[1] == root_lvl_entries[2] # Get last three entries - result, count = temp_file_source.list("/", recursive=recursive, limit=3, offset=1) + result, count = await temp_file_source.list("/", recursive=recursive, limit=3, offset=1) assert count == 4 assert len(result) == 3 assert result[0] == root_lvl_entries[1] @@ -91,95 +92,101 @@ def test_pagination(temp_file_source: TempFilesSource): assert result[2] == root_lvl_entries[3] -def test_search(temp_file_source: TempFilesSource): +@pytest.mark.asyncio +async def test_search(temp_file_source: TempFilesSource): # Search is only supported for non-recursive listings. recursive = False - root_lvl_entries, count = temp_file_source.list("/", recursive=recursive) + root_lvl_entries, count = await temp_file_source.list("/", recursive=recursive) assert count == 4 assert len(root_lvl_entries) == 4 - result, count = temp_file_source.list("/", recursive=recursive, query="a") + result, count = await temp_file_source.list("/", recursive=recursive, query="a") assert count == 1 assert len(result) == 1 assert result[0]["name"] == "a" - result, count = temp_file_source.list("/", recursive=recursive, query="b") + result, count = await temp_file_source.list("/", recursive=recursive, query="b") assert count == 1 assert len(result) == 1 assert result[0]["name"] == "b" - result, count = temp_file_source.list("/", recursive=recursive, query="c") + result, count = await temp_file_source.list("/", recursive=recursive, query="c") assert count == 1 assert len(result) == 1 assert result[0]["name"] == "c" # Searching for 'd' at root level should return the directory 'dir1' but not the file 'd' # as it is not a direct child of the root. - result, count = temp_file_source.list("/", recursive=recursive, query="d") + result, count = await temp_file_source.list("/", recursive=recursive, query="d") assert count == 1 assert len(result) == 1 assert result[0]["name"] == "dir1" # Searching for 'e' at root level should not return anything. - result, count = temp_file_source.list("/", recursive=recursive, query="e") + result, count = await temp_file_source.list("/", recursive=recursive, query="e") assert count == 0 assert len(result) == 0 - result, count = temp_file_source.list("/dir1", recursive=recursive, query="e") + result, count = await temp_file_source.list("/dir1", recursive=recursive, query="e") assert count == 1 assert len(result) == 1 assert result[0]["name"] == "e" -def test_query_with_empty_string(temp_file_source: TempFilesSource): +@pytest.mark.asyncio +async def test_query_with_empty_string(temp_file_source: TempFilesSource): recursive = False - root_lvl_entries, count = temp_file_source.list("/", recursive=recursive) + root_lvl_entries, count = await temp_file_source.list("/", recursive=recursive) assert count == 4 assert len(root_lvl_entries) == 4 - result, count = temp_file_source.list("/", recursive=recursive, query="") + result, count = await temp_file_source.list("/", recursive=recursive, query="") assert count == 4 assert len(result) == 4 assert result == root_lvl_entries -def test_pagination_not_supported_raises(temp_file_source: TempFilesSource): +@pytest.mark.asyncio +async def test_pagination_not_supported_raises(temp_file_source: TempFilesSource): TempFilesSource.supports_pagination = False recursive = False with pytest.raises(RequestParameterInvalidException) as exc_info: - temp_file_source.list("/", recursive=recursive, limit=1, offset=0) + await temp_file_source.list("/", recursive=recursive, limit=1, offset=0) assert "Pagination is not supported" in str(exc_info.value) TempFilesSource.supports_pagination = True -def test_pagination_parameters_non_negative(temp_file_source: TempFilesSource): +@pytest.mark.asyncio +async def test_pagination_parameters_non_negative(temp_file_source: TempFilesSource): recursive = False with pytest.raises(RequestParameterInvalidException) as exc_info: - temp_file_source.list("/", recursive=recursive, limit=-1, offset=0) + await temp_file_source.list("/", recursive=recursive, limit=-1, offset=0) assert "Limit must be greater than 0" in str(exc_info.value) with pytest.raises(RequestParameterInvalidException) as exc_info: - temp_file_source.list("/", recursive=recursive, limit=0, offset=0) + await temp_file_source.list("/", recursive=recursive, limit=0, offset=0) assert "Limit must be greater than 0" in str(exc_info.value) with pytest.raises(RequestParameterInvalidException) as exc_info: - temp_file_source.list("/", recursive=recursive, limit=1, offset=-1) + await temp_file_source.list("/", recursive=recursive, limit=1, offset=-1) assert "Offset must be greater than or equal to 0" in str(exc_info.value) -def test_search_not_supported_raises(temp_file_source: TempFilesSource): +@pytest.mark.asyncio +async def test_search_not_supported_raises(temp_file_source: TempFilesSource): TempFilesSource.supports_search = False recursive = False with pytest.raises(RequestParameterInvalidException) as exc_info: - temp_file_source.list("/", recursive=recursive, query="a") + await temp_file_source.list("/", recursive=recursive, query="a") assert "Server-side search is not supported by this file source" in str(exc_info.value) TempFilesSource.supports_search = True -def test_sorting_not_supported_raises(temp_file_source: TempFilesSource): +@pytest.mark.asyncio +async def test_sorting_not_supported_raises(temp_file_source: TempFilesSource): recursive = False with pytest.raises(RequestParameterInvalidException) as exc_info: - temp_file_source.list("/", recursive=recursive, sort_by="name") + await temp_file_source.list("/", recursive=recursive, sort_by="name") assert "Server-side sorting is not supported by this file source" in str(exc_info.value) @@ -202,8 +209,9 @@ def _upload_to(file_source: TempFilesSource, target_uri: str, content: str, user file_source.write_from(target_uri, f.name, user_context=user_context) -def assert_list_names(file_source: TempFilesSource, uri: str, recursive: bool, expected_names: List[str]): - result, count = file_source.list(uri, recursive=recursive) +@pytest.mark.asyncio +async def assert_list_names(file_source: TempFilesSource, uri: str, recursive: bool, expected_names: List[str]): + result, count = await file_source.list(uri, recursive=recursive) assert count == len(expected_names) assert sorted([entry["name"] for entry in result]) == sorted(expected_names) return result diff --git a/test/unit/files/test_webdav.py b/test/unit/files/test_webdav.py index 9294881ec247..771f54989d75 100644 --- a/test/unit/files/test_webdav.py +++ b/test/unit/files/test_webdav.py @@ -25,7 +25,8 @@ @skip_if_no_webdav -def test_file_source(): +@pytest.mark.asyncio +async def test_file_source(): file_sources = configured_file_sources(FILE_SOURCES_CONF) file_source_pair = file_sources.get_file_source_path("gxfiles://test1") @@ -46,7 +47,7 @@ def test_file_source(): assert subdir1["class"] == "Directory" assert subdir1["uri"] == "gxfiles://test1/subdir1" - res = list_dir(file_sources, "gxfiles://test1/subdir1", recursive=False) + res = await list_dir(file_sources, "gxfiles://test1/subdir1", recursive=False) subdir2 = find(res, name="subdir2") assert subdir2, res assert subdir2["uri"] == "gxfiles://test1/subdir1/subdir2" @@ -63,17 +64,18 @@ def test_sniff_to_tmp(): @skip_if_no_webdav -def test_serialization(): +@pytest.mark.asyncio +async def test_serialization(): configs = [FILE_SOURCES_CONF_NO_USE_TEMP_FILES, FILE_SOURCES_CONF] for config in configs: # serialize the configured file sources and rematerialize them, # ensure they still function. This is needed for uploading files. file_sources = serialize_and_recover(configured_file_sources(config)) - res = list_root(file_sources, "gxfiles://test1", recursive=True) + res = await list_root(file_sources, "gxfiles://test1", recursive=True) assert find_file_a(res) - res = list_root(file_sources, "gxfiles://test1", recursive=False) + res = await list_root(file_sources, "gxfiles://test1", recursive=False) assert find_file_a(res) _download_and_check_file(file_sources) @@ -101,13 +103,14 @@ def test_config_options(): @skip_if_no_webdav -def test_serialization_user(): +@pytest.mark.asyncio +async def test_serialization_user(): file_sources_o = configured_file_sources(USER_FILE_SOURCES_CONF) user_context = user_context_fixture() - res = list_root(file_sources_o, "gxfiles://test1", recursive=True, user_context=user_context) + res = await list_root(file_sources_o, "gxfiles://test1", recursive=True, user_context=user_context) assert find_file_a(res) file_sources = serialize_and_recover(file_sources_o, user_context=user_context) - res = list_root(file_sources, "gxfiles://test1", recursive=True, user_context=None) + res = await list_root(file_sources, "gxfiles://test1", recursive=True, user_context=None) assert find_file_a(res) From a05d5cc5ec53dcf8d696c51b47ca320eb8071c6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Manuel=20Dom=C3=ADnguez?= Date: Tue, 10 Dec 2024 14:48:18 +0100 Subject: [PATCH 2/2] Mark tests using `assert_list_names` as async --- test/unit/files/test_temp.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/unit/files/test_temp.py b/test/unit/files/test_temp.py index eae44e90c57b..1ec381926498 100644 --- a/test/unit/files/test_temp.py +++ b/test/unit/files/test_temp.py @@ -46,14 +46,16 @@ def test_file_source(file_sources: TestConfiguredFileSources): assert_realizes_contains(file_sources, f"{ROOT_URI}/dir1/sub1/f", "f") -def test_list(temp_file_source: TempFilesSource): - assert_list_names(temp_file_source, "/", recursive=False, expected_names=["a", "b", "c", "dir1"]) - assert_list_names(temp_file_source, "/dir1", recursive=False, expected_names=["d", "e", "sub1"]) +@pytest.mark.asyncio +async def test_list(temp_file_source: TempFilesSource): + await assert_list_names(temp_file_source, "/", recursive=False, expected_names=["a", "b", "c", "dir1"]) + await assert_list_names(temp_file_source, "/dir1", recursive=False, expected_names=["d", "e", "sub1"]) -def test_list_recursive(temp_file_source: TempFilesSource): +@pytest.mark.asyncio +async def test_list_recursive(temp_file_source: TempFilesSource): expected_names = ["a", "b", "c", "dir1", "d", "e", "sub1", "f"] - assert_list_names(temp_file_source, "/", recursive=True, expected_names=expected_names) + await assert_list_names(temp_file_source, "/", recursive=True, expected_names=expected_names) @pytest.mark.asyncio @@ -209,7 +211,6 @@ def _upload_to(file_source: TempFilesSource, target_uri: str, content: str, user file_source.write_from(target_uri, f.name, user_context=user_context) -@pytest.mark.asyncio async def assert_list_names(file_source: TempFilesSource, uri: str, recursive: bool, expected_names: List[str]): result, count = await file_source.list(uri, recursive=recursive) assert count == len(expected_names)