From 1e8842d2c2199c7835055fff35ecd08bd73d91f3 Mon Sep 17 00:00:00 2001 From: Maurizio Branca Date: Tue, 24 Sep 2024 22:32:22 +0200 Subject: [PATCH] Cleanup --- handlers/aws/utils.py | 54 +++++++++++---------- tests/handlers/aws/test_utils.py | 82 +++++++++++++++++++------------- 2 files changed, 78 insertions(+), 58 deletions(-) diff --git a/handlers/aws/utils.py b/handlers/aws/utils.py index 17ccc9ee..1a4b9a3d 100644 --- a/handlers/aws/utils.py +++ b/handlers/aws/utils.py @@ -31,36 +31,38 @@ INTEGRATION_SCOPE_GENERIC: str = "generic" -ARN = namedtuple("ARN", ["partition", "service", "region", "account_id", "resource_type", "resource_id"]) +ARN = namedtuple("ARN", ["partition", "service", "region", "account_id", "resource_type", "resource_name", "qualifier"]) def parse_arn(arn: str) -> ARN: """ - Parses an AWS ARN and returns a named tuple with its components. - - :param arn: The ARN string to parse - :return: A named tuple with the parsed ARN components - """ - parts = arn.split(":") - if len(parts) < 6: - raise ValueError("Invalid ARN format") - - partition = parts[1] - service = parts[2] - region = parts[3] - account_id = parts[4] - resource = parts[5] - - # Some ARNs have a resource type and resource ID separated by a slash or colon - if "/" in resource: - resource_type, resource_id = resource.split("/", 1) - elif ":" in resource: - resource_type, resource_id = resource.split(":", 1) - else: - resource_type = "" - resource_id = resource - - return ARN(partition, service, region, account_id, resource_type, resource_id) + Parse an AWS ARN (Amazon Resource Name) into a named tuple. + + Args: + arn (str): The AWS ARN to parse. + + Returns: + A named tuple with the following fields: + - partition: The AWS partition (usually 'aws'). + - service: The AWS service name. + - region: The AWS region (if applicable). + - account_id: The AWS account ID (if applicable). + - resource_type: The type of the resource. + - resource: The name of the resource. + """ + arn_parts = arn.split(":") + if len(arn_parts) < 7: + raise ValueError("Invalid AWS ARN format.") + + ARN = namedtuple("ARN", ["partition", "service", "region", "account_id", "resource_type", "resource"]) + return ARN( + partition=arn_parts[1], + service=arn_parts[2], + region=arn_parts[3] if arn_parts[3] else None, + account_id=arn_parts[4] if arn_parts[4] else None, + resource_type=arn_parts[5], + resource=arn_parts[6], + ) def get_sqs_client() -> BotoBaseClient: diff --git a/tests/handlers/aws/test_utils.py b/tests/handlers/aws/test_utils.py index e0a102c0..e163b41a 100644 --- a/tests/handlers/aws/test_utils.py +++ b/tests/handlers/aws/test_utils.py @@ -426,45 +426,63 @@ def test_cloudwatch_id_less_than_512bytes(self) -> None: @pytest.mark.unit -class TestDescribeRegions(TestCase): +class TestParseARN(TestCase): - @mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock) - def test_cache_miss(self) -> None: - from handlers.aws.utils import describe_regions + def test_parse_lambda_function_arn(self) -> None: + from handlers.aws.utils import parse_arn - # Reset the cache info before running the test - describe_regions.cache_clear() + arn = ( + "arn:aws:lambda:eu-west-1:123456789:function:mbranca-esf-ApplicationElasticServerlessForwarder-vGHtx0b7uzNu" + ) + result = parse_arn(arn) - # First call should be a cache miss - response = describe_regions(all_regions=False) - assert response["Regions"] is not None - assert len(response["Regions"]) == 1 - assert response["Regions"][0]["RegionName"] == "us-west-2" + assert result.service == "lambda" + assert result.region == "eu-west-1" + assert result.account_id == "123456789" + assert result.resource_type == "function" + assert result.resource == "mbranca-esf-ApplicationElasticServerlessForwarder-vGHtx0b7uzNu" - cache_info = describe_regions.cache_info() - assert cache_info.hits == 0 - assert cache_info.misses == 1 - assert cache_info.currsize == 1 +# @pytest.mark.unit +# class TestDescribeRegions(TestCase): - @mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock) - def test_cache_hits(self) -> None: - from handlers.aws.utils import describe_regions +# @mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock) +# def test_cache_miss(self) -> None: +# from handlers.aws.utils import describe_regions - # Reset the cache info before running the test - describe_regions.cache_clear() +# # Reset the cache info before running the test +# describe_regions.cache_clear() - # First call should be a cache miss and populate the cache - # Second and third calls should be cache hits. - response = describe_regions(all_regions=False) - response = describe_regions(all_regions=False) - response = describe_regions(all_regions=False) - assert response["Regions"] is not None - assert len(response["Regions"]) == 1 - assert response["Regions"][0]["RegionName"] == "us-west-2" +# # First call should be a cache miss +# response = describe_regions(all_regions=False) +# assert response["Regions"] is not None +# assert len(response["Regions"]) == 1 +# assert response["Regions"][0]["RegionName"] == "us-west-2" - cache_info = describe_regions.cache_info() +# cache_info = describe_regions.cache_info() - assert cache_info.hits == 2 - assert cache_info.misses == 1 - assert cache_info.currsize == 1 +# assert cache_info.hits == 0 +# assert cache_info.misses == 1 +# assert cache_info.currsize == 1 + +# @mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock) +# def test_cache_hits(self) -> None: +# from handlers.aws.utils import describe_regions + +# # Reset the cache info before running the test +# describe_regions.cache_clear() + +# # First call should be a cache miss and populate the cache +# # Second and third calls should be cache hits. +# response = describe_regions(all_regions=False) +# response = describe_regions(all_regions=False) +# response = describe_regions(all_regions=False) +# assert response["Regions"] is not None +# assert len(response["Regions"]) == 1 +# assert response["Regions"][0]["RegionName"] == "us-west-2" + +# cache_info = describe_regions.cache_info() + +# assert cache_info.hits == 2 +# assert cache_info.misses == 1 +# assert cache_info.currsize == 1