Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove call to EC2:DescribeRegions API #811

Merged
merged 15 commits into from
Sep 25, 2024
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
### v1.17.2 - 2024/09/24
##### Bug fixes
* Remove call to EC2:DescribeRegions API in the cloudwatch-logs input [811](https://github.com/elastic/elastic-serverless-forwarder/pull/811).

### v1.17.1 - 2024/09/23
##### Bug fixes
* Cache EC2:DescribeRegion API response to avoid throttling and improve performance [803](https://github.com/elastic/elastic-serverless-forwarder/pull/803).
* Cache EC2:DescribeRegions API response to avoid throttling and improve performance [803](https://github.com/elastic/elastic-serverless-forwarder/pull/803).

### v1.17.0 - 2024/07/10
##### Features
Expand Down
7 changes: 6 additions & 1 deletion handlers/aws/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
expand_event_list_from_field_resolver,
get_continuing_original_input_type,
get_input_from_log_group_subscription_data,
get_lambda_region,
get_shipper_from_input,
get_sqs_client,
get_trigger_type_and_config_source,
Expand All @@ -48,7 +49,6 @@ def lambda_handler(lambda_event: dict[str, Any], lambda_context: context_.Contex
AWS Lambda handler in handler.aws package
Parses the config and acts as front controller for inputs
"""

shared_logger.debug("lambda triggered", extra={"invoked_function_arn": lambda_context.invoked_function_arn})

try:
Expand Down Expand Up @@ -144,11 +144,16 @@ def lambda_handler(lambda_event: dict[str, Any], lambda_context: context_.Contex

shared_logger.info("trigger", extra={"size": len(cloudwatch_logs_event["logEvents"])})

lambda_region = get_lambda_region()

input_id, event_input = get_input_from_log_group_subscription_data(
config,
cloudwatch_logs_event["owner"],
cloudwatch_logs_event["logGroup"],
cloudwatch_logs_event["logStream"],
# As of today, the cloudwatch trigger is always in
# the same region as the lambda function.
lambda_region,
)

if event_input is None:
Expand Down
77 changes: 41 additions & 36 deletions handlers/aws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# or more contributor license agreements. Licensed under the Elastic License 2.0;
# you may not use this file except in compliance with the Elastic License 2.0.
import os
from functools import lru_cache
from typing import Any, Callable, Optional

import boto3
Expand Down Expand Up @@ -32,6 +31,27 @@
INTEGRATION_SCOPE_GENERIC: str = "generic"


def get_lambda_region() -> str:
"""
Get the AWS region where the Lambda function is running.

Returns the value of the `AWS_REGION` environment variable. If the
`AWS_REGION` variable is not set, it returns the value of the
`AWS_DEFAULT_REGION` variable.

If neither variable is set, it raises a `ValueError`.

Returns:
str: The AWS region.
"""
region = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION")

if region is None:
raise ValueError("AWS region not found in environment variables.")

return region


def get_sqs_client() -> BotoBaseClient:
"""
Getter for sqs client
Expand Down Expand Up @@ -386,49 +406,34 @@ def get_account_id_from_arn(lambda_arn: str) -> str:
return arn_components[4]


@lru_cache()
def describe_regions(all_regions: bool = True) -> Any:
"""
Fetches all regions from AWS and returns the response.

:return: The response from the describe_regions method
"""
return get_ec2_client().describe_regions(AllRegions=all_regions)


def get_input_from_log_group_subscription_data(
config: Config, account_id: str, log_group_name: str, log_stream_name: str
config: Config, account_id: str, log_group_name: str, log_stream_name: str, region: str
) -> tuple[str, Optional[Input]]:
"""
This function is not less resilient than the previous get_log_group_arn_and_region_from_log_group_name()
We avoid to call the describe_log_streams on the logs' client, since we have no way to apply the proper
throttling because we'd need to know the number of concurrent lambda running at the time of the call.
In order to not hardcode the list of regions we rely on ec2 DescribeRegions - as much weird as it is - that I found
no information about having any kind of throttling. We add IAM permissions for it in deployment.
"""
all_regions = describe_regions(all_regions=True)
assert "Regions" in all_regions
for region_data in all_regions["Regions"]:
region = region_data["RegionName"]
Look up for the input in the configuration using the information
from the log event.

aws_or_gov = "aws"
if "gov" in region:
aws_or_gov = "aws-us-gov"
It looks for the log stream arn, if not found it looks for the
log group arn.
"""
partition = "aws"
if "gov" in region:
partition = "aws-us-gov"

log_stream_arn = (
f"arn:{aws_or_gov}:logs:{region}:{account_id}:log-group:{log_group_name}:log-stream:{log_stream_name}"
)
event_input = config.get_input_by_id(log_stream_arn)
log_stream_arn = (
f"arn:{partition}:logs:{region}:{account_id}:log-group:{log_group_name}:log-stream:{log_stream_name}"
)
event_input = config.get_input_by_id(log_stream_arn)

if event_input is not None:
return log_stream_arn, event_input
if event_input is not None:
return log_stream_arn, event_input

log_group_arn_components = log_stream_arn.split(":")
log_group_arn = f"{':'.join(log_group_arn_components[:-2])}:*"
event_input = config.get_input_by_id(log_group_arn)
log_group_arn_components = log_stream_arn.split(":")
log_group_arn = f"{':'.join(log_group_arn_components[:-2])}:*"
event_input = config.get_input_by_id(log_group_arn)

if event_input is not None:
return log_group_arn, event_input
if event_input is not None:
return log_group_arn, event_input

return f"arn:aws:logs:%AWS_REGION%:{account_id}:log-group:{log_group_name}:*", None

Expand Down
2 changes: 1 addition & 1 deletion share/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# or more contributor license agreements. Licensed under the Elastic License 2.0;
# you may not use this file except in compliance with the Elastic License 2.0.

version = "1.17.1"
version = "1.17.2"
2 changes: 2 additions & 0 deletions tests/handlers/aws/test_integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def setUpClass(cls) -> None:

cls.localstack = lsc.start()

os.environ["AWS_REGION"] = _AWS_REGION

session = boto3.Session(region_name=_AWS_REGION)
cls.aws_session = session
cls.s3_client = session.client("s3", endpoint_url=cls.localstack.get_url())
Expand Down
56 changes: 24 additions & 32 deletions tests/handlers/aws/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# you may not use this file except in compliance with the Elastic License 2.0.


import os
import random
import string
from datetime import datetime
Expand Down Expand Up @@ -426,45 +427,36 @@ def test_cloudwatch_id_less_than_512bytes(self) -> None:


@pytest.mark.unit
class TestDescribeRegions(TestCase):
class TestGetLambdaRegion(TestCase):

@mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock)
def test_cache_miss(self) -> None:
from handlers.aws.utils import describe_regions
def test_with_aws_region(self) -> None:
from handlers.aws.utils import get_lambda_region

# Reset the cache info before running the test
describe_regions.cache_clear()
os.environ["AWS_REGION"] = "us-west-1"
os.environ["AWS_DEFAULT_REGION"] = "us-west-2"

# First call should be a cache miss
response = describe_regions(all_regions=False)
assert response["Regions"] is not None
assert len(response["Regions"]) == 1
assert response["Regions"][0]["RegionName"] == "us-west-2"
region = get_lambda_region()

cache_info = describe_regions.cache_info()
assert region == "us-west-1"

assert cache_info.hits == 0
assert cache_info.misses == 1
assert cache_info.currsize == 1
def test_with_aws_default_region(self) -> None:
from handlers.aws.utils import get_lambda_region

@mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock)
def test_cache_hits(self) -> None:
from handlers.aws.utils import describe_regions
if "AWS_REGION" in os.environ:
del os.environ["AWS_REGION"]
os.environ["AWS_DEFAULT_REGION"] = "us-west-2"

# Reset the cache info before running the test
describe_regions.cache_clear()
region = get_lambda_region()

# First call should be a cache miss and populate the cache
# Second and third calls should be cache hits.
response = describe_regions(all_regions=False)
response = describe_regions(all_regions=False)
response = describe_regions(all_regions=False)
assert response["Regions"] is not None
assert len(response["Regions"]) == 1
assert response["Regions"][0]["RegionName"] == "us-west-2"
assert region == "us-west-2"

cache_info = describe_regions.cache_info()
def test_without_variables(self) -> None:
from handlers.aws.utils import get_lambda_region

assert cache_info.hits == 2
assert cache_info.misses == 1
assert cache_info.currsize == 1
if "AWS_REGION" in os.environ:
del os.environ["AWS_REGION"]
if "AWS_DEFAULT_REGION" in os.environ:
del os.environ["AWS_DEFAULT_REGION"]

with pytest.raises(ValueError):
get_lambda_region()
Loading