diff --git a/.github/workflows/publish_dev.yml b/.github/workflows/publish_dev.yml index c54cf59..f716b3e 100644 --- a/.github/workflows/publish_dev.yml +++ b/.github/workflows/publish_dev.yml @@ -8,7 +8,7 @@ jobs: publish: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Docker Buildx id: buildx uses: docker/setup-buildx-action@v2 diff --git a/.github/workflows/publish_main.yml b/.github/workflows/publish_main.yml index f960f40..340dff1 100644 --- a/.github/workflows/publish_main.yml +++ b/.github/workflows/publish_main.yml @@ -11,7 +11,7 @@ jobs: outputs: pkg_version: ${{ steps.output_version.outputs.pkg_version }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Get version from file run: | pkg_name=$(grep -P 'version = \{attr = .*\}' pyproject.toml | grep -oP '\w+.__version__') diff --git a/.github/workflows/run_dev_tests.yml b/.github/workflows/run_dev_tests.yml index 47d3c5e..475cd44 100644 --- a/.github/workflows/run_dev_tests.yml +++ b/.github/workflows/run_dev_tests.yml @@ -10,11 +10,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ '3.8', '3.9', '3.10' ] + python-version: [ '3.9', '3.10', '3.11' ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/run_main_tests.yml b/.github/workflows/run_main_tests.yml index 53bf8c7..8ce1002 100644 --- a/.github/workflows/run_main_tests.yml +++ b/.github/workflows/run_main_tests.yml @@ -11,11 +11,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ '3.8', '3.9', '3.10' ] + python-version: [ '3.9', '3.10', '3.11' ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -28,7 +28,7 @@ jobs: verify_version: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Check version incremented run: | pkg_name=$(grep -P 'version = \{attr = .*\}' pyproject.toml | grep -oP '\w+.__version__') diff --git a/README.md b/README.md index 5d9bdc0..a114923 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ kept in sync: 1. **S3 buckets** store raw metadata files, including the ``metadata.nd.json``. 2. A **document database (DocDB)** contains unstructured json documents describing the ``metadata.nd.json`` for a data asset. -3. **Code Ocean**: data assets are mounted as CodeOcean data asssets. +3. **Code Ocean**: data assets are mounted as CodeOcean data assets. Processed results are also stored in an internal Code Ocean bucket. We have automated jobs to keep changes in DocDB and S3 in sync. diff --git a/docs/source/conf.py b/docs/source/conf.py index 528ac66..3a3c915 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,5 @@ """Configuration file for the Sphinx documentation builder.""" + # # For the full list of built-in configuration values, see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html diff --git a/pyproject.toml b/pyproject.toml index 7bb2df0..f1eef27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "aind-data-asset-indexer" description = "Service Capsule to write data asset metadata to document store" license = {text = "MIT"} -requires-python = ">=3.8" +requires-python = ">=3.9" authors = [ {name = "AIND"} ] @@ -24,7 +24,7 @@ dependencies = [ "pymongo==4.3.3", "dask==2023.5.0", "aind-data-schema==1.2.0", - "aind-codeocean-api==0.5.0", + "codeocean==0.3.0", ] [project.optional-dependencies] diff --git a/src/aind_data_asset_indexer/__init__.py b/src/aind_data_asset_indexer/__init__.py index 24dc33a..296d3b5 100644 --- a/src/aind_data_asset_indexer/__init__.py +++ b/src/aind_data_asset_indexer/__init__.py @@ -1,3 +1,3 @@ """Package""" -__version__ = "0.13.0" +__version__ = "0.14.0" diff --git a/src/aind_data_asset_indexer/aind_bucket_indexer.py b/src/aind_data_asset_indexer/aind_bucket_indexer.py index 48899a4..da94c72 100644 --- a/src/aind_data_asset_indexer/aind_bucket_indexer.py +++ b/src/aind_data_asset_indexer/aind_bucket_indexer.py @@ -272,14 +272,12 @@ def _resolve_schema_information( object_key = create_object_key( prefix=prefix, filename=core_schema_file_name ) - common_kwargs[ - "core_schema_info_in_root" - ] = get_dict_of_file_info( - s3_client=s3_client, - bucket=self.job_settings.s3_bucket, - keys=[object_key], - ).get( - object_key + common_kwargs["core_schema_info_in_root"] = ( + get_dict_of_file_info( + s3_client=s3_client, + bucket=self.job_settings.s3_bucket, + keys=[object_key], + ).get(object_key) ) self._copy_file_from_root_to_subdir(**common_kwargs) # If field is null, a file exists in the root folder, and @@ -424,9 +422,9 @@ def _process_docdb_record( ) db = docdb_client[self.job_settings.doc_db_db_name] collection = db[self.job_settings.doc_db_collection_name] - fields_to_update[ - "last_modified" - ] = datetime.utcnow().isoformat() + fields_to_update["last_modified"] = ( + datetime.utcnow().isoformat() + ) response = collection.update_one( {"_id": docdb_record["_id"]}, {"$set": fields_to_update}, diff --git a/src/aind_data_asset_indexer/codeocean_bucket_indexer.py b/src/aind_data_asset_indexer/codeocean_bucket_indexer.py index 982df2d..c27fc70 100644 --- a/src/aind_data_asset_indexer/codeocean_bucket_indexer.py +++ b/src/aind_data_asset_indexer/codeocean_bucket_indexer.py @@ -12,13 +12,13 @@ import boto3 import dask.bag as dask_bag -import requests -from aind_codeocean_api.codeocean import CodeOceanClient from aind_data_schema.core.metadata import ExternalPlatforms +from codeocean import CodeOcean +from codeocean.data_asset import DataAssetSearchOrigin, DataAssetSearchParams from mypy_boto3_s3 import S3Client from pymongo import MongoClient from pymongo.operations import UpdateOne -from requests.exceptions import ReadTimeout +from urllib3.util import Retry from aind_data_asset_indexer.models import CodeOceanIndexBucketJobSettings from aind_data_asset_indexer.utils import ( @@ -52,30 +52,51 @@ def __init__(self, job_settings: CodeOceanIndexBucketJobSettings): """Class constructor.""" self.job_settings = job_settings - def _get_external_data_asset_records(self) -> Optional[List[dict]]: + @staticmethod + def _get_external_data_asset_records( + co_client: CodeOcean, + ) -> Optional[List[dict]]: """ Retrieves list of code ocean ids and locations for external data assets. The timeout is set to 600 seconds. + + Parameters + ---------- + co_client : CodeOcean + Returns ------- List[dict] | None List items have shape {"id": str, "location": str}. If error occurs, return None. + """ try: - response = requests.get( - self.job_settings.temp_codeocean_endpoint, - timeout=600, + search_params = DataAssetSearchParams( + archived=False, + origin=DataAssetSearchOrigin.External, + limit=1000, ) - if response.status_code == 200: - return response.json() - else: - return None - except ReadTimeout: - logging.error( - f"Read timed out at " - f"{self.job_settings.temp_codeocean_endpoint}" + data_assets = co_client.data_assets.search_data_assets_iterator( + search_params=search_params ) + external_records = [] + for data_asset in data_assets: + data_asset_source = data_asset.source_bucket + if ( + data_asset_source is not None + and data_asset_source.bucket is not None + and data_asset_source.prefix is not None + ): + bucket = data_asset_source.bucket + prefix = data_asset_source.prefix + location = f"s3://{bucket}/{prefix}" + external_records.append( + {"id": data_asset.id, "location": location} + ) + return external_records + except Exception as e: + logging.exception(e) return None @staticmethod @@ -97,7 +118,7 @@ def _map_external_list_to_dict(external_recs: List[dict]) -> dict: """ new_records = dict() for r in external_recs: - location = r.get("source") + location = r.get("location") rec_id = r["id"] if location is not None and new_records.get(location) is not None: old_id_set = new_records.get(location) @@ -140,7 +161,7 @@ def _get_co_links_from_record( return external_links def _update_external_links_in_docdb( - self, docdb_client: MongoClient + self, docdb_client: MongoClient, co_client: CodeOcean ) -> None: """ This method will: @@ -159,7 +180,9 @@ def _update_external_links_in_docdb( """ # Should return a list like [{"id": co_id, "location": "s3://..."},] - list_of_co_ids_and_locations = self._get_external_data_asset_records() + list_of_co_ids_and_locations = self._get_external_data_asset_records( + co_client=co_client + ) db = docdb_client[self.job_settings.doc_db_db_name] collection = db[self.job_settings.doc_db_collection_name] if list_of_co_ids_and_locations is not None: @@ -394,9 +417,16 @@ def _delete_records_from_docdb(self, record_list: List[str]): def run_job(self): """Main method to run.""" logging.info("Starting to scan through CodeOcean.") - co_client = CodeOceanClient( + retry = Retry( + total=5, + backoff_factor=1, + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=["GET", "POST"], + ) + co_client = CodeOcean( domain=self.job_settings.codeocean_domain, token=self.job_settings.codeocean_token.get_secret_value(), + retries=retry, ) code_ocean_records = get_all_processed_codeocean_asset_records( co_client=co_client, @@ -416,7 +446,7 @@ def run_job(self): # Use existing client to add external links to fields logging.info("Adding links to records.") self._update_external_links_in_docdb( - docdb_client=iterator_docdb_client + docdb_client=iterator_docdb_client, co_client=co_client ) logging.info("Finished adding links to records") all_docdb_records = dict() diff --git a/src/aind_data_asset_indexer/models.py b/src/aind_data_asset_indexer/models.py index e7c91d6..c72017f 100644 --- a/src/aind_data_asset_indexer/models.py +++ b/src/aind_data_asset_indexer/models.py @@ -124,11 +124,12 @@ class CodeOceanIndexBucketJobSettings(IndexJobSettings): doc_db_collection_name: str codeocean_domain: str codeocean_token: SecretStr - temp_codeocean_endpoint: str = Field( + temp_codeocean_endpoint: Optional[str] = Field( + default=None, description=( - "Temp proxy to access code ocean information from their analytics " - "databases." - ) + "(deprecated) Temp proxy to access code ocean information from " + "their analytics databases. Will be removed in a future release." + ), ) @classmethod diff --git a/src/aind_data_asset_indexer/utils.py b/src/aind_data_asset_indexer/utils.py index 16f5ac1..09b3c63 100644 --- a/src/aind_data_asset_indexer/utils.py +++ b/src/aind_data_asset_indexer/utils.py @@ -9,7 +9,6 @@ from typing import Dict, Iterator, List, Optional from urllib.parse import urlparse -from aind_codeocean_api.codeocean import CodeOceanClient from aind_data_schema.core.data_description import DataLevel, DataRegex from aind_data_schema.core.metadata import CORE_FILES as CORE_SCHEMAS from aind_data_schema.core.metadata import ( @@ -18,6 +17,12 @@ create_metadata_json, ) from botocore.exceptions import ClientError +from codeocean import CodeOcean +from codeocean.data_asset import ( + DataAssetSearchParams, + DataAssetState, + DataAssetType, +) from mypy_boto3_s3 import S3Client from mypy_boto3_s3.type_defs import ( PaginatorConfigTypeDef, @@ -934,7 +939,7 @@ def build_docdb_location_to_id_map( def get_all_processed_codeocean_asset_records( - co_client: CodeOceanClient, co_data_asset_bucket: str + co_client: CodeOcean, co_data_asset_bucket: str ) -> Dict[str, dict]: """ Gets all the data asset records we're interested in indexing. The location @@ -943,7 +948,7 @@ def get_all_processed_codeocean_asset_records( Parameters ---------- - co_client : CodeOceanClient + co_client : CodeOcean co_data_asset_bucket : str Name of Code Ocean's data asset bucket Returns @@ -966,31 +971,27 @@ def get_all_processed_codeocean_asset_records( all_responses = dict() for tag in {DataLevel.DERIVED.value, "processed"}: - response = co_client.search_all_data_assets( - type="result", query=f"tag:{tag}" + search_params = DataAssetSearchParams( + type=DataAssetType.Result, query=f"tag:{tag}" + ) + iter_response = co_client.data_assets.search_data_assets_iterator( + search_params=search_params ) - # There is a bug with the codeocean api that caps the number of - # results in a single request to 10000. - if len(response.json()["results"]) >= 10000: - logging.warning( - "Number of records exceeds 10,000! This can lead to " - "possible data loss." - ) # Extract relevant information extracted_info = dict() - for data_asset_info in response.json()["results"]: - data_asset_id = data_asset_info["id"] - data_asset_name = data_asset_info["name"] - created_timestamp = data_asset_info["created"] + for data_asset_info in iter_response: + data_asset_id = data_asset_info.id + data_asset_name = data_asset_info.name + created_timestamp = data_asset_info.created created_datetime = datetime.fromtimestamp( created_timestamp, tz=timezone.utc ) # Results hosted externally have a source_bucket field - is_external = ( - data_asset_info.get("sourceBucket") is not None - or data_asset_info.get("source_bucket") is not None - ) - if not is_external and data_asset_info.get("state") == "ready": + is_external = data_asset_info.source_bucket is not None + if ( + not is_external + and data_asset_info.state == DataAssetState.Ready + ): location = f"s3://{co_data_asset_bucket}/{data_asset_id}" extracted_info[location] = { "name": data_asset_name, diff --git a/tests/test_aind_bucket_indexer.py b/tests/test_aind_bucket_indexer.py index 631fd1e..7d68e82 100644 --- a/tests/test_aind_bucket_indexer.py +++ b/tests/test_aind_bucket_indexer.py @@ -920,9 +920,9 @@ def test_process_docdb_record_valid_metadata_nd_json_file( ] self.assertEqual(expected_log_messages, captured.output) expected_docdb_record_to_write = deepcopy(mock_docdb_record) - expected_docdb_record_to_write[ - "last_modified" - ] = "2024-08-25T17:41:28+00:00" + expected_docdb_record_to_write["last_modified"] = ( + "2024-08-25T17:41:28+00:00" + ) expected_docdb_record_to_write["subject"] = self.example_md_record.get( "subject" ) diff --git a/tests/test_codeocean_bucket_indexer.py b/tests/test_codeocean_bucket_indexer.py index 5b2f2cc..41eb57f 100644 --- a/tests/test_codeocean_bucket_indexer.py +++ b/tests/test_codeocean_bucket_indexer.py @@ -1,15 +1,20 @@ """Tests methods in codeocean_bucket_indexer module""" -import json import os import unittest from datetime import datetime, timezone from pathlib import Path from unittest.mock import MagicMock, call, patch +from codeocean import CodeOcean +from codeocean.data_asset import ( + DataAsset, + DataAssetOrigin, + DataAssetState, + DataAssetType, + SourceBucket, +) from pymongo.operations import UpdateOne -from requests import Response -from requests.exceptions import ReadTimeout from aind_data_asset_indexer.codeocean_bucket_indexer import ( CodeOceanIndexBucketJob, @@ -102,54 +107,85 @@ def setUpClass(cls) -> None: }, ] - cls.example_temp_endpoint_response = [ - {"id": "abc-123", "source": "s3://bucket/prefix1"}, - {"id": "def-456", "source": "s3://bucket/prefix1"}, - {"id": "ghi-789", "source": "s3://bucket/prefix2"}, + cls.example_search_iterator_response = [ + DataAsset( + id="abc-123", + created=0, + name="prefix1", + mount="prefix1", + state=DataAssetState.Ready, + type=DataAssetType.Dataset, + last_used=0, + source_bucket=SourceBucket( + bucket="bucket", + prefix="prefix1", + origin=DataAssetOrigin.AWS, + ), + ), + DataAsset( + id="def-456", + created=0, + name="prefix1", + mount="prefix1", + state=DataAssetState.Ready, + type=DataAssetType.Dataset, + last_used=0, + source_bucket=SourceBucket( + bucket="bucket", + prefix="prefix1", + origin=DataAssetOrigin.AWS, + ), + ), + DataAsset( + id="ghi-789", + created=0, + name="prefix2", + mount="prefix2", + state=DataAssetState.Ready, + type=DataAssetType.Dataset, + last_used=0, + source_bucket=SourceBucket( + bucket="bucket", + prefix="prefix2", + origin=DataAssetOrigin.AWS, + ), + ), ] - @patch("requests.get") - def test_get_external_data_asset_records(self, mock_get: MagicMock): + @patch("codeocean.data_asset.DataAssets.search_data_assets_iterator") + def test_get_external_data_asset_records(self, mock_search: MagicMock): """Tests the _get_external_data_asset_records method""" - example_response = self.example_temp_endpoint_response - mock_get_response = Response() - mock_get_response.status_code = 200 - mock_get_response._content = json.dumps(example_response).encode( - "utf-8" + mock_search.return_value = self.example_search_iterator_response + response = self.basic_job._get_external_data_asset_records( + co_client=CodeOcean(domain="www.example.com", token="") ) - mock_get.return_value = mock_get_response - response = self.basic_job._get_external_data_asset_records() - self.assertEqual(example_response, response) + expected_response = [ + {"id": "abc-123", "location": "s3://bucket/prefix1"}, + {"id": "def-456", "location": "s3://bucket/prefix1"}, + {"id": "ghi-789", "location": "s3://bucket/prefix2"}, + ] + self.assertEqual(expected_response, response) - @patch("requests.get") - def test_get_external_data_asset_records_error(self, mock_get: MagicMock): + @patch("codeocean.data_asset.DataAssets.search_data_assets_iterator") + def test_get_external_data_asset_records_err(self, mock_search: MagicMock): """Tests the _get_external_data_asset_records method when an error response is returned""" - mock_get_response = Response() - mock_get_response.status_code = 500 - mock_get.return_value = mock_get_response - response = self.basic_job._get_external_data_asset_records() - self.assertIsNone(response) - - @patch("requests.get") - def test_get_external_data_asset_records_read_timeout( - self, mock_get: MagicMock - ): - """Tests the _get_external_data_asset_records method when the read - times out.""" - mock_get.side_effect = ReadTimeout() + mock_search.side_effect = Exception("Something went wrong!") with self.assertLogs(level="DEBUG") as captured: - response = self.basic_job._get_external_data_asset_records() - expected_log_messages = [ - "ERROR:root:Read timed out at http://some_url:8080/created_after/0" - ] - self.assertEqual(expected_log_messages, captured.output) + response = self.basic_job._get_external_data_asset_records( + co_client=CodeOcean(domain="www.example.com", token="") + ) self.assertIsNone(response) + self.assertIsNotNone(captured.output) def test_map_external_list_to_dict(self): """Tests _map_external_list_to_dict method""" mapped_response = self.basic_job._map_external_list_to_dict( - self.example_temp_endpoint_response + [ + {"id": "abc-123", "location": "s3://bucket/prefix1"}, + {"id": "def-456", "location": "s3://bucket/prefix1"}, + {"id": "ghi-789", "location": "s3://bucket/prefix2"}, + ] ) expected_response = { "s3://bucket/prefix1": {"abc-123", "def-456"}, @@ -185,27 +221,21 @@ def test_get_co_links_from_record_legacy(self): self.assertEqual(["abc-123", "def-456"], output) @patch("aind_data_asset_indexer.codeocean_bucket_indexer.MongoClient") - @patch("requests.get") + @patch("codeocean.data_asset.DataAssets.search_data_assets_iterator") @patch("aind_data_asset_indexer.codeocean_bucket_indexer.paginate_docdb") @patch("aind_data_asset_indexer.codeocean_bucket_indexer.datetime") def test_update_external_links_in_docdb( self, mock_datetime: MagicMock, mock_paginate: MagicMock, - mock_get: MagicMock, + mock_search: MagicMock, mock_docdb_client: MagicMock, ): """Tests _update_external_links_in_docdb method.""" mock_datetime.utcnow.return_value = datetime(2024, 9, 5) - # Mock requests get response - example_response = self.example_temp_endpoint_response - mock_get_response = Response() - mock_get_response.status_code = 200 - mock_get_response._content = json.dumps(example_response).encode( - "utf-8" - ) - mock_get.return_value = mock_get_response + # Mock code ocean search response + mock_search.return_value = self.example_search_iterator_response # Mock bulk_write mock_db = MagicMock() @@ -237,7 +267,8 @@ def test_update_external_links_in_docdb( with self.assertLogs(level="DEBUG") as captured: self.basic_job._update_external_links_in_docdb( - docdb_client=mock_docdb_client + docdb_client=mock_docdb_client, + co_client=CodeOcean(domain="www.example.com", token=""), ) expected_log_messages = [ "INFO:root:No code ocean data asset ids found for " @@ -284,31 +315,31 @@ def test_update_external_links_in_docdb( mock_collection.bulk_write.assert_has_calls(expected_bulk_write_calls) @patch("aind_data_asset_indexer.codeocean_bucket_indexer.MongoClient") - @patch("requests.get") + @patch("codeocean.data_asset.DataAssets.search_data_assets_iterator") @patch("aind_data_asset_indexer.codeocean_bucket_indexer.paginate_docdb") def test_update_external_links_in_docdb_error( self, mock_paginate: MagicMock, - mock_get: MagicMock, + mock_search: MagicMock, mock_docdb_client: MagicMock, ): """Tests _update_external_links_in_docdb method when there is an error retrieving info from the temp endpoint.""" - # Mock requests get response - mock_get_response = Response() - mock_get_response.status_code = 500 - mock_get.return_value = mock_get_response + # Mock search response + mock_search.side_effect = Exception("Something went wrong!") mock_db = MagicMock() mock_docdb_client.__getitem__.return_value = mock_db with self.assertLogs(level="DEBUG") as captured: self.basic_job._update_external_links_in_docdb( - docdb_client=mock_docdb_client + docdb_client=mock_docdb_client, + co_client=CodeOcean(domain="www.example.com", token=""), ) - expected_log_messages = [ + expected_log_message = ( "ERROR:root:There was an error retrieving external links!" - ] - self.assertEqual(expected_log_messages, captured.output) + ) + self.assertEqual(2, len(captured.output)) + self.assertEqual(expected_log_message, captured.output[1]) mock_paginate.assert_not_called() @patch("aind_data_asset_indexer.codeocean_bucket_indexer.MongoClient") @@ -568,8 +599,10 @@ def test_delete_records_from_docdb( "aind_data_asset_indexer.codeocean_bucket_indexer." "get_all_processed_codeocean_asset_records" ) + @patch("aind_data_asset_indexer.codeocean_bucket_indexer.CodeOcean") def test_run_job( self, + mock_codeocean_client: MagicMock, mock_get_all_co_records: MagicMock, mock_docdb_client: MagicMock, mock_paginate_docdb: MagicMock, @@ -581,6 +614,8 @@ def test_run_job( one record, add one record, and delete one record.""" mock_mongo_client = MagicMock() mock_docdb_client.return_value = mock_mongo_client + mock_co_client = MagicMock() + mock_codeocean_client.return_value = mock_co_client mock_get_all_co_records.return_value = dict( [(r["location"], r) for r in self.example_codeocean_records] ) @@ -602,7 +637,7 @@ def test_run_job( self.assertEqual(expected_log_messages, captured.output) mock_update_external_links_in_docdb.assert_called_once_with( - docdb_client=mock_mongo_client + docdb_client=mock_mongo_client, co_client=mock_co_client ) mock_process_codeocean_records.assert_called_once_with( records=[self.example_codeocean_records[0]] diff --git a/tests/test_utils.py b/tests/test_utils.py index c2f41dd..039595d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,9 +8,9 @@ from pathlib import Path from unittest.mock import MagicMock, call, patch -from aind_codeocean_api.codeocean import CodeOceanClient from botocore.exceptions import ClientError -from requests import Response +from codeocean import CodeOcean +from codeocean.data_asset import DataAsset from aind_data_asset_indexer.utils import ( build_docdb_location_to_id_map, @@ -139,7 +139,9 @@ def load_json_file(filename: str) -> dict: cls.example_put_object_response1 = load_json_file( "example_put_object_response1.json" ) - cls.example_co_search_data_assets = example_co_search_data_assets + cls.example_co_search_data_assets = [ + DataAsset(**r) for r in example_co_search_data_assets["results"] + ] def test_compute_md5_hash(self): """Tests compute_md5_hash method""" @@ -1394,26 +1396,16 @@ def test_build_docdb_location_to_id_map( } self.assertEqual(expected_map, actual_map) - @patch( - "aind_codeocean_api.codeocean.CodeOceanClient.search_all_data_assets" - ) + @patch("codeocean.data_asset.DataAssets.search_data_assets_iterator") def test_get_all_processed_codeocean_asset_records( self, mock_search_all_data_assets: MagicMock ): """Tests get_all_processed_codeocean_asset_records method""" - mock_response1 = Response() - mock_response1.status_code = 200 - body = json.dumps(self.example_co_search_data_assets) - mock_response1._content = body.encode("utf-8") - mock_response2 = Response() - mock_response2.status_code = 200 - body = json.dumps({"results": []}) - mock_response2._content = body.encode("utf-8") - mock_search_all_data_assets.side_effect = [ - mock_response1, - mock_response2, - ] - co_client = CodeOceanClient(domain="some_domain", token="some_token") + + mock_search_all_data_assets.return_value = ( + self.example_co_search_data_assets + ) + co_client = CodeOcean(domain="some_domain", token="some_token") records = get_all_processed_codeocean_asset_records( co_client=co_client, co_data_asset_bucket="some_co_bucket", @@ -1453,42 +1445,6 @@ def test_get_all_processed_codeocean_asset_records( self.assertEqual(expected_records, records) - @patch( - "aind_codeocean_api.codeocean.CodeOceanClient.search_all_data_assets" - ) - def test_get_all_processed_codeocean_asset_records_warning( - self, mock_search_all_data_assets: MagicMock - ): - """Tests get_all_processed_codeocean_asset_records method when 10,000 - records are returned""" - # Fake a response with 10,000 records - search_result_copy = deepcopy(self.example_co_search_data_assets) - result0 = search_result_copy["results"][0] - search_result_copy["results"] = [result0 for _ in range(0, 10000)] - mock_response1 = Response() - mock_response1.status_code = 200 - body = json.dumps(search_result_copy) - mock_response1._content = body.encode("utf-8") - mock_response2 = Response() - mock_response2.status_code = 200 - body = json.dumps({"results": []}) - mock_response2._content = body.encode("utf-8") - mock_search_all_data_assets.side_effect = [ - mock_response1, - mock_response2, - ] - co_client = CodeOceanClient(domain="some_domain", token="some_token") - with self.assertLogs(level="DEBUG") as captured: - get_all_processed_codeocean_asset_records( - co_client=co_client, - co_data_asset_bucket="some_co_bucket", - ) - expected_log_messages = [ - "WARNING:root:Number of records exceeds 10,000! " - "This can lead to possible data loss." - ] - self.assertEqual(expected_log_messages, captured.output) - if __name__ == "__main__": unittest.main()