-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
130 lines (110 loc) · 4.96 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import string
import random
import asyncio
import redis
SSLRequestCode = b'\x04\xd2\x16\x2f' # == hex(80877103)
StartupMessageCode = b'\x00\x03\x00\x00' # == hex(196608)
NoSSL = b'\x4E' # == 'N'
AuthenticationOk = b'\x52\x00\x00\x00\x08\x00\x00\x00\x00'
AuthenticationCleartextPassword = b'\x52\x00\x00\x00\x08\x00\x00\x00\x03'
ReadyForQuery = b'\x5A\x00\x00\x00\x05\x49' # == Z0005I , the last I stand for Idle
Query = ord('Q')
def random_stream(size=6, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def CommandComplete(tag):
lenghtTag = len(tag)
lengthMessage = lenghtTag + 1 + 4 # one for the \00 and 4 for the Int32
# assuming that tag is smaller than 256 bytes
bytesLenght = b'\x00\x00\x00' + bytes([lengthMessage])
bytesBody = bytes(tag, "utf-8") + b'\x00'
return b'\x43' + bytesLenght + bytesBody
def RowDescription(rows):
body = bytes(0)
for rowType, rowName in rows:
fieldName = bytes(rowName, "utf-8") + b'\x00'
tableId = bytes(4)
columnId = bytes(2)
dataTypeId = bytes(4)
if rowType == "int":
dataSize = bytes([0, 8])
if rowType == "string":
dataSize = bytes([255, 255])
typeModifier = bytes(4)
formatCode = (0).to_bytes(2, byteorder="big")
body += fieldName + tableId + columnId + dataTypeId + dataSize + typeModifier + formatCode
totalLen = len(body) + 4 + 2
totalLenBytes = totalLen.to_bytes(4, byteorder="big")
totalFieldsBytes = len(rows).to_bytes(2, byteorder="big")
return bytes([ord('T')]) + totalLenBytes + totalFieldsBytes + body
def DataRow(row):
body = bytes(0)
for fieldType, fieldValue in row.items():
typeField, nameField = fieldType.decode("utf-8").split(":")
if typeField == "int":
value = fieldValue.decode()
lenght = (len(value) + 1).to_bytes(4, byteorder="big")
#valueBytes = value.to_bytes(8, byteorder="big")
valueBytes = bytes(value, "utf-8") + b'\x00'
body += lenght + valueBytes
totalLen = len(body) + 4 + 2
totalLenBytes = totalLen.to_bytes(4, byteorder="big")
totalFieldsBytes = len(row).to_bytes(2, byteorder="big")
return bytes([ord('D')]) + totalLenBytes + totalFieldsBytes + body
class PostgresProtocol(asyncio.Protocol):
def __init__(self):
self.redis = redis.Redis()
self.state = "initial"
self.db = "DB"
def _execute_query(self, query):
firstToken = query.split(' ')[0]
if firstToken.upper() == "INSERT":
result = self.redis.execute_command("REDISQL.EXEC", self.db , query)
numberInserted = result[1]
self.transport.write(CommandComplete("INSERT 0 " + str(numberInserted)))
elif firstToken.upper() == "SELECT":
stream = random_stream()
result = self.redis.execute_command("REDISQL.QUERY.INTO", stream, self.db, query)
streamResult = self.redis.execute_command("XREAD", "COUNT", "1", "STREAMS", stream, "0")
firstRow = streamResult[0][1][0][1]
rows = []
for key, _ in firstRow.items():
rowType, rowName = key.decode("utf-8").split(':')
rowType, rowName = rowType.strip(), rowName.strip()
rows.append((rowType, rowName,))
self.transport.write(RowDescription(rows))
returnedRows = self.redis.xread({stream: "0"})[0][1]
for _, row in returnedRows:
self.transport.write(DataRow(row))
self.transport.write(CommandComplete("SELECT"))
else:
self.redis.execute_command("REDISQL.EXEC", self.db , query)
self.transport.write(CommandComplete(firstToken))
self.transport.write(ReadyForQuery)
def _reply(self, data):
if self.state == "initial" and data[4:8] == SSLRequestCode:
self.transport.write(NoSSL)
elif self.state == "initial" and data[4:8] == StartupMessageCode:
# we don't require a password
self.transport.write(AuthenticationOk)
# good to go for the first query!
self.transport.write(ReadyForQuery)
self.state = "readyForQuery"
if self.state == "readyForQuery" and data[0] == Query:
lenght = int.from_bytes(data[1:5], "big")
strLenght = lenght - 4
query = data[5:-1].decode("utf-8")
result = self._execute_query(query)
return
def connection_made(self, transport):
print('New Connection Made')
self.transport = transport
def data_received(self, data):
print('Data received: {!r}'.format(bytearray(data)))
self._reply(data)
def connection_lost(self, exc):
print('The server closed the connection')
print('Stop the event loop')
loop = asyncio.get_event_loop()
coro = loop.create_server(PostgresProtocol, '127.0.0.1', 8888)
server = loop.run_until_complete(coro)
loop.run_forever()