Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
zmoog committed Sep 24, 2024
1 parent 249567f commit 1e8842d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 58 deletions.
54 changes: 28 additions & 26 deletions handlers/aws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,36 +31,38 @@

INTEGRATION_SCOPE_GENERIC: str = "generic"

ARN = namedtuple("ARN", ["partition", "service", "region", "account_id", "resource_type", "resource_id"])
ARN = namedtuple("ARN", ["partition", "service", "region", "account_id", "resource_type", "resource_name", "qualifier"])


def parse_arn(arn: str) -> ARN:
"""
Parses an AWS ARN and returns a named tuple with its components.
:param arn: The ARN string to parse
:return: A named tuple with the parsed ARN components
"""
parts = arn.split(":")
if len(parts) < 6:
raise ValueError("Invalid ARN format")

partition = parts[1]
service = parts[2]
region = parts[3]
account_id = parts[4]
resource = parts[5]

# Some ARNs have a resource type and resource ID separated by a slash or colon
if "/" in resource:
resource_type, resource_id = resource.split("/", 1)
elif ":" in resource:
resource_type, resource_id = resource.split(":", 1)
else:
resource_type = ""
resource_id = resource

return ARN(partition, service, region, account_id, resource_type, resource_id)
Parse an AWS ARN (Amazon Resource Name) into a named tuple.
Args:
arn (str): The AWS ARN to parse.
Returns:
A named tuple with the following fields:
- partition: The AWS partition (usually 'aws').
- service: The AWS service name.
- region: The AWS region (if applicable).
- account_id: The AWS account ID (if applicable).
- resource_type: The type of the resource.
- resource: The name of the resource.
"""
arn_parts = arn.split(":")
if len(arn_parts) < 7:
raise ValueError("Invalid AWS ARN format.")

ARN = namedtuple("ARN", ["partition", "service", "region", "account_id", "resource_type", "resource"])
return ARN(
partition=arn_parts[1],
service=arn_parts[2],
region=arn_parts[3] if arn_parts[3] else None,
account_id=arn_parts[4] if arn_parts[4] else None,
resource_type=arn_parts[5],
resource=arn_parts[6],
)


def get_sqs_client() -> BotoBaseClient:
Expand Down
82 changes: 50 additions & 32 deletions tests/handlers/aws/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,45 +426,63 @@ def test_cloudwatch_id_less_than_512bytes(self) -> None:


@pytest.mark.unit
class TestDescribeRegions(TestCase):
class TestParseARN(TestCase):

@mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock)
def test_cache_miss(self) -> None:
from handlers.aws.utils import describe_regions
def test_parse_lambda_function_arn(self) -> None:
from handlers.aws.utils import parse_arn

# Reset the cache info before running the test
describe_regions.cache_clear()
arn = (
"arn:aws:lambda:eu-west-1:123456789:function:mbranca-esf-ApplicationElasticServerlessForwarder-vGHtx0b7uzNu"
)
result = parse_arn(arn)

# First call should be a cache miss
response = describe_regions(all_regions=False)
assert response["Regions"] is not None
assert len(response["Regions"]) == 1
assert response["Regions"][0]["RegionName"] == "us-west-2"
assert result.service == "lambda"
assert result.region == "eu-west-1"
assert result.account_id == "123456789"
assert result.resource_type == "function"
assert result.resource == "mbranca-esf-ApplicationElasticServerlessForwarder-vGHtx0b7uzNu"

cache_info = describe_regions.cache_info()

assert cache_info.hits == 0
assert cache_info.misses == 1
assert cache_info.currsize == 1
# @pytest.mark.unit
# class TestDescribeRegions(TestCase):

@mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock)
def test_cache_hits(self) -> None:
from handlers.aws.utils import describe_regions
# @mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock)
# def test_cache_miss(self) -> None:
# from handlers.aws.utils import describe_regions

# Reset the cache info before running the test
describe_regions.cache_clear()
# # Reset the cache info before running the test
# describe_regions.cache_clear()

# First call should be a cache miss and populate the cache
# Second and third calls should be cache hits.
response = describe_regions(all_regions=False)
response = describe_regions(all_regions=False)
response = describe_regions(all_regions=False)
assert response["Regions"] is not None
assert len(response["Regions"]) == 1
assert response["Regions"][0]["RegionName"] == "us-west-2"
# # First call should be a cache miss
# response = describe_regions(all_regions=False)
# assert response["Regions"] is not None
# assert len(response["Regions"]) == 1
# assert response["Regions"][0]["RegionName"] == "us-west-2"

cache_info = describe_regions.cache_info()
# cache_info = describe_regions.cache_info()

assert cache_info.hits == 2
assert cache_info.misses == 1
assert cache_info.currsize == 1
# assert cache_info.hits == 0
# assert cache_info.misses == 1
# assert cache_info.currsize == 1

# @mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock)
# def test_cache_hits(self) -> None:
# from handlers.aws.utils import describe_regions

# # Reset the cache info before running the test
# describe_regions.cache_clear()

# # First call should be a cache miss and populate the cache
# # Second and third calls should be cache hits.
# response = describe_regions(all_regions=False)
# response = describe_regions(all_regions=False)
# response = describe_regions(all_regions=False)
# assert response["Regions"] is not None
# assert len(response["Regions"]) == 1
# assert response["Regions"][0]["RegionName"] == "us-west-2"

# cache_info = describe_regions.cache_info()

# assert cache_info.hits == 2
# assert cache_info.misses == 1
# assert cache_info.currsize == 1

0 comments on commit 1e8842d

Please sign in to comment.