forked from JiahuiYu/generative_inpainting
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathserver.py
102 lines (76 loc) · 2.79 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import cv2
import time
import os
import threading
import numpy as np
import tensorflow as tf
from PIL import Image
from io import BytesIO
from inpainting import inpaint
from queue import Queue, Empty
from inpaint_model import InpaintCAModel
from flask import Flask, request, render_template, send_file, jsonify
app = Flask(__name__, template_folder='./templates/')
# pre-load session ------------------------------------
CHECKPOINT_DIR_PLACES2 = "./model_logs/release_places2_256_deepfill_v2"
CHECKPOINT_DIR_CELEBA = "./model_logs/release_celeba_hq_256_deepfill_v2"
model = InpaintCAModel()
sess_config = tf.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.Session(config=sess_config)
g = tf.get_default_graph()
#----------------------------------------------------
requests_queue = Queue()
BATCH_SIZE = 1
CHECK_INTERVAL = 0.1
def handle_requests_by_batch():
while True:
requests_batch = []
while not (len(requests_batch) >= BATCH_SIZE):
try:
requests_batch.append(requests_queue.get(timeout=CHECK_INTERVAL))
except Empty:
continue
for request in requests_batch:
request['output'] = run(request['input'][0], request['input'][1], request['input'][2])
threading.Thread(target=handle_requests_by_batch).start()
@app.route("/", methods=["GET"])
def index():
return render_template("index.html")
@app.route("/inpainting", methods=["POST"])
def inpainting():
if requests_queue.qsize() > BATCH_SIZE:
return jsonify({'error': 'TooManyReqeusts'}), 429
image = cv2.imdecode(np.fromstring(request.files['image'].read(), np.uint8), cv2.IMREAD_COLOR)
mask = cv2.imdecode(np.fromstring(request.files['mask'].read(), np.uint8), cv2.IMREAD_COLOR)
model_name = request.form['model']
checkpoint = CHECKPOINT_DIR_PLACES2 if model_name == "places2" else CHECKPOINT_DIR_CELEBA
req = {
'input': [image, mask, checkpoint]
}
requests_queue.put(req)
while 'output' not in req:
time.sleep(CHECK_INTERVAL)
io = req['output']
if io == "error":
return jsonify({'error': 'Server error'}), 500
return send_file(io, mimetype="image/png")
def run(image, mask, checkpoint):
try:
output = inpaint(image, mask, model, checkpoint)
img = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
collections = g.get_all_collection_keys()
for name in collections:
g.clear_collection(name)
result = Image.fromarray(img)
io = BytesIO()
result.save(io,"PNG")
io.seek(0)
return io
except Exception as e:
return "error"
@app.route("/healthz", methods=["GET"])
def checkHealth():
return "ok", 200
if __name__ == "__main__":
app.run(host='0.0.0.0', port=80, threaded=True)