Skip to content

Commit

Permalink
Improve and simplify parse_arn
Browse files Browse the repository at this point in the history
  • Loading branch information
zmoog committed Sep 24, 2024
1 parent f5bf848 commit 811567c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 56 deletions.
46 changes: 25 additions & 21 deletions handlers/aws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

INTEGRATION_SCOPE_GENERIC: str = "generic"

ARN = namedtuple("ARN", ["partition", "service", "region", "account_id", "resource_type", "resource"])
ARN = namedtuple("ARN", ["partition", "service", "region", "account_id", "resource"])


def parse_arn(arn: str) -> ARN:
Expand All @@ -50,18 +50,25 @@ def parse_arn(arn: str) -> ARN:
- resource_type: The type of the resource.
- resource: The name of the resource.
"""
arn_parts = arn.split(":")
if len(arn_parts) < 7:
# we split only 6 times to keep the resource as a single string
# even if it contains colons.
#
# For example, an CloudWatch log group ARN looks like this:
# arn:aws:logs:eu-west-1:627286350134:log-group:/aws/lambda/mbranca-esf-vGHtx0b7uzNu:*
#
# The resource is the log group name, which contains colons.
# If we split more than 6 times, we would get a wrong resource name.
arn_parts = arn.split(":", 5)
if len(arn_parts) < 6:
raise ValueError("Invalid AWS ARN format.")

return ARN(
partition=arn_parts[1],
service=arn_parts[2],
region=arn_parts[3] if arn_parts[3] else None,
account_id=arn_parts[4] if arn_parts[4] else None,
resource_type=arn_parts[5],
resource=arn_parts[6],
)
partition = arn_parts[1]
service = arn_parts[2]
region = arn_parts[3] if arn_parts[3] else None
account_id = arn_parts[4] if arn_parts[4] else None
resource = arn_parts[5]

return ARN(partition, service, region, account_id, resource)


def get_sqs_client() -> BotoBaseClient:
Expand Down Expand Up @@ -418,19 +425,16 @@ def get_account_id_from_arn(lambda_arn: str) -> str:
return arn_components[4]


# @lru_cache()
# def describe_regions(all_regions: bool = True) -> Any:
# """
# Fetches all regions from AWS and returns the response.

# :return: The response from the describe_regions method
# """
# return get_ec2_client().describe_regions(AllRegions=all_regions)


def get_input_from_log_group_subscription_data(
config: Config, account_id: str, log_group_name: str, log_stream_name: str, region: str
) -> tuple[str, Optional[Input]]:
"""
Look up for the input in the configuration using the information
from the log event.
It looks for the log stream arn, if not found it looks for the
log group arn.
"""
partition = "aws"
if "gov" in region:
partition = "aws-us-gov"
Expand Down
81 changes: 46 additions & 35 deletions tests/handlers/aws/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,50 +439,61 @@ def test_parse_lambda_function_arn(self) -> None:
assert result.service == "lambda"
assert result.region == "eu-west-1"
assert result.account_id == "123456789"
assert result.resource_type == "function"
assert result.resource == "mbranca-esf-ApplicationElasticServerlessForwarder-vGHtx0b7uzNu"
assert result.resource == "function:mbranca-esf-ApplicationElasticServerlessForwarder-vGHtx0b7uzNu"

def test_parse_log_group(self) -> None:
from handlers.aws.utils import parse_arn

arn = "arn:aws:logs:eu-west-1:123456789:log-group:/aws/lambda/mbranca-esf-vGHtx0b7uzNu:*"
result = parse_arn(arn)

print(result)

assert result.service == "logs"
assert result.region == "eu-west-1"
assert result.account_id == "123456789"
assert result.resource == "log-group:/aws/lambda/mbranca-esf-vGHtx0b7uzNu:*"

def test_parse_s3_bucket(self) -> None:
from handlers.aws.utils import parse_arn

# @pytest.mark.unit
# class TestDescribeRegions(TestCase):
arn = "arn:aws:s3:::mbranca-esf-data"
result = parse_arn(arn)

# @mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock)
# def test_cache_miss(self) -> None:
# from handlers.aws.utils import describe_regions
assert result.service == "s3"
assert result.region is None
assert result.account_id is None
assert result.resource == "mbranca-esf-data"

# # Reset the cache info before running the test
# describe_regions.cache_clear()
def test_parse_iam_user(self) -> None:
from handlers.aws.utils import parse_arn

# # First call should be a cache miss
# response = describe_regions(all_regions=False)
# assert response["Regions"] is not None
# assert len(response["Regions"]) == 1
# assert response["Regions"][0]["RegionName"] == "us-west-2"
arn = "arn:aws:iam::123456789012:user/johndoe"
result = parse_arn(arn)

# cache_info = describe_regions.cache_info()
assert result.service == "iam"
assert result.region is None
assert result.account_id == "123456789012"
assert result.resource == "user/johndoe"

# assert cache_info.hits == 0
# assert cache_info.misses == 1
# assert cache_info.currsize == 1
def test_parse_sns_topic(self) -> None:
from handlers.aws.utils import parse_arn

# @mock.patch("handlers.aws.utils.get_ec2_client", lambda: _ec2_client_mock)
# def test_cache_hits(self) -> None:
# from handlers.aws.utils import describe_regions
arn = "arn:aws:sns:us-east-1:123456789012:example-sns-topic-name"
result = parse_arn(arn)

# # Reset the cache info before running the test
# describe_regions.cache_clear()
assert result.service == "sns"
assert result.region == "us-east-1"
assert result.account_id == "123456789012"
assert result.resource == "example-sns-topic-name"

# # First call should be a cache miss and populate the cache
# # Second and third calls should be cache hits.
# response = describe_regions(all_regions=False)
# response = describe_regions(all_regions=False)
# response = describe_regions(all_regions=False)
# assert response["Regions"] is not None
# assert len(response["Regions"]) == 1
# assert response["Regions"][0]["RegionName"] == "us-west-2"
def test_parse_vpc(self) -> None:
from handlers.aws.utils import parse_arn

# cache_info = describe_regions.cache_info()
arn = "arn:aws:ec2:us-east-1:123456789012:vpc/vpc-0e9801d129EXAMPLE"
result = parse_arn(arn)

# assert cache_info.hits == 2
# assert cache_info.misses == 1
# assert cache_info.currsize == 1
assert result.service == "ec2"
assert result.region == "us-east-1"
assert result.account_id == "123456789012"
assert result.resource == "vpc/vpc-0e9801d129EXAMPLE"

0 comments on commit 811567c

Please sign in to comment.