Skip to content

Commit

Permalink
implement sts_token_buffer_time attribute for transport_options to up…
Browse files Browse the repository at this point in the history
…date token earlier than expiration time
  • Loading branch information
eisichenko committed Jan 5, 2025
1 parent 4c64cdd commit b45c6c8
Showing 1 changed file with 37 additions and 29 deletions.
66 changes: 37 additions & 29 deletions kombu/transport/SQS.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@
},
}
'sts_role_arn': 'arn:aws:iam::<xxx>:role/STSTest', # optional
'sts_token_timeout': 900 # optional
'sts_token_timeout': 900, # optional
'sts_token_buffer_time': 60 # optional
}
Note that FIFO and standard queues must be named accordingly (the name of
Expand All @@ -91,6 +92,9 @@
sts_token_timeout. sts_role_arn is the assumed IAM role ARN we are trying
to access with. sts_token_timeout is the token timeout, defaults (and minimum)
to 900 seconds. After the mentioned period, a new token will be created.
sts_token_buffer_time (seconds) is the time by which you want to refresh your token
earlier than its actual expiration time, defaults to 0 (no time buffer will be added),
should be less than sts_token_timeout.
Expand Down Expand Up @@ -136,7 +140,7 @@
import socket
import string
import uuid
from datetime import datetime
from datetime import datetime, timedelta
from queue import Empty

from botocore.client import Config
Expand Down Expand Up @@ -765,34 +769,38 @@ def sqs(self, queue=None):
)
return c

def _refresh_sqs_client(self, queue, q):
sts_creds = self.generate_sts_session_token_with_buffer(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900),
self.transport_options.get('sts_token_buffer_time', 0),
)
self.sts_expiration = sts_creds['Expiration']
self._predefined_queue_clients[queue] = self.new_sqs_client(
region=q.get('region', self.region),
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return self._predefined_queue_clients[queue]

def _handle_sts_session(self, queue, q):
if not hasattr(self, 'sts_expiration'): # STS token - token init
sts_creds = self.generate_sts_session_token(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900))
self.sts_expiration = sts_creds['Expiration']
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
region=q.get('region', self.region),
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return c
# STS token - refresh if expired
elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow():
sts_creds = self.generate_sts_session_token(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900))
self.sts_expiration = sts_creds['Expiration']
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
region=q.get('region', self.region),
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return c
else: # STS token - ruse existing
return self._predefined_queue_clients[queue]
"""
Refreshes the SQS client with a new token on STS token initialization
or expiration. Otherwise, using cached client.
"""
if (
not hasattr(self, 'sts_expiration') or
self.sts_expiration.replace(tzinfo=None) < datetime.utcnow()
):
return self._refresh_sqs_client(queue, q)
return self._predefined_queue_clients[queue]

def generate_sts_session_token_with_buffer(self, role_arn, token_expiry_seconds, token_buffer_seconds=0):
credentials = self.generate_sts_session_token(role_arn, token_expiry_seconds)
if token_buffer_seconds and token_buffer_seconds < token_expiry_seconds:
credentials["Expiration"] -= timedelta(seconds=token_buffer_seconds)
return credentials

def generate_sts_session_token(self, role_arn, token_expiry_seconds):
sts_client = boto3.client('sts')
Expand Down

0 comments on commit b45c6c8

Please sign in to comment.