-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathTrainSSD.py
183 lines (150 loc) · 7.71 KB
/
TrainSSD.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import cv2
import numpy as np
import pickle
import keras
from keras.applications.imagenet_utils import preprocess_input
from keras.preprocessing import image
from rely.ssd_training import MultiboxLoss
from rely.ssd_v2 import SSD300v2
from rely.ssd_utils import BBoxUtility
from rely.generate import Generator
class TrainSSD(object):
def __init__(self, weights_path='', Label_Data_path='', ImgDir_path='', PriorBoxes_path='', classes_name=None):
# 模型建立初始化
self.voc_classes = classes_name
self.NUM_CLASSES = len(self.voc_classes) + 1
self.input_shape = (300, 300, 3)
self.weights_path = weights_path
self.bbox_util = BBoxUtility(self.NUM_CLASSES)
# 模型训练初始化
self.XML_Data_path = Label_Data_path # 前处理文件路径
self.ImgDir_path = ImgDir_path # 图片集路径
self.PriorBoxes_path = PriorBoxes_path # 预设边框文件路径
# 建立模型
def SD_fn_BuildSSD(self):
self.model = SSD300v2(self.input_shape, num_classes=self.NUM_CLASSES)
self.model.load_weights(self.weights_path, by_name=True)
# 训练模型
def SD_fn_TrainSSD(self):
# 训练准备
priors = pickle.load(open(self.PriorBoxes_path, 'rb'))
bbox_util = BBoxUtility(self.NUM_CLASSES, priors)
gt = pickle.load(open(self.XML_Data_path, 'rb'))
keys = sorted(gt.keys())
num_train = int(round(0.8 * len(keys)))
train_keys = keys[:num_train]
val_keys = keys[num_train:]
num_val = len(val_keys)
# 实例化训练数据迭代器
gen = Generator(gt=gt,
bbox_util=bbox_util,
batch_size=16,
path_prefix=self.ImgDir_path,
train_keys=train_keys, val_keys=val_keys,
image_size=(self.input_shape[0], self.input_shape[1]),
do_crop=False)
# 冻结相关层
'''
freeze = ['input_1', 'conv1_1', 'conv1_2', 'pool1',
'conv2_1', 'conv2_2', 'pool2',
'conv3_1', 'conv3_2', 'conv3_3', 'pool3'] # ,
# 'conv4_1', 'conv4_2', 'conv4_3', 'pool4']
'''
freeze = ['input_1', 'conv1_1', 'conv1_2', 'pool1',
'conv2_1', 'conv2_2', 'pool2',
'conv3_1', 'conv3_2', 'conv3_3', 'pool3',
'conv4_1', 'conv4_2', 'conv4_3', 'pool4', 'conv4_3_norm',
'conv5_1', 'conv5_2', 'conv5_3', 'pool5', 'fc6', 'fc7',
'conv6_1', 'conv6_2',
'conv7_1', 'conv7_1z', 'conv7_2',
'conv8_1',
'pool6'
]
for L in self.model.layers:
if L.name in freeze:
L.trainable = False
self.base_lr = 3e-4 # 定义学习率
# Tensorboard
tensorboard = keras.callbacks.TensorBoard(log_dir='./logs',
histogram_freq=0,
batch_size=8,
write_graph=True,
write_grads=True,
write_images=False,
embeddings_freq=0,)
# 定义回合
callbacks = [keras.callbacks.ModelCheckpoint('./weights/weights.{epoch:02d}-{val_loss:.4f}.hdf5',
verbose=1,
save_weights_only=True),
keras.callbacks.LearningRateScheduler(self.schedule),
tensorboard]
# 配置训练
optim = keras.optimizers.Adam(lr=self.base_lr)
self.model.compile(optimizer=optim,
loss=MultiboxLoss(self.NUM_CLASSES, neg_pos_ratio=2.0).compute_loss)
nb_epoch = 100 # 训练回合数
history = self.model.fit_generator(gen.generate(True), # 训练数据迭代器
gen.train_batches, # 一个回合 epoch 中的步数
nb_epoch, # 训练回合数
verbose=1, # 为1表示不在标准输出流中输出日志信息
# #为1表示数据条标准输出显示训练进度
# #为2表示每个回合结束后输出一次训练进度
callbacks=callbacks, # 回调函数
validation_data=gen.generate(False), # 验证集
nb_val_samples=gen.val_batches,
nb_worker=1)
def schedule(self, epoch, decay=0.9):
return self.base_lr * decay ** (epoch)
# 模型预测
def SD_fn_Predict(self, img, min_score=0.6):
inputs = cv2.resize(img, (300, 300))
inputs = image.img_to_array(inputs)
inputs = np.expand_dims(inputs, axis=0)
inputs = preprocess_input(inputs)
preds = self.model.predict(inputs, batch_size=1, verbose=0) # verbose = 1 显示耗时
results = self.bbox_util.detection_out(preds) # 非最大抑制
h, w = img.shape[:2]
preds = []
det_label = results[0][:, 0] # 类别索引
det_conf = results[0][:, 1] # 概率
det_xmin = results[0][:, 2] # 坐标
det_ymin = results[0][:, 3] # 坐标
det_xmax = results[0][:, 4] # 坐标
det_ymax = results[0][:, 5] # 坐标
# Get detections with confidence higher than 0.6.
top_indices = [i for i, conf in enumerate(det_conf) if conf >= min_score]
top_conf = det_conf[top_indices]
top_label_indices = det_label[top_indices].tolist()
top_xmin = det_xmin[top_indices]
top_ymin = det_ymin[top_indices]
top_xmax = det_xmax[top_indices]
top_ymax = det_ymax[top_indices]
for i in range(top_conf.shape[0]):
xmin = int(round(top_xmin[i] * w))
ymin = int(round(top_ymin[i] * h))
xmax = int(round(top_xmax[i] * w))
ymax = int(round(top_ymax[i] * h))
score = top_conf[i]
label = int(top_label_indices[i]) - 1
preds.append((label, score, xmin, ymin, xmax, ymax))
return preds
def GetPosition(self, preds):
pos = []
for lab, score, xmin, ymin, xmax, ymax in preds:
if self.voc_classes[lab] == 'Target':
pos.append((xmin, ymin, xmax, ymax))
return pos
if __name__ == '__main__':
import os, sys
work_space = os.path.split(sys.argv[0])[0]
os.chdir(work_space)
# 标记数据集中的类别名称
# 需要和教程4.1中的 name.txt 文件中的类别名称一致 包括顺序和字符一致
name_calss = ['Target', 'person', 'cup', 'fan']
my_ssd = TrainSSD(weights_path='./weight/weights_SSD300.hdf5', # 权值文件,后期可更改为训练好的权值文件,以实继续训练
Label_Data_path='my_new_data.pkl', # 教程5.1中生成的标记数据文件,包含标注信息
ImgDir_path='D:\my_data\JPEGImages\\', # 数据集图片保存位置--文件夹
PriorBoxes_path='prior_boxes_ssd300.pkl', # 这里是预设框文件,不用改
classes_name=name_calss)
my_ssd.SD_fn_BuildSSD()
my_ssd.SD_fn_TrainSSD()