Skip to content

Commit

Permalink
Cache describe_regions using lru_cache from stdlib (#803)
Browse files Browse the repository at this point in the history
Caches EC2:DescribeRegion API calls response.

On high-volume deployments, ESF can hit the EC2:DescribeRegions API requests limit, causing throttling errors like the following:

```text
An error occurred (RequestLimitExceeded) when calling the DescribeRegions operation (reached max retries: 4): Request limit exceeded.
```

ESF uses the list of existing regions to parse incoming events from the `cloudwatch-logs` input. Since new AWS region additions do not happen frequently, picking up and caching the list of existing regions at function startup seems adequate.
  • Loading branch information
zmoog authored Sep 23, 2024
1 parent 8be4fc4 commit 29c08f4
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
### v1.17.1 - 2024/09/23
##### Bug fixes
* Cache EC2:DescribeRegion API response to avoid throttling and improve performance [803](https://github.com/elastic/elastic-serverless-forwarder/pull/803).

### v1.17.0 - 2024/07/10
##### Features
* Add dead letter index for ES outputs [733](https://github.com/elastic/elastic-serverless-forwarder/pull/733).
Expand Down
16 changes: 12 additions & 4 deletions handlers/aws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# or more contributor license agreements. Licensed under the Elastic License 2.0;
# you may not use this file except in compliance with the Elastic License 2.0.
import os
from functools import lru_cache
from typing import Any, Callable, Optional

import boto3
Expand Down Expand Up @@ -385,6 +386,16 @@ def get_account_id_from_arn(lambda_arn: str) -> str:
return arn_components[4]


@lru_cache()
def describe_regions(all_regions: bool = True) -> Any:
"""
Fetches all regions from AWS and returns the response.
:return: The response from the describe_regions method
"""
return get_ec2_client().describe_regions(AllRegions=all_regions)


def get_input_from_log_group_subscription_data(
config: Config, account_id: str, log_group_name: str, log_stream_name: str
) -> tuple[str, Optional[Input]]:
Expand All @@ -395,12 +406,9 @@ def get_input_from_log_group_subscription_data(
In order to not hardcode the list of regions we rely on ec2 DescribeRegions - as much weird as it is - that I found
no information about having any kind of throttling. We add IAM permissions for it in deployment.
"""
all_regions = get_ec2_client().describe_regions(AllRegions=True)
all_regions = describe_regions(all_regions=True)
assert "Regions" in all_regions
for region_data in all_regions["Regions"]:

# arn:aws:logs:region:account-id:log-group:log_group_name:*

region = region_data["RegionName"]

aws_or_gov = "aws"
Expand Down
2 changes: 1 addition & 1 deletion share/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# or more contributor license agreements. Licensed under the Elastic License 2.0;
# you may not use this file except in compliance with the Elastic License 2.0.

version = "1.17.0"
version = "1.17.1"
60 changes: 60 additions & 0 deletions tests/handlers/aws/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any
from unittest import TestCase

import mock
import pytest

from handlers.aws.utils import (
Expand Down Expand Up @@ -66,6 +67,20 @@ def _get_random_digit_string_of_size(size: int) -> str:
return "".join(random.choices(string.digits, k=size))


def _describe_regions(AllRegions: bool) -> dict[str, Any]:
return {
"Regions": [
{
"RegionName": "us-west-2",
},
]
}


_ec2_client_mock = mock.MagicMock()
_ec2_client_mock.describe_regions = _describe_regions


@pytest.mark.unit
class TestGetTriggerTypeAndConfigSource(TestCase):
def test_get_trigger_type_and_config_source(self) -> None:
Expand Down Expand Up @@ -408,3 +423,48 @@ def test_cloudwatch_id_less_than_512bytes(self) -> None:

generated_id = cloudwatch_logs_object_id(relevant_fields_for_id)
assert _utf8len(generated_id) <= MAX_ES_ID_SIZ_BYTES


@pytest.mark.unit
class TestDescribeRegions(TestCase):

@mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock)
def test_cache_miss(self) -> None:
from handlers.aws.utils import describe_regions

# Reset the cache info before running the test
describe_regions.cache_clear()

# First call should be a cache miss
response = describe_regions(all_regions=False)
assert response["Regions"] is not None
assert len(response["Regions"]) == 1
assert response["Regions"][0]["RegionName"] == "us-west-2"

cache_info = describe_regions.cache_info()

assert cache_info.hits == 0
assert cache_info.misses == 1
assert cache_info.currsize == 1

@mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock)
def test_cache_hits(self) -> None:
from handlers.aws.utils import describe_regions

# Reset the cache info before running the test
describe_regions.cache_clear()

# First call should be a cache miss and populate the cache
# Second and third calls should be cache hits.
response = describe_regions(all_regions=False)
response = describe_regions(all_regions=False)
response = describe_regions(all_regions=False)
assert response["Regions"] is not None
assert len(response["Regions"]) == 1
assert response["Regions"][0]["RegionName"] == "us-west-2"

cache_info = describe_regions.cache_info()

assert cache_info.hits == 2
assert cache_info.misses == 1
assert cache_info.currsize == 1

0 comments on commit 29c08f4

Please sign in to comment.