Skip to content

Commit

Permalink
Merge pull request #1533 from GSA/notify-api-1531
Browse files Browse the repository at this point in the history
Investigate bulk_save_objects for notification inserts
  • Loading branch information
terrazoon authored Jan 15, 2025
2 parents 67d03dd + 59dfb05 commit 0a7ccc5
Show file tree
Hide file tree
Showing 17 changed files with 278 additions and 46 deletions.
56 changes: 53 additions & 3 deletions app/celery/scheduled_tasks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from datetime import timedelta
import json
from datetime import datetime, timedelta

from flask import current_app
from sqlalchemy import between
from sqlalchemy.exc import SQLAlchemyError

from app import notify_celery, zendesk_client
from app import notify_celery, redis_store, zendesk_client
from app.celery.tasks import (
get_recipient_csv_and_template_and_sender_id,
process_incomplete_jobs,
Expand All @@ -24,6 +25,7 @@
find_missing_row_for_job,
)
from app.dao.notifications_dao import (
dao_batch_insert_notifications,
dao_close_out_delivery_receipts,
dao_update_delivery_receipts,
notifications_not_yet_sent,
Expand All @@ -34,7 +36,7 @@
)
from app.dao.users_dao import delete_codes_older_created_more_than_a_day_ago
from app.enums import JobStatus, NotificationType
from app.models import Job
from app.models import Job, Notification
from app.notifications.process_notifications import send_notification_to_queue
from app.utils import utc_now
from notifications_utils import aware_utcnow
Expand Down Expand Up @@ -288,3 +290,51 @@ def process_delivery_receipts(self):
)
def cleanup_delivery_receipts(self):
dao_close_out_delivery_receipts()


@notify_celery.task(bind=True, name="batch-insert-notifications")
def batch_insert_notifications(self):
batch = []

# TODO We probably need some way to clear the list if
# things go haywire. A command?

# with redis_store.pipeline():
# while redis_store.llen("message_queue") > 0:
# redis_store.lpop("message_queue")
# current_app.logger.info("EMPTY!")
# return
current_len = redis_store.llen("message_queue")
with redis_store.pipeline():
# since this list is being fed by other processes, just grab what is available when
# this call is made and process that.

count = 0
while count < current_len:
count = count + 1
notification_bytes = redis_store.lpop("message_queue")
notification_dict = json.loads(notification_bytes.decode("utf-8"))
notification_dict["status"] = notification_dict.pop("notification_status")
if not notification_dict.get("created_at"):
notification_dict["created_at"] = utc_now()
elif isinstance(notification_dict["created_at"], list):
notification_dict["created_at"] = notification_dict["created_at"][0]
notification = Notification(**notification_dict)
if notification is not None:
batch.append(notification)
try:
dao_batch_insert_notifications(batch)
except Exception:
current_app.logger.exception("Notification batch insert failed")
for n in batch:
# Use 'created_at' as a TTL so we don't retry infinitely
notification_time = n.created_at
if isinstance(notification_time, str):
notification_time = datetime.fromisoformat(n.created_at)
if notification_time < utc_now() - timedelta(seconds=50):
current_app.logger.warning(
f"Abandoning stale data, could not write to db: {n.serialize_for_redis(n)}"
)
continue
else:
redis_store.rpush("message_queue", json.dumps(n.serialize_for_redis(n)))
2 changes: 1 addition & 1 deletion app/celery/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def save_sms(self, service_id, notification_id, encrypted_notification, sender_i
)
)
provider_tasks.deliver_sms.apply_async(
[str(saved_notification.id)], queue=QueueNames.SEND_SMS
[str(saved_notification.id)], queue=QueueNames.SEND_SMS, countdown=60
)

current_app.logger.debug(
Expand Down
11 changes: 11 additions & 0 deletions app/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,17 @@ def _update_template(id, name, template_type, content, subject):
db.session.commit()


@notify_command(name="clear-redis-list")
@click.option("-n", "--name_of_list", required=True)
def clear_redis_list(name_of_list):
my_len_before = redis_store.llen(name_of_list)
redis_store.ltrim(name_of_list, 1, 0)
my_len_after = redis_store.llen(name_of_list)
current_app.logger.info(
f"Cleared redis list {name_of_list}. Before: {my_len_before} after {my_len_after}"
)


@notify_command(name="update-templates")
def update_templates():
with open(current_app.config["CONFIG_FILES"] + "/templates.json") as f:
Expand Down
5 changes: 5 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ class Config(object):
"schedule": timedelta(minutes=82),
"options": {"queue": QueueNames.PERIODIC},
},
"batch-insert-notifications": {
"task": "batch-insert-notifications",
"schedule": 10.0,
"options": {"queue": QueueNames.PERIODIC},
},
"expire-or-delete-invitations": {
"task": "expire-or-delete-invitations",
"schedule": timedelta(minutes=66),
Expand Down
8 changes: 8 additions & 0 deletions app/dao/notifications_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,3 +799,11 @@ def dao_close_out_delivery_receipts():
current_app.logger.info(
f"Marked {result.rowcount} notifications as technical failures"
)


def dao_batch_insert_notifications(batch):

db.session.bulk_save_objects(batch)
db.session.commit()
current_app.logger.info(f"Batch inserted notifications: {len(batch)}")
return len(batch)
29 changes: 28 additions & 1 deletion app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy import CheckConstraint, Index, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.declarative import DeclarativeMeta, declared_attr
from sqlalchemy.orm import validates
from sqlalchemy.orm.collections import attribute_mapped_collection

Expand Down Expand Up @@ -1694,6 +1694,33 @@ def get_created_by_email_address(self):
else:
return None

def serialize_for_redis(self, obj):
if isinstance(obj.__class__, DeclarativeMeta):
fields = {}
for column in obj.__table__.columns:
if column.name == "notification_status":
new_name = "status"
value = getattr(obj, new_name)
elif column.name == "created_at":
if isinstance(obj.created_at, str):
value = obj.created_at
else:
value = (obj.created_at.strftime("%Y-%m-%d %H:%M:%S"),)
elif column.name in ["sent_at", "completed_at"]:
value = None
elif column.name.endswith("_id"):
value = getattr(obj, column.name)
value = str(value)
else:
value = getattr(obj, column.name)
if column.name in ["message_id", "api_key_id"]:
pass # do nothing because we don't have the message id yet
else:
fields[column.name] = value

return fields
raise ValueError("Provided object is not a SQLAlchemy instance")

def serialize_for_csv(self):
serialized = {
"row_number": (
Expand Down
26 changes: 14 additions & 12 deletions app/notifications/process_notifications.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import os
import uuid

from flask import current_app
Expand All @@ -11,7 +13,7 @@
dao_notification_exists,
get_notification_by_id,
)
from app.enums import KeyType, NotificationStatus, NotificationType
from app.enums import NotificationStatus, NotificationType
from app.errors import BadRequestError
from app.models import Notification
from app.utils import hilite, utc_now
Expand Down Expand Up @@ -139,18 +141,18 @@ def persist_notification(

# if simulated create a Notification model to return but do not persist the Notification to the dB
if not simulated:
current_app.logger.info("Firing dao_create_notification")
dao_create_notification(notification)
if key_type != KeyType.TEST and current_app.config["REDIS_ENABLED"]:
current_app.logger.info(
"Redis enabled, querying cache key for service id: {}".format(
service.id
if notification.notification_type == NotificationType.SMS:
# it's just too hard with redis and timing to test this here
if os.getenv("NOTIFY_ENVIRONMENT") == "test":
dao_create_notification(notification)
else:
redis_store.rpush(
"message_queue",
json.dumps(notification.serialize_for_redis(notification)),
)
)
else:
dao_create_notification(notification)

current_app.logger.info(
f"{notification_type} {notification_id} created at {notification_created_at}"
)
return notification


Expand All @@ -172,7 +174,7 @@ def send_notification_to_queue_detached(
deliver_task = provider_tasks.deliver_email

try:
deliver_task.apply_async([str(notification_id)], queue=queue)
deliver_task.apply_async([str(notification_id)], queue=queue, countdown=60)
except Exception:
dao_delete_notifications_by_id(notification_id)
raise
Expand Down
19 changes: 19 additions & 0 deletions notifications_utils/clients/redis/redis_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class RedisClient:
active = False
scripts = {}

def pipeline(self):
return self.redis_store.pipeline()

def init_app(self, app):
self.active = app.config.get("REDIS_ENABLED")
if self.active:
Expand Down Expand Up @@ -156,6 +159,22 @@ def get(self, key, raise_exception=False):

return None

def rpush(self, key, value):
if self.active:
self.redis_store.rpush(key, value)

def lpop(self, key):
if self.active:
return self.redis_store.lpop(key)

def llen(self, key):
if self.active:
return self.redis_store.llen(key)

def ltrim(self, key, start, end):
if self.active:
return self.redis_store.ltrim(key, start, end)

def delete(self, *keys, raise_exception=False):
keys = [prepare_value(k) for k in keys]
if self.active:
Expand Down
107 changes: 104 additions & 3 deletions tests/app/celery/test_scheduled_tasks.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import json
from collections import namedtuple
from datetime import timedelta
from unittest import mock
from unittest.mock import ANY, call
from unittest.mock import ANY, MagicMock, call

import pytest

from app.celery import scheduled_tasks
from app.celery.scheduled_tasks import (
batch_insert_notifications,
check_for_missing_rows_in_completed_jobs,
check_for_services_with_high_failure_rates_or_sending_to_tv_numbers,
check_job_status,
delete_verify_codes,
expire_or_delete_invitations,
process_delivery_receipts,
replay_created_notifications,
run_scheduled_jobs,
)
Expand Down Expand Up @@ -308,10 +311,10 @@ def test_replay_created_notifications(notify_db_session, sample_service, mocker)

replay_created_notifications()
email_delivery_queue.assert_called_once_with(
[str(old_email.id)], queue="send-email-tasks"
[str(old_email.id)], queue="send-email-tasks", countdown=60
)
sms_delivery_queue.assert_called_once_with(
[str(old_sms.id)], queue="send-sms-tasks"
[str(old_sms.id)], queue="send-sms-tasks", countdown=60
)


Expand Down Expand Up @@ -523,3 +526,101 @@ def test_check_for_services_with_high_failure_rates_or_sending_to_tv_numbers(
technical_ticket=True,
)
mock_send_ticket_to_zendesk.assert_called_once()


def test_batch_insert_with_valid_notifications(mocker):
mocker.patch("app.celery.scheduled_tasks.dao_batch_insert_notifications")
rs = MagicMock()
mocker.patch("app.celery.scheduled_tasks.redis_store", rs)
notifications = [
{"id": 1, "notification_status": "pending"},
{"id": 2, "notification_status": "pending"},
]
serialized_notifications = [json.dumps(n).encode("utf-8") for n in notifications]

pipeline_mock = MagicMock()

rs.pipeline.return_value.__enter__.return_value = pipeline_mock
rs.llen.return_value = len(notifications)
rs.lpop.side_effect = serialized_notifications

batch_insert_notifications()

rs.llen.assert_called_once_with("message_queue")
rs.lpop.assert_called_with("message_queue")


def test_batch_insert_with_expired_notifications(mocker):
expired_time = utc_now() - timedelta(minutes=2)
mocker.patch(
"app.celery.scheduled_tasks.dao_batch_insert_notifications",
side_effect=Exception("DB Error"),
)
rs = MagicMock()
mocker.patch("app.celery.scheduled_tasks.redis_store", rs)
notifications = [
{
"id": 1,
"notification_status": "pending",
"created_at": utc_now().isoformat(),
},
{
"id": 2,
"notification_status": "pending",
"created_at": expired_time.isoformat(),
},
]
serialized_notifications = [json.dumps(n).encode("utf-8") for n in notifications]

pipeline_mock = MagicMock()

rs.pipeline.return_value.__enter__.return_value = pipeline_mock
rs.llen.return_value = len(notifications)
rs.lpop.side_effect = serialized_notifications

batch_insert_notifications()

rs.llen.assert_called_once_with("message_queue")
rs.rpush.assert_called_once()
requeued_notification = json.loads(rs.rpush.call_args[0][1])
assert requeued_notification["id"] == 1


def test_batch_insert_with_malformed_notifications(mocker):
rs = MagicMock()
mocker.patch("app.celery.scheduled_tasks.redis_store", rs)
malformed_data = b"not_a_valid_json"
pipeline_mock = MagicMock()

rs.pipeline.return_value.__enter__.return_value = pipeline_mock
rs.llen.return_value = 1
rs.lpop.side_effect = [malformed_data]

with pytest.raises(json.JSONDecodeError):
batch_insert_notifications()

rs.llen.assert_called_once_with("message_queue")
rs.rpush.assert_not_called()


def test_process_delivery_receipts_success(mocker):
dao_update_mock = mocker.patch(
"app.celery.scheduled_tasks.dao_update_delivery_receipts"
)
cloudwatch_mock = mocker.patch("app.celery.scheduled_tasks.AwsCloudwatchClient")
cloudwatch_mock.return_value.check_delivery_receipts.return_value = (
range(2000),
range(500),
)
current_app_mock = mocker.patch("app.celery.scheduled_tasks.current_app")
current_app_mock.return_value = MagicMock()
processor = MagicMock()
processor.process_delivery_receipts = process_delivery_receipts
processor.retry = MagicMock()

processor.process_delivery_receipts()
assert dao_update_mock.call_count == 3
dao_update_mock.assert_any_call(list(range(1000)), True)
dao_update_mock.assert_any_call(list(range(1000, 2000)), True)
dao_update_mock.assert_any_call(list(range(500)), False)
processor.retry.assert_not_called()
Loading

0 comments on commit 0a7ccc5

Please sign in to comment.