-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassifier.py
executable file
·240 lines (204 loc) · 8.28 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#!/usr/bin/env python3
#import matplotlib.pyplot as plt
import argparse as ap
import numpy as np
import cv2 as cv
import os
import sys
class CustomFormatter(ap.HelpFormatter):
def _format_action_invocation(self, action):
if not action.option_strings:
metavar, = self._metavar_formatter(action, action.dest)(1)
return metavar
else:
parts = []
if action.nargs == 0:
parts.extend(action.option_strings)
else:
default = action.dest.upper()
args_string = self._format_args(action, default)
for option_string in action.option_strings:
#parts.append('%s %s' % (option_string, args_string))
parts.append('%s' % option_string)
parts[-1] += ' %s'%args_string
return ', '.join(parts)
# Parser Arguments
parser = ap.ArgumentParser(description='Cascade Classifier', formatter_class=CustomFormatter)
parser.add_argument("-s", "--save", metavar='', help="specify output name")
parser.add_argument("-c", "--cas", metavar='', help="specify specific trained cascade", default="./stage_outputs/cascade.xml")
parser.add_argument("-i", "--img", metavar='', help="specify image to be classified")
parser.add_argument("-d", "--dir", metavar='', help="specify directory of images to be classified")
parser.add_argument("-v", "--vid", metavar='', help="specify video to be classified")
parser.add_argument("-w", "--cam", metavar='', help="enable camera access for classification")
parser.add_argument("-f", "--fps", help="enable frames text (TODO)", action="store_true")
parser.add_argument("-o", "--circle", help="enable circle detection", action="store_true")
parser.add_argument("-z", "--scale", metavar='', help="decrease video scale by scale factor", type=int, default=1)
parser.add_argument("-t", "--track", metavar='', help="select tracking algorithm [KCF, CSRT, MEDIANFLOW]", choices=['KCF', 'CSRT', 'MEDIANFLOW'])
args = parser.parse_args(sys.argv[1:])
# Load the trained cascade
cascade = cv.CascadeClassifier()
if not cascade.load(args.cas):
print("Can't find cascade file. Do you have the directory ./stage_outputs/cascade.xml")
exit(0)
def plot():
pass
def detect_circles(src):
img = src
img_gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
img_blur = cv.medianBlur(img_gray, 5)
rows = img_blur.shape[0]
#Images circles = cv.HoughCircles(img_blur, cv.HOUGH_GRADIENT, 1, rows / 3, param1=100, param2=40, maxRadius=40)
circles = cv.HoughCircles(img_blur, cv.HOUGH_GRADIENT, 1, rows/3, param1=100, param2=15, minRadius=10, maxRadius=15)
if circles is not None:
circles = np.uint16(np.around(circles))
for i in circles[0, :]:
center = (i[0], i[1])
# circle center
cv.circle(img, center, 1, (0, 100, 100), 3)
# circle outline
radius = i[2]
cv.circle(img, center, radius, (255, 0, 255), 3)
return img
def choose_tracker():
OPENCV_TRACKERS = {
'KCF': cv.TrackerKCF_create(),
'CSRT': cv.TrackerCSRT_create(),
'MEDIANFLOW': cv.TrackerMedianFlow_create()
}
tracker = OPENCV_TRACKERS[args.track]
return tracker
def tracking(vid, tracker):
ok, frame = vid.read()
frame = scale(frame, args.scale)
ok, roi = tracker.update(frame)
if ok:
p1 = (int(roi[0]), int(roi[1]))
p2 = (int(roi[0] + roi[2]), int(roi[1] + roi[3]))
cv.rectangle(frame, p1, p2, (0,255,0), 2, 1)
cpoint_circle = cv.circle(frame, (int(roi[0]+(roi[2]/2)), int(roi[1]+(roi[3]/2))), 3, (0,255,0), 3)
return frame
def save(frame):
# Need dimensions of frame to determine proper video output
fourcc = cv.VideoWriter_fourcc(*'XVID')
height, width, channels = frame.shape
out = cv.VideoWriter(args.save + '.avi', fourcc, 30.0, (width, height))
return out
def get_roi(frame):
# Get initial bounding box by running cascade detection on first frame
frame_gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY)
frame_gray = cv.GaussianBlur(frame_gray, (3, 3), 0)
cas_object = cascade.detectMultiScale(frame_gray)
if len(cas_object) == 0:
return []
roi = (cas_object[0][0], cas_object[0][1], cas_object[0][2], cas_object[0][3])
return roi
def get_cascade(frame):
frame_gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY)
frame_gray = cv.GaussianBlur(frame_gray, (3, 3), 0)
cas_object = cascade.detectMultiScale(frame_gray)
for (x, y, w, h) in cas_object:
cv.rectangle(frame, (x,y), (x+w, y+h), (0,0,255), 2)
cpoint_circle = cv.circle(frame, (int(x+(w/2)), int(y+(h/2))), 3, (0,0,255), 3)
return frame
def scale(frame, scale_factor):
height, width, channels = frame.shape
scaled_height = int(height/scale_factor)
scaled_width = int(width/scale_factor)
resized_frame = cv.resize(frame, (scaled_width, scaled_height))
return resized_frame
def img_classifier():
# Read image, convert to gray, equalize histogram, and detect.
img = cv.imread(args.img, cv.IMREAD_COLOR)
img_gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
#img_gray = cv.equalizeHist(img_gray)
cas_object = cascade.detectMultiScale(img_gray)
for (x, y, w, h) in cas_object:
roi = cv.rectangle(img, (x,y), (x+w, y+h), (0,0,255), 2)
cpoint_circle = cv.circle(img, (int(x+(w/2)), int(y+(h/2))), 3, (0,0,255), 3)
if args.circle is True:
roi = img[y:y+h, x:x+w]
img = detect_circles(roi)
cv.imshow('image', img)
cv.waitKey(0)
cv.destroyAllWindows()
def dir_classifier():
imgs = []
for filename in os.listdir(args.dir):
img = cv.imread(os.path.join(args.dir, filename))
if img is not None:
img_gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
cas_object = cascade.detectMultiScale(img_gray)
for (x, y, w, h) in cas_object:
cv.rectangle(img, (x,y), (x+w, y+h), (0,0,255), 2)
cpoint_circle = cv.circle(img, (int(x+(w/2)), int(y+(h/2))), 3, (0,0,255), 3)
if args.circle is True:
roi = img[y:y+h, x:x+w]
img = detect_circles(roi)
cv.imshow(str(filename), img)
cv.waitKey(0)
imgs.append(img)
#print(imgs)
#return imgs
def vid_classifier():
vid = cv.VideoCapture(args.vid)
if not vid.isOpened():
print("Could not open video")
sys.exit()
# Read the first frame
_ , frame = vid.read()
frame = scale(frame, args.scale)
if not _:
print("Cannot read video file")
sys.exit()
if args.save is not None and _ is True:
out = save(frame=frame)
if args.track is not None and _ is True:
cas_roi = get_roi(frame)
while not cas_roi:
_, frame = vid.read()
frame = scale(frame, args.scale)
cas_roi = get_roi(frame)
tracker = choose_tracker()
tracker.init(frame, cas_roi)
while(vid.isOpened()):
_ , frame = vid.read()
frame = scale(frame, args.scale)
frame = get_cascade(frame)
if args.track is not None:
frame = tracking(vid=vid, tracker=tracker)
if args.circle is True:
roi = get_roi(frame)
roi_circle = frame[int(roi[1]):int(roi[1] + roi[3]), int(roi[0]):int(roi[0] + roi[2])]
frame = detect_circles(roi_circle)
cv.imshow('video', frame)
if args.save is not None:
out.write(frame)
if cv.waitKey(1) & 0xFF == ord('q'):
break
if args.save is not None:
out.release()
vid.release()
cv.destroyAllWindows()
def cam_classifier():
cam = cv.VideoCapture(0)
if not cam.isOpened():
raise IOError("Cannot access camera")
while(cam.isOpened()):
_, frame = cap.read()
frame = get_cascade(frame)
cv2.imshow('camera', frame)
if cv.waitKey(10) & 0xFF == ord('q'):
break
cam.release()
cv.destroyAllWindows()
if __name__ == "__main__":
if args.img is not None:
img_classifier()
elif args.vid is not None:
vid_classifier()
elif args.dir is not None:
dir_classifier()
elif args.cam is not None:
cam_clasifier()
else:
parser.print_help()