-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_tests.py
131 lines (108 loc) · 3.82 KB
/
generate_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import ast
import astunparse
module_translations = {
'requests': 'async_requests'
}
class_translations = {
'HTTPAdapter': 'AsyncHTTPAdapter',
'HTTPDigestAuth': 'AsyncHTTPDigestAuth',
'Response': 'AsyncResponse',
'Session': 'AsyncSession',
}
tests_to_skip = [
'test_response_is_iterable'
]
methods_to_yield_from = ['get', 'head', 'put', 'patch', 'post', 'send', 'request']
assign_template = """
task = asyncio.Task({rhs})
loop = asyncio.get_event_loop()
{lhs} = loop.run_until_complete(task)
""".strip()
expr_template = """
task = asyncio.Task({expr})
loop = asyncio.get_event_loop()
loop.run_until_complete(task)
""".strip()
class TestRequestsTransformer(ast.NodeTransformer):
def __init__(self):
self.imported_asyncio = False
def visit_Import(self, node):
name = node.names[0].name
if name in module_translations:
return ast.Import([ast.alias(module_translations[name], None)])
else:
if not self.imported_asyncio:
self.imported_asyncio = True
return [ast.Import([ast.alias('asyncio', None)]), node]
else:
return node
def visit_ImportFrom(self, node):
path = node.module.split('.')
if path[0] in module_translations:
module = '.'.join([module_translations[path[0]]] + path[1:])
else:
module = node.module
names = []
for alias in node.names:
if alias.name in class_translations:
names.append(ast.alias(class_translations[alias.name], alias.asname))
else:
names.append(alias)
return ast.ImportFrom(module, names, node.level)
def visit_Name(self, node):
if node.id in module_translations:
return ast.Name(module_translations[node.id], node.ctx)
elif node.id in class_translations:
return ast.Name(class_translations[node.id], node.ctx)
else:
return node
def visit_Attribute(self, node):
self.generic_visit(node)
if node.attr in class_translations:
return ast.Attribute(node.value, class_translations[node.attr], node.ctx)
else:
return node
def visit_Expr(self, node):
self.generic_visit(node)
try:
if node.value.func.attr in methods_to_yield_from:
expr = astunparse.unparse(node)
new_source = expr_template.format(expr=expr)
new_node = ast.parse(new_source)
return new_node
else:
return node
except AttributeError:
return node
def visit_Assign(self, node):
self.generic_visit(node)
try:
if node.value.func.attr in methods_to_yield_from and \
astunparse.unparse(node.value.func).strip() != 'os.environ.get':
lhs = node.targets[0].id
rhs = astunparse.unparse(node.value).strip()
new_source = assign_template.format(lhs=lhs, rhs=rhs)
new_node = ast.parse(new_source)
return new_node
else:
return node
except AttributeError:
return node
def visit_FunctionDef(self, node):
if node.name in tests_to_skip:
return None
else:
self.generic_visit(node)
return node
if __name__ == '__main__':
with open('test_requests.py') as f:
source = f.read()
tree = ast.parse(source)
TestRequestsTransformer().visit(tree)
with open('test_async_requests.py', 'w') as f:
f.write(astunparse.unparse(tree))
print('Generated test_async_requests.py')
print('Manual tweaks required for:')
print(' test_basicauth_with_netrc')
print(' test_prepared_request_hook')
print(' test_DIGEST_STREAM')