Skip to content

Commit

Permalink
refactor: uses codeocean sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
jtyoung84 committed Jan 14, 2025
1 parent aa2ac2e commit f5d3fa3
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 93 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Configuration file for the Sphinx documentation builder."""

#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"pymongo==4.3.3",
"dask==2023.5.0",
"aind-data-schema==1.2.0",
"aind-codeocean-api==0.5.0",
"codeocean==0.3.0",
]

[project.optional-dependencies]
Expand Down
20 changes: 9 additions & 11 deletions src/aind_data_asset_indexer/aind_bucket_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,12 @@ def _resolve_schema_information(
object_key = create_object_key(
prefix=prefix, filename=core_schema_file_name
)
common_kwargs[
"core_schema_info_in_root"
] = get_dict_of_file_info(
s3_client=s3_client,
bucket=self.job_settings.s3_bucket,
keys=[object_key],
).get(
object_key
common_kwargs["core_schema_info_in_root"] = (
get_dict_of_file_info(
s3_client=s3_client,
bucket=self.job_settings.s3_bucket,
keys=[object_key],
).get(object_key)
)
self._copy_file_from_root_to_subdir(**common_kwargs)
# If field is null, a file exists in the root folder, and
Expand Down Expand Up @@ -424,9 +422,9 @@ def _process_docdb_record(
)
db = docdb_client[self.job_settings.doc_db_db_name]
collection = db[self.job_settings.doc_db_collection_name]
fields_to_update[
"last_modified"
] = datetime.utcnow().isoformat()
fields_to_update["last_modified"] = (
datetime.utcnow().isoformat()
)
response = collection.update_one(
{"_id": docdb_record["_id"]},
{"$set": fields_to_update},
Expand Down
12 changes: 10 additions & 2 deletions src/aind_data_asset_indexer/codeocean_bucket_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
import boto3
import dask.bag as dask_bag
import requests
from aind_codeocean_api.codeocean import CodeOceanClient
from aind_data_schema.core.metadata import ExternalPlatforms
from codeocean import CodeOcean
from mypy_boto3_s3 import S3Client
from pymongo import MongoClient
from pymongo.operations import UpdateOne
from requests.exceptions import ReadTimeout
from urllib3.util import Retry

from aind_data_asset_indexer.models import CodeOceanIndexBucketJobSettings
from aind_data_asset_indexer.utils import (
Expand Down Expand Up @@ -394,9 +395,16 @@ def _delete_records_from_docdb(self, record_list: List[str]):
def run_job(self):
"""Main method to run."""
logging.info("Starting to scan through CodeOcean.")
co_client = CodeOceanClient(
retry = Retry(
total=5,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["GET", "POST"],
)
co_client = CodeOcean(
domain=self.job_settings.codeocean_domain,
token=self.job_settings.codeocean_token.get_secret_value(),
retries=retry,
)
code_ocean_records = get_all_processed_codeocean_asset_records(
co_client=co_client,
Expand Down
43 changes: 22 additions & 21 deletions src/aind_data_asset_indexer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Dict, Iterator, List, Optional
from urllib.parse import urlparse

from aind_codeocean_api.codeocean import CodeOceanClient
from aind_data_schema.core.data_description import DataLevel, DataRegex
from aind_data_schema.core.metadata import CORE_FILES as CORE_SCHEMAS
from aind_data_schema.core.metadata import (
Expand All @@ -18,6 +17,12 @@
create_metadata_json,
)
from botocore.exceptions import ClientError
from codeocean import CodeOcean
from codeocean.data_asset import (
DataAssetSearchParams,
DataAssetState,
DataAssetType,
)
from mypy_boto3_s3 import S3Client
from mypy_boto3_s3.type_defs import (
PaginatorConfigTypeDef,
Expand Down Expand Up @@ -934,7 +939,7 @@ def build_docdb_location_to_id_map(


def get_all_processed_codeocean_asset_records(
co_client: CodeOceanClient, co_data_asset_bucket: str
co_client: CodeOcean, co_data_asset_bucket: str
) -> Dict[str, dict]:
"""
Gets all the data asset records we're interested in indexing. The location
Expand All @@ -943,7 +948,7 @@ def get_all_processed_codeocean_asset_records(
Parameters
----------
co_client : CodeOceanClient
co_client : CodeOcean
co_data_asset_bucket : str
Name of Code Ocean's data asset bucket
Returns
Expand All @@ -966,31 +971,27 @@ def get_all_processed_codeocean_asset_records(
all_responses = dict()

for tag in {DataLevel.DERIVED.value, "processed"}:
response = co_client.search_all_data_assets(
type="result", query=f"tag:{tag}"
search_params = DataAssetSearchParams(
type=DataAssetType.Result, query=f"tag:{tag}"
)
iter_response = co_client.data_assets.search_data_assets_iterator(
search_params=search_params
)
# There is a bug with the codeocean api that caps the number of
# results in a single request to 10000.
if len(response.json()["results"]) >= 10000:
logging.warning(
"Number of records exceeds 10,000! This can lead to "
"possible data loss."
)
# Extract relevant information
extracted_info = dict()
for data_asset_info in response.json()["results"]:
data_asset_id = data_asset_info["id"]
data_asset_name = data_asset_info["name"]
created_timestamp = data_asset_info["created"]
for data_asset_info in iter_response:
data_asset_id = data_asset_info.id
data_asset_name = data_asset_info.name
created_timestamp = data_asset_info.created
created_datetime = datetime.fromtimestamp(
created_timestamp, tz=timezone.utc
)
# Results hosted externally have a source_bucket field
is_external = (
data_asset_info.get("sourceBucket") is not None
or data_asset_info.get("source_bucket") is not None
)
if not is_external and data_asset_info.get("state") == "ready":
is_external = data_asset_info.source_bucket is not None
if (
not is_external
and data_asset_info.state == DataAssetState.Ready
):
location = f"s3://{co_data_asset_bucket}/{data_asset_id}"
extracted_info[location] = {
"name": data_asset_name,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_aind_bucket_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,9 +920,9 @@ def test_process_docdb_record_valid_metadata_nd_json_file(
]
self.assertEqual(expected_log_messages, captured.output)
expected_docdb_record_to_write = deepcopy(mock_docdb_record)
expected_docdb_record_to_write[
"last_modified"
] = "2024-08-25T17:41:28+00:00"
expected_docdb_record_to_write["last_modified"] = (
"2024-08-25T17:41:28+00:00"
)
expected_docdb_record_to_write["subject"] = self.example_md_record.get(
"subject"
)
Expand Down
66 changes: 11 additions & 55 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from pathlib import Path
from unittest.mock import MagicMock, call, patch

from aind_codeocean_api.codeocean import CodeOceanClient
from botocore.exceptions import ClientError
from requests import Response
from codeocean import CodeOcean
from codeocean.data_asset import DataAsset

from aind_data_asset_indexer.utils import (
build_docdb_location_to_id_map,
Expand Down Expand Up @@ -139,7 +139,9 @@ def load_json_file(filename: str) -> dict:
cls.example_put_object_response1 = load_json_file(
"example_put_object_response1.json"
)
cls.example_co_search_data_assets = example_co_search_data_assets
cls.example_co_search_data_assets = [
DataAsset(**r) for r in example_co_search_data_assets["results"]
]

def test_compute_md5_hash(self):
"""Tests compute_md5_hash method"""
Expand Down Expand Up @@ -1394,26 +1396,16 @@ def test_build_docdb_location_to_id_map(
}
self.assertEqual(expected_map, actual_map)

@patch(
"aind_codeocean_api.codeocean.CodeOceanClient.search_all_data_assets"
)
@patch("codeocean.data_asset.DataAssets.search_data_assets_iterator")
def test_get_all_processed_codeocean_asset_records(
self, mock_search_all_data_assets: MagicMock
):
"""Tests get_all_processed_codeocean_asset_records method"""
mock_response1 = Response()
mock_response1.status_code = 200
body = json.dumps(self.example_co_search_data_assets)
mock_response1._content = body.encode("utf-8")
mock_response2 = Response()
mock_response2.status_code = 200
body = json.dumps({"results": []})
mock_response2._content = body.encode("utf-8")
mock_search_all_data_assets.side_effect = [
mock_response1,
mock_response2,
]
co_client = CodeOceanClient(domain="some_domain", token="some_token")

mock_search_all_data_assets.return_value = (
self.example_co_search_data_assets
)
co_client = CodeOcean(domain="some_domain", token="some_token")
records = get_all_processed_codeocean_asset_records(
co_client=co_client,
co_data_asset_bucket="some_co_bucket",
Expand Down Expand Up @@ -1453,42 +1445,6 @@ def test_get_all_processed_codeocean_asset_records(

self.assertEqual(expected_records, records)

@patch(
"aind_codeocean_api.codeocean.CodeOceanClient.search_all_data_assets"
)
def test_get_all_processed_codeocean_asset_records_warning(
self, mock_search_all_data_assets: MagicMock
):
"""Tests get_all_processed_codeocean_asset_records method when 10,000
records are returned"""
# Fake a response with 10,000 records
search_result_copy = deepcopy(self.example_co_search_data_assets)
result0 = search_result_copy["results"][0]
search_result_copy["results"] = [result0 for _ in range(0, 10000)]
mock_response1 = Response()
mock_response1.status_code = 200
body = json.dumps(search_result_copy)
mock_response1._content = body.encode("utf-8")
mock_response2 = Response()
mock_response2.status_code = 200
body = json.dumps({"results": []})
mock_response2._content = body.encode("utf-8")
mock_search_all_data_assets.side_effect = [
mock_response1,
mock_response2,
]
co_client = CodeOceanClient(domain="some_domain", token="some_token")
with self.assertLogs(level="DEBUG") as captured:
get_all_processed_codeocean_asset_records(
co_client=co_client,
co_data_asset_bucket="some_co_bucket",
)
expected_log_messages = [
"WARNING:root:Number of records exceeds 10,000! "
"This can lead to possible data loss."
]
self.assertEqual(expected_log_messages, captured.output)


if __name__ == "__main__":
unittest.main()

0 comments on commit f5d3fa3

Please sign in to comment.