Skip to content

Commit

Permalink
Generate partition aware STS endpoints for EKS Hook (#45725)
Browse files Browse the repository at this point in the history
Up until now the STS endpoint url used by the EKS hook (during token
generation) was hardcoded to be the commercial partition. This change
ensures the endpoint urls respect partitions and regions.
  • Loading branch information
o-nikolas authored Jan 16, 2025
1 parent 4bc37af commit 439f7b1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
13 changes: 10 additions & 3 deletions providers/src/airflow/providers/amazon/aws/hooks/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import base64
import json
import os
import sys
import tempfile
from collections.abc import Generator
Expand All @@ -32,6 +33,7 @@
from botocore.signers import RequestSigner

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.sts import StsHook
from airflow.utils import yaml
from airflow.utils.json import AirflowJsonEncoder

Expand Down Expand Up @@ -612,9 +614,14 @@ def generate_config_file(
def fetch_access_token_for_cluster(self, eks_cluster_name: str) -> str:
session = self.get_session()
service_id = self.conn.meta.service_model.service_id
sts_url = (
f"https://sts.{session.region_name}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15"
)
# This env variable is required so that we get a regionalized endpoint for STS in regions that
# otherwise default to global endpoints. The mechanism below to generate the token is very picky that
# the endpoint is regional.
os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"
try:
sts_url = f"{StsHook(region_name=session.region_name).conn_client_meta.endpoint_url}/?Action=GetCallerIdentity&Version=2011-06-15"
finally:
del os.environ["AWS_STS_REGIONAL_ENDPOINTS"]

signer = RequestSigner(
service_id=service_id,
Expand Down
4 changes: 3 additions & 1 deletion providers/tests/amazon/aws/hooks/test_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,10 +1283,12 @@ def test_generate_config_file(self, mock_conn, aws_conn_id, region_name, expecte
}

@mock.patch("airflow.providers.amazon.aws.hooks.eks.RequestSigner")
@mock.patch("airflow.providers.amazon.aws.hooks.eks.StsHook")
@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_session")
def test_fetch_access_token_for_cluster(self, mock_get_session, mock_conn, mock_signer):
def test_fetch_access_token_for_cluster(self, mock_get_session, mock_conn, mock_sts_hook, mock_signer):
mock_signer.return_value.generate_presigned_url.return_value = "http://example.com"
mock_sts_hook.return_value.conn_client_meta.endpoint_url = "https://sts.us-east-1.amazonaws.com"
mock_get_session.return_value.region_name = "us-east-1"
hook = EksHook()
token = hook.fetch_access_token_for_cluster(eks_cluster_name="test-cluster")
Expand Down

0 comments on commit 439f7b1

Please sign in to comment.