Skip to content

Commit

Permalink
feat: adds aind-bucket-indexer job (#49)
Browse files Browse the repository at this point in the history
* feat: adds aind bucket index job
  • Loading branch information
jtyoung84 authored May 16, 2024
1 parent ff516ba commit eea9346
Show file tree
Hide file tree
Showing 10 changed files with 2,118 additions and 90 deletions.
427 changes: 427 additions & 0 deletions src/aind_data_asset_indexer/aind_bucket_indexer.py

Large diffs are not rendered by default.

13 changes: 12 additions & 1 deletion src/aind_data_asset_indexer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

import boto3
from pydantic import Field
from pydantic import Field, SecretStr
from pydantic_settings import BaseSettings


Expand Down Expand Up @@ -45,3 +45,14 @@ def from_param_store(cls, param_store_name: str):
param_store_client.close()
parameters = response["Parameter"]["Value"]
return cls.model_validate_json(parameters)


class AindIndexBucketJobSettings(IndexJobSettings):
"""Aind Index Bucket Job Settings"""

doc_db_host: str
doc_db_port: int
doc_db_user_name: str
doc_db_password: SecretStr
doc_db_db_name: str
doc_db_collection_name: str
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module to handle populating s3 bucket with metadata files."""
import argparse
import logging
import os
import sys
import warnings
from typing import List
Expand All @@ -16,6 +17,7 @@
upload_metadata_json_str_to_s3,
)

logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
# pydantic raises too many serialization warnings
warnings.filterwarnings("ignore", category=UserWarning)

Expand Down
259 changes: 250 additions & 9 deletions src/aind_data_asset_indexer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from json.decoder import JSONDecodeError
from typing import Dict, Iterator, List, Optional
from urllib.parse import urlparse

from aind_data_schema.core.metadata import Metadata
from aind_data_schema.utils.json_writer import SchemaWriter
Expand All @@ -12,6 +13,7 @@
PaginatorConfigTypeDef,
PutObjectOutputTypeDef,
)
from pymongo import MongoClient

# TODO: This would be better if it was available in aind-data-schema
core_schema_file_names = [
Expand All @@ -21,6 +23,97 @@
]


def create_metadata_object_key(prefix: str) -> str:
"""
For a given s3 prefix, create the expected object key for the
metadata.nd.json file.
Parameters
----------
prefix : str
For example, ecephys_123456_2020-10-10_01-02-03
Returns
-------
str
For example, ecephys_123456_2020-10-10_01-02-03/metadata.nd.json
"""
stripped_prefix = prefix.strip("/")
return f"{stripped_prefix}/{Metadata.default_filename()}"


def is_record_location_valid(
record: dict, expected_bucket: str, expected_prefix: Optional[str] = None
) -> bool:
"""
Check if a given record has a valid location url.
Parameters
----------
record : dict
Metadata record as a dictionary
expected_bucket : str
The expected s3 bucket the location should have.
expected_prefix: Optional[str]
If provided, also check that the record location matches the expected
s3_prefix. Default is None, which won't perform the check.
Returns
-------
bool
True if there is a location field and the url in the field has a form
like 's3://{expected_bucket}/prefix'
Will return False if there is no s3 scheme, the bucket does not match
the expected bucket, the prefix contains forward slashes, or the prefix
doesn't match the expected prefix.
"""
expected_stripped_prefix = (
None if expected_prefix is None else expected_prefix.strip("/")
)
if record.get("location") is None:
return False
else:
parts = urlparse(record.get("location"), allow_fragments=False)
if parts.scheme != "s3":
return False
elif parts.netloc != expected_bucket:
return False
else:
stripped_prefix = parts.path.strip("/")
if (
stripped_prefix == ""
or len(stripped_prefix.split("/")) > 1
or (
expected_prefix is not None
and stripped_prefix != expected_stripped_prefix
)
):
return False
else:
return True


def get_s3_bucket_and_prefix(s3_location: str) -> Dict[str, str]:
"""
For a location url like s3://bucket/prefix, it will return the bucket
and prefix. It doesn't check the scheme is s3. It will strip the leading
and trailing forward slashes from the prefix.
Parameters
----------
s3_location : str
For example, 's3://some_bucket/some_prefix'
Returns
-------
Dict[str, str]
For example, {'bucket': 'some_bucket', 'prefix': 'some_prefix'}
"""
parts = urlparse(s3_location, allow_fragments=False)
stripped_prefix = parts.path.strip("/")
return {"bucket": parts.netloc, "prefix": stripped_prefix}


def compute_md5_hash(json_contents: str) -> str:
"""
Computes the md5 hash of the object as it would be stored in S3. Useful
Expand All @@ -37,7 +130,7 @@ def compute_md5_hash(json_contents: str) -> str:
"""
contents = json.dumps(
json.loads(json_contents), indent=3, ensure_ascii=False
json.loads(json_contents), indent=3, ensure_ascii=False, sort_keys=True
).encode("utf-8")
return hashlib.md5(contents).hexdigest()

Expand All @@ -61,10 +154,10 @@ def upload_metadata_json_str_to_s3(
Response of the put object operation.
"""
stripped_prefix = prefix[:-1] if prefix.endswith("/") else prefix
stripped_prefix = prefix.strip("/")
object_key = f"{stripped_prefix}/{Metadata.default_filename()}"
contents = json.dumps(
json.loads(metadata_json), indent=3, ensure_ascii=False
json.loads(metadata_json), indent=3, ensure_ascii=False, sort_keys=True
).encode("utf-8")
response = s3_client.put_object(
Bucket=bucket, Key=object_key, Body=contents
Expand Down Expand Up @@ -225,6 +318,7 @@ def build_metadata_record_from_prefix(
metadata_nd_overwrite: bool,
prefix: str,
s3_client: S3Client,
optional_name: Optional[str] = None,
) -> Optional[str]:
"""
For a given bucket and prefix, this method will return a JSON string
Expand All @@ -240,6 +334,9 @@ def build_metadata_record_from_prefix(
metadata_nd_overwrite : bool
prefix : str
s3_client : S3Client
optional_name : Optional[str]
If optional is None, then a name will be constructed from the s3_prefix.
Default is None.
Returns
-------
Expand All @@ -249,7 +346,7 @@ def build_metadata_record_from_prefix(
metadata.nd.json file and the file is corrupt.
"""
stripped_prefix = prefix[:-1] if prefix.endswith("/") else prefix
stripped_prefix = prefix.strip("/")
metadata_nd_file_key = stripped_prefix + "/" + Metadata.default_filename()
does_metadata_nd_file_exist = does_s3_object_exist(
s3_client=s3_client, bucket=bucket, key=metadata_nd_file_key
Expand All @@ -262,9 +359,11 @@ def build_metadata_record_from_prefix(
s3_file_responses = get_dict_of_file_info(
s3_client=s3_client, bucket=bucket, keys=file_keys
)
# Strip the trailing slash from the prefix
record_name = (
stripped_prefix if optional_name is None else optional_name
)
metadata_dict = {
"name": stripped_prefix,
"name": record_name,
"location": f"s3://{bucket}/{stripped_prefix}",
}
for object_key, response_data in s3_file_responses.items():
Expand All @@ -274,9 +373,6 @@ def build_metadata_record_from_prefix(
s3_client=s3_client, bucket=bucket, object_key=object_key
)
if json_contents is not None:
# Old version of pycharm highlights a warning since
# it doesn't know check above ensures json_contents is not
# None
# noinspection PyTypeChecker
is_corrupt = is_dict_corrupt(input_dict=json_contents)
if not is_corrupt:
Expand All @@ -298,3 +394,148 @@ def build_metadata_record_from_prefix(
if metadata_contents is None
else json.dumps(metadata_contents)
)


def does_metadata_record_exist_in_docdb(
docdb_client: MongoClient,
db_name: str,
collection_name: str,
bucket: str,
prefix: str,
) -> bool:
"""
For a given bucket and prefix, check if there is already a record in DocDb
Parameters
----------
docdb_client : MongoClient
db_name : str
collection_name : str
bucket : str
prefix : str
Returns
-------
True if there is a record in DocDb. Otherwise, False.
"""
stripped_prefix = prefix.strip("/")
location = f"s3//{bucket}/{stripped_prefix}"
db = docdb_client[db_name]
collection = db[collection_name]
records = list(
collection.find(
filter={"location": location}, projection={"_id": 1}, limit=1
)
)
if len(records) == 0:
return False
else:
return True


def get_record_from_docdb(
docdb_client: MongoClient,
db_name: str,
collection_name: str,
record_id: str,
) -> Optional[dict]:
"""
Download a record from docdb using the record _id.
Parameters
----------
docdb_client : MongoClient
db_name : str
collection_name : str
record_id : str
Returns
-------
Optional[dict]
None if record does not exist. Otherwise, it will return the record as
a dict.
"""
db = docdb_client[db_name]
collection = db[collection_name]
records = list(collection.find(filter={"_id": record_id}, limit=1))
if len(records) > 0:
return records[0]
else:
return None


def paginate_docdb(
db_name: str,
collection_name: str,
docdb_client: MongoClient,
page_size: int = 1000,
filter_query: Optional[dict] = None,
projection: Optional[dict] = None,
) -> Iterator[List[dict]]:
"""
Paginate through records in DocDb.
Parameters
----------
db_name : str
collection_name : str
docdb_client : MongoClient
page_size : int
Default is 1000
filter_query : Optional[dict]
projection : Optional[dict]
Returns
-------
Iterator[List[dict]]
"""
if filter_query is None:
filter_query = {}
if projection is None:
projection = {}
db = docdb_client[db_name]
collection = db[collection_name]
cursor = collection.find(filter=filter_query, projection=projection)
obj = next(cursor, None)
while obj:
page = []
while len(page) < page_size and obj:
page.append(obj)
obj = next(cursor, None)
yield page


def build_docdb_location_to_id_map(
db_name: str,
collection_name: str,
docdb_client: MongoClient,
bucket: str,
prefixes: List[str],
) -> Dict[str, str]:
"""
For a given s3 bucket and list of prefixes, return a dictionary that looks
like {'s3://bucket/prefix': 'abc-1234'} where the value is the id of the
record in DocDb. If the record does not exist, then there will be no key
in the dictionary.
Parameters
----------
db_name : str
collection_name : ste
docdb_client : MongoClient
bucket : str
prefixes : List[str]
Returns
-------
Dict[str, str]
"""
stripped_prefixes = [p.strip("/") for p in prefixes]
locations = [f"s3://{bucket}/{p}" for p in stripped_prefixes]
filter_query = {"location": {"$in": locations}}
projection = {"_id": 1, "location": 1}
db = docdb_client[db_name]
collection = db[collection_name]
results = collection.find(filter=filter_query, projection=projection)
location_to_id_map = {r["location"]: r["_id"] for r in results}
return location_to_id_map
Loading

0 comments on commit eea9346

Please sign in to comment.