Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate bulk_save_objects for notification inserts #1533

Merged
merged 35 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions app/celery/scheduled_tasks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from datetime import timedelta
import json
from datetime import datetime, timedelta

from flask import current_app
from sqlalchemy import between
from sqlalchemy.exc import SQLAlchemyError

from app import notify_celery, zendesk_client
from app import notify_celery, redis_store, zendesk_client
from app.celery.tasks import (
get_recipient_csv_and_template_and_sender_id,
process_incomplete_jobs,
Expand All @@ -24,6 +25,7 @@
find_missing_row_for_job,
)
from app.dao.notifications_dao import (
dao_batch_insert_notifications,
dao_close_out_delivery_receipts,
dao_update_delivery_receipts,
notifications_not_yet_sent,
Expand All @@ -34,7 +36,7 @@
)
from app.dao.users_dao import delete_codes_older_created_more_than_a_day_ago
from app.enums import JobStatus, NotificationType
from app.models import Job
from app.models import Job, Notification
from app.notifications.process_notifications import send_notification_to_queue
from app.utils import utc_now
from notifications_utils import aware_utcnow
Expand Down Expand Up @@ -286,3 +288,46 @@ def process_delivery_receipts(self):
)
def cleanup_delivery_receipts(self):
dao_close_out_delivery_receipts()


@notify_celery.task(bind=True, name="batch-insert-notifications")
def batch_insert_notifications(self):
batch = []

# TODO We probably need some way to clear the list if
# things go haywire. A command?

ccostino marked this conversation as resolved.
Show resolved Hide resolved
# with redis_store.pipeline():
# while redis_store.llen("message_queue") > 0:
# redis_store.lpop("message_queue")
# current_app.logger.info("EMPTY!")
# return
current_len = redis_store.llen("message_queue")
with redis_store.pipeline():
ccostino marked this conversation as resolved.
Show resolved Hide resolved
# since this list is being fed by other processes, just grab what is available when
# this call is made and process that.

count = 0
while count < current_len:
count = count + 1
notification_bytes = redis_store.lpop("message_queue")
notification_dict = json.loads(notification_bytes.decode("utf-8"))
notification_dict["status"] = notification_dict.pop("notification_status")
if not notification_dict.get("created_at"):
notification_dict["created_at"] = utc_now()
notification = Notification(**notification_dict)
if notification is not None:
batch.append(notification)
try:
dao_batch_insert_notifications(batch)
except Exception:
current_app.logger.exception("Notification batch insert failed")
for n in batch:
# Use 'created_at' as a TTL so we don't retry infinitely
if datetime.fromisoformat(n.created_at) < utc_now() - timedelta(seconds=50):
current_app.logger.warning(
f"Abandoning stale data, could not write to db: {n.serialize_for_redis(n)}"
)
continue
else:
redis_store.rpush("message_queue", json.dumps(n.serialize_for_redis(n)))
2 changes: 1 addition & 1 deletion app/celery/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def save_sms(self, service_id, notification_id, encrypted_notification, sender_i
)
)
provider_tasks.deliver_sms.apply_async(
[str(saved_notification.id)], queue=QueueNames.SEND_SMS
[str(saved_notification.id)], queue=QueueNames.SEND_SMS, countdown=60
)

current_app.logger.debug(
Expand Down
5 changes: 5 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ class Config(object):
"schedule": timedelta(minutes=82),
"options": {"queue": QueueNames.PERIODIC},
},
"batch-insert-notifications": {
"task": "batch-insert-notifications",
"schedule": 10.0,
"options": {"queue": QueueNames.PERIODIC},
},
"expire-or-delete-invitations": {
"task": "expire-or-delete-invitations",
"schedule": timedelta(minutes=66),
Expand Down
8 changes: 8 additions & 0 deletions app/dao/notifications_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,3 +799,11 @@ def dao_close_out_delivery_receipts():
current_app.logger.info(
f"Marked {result.rowcount} notifications as technical failures"
)


def dao_batch_insert_notifications(batch):

db.session.bulk_save_objects(batch)
db.session.commit()
current_app.logger.info(f"Batch inserted notifications: {len(batch)}")
return len(batch)
31 changes: 30 additions & 1 deletion app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy import CheckConstraint, Index, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.declarative import DeclarativeMeta, declared_attr
from sqlalchemy.orm import validates
from sqlalchemy.orm.collections import attribute_mapped_collection

Expand Down Expand Up @@ -1694,6 +1694,35 @@ def get_created_by_email_address(self):
else:
return None

def serialize_for_redis(self, obj):
if isinstance(obj.__class__, DeclarativeMeta):
fields = {}
for column in obj.__table__.columns:
if column.name == "notification_status":
new_name = "status"
value = getattr(obj, new_name)
elif column.name == "created_at":
if isinstance(obj.created_at, str):
value = obj.created_at
else:
value = (obj.created_at.strftime("%Y-%m-%d %H:%M:%S"),)
elif column.name in ["sent_at", "completed_at"]:
value = None
elif column.name.endswith("_id"):
value = getattr(obj, column.name)
value = str(value)
else:
value = getattr(obj, column.name)
if column.name in ["message_id", "api_key_id"]:
pass # do nothing because we don't have the message id yet
else:
fields[column.name] = value
current_app.logger.warning(f"FIELDS {fields}")
print(f"FIELDS {fields}", flush=True)
ccostino marked this conversation as resolved.
Show resolved Hide resolved

return fields
raise ValueError("Provided object is not a SQLAlchemy instance")

def serialize_for_csv(self):
serialized = {
"row_number": (
Expand Down
26 changes: 14 additions & 12 deletions app/notifications/process_notifications.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import os
import uuid

from flask import current_app
Expand All @@ -11,7 +13,7 @@
dao_notification_exists,
get_notification_by_id,
)
from app.enums import KeyType, NotificationStatus, NotificationType
from app.enums import NotificationStatus, NotificationType
from app.errors import BadRequestError
from app.models import Notification
from app.utils import hilite, utc_now
Expand Down Expand Up @@ -139,18 +141,18 @@ def persist_notification(

# if simulated create a Notification model to return but do not persist the Notification to the dB
if not simulated:
current_app.logger.info("Firing dao_create_notification")
dao_create_notification(notification)
if key_type != KeyType.TEST and current_app.config["REDIS_ENABLED"]:
current_app.logger.info(
"Redis enabled, querying cache key for service id: {}".format(
service.id
if notification.notification_type == NotificationType.SMS:
# it's just too hard with redis and timing to test this here
if os.getenv("NOTIFY_ENVIRONMENT") == "test":
dao_create_notification(notification)
else:
redis_store.rpush(
"message_queue",
json.dumps(notification.serialize_for_redis(notification)),
)
)
else:
dao_create_notification(notification)

current_app.logger.info(
f"{notification_type} {notification_id} created at {notification_created_at}"
)
return notification


Expand All @@ -172,7 +174,7 @@ def send_notification_to_queue_detached(
deliver_task = provider_tasks.deliver_email

try:
deliver_task.apply_async([str(notification_id)], queue=queue)
deliver_task.apply_async([str(notification_id)], queue=queue, countdown=60)
except Exception:
dao_delete_notifications_by_id(notification_id)
raise
Expand Down
16 changes: 16 additions & 0 deletions notifications_utils/clients/redis/redis_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class RedisClient:
active = False
scripts = {}

@classmethod
ccostino marked this conversation as resolved.
Show resolved Hide resolved
def pipeline(cls):
return cls.redis_store.pipeline()

def init_app(self, app):
self.active = app.config.get("REDIS_ENABLED")
if self.active:
Expand Down Expand Up @@ -156,6 +160,18 @@ def get(self, key, raise_exception=False):

return None

def rpush(self, key, value):
if self.active:
self.redis_store.rpush(key, value)

def lpop(self, key):
if self.active:
return self.redis_store.lpop(key)

def llen(self, key):
if self.active:
return self.redis_store.llen(key)

def delete(self, *keys, raise_exception=False):
keys = [prepare_value(k) for k in keys]
if self.active:
Expand Down
107 changes: 104 additions & 3 deletions tests/app/celery/test_scheduled_tasks.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import json
from collections import namedtuple
from datetime import timedelta
from unittest import mock
from unittest.mock import ANY, call
from unittest.mock import ANY, MagicMock, call

import pytest

from app.celery import scheduled_tasks
from app.celery.scheduled_tasks import (
batch_insert_notifications,
check_for_missing_rows_in_completed_jobs,
check_for_services_with_high_failure_rates_or_sending_to_tv_numbers,
check_job_status,
delete_verify_codes,
expire_or_delete_invitations,
process_delivery_receipts,
replay_created_notifications,
run_scheduled_jobs,
)
Expand Down Expand Up @@ -308,10 +311,10 @@ def test_replay_created_notifications(notify_db_session, sample_service, mocker)

replay_created_notifications()
email_delivery_queue.assert_called_once_with(
[str(old_email.id)], queue="send-email-tasks"
[str(old_email.id)], queue="send-email-tasks", countdown=60
)
sms_delivery_queue.assert_called_once_with(
[str(old_sms.id)], queue="send-sms-tasks"
[str(old_sms.id)], queue="send-sms-tasks", countdown=60
)


Expand Down Expand Up @@ -523,3 +526,101 @@ def test_check_for_services_with_high_failure_rates_or_sending_to_tv_numbers(
technical_ticket=True,
)
mock_send_ticket_to_zendesk.assert_called_once()


def test_batch_insert_with_valid_notifications(mocker):
mocker.patch("app.celery.scheduled_tasks.dao_batch_insert_notifications")
rs = MagicMock()
mocker.patch("app.celery.scheduled_tasks.redis_store", rs)
notifications = [
{"id": 1, "notification_status": "pending"},
{"id": 2, "notification_status": "pending"},
]
serialized_notifications = [json.dumps(n).encode("utf-8") for n in notifications]

pipeline_mock = MagicMock()

rs.pipeline.return_value.__enter__.return_value = pipeline_mock
rs.llen.return_value = len(notifications)
rs.lpop.side_effect = serialized_notifications

batch_insert_notifications()

rs.llen.assert_called_once_with("message_queue")
rs.lpop.assert_called_with("message_queue")


def test_batch_insert_with_expired_notifications(mocker):
expired_time = utc_now() - timedelta(minutes=2)
mocker.patch(
"app.celery.scheduled_tasks.dao_batch_insert_notifications",
side_effect=Exception("DB Error"),
)
rs = MagicMock()
mocker.patch("app.celery.scheduled_tasks.redis_store", rs)
notifications = [
{
"id": 1,
"notification_status": "pending",
"created_at": utc_now().isoformat(),
},
{
"id": 2,
"notification_status": "pending",
"created_at": expired_time.isoformat(),
},
]
serialized_notifications = [json.dumps(n).encode("utf-8") for n in notifications]

pipeline_mock = MagicMock()

rs.pipeline.return_value.__enter__.return_value = pipeline_mock
rs.llen.return_value = len(notifications)
rs.lpop.side_effect = serialized_notifications

batch_insert_notifications()

rs.llen.assert_called_once_with("message_queue")
rs.rpush.assert_called_once()
requeued_notification = json.loads(rs.rpush.call_args[0][1])
assert requeued_notification["id"] == 1


def test_batch_insert_with_malformed_notifications(mocker):
rs = MagicMock()
mocker.patch("app.celery.scheduled_tasks.redis_store", rs)
malformed_data = b"not_a_valid_json"
pipeline_mock = MagicMock()

rs.pipeline.return_value.__enter__.return_value = pipeline_mock
rs.llen.return_value = 1
rs.lpop.side_effect = [malformed_data]

with pytest.raises(json.JSONDecodeError):
batch_insert_notifications()

rs.llen.assert_called_once_with("message_queue")
rs.rpush.assert_not_called()


def test_process_delivery_receipts_success(mocker):
dao_update_mock = mocker.patch(
"app.celery.scheduled_tasks.dao_update_delivery_receipts"
)
cloudwatch_mock = mocker.patch("app.celery.scheduled_tasks.AwsCloudwatchClient")
cloudwatch_mock.return_value.check_delivery_receipts.return_value = (
range(2000),
range(500),
)
current_app_mock = mocker.patch("app.celery.scheduled_tasks.current_app")
current_app_mock.return_value = MagicMock()
processor = MagicMock()
processor.process_delivery_receipts = process_delivery_receipts
processor.retry = MagicMock()

processor.process_delivery_receipts()
assert dao_update_mock.call_count == 3
dao_update_mock.assert_any_call(list(range(1000)), True)
dao_update_mock.assert_any_call(list(range(1000, 2000)), True)
dao_update_mock.assert_any_call(list(range(500)), False)
processor.retry.assert_not_called()
Loading
Loading