-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathclient.py
117 lines (93 loc) · 4.11 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import multiprocessing as mp
import numpy as np
import cv2
import torch
from PIL import Image
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.engine import DefaultPredictor
from detectron2.data import MetadataCatalog
from flask import Flask, request, Response, render_template, jsonify, send_file
import io
app = Flask(__name__, static_url_path='/static')
def convert_PIL_to_numpy(image, format):
if format is not None:
# PIL only supports RGB, so convert to RGB and flip channels over below
conversion_format = format
if format in ["BGR", "YUV-BT.601"]:
conversion_format = "RGB"
image = image.convert(conversion_format)
image = np.asarray(image)
# PIL squeezes out the channel dimension for "L", so make it HWC
if format == "L":
image = np.expand_dims(image, -1)
# handle formats not supported by PIL
elif format == "BGR":
# flip channels if needed
image = image[:, :, ::-1]
elif format == "YUV-BT.601":
image = image / 255.0
image = np.dot(image, np.array(_M_RGB2YUV).T)
return image
def read_image(file, format=None):
image = Image.open(file)
return convert_PIL_to_numpy(image, format)
@app.route('/health')
def health():
return "ok"
# @app.route('/')
# def main():
# return render_template('index.html')
panoptic_cfg = get_cfg()
panoptic_cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
panoptic_cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
panopticPredictor = DefaultPredictor(panoptic_cfg)
instance_cfg = get_cfg()
instance_cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
instance_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
instance_cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
instancePredictor = DefaultPredictor(instance_cfg)
keypoint_cfg = get_cfg()
keypoint_cfg.merge_from_file(model_zoo.get_config_file("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml"))
keypoint_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
keypoint_cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")
keypointPredictor = DefaultPredictor(keypoint_cfg)
@app.route('/<path>', methods=['POST'])
def predict(path):
try:
input_file = request.files['file']
if input_file.content_type not in ['image/jpeg', 'image/jpg', 'image/png']:
return jsonify({'message': 'Only support jpeg, jpg or png'}), 400
print(input_file.content_type)
if input_file.content_type == 'image/png':
input_file = Image.open(input_file).convert('RGB')
np = convert_PIL_to_numpy(input_file, None)
else:
np = read_image(input_file)
if path == 'keypoint':
cfg = keypoint_cfg
predictions = keypointPredictor(np)
elif path == 'instancesegmentation':
cfg = instance_cfg
predictions = instancePredictor(np)
elif path == 'panopticsegmentation':
cfg = panoptic_cfg
predictions = panopticPredictor(np)
else:
return jsonify({'message': 'path is not vaild'}), 400
visualizer = Visualizer(np[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.6)
instances = predictions["instances"].to('cpu')
vis_output = visualizer.draw_instance_predictions(predictions=instances)
cv2.imwrite('abc.jpg',vis_output.get_image()[:, :, ::-1])
result_image = Image.fromarray(vis_output.get_image()[:, :, ::-1])
result = io.BytesIO()
result_image.save(result, 'JPEG', quality=95)
result.seek(0)
return send_file(result, mimetype='image/jpeg')
except Exception as e:
print(e)
return jsonify({'message': 'Server error'}), 500
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
app.run(host="0.0.0.0", port=80)