Skip to content

Commit

Permalink
feat: removes temp workaround for co assets (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtyoung84 authored Jan 15, 2025
1 parent c7a9e11 commit 1ef85d8
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 82 deletions.
58 changes: 40 additions & 18 deletions src/aind_data_asset_indexer/codeocean_bucket_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@

import boto3
import dask.bag as dask_bag
import requests
from aind_data_schema.core.metadata import ExternalPlatforms
from codeocean import CodeOcean
from codeocean.data_asset import DataAssetSearchOrigin, DataAssetSearchParams
from mypy_boto3_s3 import S3Client
from pymongo import MongoClient
from pymongo.operations import UpdateOne
from requests.exceptions import ReadTimeout
from urllib3.util import Retry

from aind_data_asset_indexer.models import CodeOceanIndexBucketJobSettings
Expand Down Expand Up @@ -53,30 +52,51 @@ def __init__(self, job_settings: CodeOceanIndexBucketJobSettings):
"""Class constructor."""
self.job_settings = job_settings

def _get_external_data_asset_records(self) -> Optional[List[dict]]:
@staticmethod
def _get_external_data_asset_records(
co_client: CodeOcean,
) -> Optional[List[dict]]:
"""
Retrieves list of code ocean ids and locations for external data
assets. The timeout is set to 600 seconds.
Parameters
----------
co_client : CodeOcean
Returns
-------
List[dict] | None
List items have shape {"id": str, "location": str}. If error occurs,
return None.
"""
try:
response = requests.get(
self.job_settings.temp_codeocean_endpoint,
timeout=600,
search_params = DataAssetSearchParams(
archived=False,
origin=DataAssetSearchOrigin.External,
limit=1000,
)
if response.status_code == 200:
return response.json()
else:
return None
except ReadTimeout:
logging.error(
f"Read timed out at "
f"{self.job_settings.temp_codeocean_endpoint}"
data_assets = co_client.data_assets.search_data_assets_iterator(
search_params=search_params
)
external_records = []
for data_asset in data_assets:
data_asset_source = data_asset.source_bucket
if (
data_asset_source is not None
and data_asset_source.bucket is not None
and data_asset_source.prefix is not None
):
bucket = data_asset_source.bucket
prefix = data_asset_source.prefix
location = f"s3://{bucket}/{prefix}"
external_records.append(
{"id": data_asset.id, "location": location}
)
return external_records
except Exception as e:
logging.exception(e)
return None

@staticmethod
Expand All @@ -98,7 +118,7 @@ def _map_external_list_to_dict(external_recs: List[dict]) -> dict:
"""
new_records = dict()
for r in external_recs:
location = r.get("source")
location = r.get("location")
rec_id = r["id"]
if location is not None and new_records.get(location) is not None:
old_id_set = new_records.get(location)
Expand Down Expand Up @@ -141,7 +161,7 @@ def _get_co_links_from_record(
return external_links

def _update_external_links_in_docdb(
self, docdb_client: MongoClient
self, docdb_client: MongoClient, co_client: CodeOcean
) -> None:
"""
This method will:
Expand All @@ -160,7 +180,9 @@ def _update_external_links_in_docdb(
"""
# Should return a list like [{"id": co_id, "location": "s3://..."},]
list_of_co_ids_and_locations = self._get_external_data_asset_records()
list_of_co_ids_and_locations = self._get_external_data_asset_records(
co_client=co_client
)
db = docdb_client[self.job_settings.doc_db_db_name]
collection = db[self.job_settings.doc_db_collection_name]
if list_of_co_ids_and_locations is not None:
Expand Down Expand Up @@ -424,7 +446,7 @@ def run_job(self):
# Use existing client to add external links to fields
logging.info("Adding links to records.")
self._update_external_links_in_docdb(
docdb_client=iterator_docdb_client
docdb_client=iterator_docdb_client, co_client=co_client
)
logging.info("Finished adding links to records")
all_docdb_records = dict()
Expand Down
9 changes: 5 additions & 4 deletions src/aind_data_asset_indexer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,12 @@ class CodeOceanIndexBucketJobSettings(IndexJobSettings):
doc_db_collection_name: str
codeocean_domain: str
codeocean_token: SecretStr
temp_codeocean_endpoint: str = Field(
temp_codeocean_endpoint: Optional[str] = Field(
default=None,
description=(
"Temp proxy to access code ocean information from their analytics "
"databases."
)
"(deprecated) Temp proxy to access code ocean information from "
"their analytics databases. Will be removed in a future release."
),
)

@classmethod
Expand Down
155 changes: 95 additions & 60 deletions tests/test_codeocean_bucket_indexer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
"""Tests methods in codeocean_bucket_indexer module"""

import json
import os
import unittest
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import MagicMock, call, patch

from codeocean import CodeOcean
from codeocean.data_asset import (
DataAsset,
DataAssetOrigin,
DataAssetState,
DataAssetType,
SourceBucket,
)
from pymongo.operations import UpdateOne
from requests import Response
from requests.exceptions import ReadTimeout

from aind_data_asset_indexer.codeocean_bucket_indexer import (
CodeOceanIndexBucketJob,
Expand Down Expand Up @@ -102,54 +107,85 @@ def setUpClass(cls) -> None:
},
]

cls.example_temp_endpoint_response = [
{"id": "abc-123", "source": "s3://bucket/prefix1"},
{"id": "def-456", "source": "s3://bucket/prefix1"},
{"id": "ghi-789", "source": "s3://bucket/prefix2"},
cls.example_search_iterator_response = [
DataAsset(
id="abc-123",
created=0,
name="prefix1",
mount="prefix1",
state=DataAssetState.Ready,
type=DataAssetType.Dataset,
last_used=0,
source_bucket=SourceBucket(
bucket="bucket",
prefix="prefix1",
origin=DataAssetOrigin.AWS,
),
),
DataAsset(
id="def-456",
created=0,
name="prefix1",
mount="prefix1",
state=DataAssetState.Ready,
type=DataAssetType.Dataset,
last_used=0,
source_bucket=SourceBucket(
bucket="bucket",
prefix="prefix1",
origin=DataAssetOrigin.AWS,
),
),
DataAsset(
id="ghi-789",
created=0,
name="prefix2",
mount="prefix2",
state=DataAssetState.Ready,
type=DataAssetType.Dataset,
last_used=0,
source_bucket=SourceBucket(
bucket="bucket",
prefix="prefix2",
origin=DataAssetOrigin.AWS,
),
),
]

@patch("requests.get")
def test_get_external_data_asset_records(self, mock_get: MagicMock):
@patch("codeocean.data_asset.DataAssets.search_data_assets_iterator")
def test_get_external_data_asset_records(self, mock_search: MagicMock):
"""Tests the _get_external_data_asset_records method"""
example_response = self.example_temp_endpoint_response
mock_get_response = Response()
mock_get_response.status_code = 200
mock_get_response._content = json.dumps(example_response).encode(
"utf-8"
mock_search.return_value = self.example_search_iterator_response
response = self.basic_job._get_external_data_asset_records(
co_client=CodeOcean(domain="www.example.com", token="")
)
mock_get.return_value = mock_get_response
response = self.basic_job._get_external_data_asset_records()
self.assertEqual(example_response, response)
expected_response = [
{"id": "abc-123", "location": "s3://bucket/prefix1"},
{"id": "def-456", "location": "s3://bucket/prefix1"},
{"id": "ghi-789", "location": "s3://bucket/prefix2"},
]
self.assertEqual(expected_response, response)

@patch("requests.get")
def test_get_external_data_asset_records_error(self, mock_get: MagicMock):
@patch("codeocean.data_asset.DataAssets.search_data_assets_iterator")
def test_get_external_data_asset_records_err(self, mock_search: MagicMock):
"""Tests the _get_external_data_asset_records method when an error
response is returned"""
mock_get_response = Response()
mock_get_response.status_code = 500
mock_get.return_value = mock_get_response
response = self.basic_job._get_external_data_asset_records()
self.assertIsNone(response)

@patch("requests.get")
def test_get_external_data_asset_records_read_timeout(
self, mock_get: MagicMock
):
"""Tests the _get_external_data_asset_records method when the read
times out."""
mock_get.side_effect = ReadTimeout()
mock_search.side_effect = Exception("Something went wrong!")
with self.assertLogs(level="DEBUG") as captured:
response = self.basic_job._get_external_data_asset_records()
expected_log_messages = [
"ERROR:root:Read timed out at http://some_url:8080/created_after/0"
]
self.assertEqual(expected_log_messages, captured.output)
response = self.basic_job._get_external_data_asset_records(
co_client=CodeOcean(domain="www.example.com", token="")
)
self.assertIsNone(response)
self.assertIsNotNone(captured.output)

def test_map_external_list_to_dict(self):
"""Tests _map_external_list_to_dict method"""
mapped_response = self.basic_job._map_external_list_to_dict(
self.example_temp_endpoint_response
[
{"id": "abc-123", "location": "s3://bucket/prefix1"},
{"id": "def-456", "location": "s3://bucket/prefix1"},
{"id": "ghi-789", "location": "s3://bucket/prefix2"},
]
)
expected_response = {
"s3://bucket/prefix1": {"abc-123", "def-456"},
Expand Down Expand Up @@ -185,27 +221,21 @@ def test_get_co_links_from_record_legacy(self):
self.assertEqual(["abc-123", "def-456"], output)

@patch("aind_data_asset_indexer.codeocean_bucket_indexer.MongoClient")
@patch("requests.get")
@patch("codeocean.data_asset.DataAssets.search_data_assets_iterator")
@patch("aind_data_asset_indexer.codeocean_bucket_indexer.paginate_docdb")
@patch("aind_data_asset_indexer.codeocean_bucket_indexer.datetime")
def test_update_external_links_in_docdb(
self,
mock_datetime: MagicMock,
mock_paginate: MagicMock,
mock_get: MagicMock,
mock_search: MagicMock,
mock_docdb_client: MagicMock,
):
"""Tests _update_external_links_in_docdb method."""
mock_datetime.utcnow.return_value = datetime(2024, 9, 5)

# Mock requests get response
example_response = self.example_temp_endpoint_response
mock_get_response = Response()
mock_get_response.status_code = 200
mock_get_response._content = json.dumps(example_response).encode(
"utf-8"
)
mock_get.return_value = mock_get_response
# Mock code ocean search response
mock_search.return_value = self.example_search_iterator_response

# Mock bulk_write
mock_db = MagicMock()
Expand Down Expand Up @@ -237,7 +267,8 @@ def test_update_external_links_in_docdb(

with self.assertLogs(level="DEBUG") as captured:
self.basic_job._update_external_links_in_docdb(
docdb_client=mock_docdb_client
docdb_client=mock_docdb_client,
co_client=CodeOcean(domain="www.example.com", token=""),
)
expected_log_messages = [
"INFO:root:No code ocean data asset ids found for "
Expand Down Expand Up @@ -284,31 +315,31 @@ def test_update_external_links_in_docdb(
mock_collection.bulk_write.assert_has_calls(expected_bulk_write_calls)

@patch("aind_data_asset_indexer.codeocean_bucket_indexer.MongoClient")
@patch("requests.get")
@patch("codeocean.data_asset.DataAssets.search_data_assets_iterator")
@patch("aind_data_asset_indexer.codeocean_bucket_indexer.paginate_docdb")
def test_update_external_links_in_docdb_error(
self,
mock_paginate: MagicMock,
mock_get: MagicMock,
mock_search: MagicMock,
mock_docdb_client: MagicMock,
):
"""Tests _update_external_links_in_docdb method when there is an
error retrieving info from the temp endpoint."""
# Mock requests get response
mock_get_response = Response()
mock_get_response.status_code = 500
mock_get.return_value = mock_get_response
# Mock search response
mock_search.side_effect = Exception("Something went wrong!")

mock_db = MagicMock()
mock_docdb_client.__getitem__.return_value = mock_db
with self.assertLogs(level="DEBUG") as captured:
self.basic_job._update_external_links_in_docdb(
docdb_client=mock_docdb_client
docdb_client=mock_docdb_client,
co_client=CodeOcean(domain="www.example.com", token=""),
)
expected_log_messages = [
expected_log_message = (
"ERROR:root:There was an error retrieving external links!"
]
self.assertEqual(expected_log_messages, captured.output)
)
self.assertEqual(2, len(captured.output))
self.assertEqual(expected_log_message, captured.output[1])
mock_paginate.assert_not_called()

@patch("aind_data_asset_indexer.codeocean_bucket_indexer.MongoClient")
Expand Down Expand Up @@ -568,8 +599,10 @@ def test_delete_records_from_docdb(
"aind_data_asset_indexer.codeocean_bucket_indexer."
"get_all_processed_codeocean_asset_records"
)
@patch("aind_data_asset_indexer.codeocean_bucket_indexer.CodeOcean")
def test_run_job(
self,
mock_codeocean_client: MagicMock,
mock_get_all_co_records: MagicMock,
mock_docdb_client: MagicMock,
mock_paginate_docdb: MagicMock,
Expand All @@ -581,6 +614,8 @@ def test_run_job(
one record, add one record, and delete one record."""
mock_mongo_client = MagicMock()
mock_docdb_client.return_value = mock_mongo_client
mock_co_client = MagicMock()
mock_codeocean_client.return_value = mock_co_client
mock_get_all_co_records.return_value = dict(
[(r["location"], r) for r in self.example_codeocean_records]
)
Expand All @@ -602,7 +637,7 @@ def test_run_job(
self.assertEqual(expected_log_messages, captured.output)

mock_update_external_links_in_docdb.assert_called_once_with(
docdb_client=mock_mongo_client
docdb_client=mock_mongo_client, co_client=mock_co_client
)
mock_process_codeocean_records.assert_called_once_with(
records=[self.example_codeocean_records[0]]
Expand Down

0 comments on commit 1ef85d8

Please sign in to comment.