-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathapp.py
61 lines (50 loc) · 1.85 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from starlette.applications import Starlette
from starlette.responses import UJSONResponse
import gpt_2_simple as gpt2
import tensorflow as tf
import uvicorn
import os
import gc
app = Starlette(debug=False)
sess = gpt2.start_tf_sess(threads=1)
gpt2.load_gpt2(sess)
# Needed to avoid cross-domain issues
response_header = {
'Access-Control-Allow-Origin': '*'
}
generate_count = 0
@app.route('/', methods=['GET', 'POST', 'HEAD'])
async def homepage(request):
global generate_count
global sess
if request.method == 'GET':
params = request.query_params
elif request.method == 'POST':
params = await request.json()
elif request.method == 'HEAD':
return UJSONResponse({'text': ''},
headers=response_header)
text = gpt2.generate(sess,
length=int(params.get('length', 1023)),
temperature=float(params.get('temperature', 0.7)),
top_k=int(params.get('top_k', 0)),
top_p=float(params.get('top_p', 0)),
prefix=params.get('prefix', '')[:500],
truncate=params.get('truncate', None),
include_prefix=str(params.get(
'include_prefix', True)).lower() == 'true',
return_as_list=True
)[0]
generate_count += 1
if generate_count == 8:
# Reload model to prevent Graph/Session from going OOM
tf.reset_default_graph()
sess.close()
sess = gpt2.start_tf_sess(threads=1)
gpt2.load_gpt2(sess)
generate_count = 0
gc.collect()
return UJSONResponse({'text': text},
headers=response_header)
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=int(os.environ.get('PORT', 8080)))