-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathuser-data.py
160 lines (115 loc) · 4.73 KB
/
user-data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Template script to be used as user-data for EC2 instances."""
import boto3
import json
import subprocess # nosec (remove bandit warning)
import re
from decimal import Decimal
from typing import List, Tuple
with open('/sqs-queue', encoding='utf-8') as sqs_queue_file:
QUEUE_URL = sqs_queue_file.read().strip()
with open('/control-sqs-queue', encoding='utf-8') as sqs_queue_file:
CONTROL_QUEUE_URL = sqs_queue_file.read().strip()
with open('/dynamodb-write-table', encoding='utf-8') as dynamodb_write_table_file:
WRITE_TABLE_NAME = dynamodb_write_table_file.read().strip()
with open('/dynamodb-read-table', encoding='utf-8') as dynamodb_read_table_file:
READ_TABLE_NAME = dynamodb_read_table_file.read().strip()
with open('/region', encoding='utf-8') as region_file:
REGION_NAME = region_file.read().strip()
with open('/az', encoding='utf-8') as region_file:
AZ_NAME = region_file.read().strip()
def test_bandwidth(server_ip: str, port: int = 5201) -> Decimal:
"""Test network bandwidth to a host."""
command = ['iperf3', '-c', server_ip, '-p', str(port), '-J']
result = subprocess.run(command, capture_output=True, text=True)
if result.returncode != 0:
raise ValueError(f'Error running iperf3: {result.stderr}')
data = json.loads(result.stdout)
bandwidth_bps = data['end']['sum_received']['bits_per_second']
# Convert to Gb/s
bandwidth_gbps = Decimal(bandwidth_bps) / Decimal(10**9)
return bandwidth_gbps
def test_network_latency(hostname: str) -> Decimal:
"""Test network latency to a host."""
# Run the ping command
command = ['ping', '-c', '100', '-i', '0.1', hostname]
result = subprocess.run(command, stdout=subprocess.PIPE, text=True, check=True) # nosec (remove bandit warning)
time_values = re.findall(r'time=([\d.]+)', result.stdout)
avg_time = sum(map(Decimal, time_values)) / Decimal(len(time_values))
return avg_time
def write_to_dynamodb(az_name: str, network_latency: Decimal, bandwidth: Decimal) -> None:
"""Write to DynamoDB."""
dynamodb = boto3.resource('dynamodb',
region_name='us-east-1'
)
table = dynamodb.Table(WRITE_TABLE_NAME)
item = {
'availability_zone_from': AZ_NAME,
'availability_zone_to': az_name,
'network_latency_ms': network_latency,
'bandwidth_gbps': bandwidth,
}
table.put_item(Item=item)
def read_from_dynamodb_table() -> Tuple[List[str], List[str], str]:
"""Read from the DynamoDB table."""
dynamodb = boto3.resource('dynamodb',
region_name='us-east-1'
)
table = dynamodb.Table(READ_TABLE_NAME)
response = table.query(
KeyConditionExpression=boto3.dynamodb.conditions.Key('availability_zone').eq(AZ_NAME)
)
azs = response.get('Items')[0].get('azs').split(',')
pairs = [az.split(':') for az in azs]
ips, az_names = zip(*pairs)
next_az_queue = response.get('Items')[0].get('next_az_queue')
return ips, az_names, next_az_queue
def trigger_next_az(next_az_queue: str) -> None:
"""Trigger the next AZ."""
sqs = boto3.client('sqs',
region_name=REGION_NAME
)
sqs.send_message(
QueueUrl=next_az_queue,
MessageBody='Go'
)
def trigger_done() -> None:
"""Tell the control queue the region is done."""
sqs = boto3.client('sqs',
region_name='us-east-1'
)
sqs.send_message(
QueueUrl=CONTROL_QUEUE_URL,
MessageBody=f'DONE - {REGION_NAME}'
)
def poll_sqs_queue() -> None:
"""Poll the SQS queue for messages."""
sqs = boto3.client('sqs',
region_name=REGION_NAME
)
while True:
messages = sqs.receive_message(
QueueUrl=QUEUE_URL,
MaxNumberOfMessages=1,
WaitTimeSeconds=20,
)
if 'Messages' in messages:
message = messages['Messages'][0]
receipt_handle = message['ReceiptHandle']
# Delete message from queue
sqs.delete_message(
QueueUrl=QUEUE_URL,
ReceiptHandle=receipt_handle
)
assert message['Body'] == 'Go' # nosec (remove bandit warning)
break
if __name__ == '__main__':
poll_sqs_queue()
az_ips, az_names, next_az_queue = read_from_dynamodb_table()
for ip, az_name in zip(az_ips, az_names):
network_latency = test_network_latency(ip)
bandwidth = test_bandwidth(ip)
write_to_dynamodb(az_name, network_latency, bandwidth)
if next_az_queue == 'DONE':
trigger_done()
else:
trigger_next_az(next_az_queue)