Skip to content

Commit

Permalink
[hma] Add metadata to API (#1593)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dcallies authored May 23, 2024
1 parent 7b2e51c commit fafe5d4
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 13 deletions.
80 changes: 67 additions & 13 deletions hasher-matcher-actioner/src/OpenMediaMatch/blueprints/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@

from OpenMediaMatch import persistence
from OpenMediaMatch.utils import flask_utils
from OpenMediaMatch.storage.interface import BankConfig, SignalTypeIndexBuildCheckpoint
import OpenMediaMatch.storage.interface as iface
from OpenMediaMatch.blueprints import hashing


class BankedContentMetadata(t.TypedDict):
content_id: t.NotRequired[str]
content_uri: t.NotRequired[str]
json: t.NotRequired[dict[t.Any, t.Any]]


bp = Blueprint("curation", __name__)
bp.register_error_handler(HTTPException, flask_utils.api_error_handler)

Expand Down Expand Up @@ -49,8 +55,8 @@ def bank_create():
return jsonify(bank_create_impl(name, enabled_ratio)), 201


def bank_create_impl(name: str, enabled_ratio: float = 1.0) -> BankConfig:
bank = BankConfig(name=name, matching_enabled_ratio=enabled_ratio)
def bank_create_impl(name: str, enabled_ratio: float = 1.0) -> iface.BankConfig:
bank = iface.BankConfig(name=name, matching_enabled_ratio=enabled_ratio)
try:
persistence.get_storage().bank_update(bank, create=True)
except ValueError as e:
Expand All @@ -62,7 +68,6 @@ def bank_create_impl(name: str, enabled_ratio: float = 1.0) -> BankConfig:

@bp.route("/bank/<bank_name>", methods=["PUT"])
def bank_update(bank_name: str):
# TODO - rewrite using persistence.get_storage()
storage = persistence.get_storage()
data = request.get_json()
bank = storage.get_bank(bank_name)
Expand Down Expand Up @@ -92,14 +97,57 @@ def bank_delete(bank_name: str):
return {"message": "Done"}


def _validate_bank_add_metadata() -> t.Optional[BankedContentMetadata]:
if not request.is_json:
print("Not json")
return None
j = request.json
print("json: %s" % j)
if not isinstance(j, dict):
print("Not dict")
return None
metadata = j.get("metadata")
if metadata is None:
print("Not meta")
return None
# Validate
if not isinstance(metadata, dict):
abort(400, "metadata should be a json object")
expected_keys = BankedContentMetadata.__optional_keys__.union(
BankedContentMetadata.__required_keys__
)
unexpected = set(metadata).difference(expected_keys)
if unexpected:
abort(
400, f"metadata contains unexpected keys: {' ,'.join(sorted(unexpected))}"
)
return t.cast(BankedContentMetadata, metadata)


@bp.route("/bank/<bank_name>/content", methods=["POST"])
def bank_add_file(bank_name: str):
"""
Add content to a bank by providing a URI to the content (via the `url`
query parameter), or uploading a file (via multipart/form-data).
@see OpenMediaMatch.blueprints.hashing hash_media()
@see OpenMediaMatch.blueprints.hashing hash_media_post()
Inputs:
* The content to be banked, in one of these formats:
1. URI via the `url` query parameter
2. form-data with the proper MIME type set
* Optional metadata about the file in the `metadata` query param as a
json object. All keys are optional:
{
content_id: as a string, assumed (but not enforced) to be unique
content_uri: as a URI. This WILL NOT be automatically populated from
the `url` parameter without being populated, and is
intended to be used for the
json: as a json object, can be anything you plan to need in
the long term
}
Returns: the signatures created and id
{
Expand All @@ -116,19 +164,23 @@ def bank_add_file(bank_name: str):
if not bank:
abort(404, f"bank '{bank_name}' not found")

metadata = _validate_bank_add_metadata()

# Url was passed as a query param?
if request.args.get("url", None):
hashes = hashing.hash_media()
# File uploaded via multipart/form-data?
elif request.files:
hashes = hashing.hash_media_post_impl()
else:
abort(400, "Neither `url` query param nor multipart file upload was received")
return _bank_add_signals(bank, hashes)
abort(400, "Neither `url` nor multipart file upload was received")
return _bank_add_signals(bank, hashes, metadata)


def _bank_add_signals(
bank: BankConfig, signal_type_to_signal_str: dict[str, str]
bank: iface.BankConfig,
signal_type_to_signal_str: dict[str, str],
metadata: t.Optional[BankedContentMetadata],
) -> dict[str, t.Any]:
if not signal_type_to_signal_str:
abort(400, "No signals given")
Expand All @@ -145,11 +197,13 @@ def _bank_add_signals(
signals[st.signal_type] = st.signal_type.validate_signal_str(val)
except Exception as e:
abort(400, f"Invalid {name} signal: {str(e)}")
content_id = storage.bank_add_content(
bank.name,
signals,

content_config = iface.BankContentConfig(
id=0, disable_until_ts=0, collab_metadata={}, original_media_uri=None, bank=bank
)

content_id = storage.bank_add_content(bank.name, signals, content_config)

return {
"id": content_id,
"signals": {st.get_name(): val for st, val in signals.items()},
Expand All @@ -168,7 +222,7 @@ def bank_add_as_signals(bank_name: str):
bank = storage.get_bank(bank_name)
if not bank:
abort(404, f"bank '{bank_name}' not found")
return _bank_add_signals(bank, t.cast(dict[str, str], request.json))
return _bank_add_signals(bank, t.cast(dict[str, str], request.json), None)


def _get_collab(name: str):
Expand Down Expand Up @@ -452,9 +506,9 @@ def signal_type_index_status() -> dict[str, dict[str, t.Any]]:
config.signal_type,
)
if tar is None:
tar = SignalTypeIndexBuildCheckpoint.get_empty()
tar = iface.SignalTypeIndexBuildCheckpoint.get_empty()
if last is None:
last = SignalTypeIndexBuildCheckpoint.get_empty()
last = iface.SignalTypeIndexBuildCheckpoint.get_empty()
ret[name] = {
"db_size": tar.total_hash_count,
"index_size": last.total_hash_count,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,8 @@ def bank_add_content(

bank = self._get_bank(bank_name)
content = database.BankContent(bank=bank)
if config is not None:
content.original_content_uri = config.original_media_uri
sesh.add(content)
for content_signal, value in content_signals.items():
hash = database.ContentSignal(
Expand Down
20 changes: 20 additions & 0 deletions hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,26 @@ def test_banks_add_hash(client: FlaskClient):
}


def test_banks_add_metadata(client: FlaskClient):
bank_name = "NEW_BANK"
create_bank(client, bank_name)

image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"
post_request = f"/c/bank/{bank_name}/content?url={image_url}&content_type=photo"

post_response = client.post(
post_request, json={"metadata": {"invalid_metadata": 5}}
)
assert post_response.status_code == 400, str(post_response.get_json())

post_response = client.post(
post_request,
json={"metadata": {"content_id": "1197433091", "json": {"asdf": {}}}},
)

assert post_response.status_code == 200, str(post_response.get_json())


def test_banks_add_hash_index(app: Flask, client: FlaskClient):
bank_name = "NEW_BANK"
bank_name_2 = "NEW_BANK_2"
Expand Down

0 comments on commit fafe5d4

Please sign in to comment.