diff --git a/documentation/tf/odas.png b/documentation/tf/odas.png index cc3ac9df..1d86eeb2 100644 Binary files a/documentation/tf/odas.png and b/documentation/tf/odas.png differ diff --git a/documentation/tf/tf.drawio b/documentation/tf/tf.drawio index 8bf61d08..308a867d 100644 --- a/documentation/tf/tf.drawio +++ b/documentation/tf/tf.drawio @@ -1,6 +1,6 @@ - + - + @@ -264,12 +264,12 @@ - - + + - + diff --git a/ros/demos/control_panel/src/control_panel_node.cpp b/ros/demos/control_panel/src/control_panel_node.cpp index b17a2742..3ef46328 100644 --- a/ros/demos/control_panel/src/control_panel_node.cpp +++ b/ros/demos/control_panel/src/control_panel_node.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -53,7 +54,8 @@ int startNode(int argc, char* argv[]) } auto solver = make_unique(); - HbbaLite hbba(desireSet, move(strategies), {{"motor", 1}, {"sound", 1}, {"led", 1}}, move(solver)); + auto strategyStateLogger = make_unique(nodeHandle); + HbbaLite hbba(desireSet, move(strategies), {{"motor", 1}, {"sound", 1}, {"led", 1}}, move(solver), move(strategyStateLogger)); QApplication application(argc, argv); ControlPanel controlPanel(nodeHandle, desireSet, camera2dWideEnabled); diff --git a/ros/demos/home_logger/src/home_logger_node.cpp b/ros/demos/home_logger/src/home_logger_node.cpp index 57b060c3..2b278fc1 100644 --- a/ros/demos/home_logger/src/home_logger_node.cpp +++ b/ros/demos/home_logger/src/home_logger_node.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include @@ -93,7 +94,8 @@ void startNode( strategies.emplace_back(createPlaySoundStrategy(filterPool, desireSet, nodeHandle)); auto solver = make_unique(); - HbbaLite hbba(desireSet, move(strategies), {{"motor", 1}, {"sound", 1}, {"led", 1}}, move(solver)); + auto strategyStateLogger = make_unique(nodeHandle); + HbbaLite hbba(desireSet, move(strategies), {{"motor", 1}, {"sound", 1}, {"led", 1}}, move(solver), move(strategyStateLogger)); VolumeManager volumeManager(nodeHandle); SQLite::Database database(databasePath, SQLite::OPEN_READWRITE | SQLite::OPEN_CREATE); diff --git a/ros/demos/smart_speaker/src/smart_speaker_rss_node.cpp b/ros/demos/smart_speaker/src/smart_speaker_rss_node.cpp index 335cb4cc..7fea2cdd 100644 --- a/ros/demos/smart_speaker/src/smart_speaker_rss_node.cpp +++ b/ros/demos/smart_speaker/src/smart_speaker_rss_node.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include @@ -59,7 +60,8 @@ void startNode( strategies.emplace_back(createPlaySoundStrategy(filterPool, desireSet, nodeHandle)); auto solver = make_unique(); - HbbaLite hbba(desireSet, move(strategies), {{"motor", 1}, {"sound", 1}, {"led", 1}}, move(solver)); + auto strategyStateLogger = make_unique(nodeHandle); + HbbaLite hbba(desireSet, move(strategies), {{"motor", 1}, {"sound", 1}, {"led", 1}}, move(solver), move(strategyStateLogger)); StateManager stateManager; type_index idleStateType(typeid(RssIdleState)); diff --git a/ros/demos/smart_speaker/src/smart_speaker_smart_node.cpp b/ros/demos/smart_speaker/src/smart_speaker_smart_node.cpp index 20f642e9..dc5b5af4 100644 --- a/ros/demos/smart_speaker/src/smart_speaker_smart_node.cpp +++ b/ros/demos/smart_speaker/src/smart_speaker_smart_node.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include @@ -61,7 +62,8 @@ void startNode( strategies.emplace_back(createPlaySoundStrategy(filterPool, desireSet, nodeHandle)); auto solver = make_unique(); - HbbaLite hbba(desireSet, move(strategies), {{"motor", 1}, {"sound", 1}, {"led", 1}}, move(solver)); + auto strategyStateLogger = make_unique(nodeHandle); + HbbaLite hbba(desireSet, move(strategies), {{"motor", 1}, {"sound", 1}, {"led", 1}}, move(solver), move(strategyStateLogger)); StateManager stateManager; type_index askOtherTaskStateType(typeid(SmartAskOtherTaskState)); diff --git a/ros/perceptions/audio_analyzer/msg/AudioAnalysis.msg b/ros/perceptions/audio_analyzer/msg/AudioAnalysis.msg index 861656db..32347159 100644 --- a/ros/perceptions/audio_analyzer/msg/AudioAnalysis.msg +++ b/ros/perceptions/audio_analyzer/msg/AudioAnalysis.msg @@ -1,5 +1,6 @@ std_msgs/Header header +int64 tracking_id audio_utils/AudioFrame audio_frame string[] audio_classes diff --git a/ros/perceptions/audio_analyzer/scripts/audio_analyzer_node.py b/ros/perceptions/audio_analyzer/scripts/audio_analyzer_node.py index 85fc68f1..7140d1ca 100755 --- a/ros/perceptions/audio_analyzer/scripts/audio_analyzer_node.py +++ b/ros/perceptions/audio_analyzer/scripts/audio_analyzer_node.py @@ -86,7 +86,7 @@ def _audio_cb(self, msg): self._audio_analysis_count += audio_frame.shape[0] def _analyse(self): - audio_buffer = self._get_audio_buffer() + audio_buffer, sst_id = self._get_audio_buffer_and_sst_id() audio_descriptor_buffer = audio_buffer[-self._audio_descriptor_extractor.get_supported_duration():] audio_descriptor, audio_class_probabilities = self._audio_descriptor_extractor(audio_descriptor_buffer) audio_descriptor = audio_descriptor.tolist() @@ -99,21 +99,22 @@ def _analyse(self): voice_descriptor = [] audio_classes = self._get_audio_classes(audio_class_probabilities) - self._publish_audio_analysis(audio_buffer, audio_classes, audio_descriptor, voice_descriptor) + self._publish_audio_analysis(sst_id, audio_buffer, audio_classes, audio_descriptor, voice_descriptor) - def _get_audio_buffer(self): + def _get_audio_buffer_and_sst_id(self): with self._audio_frames_lock: + sst_id = self._sst_id audio_buffer = torch.cat(self._audio_frames, dim=0) if audio_buffer.size()[0] < self._audio_buffer_duration: - return torch.cat([torch.zeros(self._audio_buffer_duration - audio_buffer.size()[0]), audio_buffer], dim=0) + return torch.cat([torch.zeros(self._audio_buffer_duration - audio_buffer.size()[0]), audio_buffer], dim=0), sst_id else: - return audio_buffer[-self._audio_buffer_duration:] + return audio_buffer[-self._audio_buffer_duration:], sst_id def _get_audio_classes(self, audio_class_probabilities): return [self._class_names[i] for i in range(len(self._class_names)) if audio_class_probabilities[i].item() >= self._class_probability_threshold] - def _publish_audio_analysis(self, audio_buffer, audio_classes, audio_descriptor, voice_descriptor): + def _publish_audio_analysis(self, sst_id, audio_buffer, audio_classes, audio_descriptor, voice_descriptor): with self._audio_direction_lock: frame_id, direction_x, direction_y, direction_z = self._audio_direction @@ -122,6 +123,8 @@ def _publish_audio_analysis(self, audio_buffer, audio_classes, audio_descriptor, msg.header.stamp = rospy.Time.now() msg.header.frame_id = frame_id + msg.tracking_id = sst_id + msg.audio_frame.format = 'float' msg.audio_frame.channel_count = SUPPORTED_CHANNEL_COUNT msg.audio_frame.sampling_frequency = self._supported_sampling_frequency diff --git a/ros/perceptions/person_identification/README.md b/ros/perceptions/person_identification/README.md index a5b2987a..ef508f3a 100644 --- a/ros/perceptions/person_identification/README.md +++ b/ros/perceptions/person_identification/README.md @@ -13,6 +13,7 @@ roslaunch person_identification capture_face.launch name:= neural_n #### Parameters - `name` (string): The person name. - `mean_size` (int): How many descriptor to average. + - `face_sharpness_score_threshold` (double): The threshold to consider the face sharp enough. #### Subscribed Topics - `video_analysis` ([video_analyzer/VideoAnalysis](../video_analyzer/msg/VideoAnalysis.msg)): The video analysis containing the detected objects. @@ -38,6 +39,7 @@ roslaunch person_identification capture_voice.launch name:= neural_ This node performs person identification. The people must be already added to `people.json` with the previous nodes. #### Parameters + - `face_sharpness_score_threshold` (double): The threshold to consider the face sharp enough. - `face_descriptor_threshold` (double): The maximum distance between two face descriptors to be considered the same person. - `voice_descriptor_threshold` (double): The maximum distance between two voice descriptors to be considered the same person. - `face_voice_descriptor_threshold` (double): The maximum distance between two merged descriptors to be considered the same person. diff --git a/ros/perceptions/person_identification/launch/capture_face.launch b/ros/perceptions/person_identification/launch/capture_face.launch index 131ba77c..200c3ceb 100644 --- a/ros/perceptions/person_identification/launch/capture_face.launch +++ b/ros/perceptions/person_identification/launch/capture_face.launch @@ -2,23 +2,32 @@ - + + + + + + + + + + - - - - - - + + + + + - + + @@ -42,5 +51,6 @@ + diff --git a/ros/perceptions/person_identification/scripts/capture_face_node.py b/ros/perceptions/person_identification/scripts/capture_face_node.py index 23ba5ea2..0bc52bfb 100755 --- a/ros/perceptions/person_identification/scripts/capture_face_node.py +++ b/ros/perceptions/person_identification/scripts/capture_face_node.py @@ -18,22 +18,23 @@ class CaptureFaceNode: def __init__(self): self._name = rospy.get_param('~name') self._mean_size = rospy.get_param('~mean_size') + self._face_sharpness_score_threshold = rospy.get_param('~face_sharpness_score_threshold') self._descriptors_lock = threading.Lock() self._descriptors = [] self._video_analysis_sub = rospy.Subscriber('video_analysis', VideoAnalysis, self._video_analysis_cb, queue_size=1) def _video_analysis_cb(self, msg): - face_descriptor = None + face_object = None for object in msg.objects: - if len(object.face_descriptor) > 0 and face_descriptor is not None: + if len(object.face_descriptor) > 0 and face_object is not None: rospy.logwarn('Only one face must be present in the image.') elif len(object.face_descriptor) > 0: - face_descriptor = object.face_descriptor + face_object = object - if face_descriptor is not None: + if face_object is not None and face_object.face_sharpness_score >= self._face_sharpness_score_threshold: with self._descriptors_lock: - self._descriptors.append(face_descriptor) + self._descriptors.append(face_object.face_descriptor) def run(self): self.enable_video_analyzer() diff --git a/ros/perceptions/person_identification/scripts/person_identification_node.py b/ros/perceptions/person_identification/scripts/person_identification_node.py index bb201812..3e308a3b 100755 --- a/ros/perceptions/person_identification/scripts/person_identification_node.py +++ b/ros/perceptions/person_identification/scripts/person_identification_node.py @@ -41,6 +41,7 @@ def __init__(self, descriptor, direction): class PersonIdentificationNode: def __init__(self): + self._face_sharpness_score_threshold = rospy.get_param('~face_sharpness_score_threshold') self._face_descriptor_threshold = rospy.get_param('~face_descriptor_threshold') self._voice_descriptor_threshold = rospy.get_param('~voice_descriptor_threshold') self._face_voice_descriptor_threshold = rospy.get_param('~face_voice_descriptor_threshold') @@ -83,7 +84,8 @@ def _video_analysis_cb(self, msg): for object in msg.objects: if len(object.face_descriptor) == 0 or len(object.person_pose_2d) == 0 or len(object.person_pose_3d) == 0 \ or len(object.person_pose_confidence) == 0 \ - or object.person_pose_confidence[PERSON_POSE_NOSE_INDEX] < self._nose_confidence_threshold: + or object.person_pose_confidence[PERSON_POSE_NOSE_INDEX] < self._nose_confidence_threshold \ + or object.face_sharpness_score < self._face_sharpness_score_threshold: continue position_2d = object.person_pose_2d[PERSON_POSE_NOSE_INDEX] diff --git a/ros/perceptions/video_analyzer/README.md b/ros/perceptions/video_analyzer/README.md index b6737871..e2a49813 100644 --- a/ros/perceptions/video_analyzer/README.md +++ b/ros/perceptions/video_analyzer/README.md @@ -7,7 +7,9 @@ This node detects objects. If the audio contains a person, it estimates the pose This node uses RGB images, so the 3D positions are not set. #### Parameters - - `use_descriptor_yolo_v4` (bool): Indicates to use the network extracting a object embedding or not. + - `use_descriptor_yolo` (bool): Indicates to use the network extracting an object embedding or not. + - `yolo_model` (string): If descriptor_yolo_v4 is not used, it indicates which model to use for YOLO (yolo_v4_coco, yolo_v4_tiny_coco, yolo_v7_coco, yolo_v7_tiny_coco or yolo_v7_objects365). + If descriptor_yolo is used, it indicates which model to use for Descriptor-YOLO (yolo_v4_tiny_coco, yolo_v7_coco or yolo_v7_objects365). - `confidence_threshold` (double): The object confidence threshold. - `nms_threshold` (double): The Non-Maximum Suppresion threshold. - `person_probability_threshold` (double): The person confidence threshold. @@ -36,7 +38,9 @@ This node detects objects. If the audio contains a person, it estimates the pose This node uses RGB-D images, so the 3D positions are set. #### Parameters - - `use_descriptor_yolo_v4` (bool): Indicates to use the network extracting a object embedding or not. + - `use_descriptor_yolo` (bool): Indicates to use the network extracting an object embedding or not. + - `yolo_model` (string): If descriptor_yolo_v4 is not used, it indicates which model to use for YOLO (yolo_v4_coco, yolo_v4_tiny_coco, yolo_v7_coco, yolo_v7_tiny_coco or yolo_v7_objects365). + If descriptor_yolo is used, it indicates which model to use for Descriptor-YOLO (yolo_v4_tiny_coco, yolo_v7_coco or yolo_v7_objects365). - `confidence_threshold` (double): The object confidence threshold. - `nms_threshold` (double): The Non-Maximum Suppresion threshold. - `person_probability_threshold` (double): The person confidence threshold. diff --git a/ros/perceptions/video_analyzer/msg/VideoAnalysisObject.msg b/ros/perceptions/video_analyzer/msg/VideoAnalysisObject.msg index ae798cbf..042d41ff 100644 --- a/ros/perceptions/video_analyzer/msg/VideoAnalysisObject.msg +++ b/ros/perceptions/video_analyzer/msg/VideoAnalysisObject.msg @@ -6,6 +6,7 @@ float32 height_2d # Normalized string object_class float32 object_confidence +float32 object_class_probability float32[] object_descriptor sensor_msgs/Image object_image @@ -15,4 +16,6 @@ float32[] person_pose_confidence sensor_msgs/Image person_pose_image float32[] face_descriptor +int32 face_alignment_keypoint_count +float32 face_sharpness_score sensor_msgs/Image face_image diff --git a/ros/perceptions/video_analyzer/scripts/video_analyzer_2d_node.py b/ros/perceptions/video_analyzer/scripts/video_analyzer_2d_node.py index 6fcdd157..fc643011 100755 --- a/ros/perceptions/video_analyzer/scripts/video_analyzer_2d_node.py +++ b/ros/perceptions/video_analyzer/scripts/video_analyzer_2d_node.py @@ -52,6 +52,7 @@ def _object_analysis_to_msg(self, object_analysis, image_height, image_width): o.height_2d = object_analysis.height / image_height o.object_class = object_analysis.object_class o.object_confidence = object_analysis.object_confidence + o.object_class_probability = object_analysis.object_class_probability if object_analysis.object_image is not None: o.object_image = self._cv_bridge.cv2_to_imgmsg(object_analysis.object_image, encoding='rgb8') o.object_descriptor = object_analysis.object_descriptor @@ -65,6 +66,8 @@ def _object_analysis_to_msg(self, object_analysis, image_height, image_width): if object_analysis.face_analysis is not None: o.face_descriptor = object_analysis.face_analysis.descriptor + o.face_alignment_keypoint_count = object_analysis.face_analysis.alignment_keypoint_count + o.face_sharpness_score = object_analysis.face_analysis.sharpness_score if object_analysis.face_analysis.face_image is not None: o.face_image = self._cv_bridge.cv2_to_imgmsg(object_analysis.face_analysis.face_image, encoding='rgb8') diff --git a/ros/perceptions/video_analyzer/scripts/video_analyzer_3d_node.py b/ros/perceptions/video_analyzer/scripts/video_analyzer_3d_node.py index 95737785..77b0fb95 100755 --- a/ros/perceptions/video_analyzer/scripts/video_analyzer_3d_node.py +++ b/ros/perceptions/video_analyzer/scripts/video_analyzer_3d_node.py @@ -73,6 +73,7 @@ def _object_analysis_to_msg(self, object_analysis, image_height, image_width, de o.height_2d = object_analysis.height / image_height o.object_class = object_analysis.object_class o.object_confidence = object_analysis.object_confidence + o.object_class_probability = object_analysis.object_class_probability if object_analysis.object_image is not None: o.object_image = self._cv_bridge.cv2_to_imgmsg(object_analysis.object_image, encoding='rgb8') o.object_descriptor = object_analysis.object_descriptor @@ -89,6 +90,8 @@ def _object_analysis_to_msg(self, object_analysis, image_height, image_width, de if object_analysis.face_analysis is not None: o.face_descriptor = object_analysis.face_analysis.descriptor + o.face_alignment_keypoint_count = object_analysis.face_analysis.alignment_keypoint_count + o.face_sharpness_score = object_analysis.face_analysis.sharpness_score if object_analysis.face_analysis.face_image is not None: o.face_image = self._cv_bridge.cv2_to_imgmsg(object_analysis.face_analysis.face_image, encoding='rgb8') diff --git a/ros/perceptions/video_analyzer/src/video_analyzer/lib_video_analyzer_node.py b/ros/perceptions/video_analyzer/src/video_analyzer/lib_video_analyzer_node.py index f8ef22d0..91777ae1 100644 --- a/ros/perceptions/video_analyzer/src/video_analyzer/lib_video_analyzer_node.py +++ b/ros/perceptions/video_analyzer/src/video_analyzer/lib_video_analyzer_node.py @@ -13,10 +13,12 @@ from geometry_msgs.msg import Point from video_analyzer.msg import VideoAnalysis, SemanticSegmentation -from dnn_utils import DescriptorYoloV4, YoloV4, PoseEstimator, FaceDescriptorExtractor, SemanticSegmentationNetwork +from dnn_utils import DescriptorYolo, Yolo, PoseEstimator, FaceDescriptorExtractor, SemanticSegmentationNetwork import hbba_lite +BOX_COLOR = (255, 0, 0) +BOX_TEXT_COLOR = (0, 255, 0) PERSON_POSE_KEYPOINT_COLORS = [(0, 255, 0), (255, 0, 0), (0, 0, 255), @@ -38,7 +40,7 @@ class ObjectAnalysis: def __init__(self, center_x, center_y, width, height, - object_class, object_confidence, + object_class, object_confidence, object_class_probability, object_descriptor=None, object_image=None, pose_analysis=None, face_analysis=None): self.center_x = center_x @@ -47,6 +49,7 @@ def __init__(self, center_x, center_y, width, height, self.height = height self.object_class = object_class self.object_confidence = object_confidence + self.object_class_probability = object_class_probability self.object_descriptor = object_descriptor self.object_image = object_image @@ -57,7 +60,7 @@ def __init__(self, center_x, center_y, width, height, def from_yoloV4_prediction(prediction, object_class_names): return ObjectAnalysis(prediction.center_x, prediction.center_y, prediction.width, prediction.height, object_class_names[prediction.class_index], prediction.confidence, - prediction.descriptor) + prediction.class_probabilities[prediction.class_index], prediction.descriptor) class PoseAnalysis: def __init__(self, pose_coordinates, pose_confidence, pose_image): @@ -67,14 +70,17 @@ def __init__(self, pose_coordinates, pose_confidence, pose_image): class FaceAnalysis: - def __init__(self, descriptor, face_image=None): + def __init__(self, descriptor, alignment_keypoint_count, sharpness_score, face_image=None): self.descriptor = descriptor + self.alignment_keypoint_count = alignment_keypoint_count + self.sharpness_score = sharpness_score self.face_image = face_image class VideoAnalyzerNode: def __init__(self): - self._use_descriptor_yolo_v4 = rospy.get_param('~use_descriptor_yolo_v4') + self._use_descriptor_yolo = rospy.get_param('~use_descriptor_yolo') + self._yolo_model = rospy.get_param('~yolo_model', None) self._confidence_threshold = rospy.get_param('~confidence_threshold') self._nms_threshold = rospy.get_param('~nms_threshold') self._person_probability_threshold = rospy.get_param('~person_probability_threshold') @@ -90,12 +96,12 @@ def __init__(self): if self._face_descriptor_enabled and not self._pose_enabled: raise ValueError('The pose estimation must be enabled when the face descriptor extraction is enabled.') - if self._use_descriptor_yolo_v4: - self._object_detector = DescriptorYoloV4(confidence_threshold=self._confidence_threshold, - nms_threshold=self._nms_threshold, inference_type=self._inference_type) + if self._use_descriptor_yolo: + self._object_detector = DescriptorYolo(self._yolo_model, confidence_threshold=self._confidence_threshold, + nms_threshold=self._nms_threshold, inference_type=self._inference_type) else: - self._object_detector = YoloV4(confidence_threshold=self._confidence_threshold, - nms_threshold=self._nms_threshold, inference_type=self._inference_type) + self._object_detector = Yolo(self._yolo_model, confidence_threshold=self._confidence_threshold, + nms_threshold=self._nms_threshold, inference_type=self._inference_type) self._object_class_names = self._object_detector.get_class_names() if self._pose_enabled: @@ -156,13 +162,15 @@ def _analyse_person(self, cv_color_image, color_image_tensor, x0, y0): face_analysis = None if self._face_descriptor_enabled: try: - face_descriptor, face_image = self._face_descriptor_extractor(color_image_tensor, - pose_coordinates, pose_confidence) + face_descriptor, face_image, face_alignment_keypoint_count, face_sharpness_score = self._face_descriptor_extractor( + color_image_tensor, pose_coordinates, pose_confidence, self._pose_confidence_threshold) except ValueError: face_descriptor = torch.tensor([]) + face_sharpness_score = -1.0 face_image = None + face_alignment_keypoint_count = 0 - face_analysis = FaceAnalysis(face_descriptor.tolist()) + face_analysis = FaceAnalysis(face_descriptor.tolist(), face_alignment_keypoint_count, face_sharpness_score) if self._cropped_image_enabled: face_analysis.face_image = face_image @@ -197,8 +205,9 @@ def _publish_analysed_image(self, color_image, header, object_analyses): def _draw_object_analysis(self, image, object_analysis): x0, y0, x1, y1 = self._get_bbox(object_analysis, image.shape[1], image.shape[0]) - color = (255, 0, 0) - cv2.rectangle(image, (x0, y0), (x1, y1), color, thickness=4) + cv2.rectangle(image, (x0, y0), (x1, y1), BOX_COLOR, thickness=4) + text = f'{object_analysis.object_class}({object_analysis.object_confidence:.2f}, {object_analysis.object_class_probability:.2f})' + cv2.putText(image, text, (x0, y0), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1.0, color=BOX_TEXT_COLOR, thickness=3) if object_analysis.pose_analysis is not None: self._draw_person_pose(image, diff --git a/ros/t_top/config/configuration_16SoundsUSB.cfg b/ros/t_top/config/configuration_16SoundsUSB.cfg index 27618e1e..ac5ccf9d 100644 --- a/ros/t_top/config/configuration_16SoundsUSB.cfg +++ b/ros/t_top/config/configuration_16SoundsUSB.cfg @@ -55,7 +55,7 @@ general: # Microphone 1 { - mu = ( 0.08089938405708777, 0.12107448437676213, -0.24828407179715611 ); + mu = ( 0.08089938405708777, 0.12107448437676213, 0.020429590050176742 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( 0.5487302415312265, 0.8212328416301297, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -63,7 +63,7 @@ general: # Microphone 2 { - mu = ( 0.02967060476785486, 0.14916420310718073, -0.2891432518975096 ); + mu = ( 0.02967060476785486, 0.14916420310718073, -0.02042959005017674 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( 0.19268843641472078, 0.9687101860798543, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -71,7 +71,7 @@ general: # Microphone 3 { - mu = ( -0.028408085870891577, 0.14281709199205483, -0.24828407179715611 ); + mu = ( -0.028408085870891577, 0.14281709199205483, 0.020429590050176742 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( -0.1926884364209004, 0.968710186078625, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -79,7 +79,7 @@ general: # Microphone 4 { - mu = ( -0.0844947336941189, 0.12645530536063107, -0.2891432518975096 ); + mu = ( -0.0844947336941189, 0.12645530536063107, -0.02042959005017674 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( -0.5487302415364653, 0.8212328416266292, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -87,7 +87,7 @@ general: # Microphone 5 { - mu = ( -0.12107448437676213, 0.08089938405708777, -0.24828407179715611 ); + mu = ( -0.12107448437676213, 0.08089938405708777, 0.020429590050176742 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( -0.8212328416301297, 0.5487302415312265, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -95,7 +95,7 @@ general: # Microphone 6 { - mu = ( -0.14916420310718073, 0.029670604767854866, -0.2891432518975096 ); + mu = ( -0.14916420310718073, 0.029670604767854866, -0.02042959005017674 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( -0.9687101860798543, 0.19268843641472083, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -103,7 +103,7 @@ general: # Microphone 7 { - mu = ( -0.14281709199205483, -0.02840808587089157, -0.24828407179715611 ); + mu = ( -0.14281709199205483, -0.02840808587089157, 0.020429590050176742 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( -0.968710186078625, -0.19268843642090033, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -111,7 +111,7 @@ general: # Microphone 8 { - mu = ( -0.12645530536063107, -0.08449473369411888, -0.2891432518975096 ); + mu = ( -0.12645530536063107, -0.08449473369411888, -0.02042959005017674 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( -0.8212328416266292, -0.5487302415364652, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -119,15 +119,15 @@ general: # Microphone 9 { - mu = ( -0.08089938405708777, -0.12107448437676213, -0.24828407179715611 ); + mu = ( -0.08089938405708778, -0.12107448437676213, 0.020429590050176742 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); - direction = ( -0.5487302415312265, -0.8212328416301297, 0.15643446505544817 ); + direction = ( -0.5487302415312266, -0.8212328416301297, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); }, # Microphone 10 { - mu = ( -0.029670604767854943, -0.14916420310718073, -0.2891432518975096 ); + mu = ( -0.029670604767854943, -0.14916420310718073, -0.02042959005017674 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( -0.19268843641472133, -0.9687101860798542, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -135,7 +135,7 @@ general: # Microphone 11 { - mu = ( 0.02840808587089156, -0.14281709199205483, -0.24828407179715611 ); + mu = ( 0.02840808587089156, -0.14281709199205483, 0.020429590050176742 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( 0.19268843642090028, -0.9687101860786251, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -143,7 +143,7 @@ general: # Microphone 12 { - mu = ( 0.08449473369411893, -0.12645530536063104, -0.2891432518975096 ); + mu = ( 0.08449473369411893, -0.12645530536063104, -0.02042959005017674 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( 0.5487302415364655, -0.821232841626629, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -151,7 +151,7 @@ general: # Microphone 13 { - mu = ( 0.12107448437676219, -0.08089938405708767, -0.24828407179715611 ); + mu = ( 0.12107448437676219, -0.08089938405708767, 0.020429590050176742 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( 0.8212328416301301, -0.5487302415312258, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -159,7 +159,7 @@ general: # Microphone 14 { - mu = ( 0.14916420310718076, -0.029670604767854686, -0.2891432518975096 ); + mu = ( 0.14916420310718076, -0.029670604767854686, -0.02042959005017674 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( 0.9687101860798545, -0.19268843641471967, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -167,7 +167,7 @@ general: # Microphone 15 { - mu = ( 0.1428170919920548, 0.028408085870891803, -0.24828407179715611 ); + mu = ( 0.1428170919920548, 0.028408085870891803, 0.020429590050176742 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( 0.9687101860786248, 0.19268843642090192, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); @@ -175,13 +175,11 @@ general: # Microphone 16 { - mu = ( 0.1264553053606309, 0.08449473369411915, -0.2891432518975096 ); + mu = ( 0.1264553053606309, 0.08449473369411915, -0.02042959005017674 ); sigma2 = ( +1E-6, 0.0, 0.0, 0.0, +1E-6, 0.0, 0.0, 0.0, +1E-6 ); direction = ( 0.8212328416266281, 0.548730241536467, 0.15643446505544817 ); angle = ( 80.0, 100.0 ); } - - ); # Spatial filters to include only a range of direction if required diff --git a/ros/t_top/launch/perceptions/audio_analyzer.launch b/ros/t_top/launch/perceptions/audio_analyzer.launch index 76b23c76..7a6c044f 100644 --- a/ros/t_top/launch/perceptions/audio_analyzer.launch +++ b/ros/t_top/launch/perceptions/audio_analyzer.launch @@ -3,7 +3,7 @@ - + diff --git a/ros/t_top/launch/perceptions/music_beat_detector.launch b/ros/t_top/launch/perceptions/music_beat_detector.launch index 3400021b..d95314bc 100644 --- a/ros/t_top/launch/perceptions/music_beat_detector.launch +++ b/ros/t_top/launch/perceptions/music_beat_detector.launch @@ -6,7 +6,6 @@ - diff --git a/ros/t_top/launch/perceptions/person_identification.launch b/ros/t_top/launch/perceptions/person_identification.launch index a0fbb802..0af7a3e1 100644 --- a/ros/t_top/launch/perceptions/person_identification.launch +++ b/ros/t_top/launch/perceptions/person_identification.launch @@ -1,6 +1,7 @@ + diff --git a/ros/t_top/launch/perceptions/video_analyzer.launch b/ros/t_top/launch/perceptions/video_analyzer.launch index b02ebd1b..37d2c2bb 100644 --- a/ros/t_top/launch/perceptions/video_analyzer.launch +++ b/ros/t_top/launch/perceptions/video_analyzer.launch @@ -4,7 +4,8 @@ - + + @@ -33,7 +34,8 @@ - + + @@ -42,7 +44,7 @@ - + diff --git a/ros/t_top/launch/platform/daemon_ros_client.launch b/ros/t_top/launch/platform/daemon_ros_client.launch index 8a7d8a01..26683578 100644 --- a/ros/t_top/launch/platform/daemon_ros_client.launch +++ b/ros/t_top/launch/platform/daemon_ros_client.launch @@ -15,7 +15,7 @@ - + diff --git a/ros/t_top/launch/tests/test_robot_name_detector.launch b/ros/t_top/launch/tests/test_robot_name_detector.launch new file mode 100644 index 00000000..019226c1 --- /dev/null +++ b/ros/t_top/launch/tests/test_robot_name_detector.launch @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/ros/t_top/launch/tests/test_video_analyzer_2d_node.launch b/ros/t_top/launch/tests/test_video_analyzer_2d_node.launch index a9307762..b338b595 100644 --- a/ros/t_top/launch/tests/test_video_analyzer_2d_node.launch +++ b/ros/t_top/launch/tests/test_video_analyzer_2d_node.launch @@ -9,15 +9,16 @@ - - + + + - + - - + + diff --git a/ros/t_top/launch/tests/test_video_analyzer_3d_node.launch b/ros/t_top/launch/tests/test_video_analyzer_3d_node.launch index 00885738..d89d1c13 100644 --- a/ros/t_top/launch/tests/test_video_analyzer_3d_node.launch +++ b/ros/t_top/launch/tests/test_video_analyzer_3d_node.launch @@ -15,7 +15,8 @@ - + + diff --git a/ros/utils/dnn_utils/scripts/export_models.bash b/ros/utils/dnn_utils/scripts/export_models.bash index ff4be691..97ed3ca0 100755 --- a/ros/utils/dnn_utils/scripts/export_models.bash +++ b/ros/utils/dnn_utils/scripts/export_models.bash @@ -14,7 +14,7 @@ if [ -f Weights.zip ]; then OLD_TIME=$(stat Weights.zip -c %Y) fi -if OUT=$(wget -N https://introlab.3it.usherbrooke.ca/mediawiki-introlab/images/4/4e/Weights.zip 2>&1); then +if OUT=$(wget -N https://github.com/introlab/t-top/releases/download/DNN_Weights_v4.0.0/Weights.zip 2>&1); then # Output to stdout on success echo $OUT else @@ -43,11 +43,19 @@ trap 'jobs -p | xargs -I '{}' kill '{}' &> /dev/null; wait' INT QUIT KILL TERM set -e -python3 export_descriptor_yolo_v4.py --dataset_type coco --model_type yolo_v4_tiny --descriptor_size 128 --output_dir $SCRIPT_PATH/../models --torch_script_filename descriptor_yolo_v4.ts.pth --trt_filename descriptor_yolo_v4.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/descriptor_yolo_v4_tiny_3.pth --trt_fp16 $FORCE_EXPORT -python3 export_yolo_v4.py --model_type yolo_v4_tiny --output_dir $SCRIPT_PATH/../models --torch_script_filename yolo_v4.ts.pth --trt_filename yolo_v4.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/yolo_v4_tiny.pth --trt_fp16 $FORCE_EXPORT +python3 export_descriptor_yolo.py --dataset_type coco --model_type yolo_v4_tiny --descriptor_size 128 --output_dir $SCRIPT_PATH/../models --torch_script_filename descriptor_yolo_v4_tiny_coco.ts.pth --trt_filename descriptor_yolo_v4_tiny_coco.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/descriptor_yolo_v4_tiny_coco.pth --trt_fp16 $FORCE_EXPORT +python3 export_descriptor_yolo.py --dataset_type coco --model_type yolo_v7 --descriptor_size 128 --output_dir $SCRIPT_PATH/../models --torch_script_filename descriptor_yolo_v7_coco.ts.pth --trt_filename descriptor_yolo_v7_coco.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/descriptor_yolo_v7_coco.pth --trt_fp16 $FORCE_EXPORT -python3 export_pose_estimator.py --backbone_type resnet18 --upsampling_count 3 --output_dir $SCRIPT_PATH/../models --torch_script_filename pose_estimator.ts.pth --trt_filename pose_estimator.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/pose_estimator_resnet18_upsampling_count_3.pth --trt_fp16 $FORCE_EXPORT -python3 export_face_descriptor_extractor.py --embedding_size 256 --output_dir $SCRIPT_PATH/../models --torch_script_filename face_descriptor_extractor.ts.pth --trt_filename face_descriptor_extractor.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/face_descriptor_extractor_embedding_size_256.pth --trt_fp16 $FORCE_EXPORT +python3 export_yolo.py --dataset_type coco --model_type yolo_v4 --output_dir $SCRIPT_PATH/../models --torch_script_filename yolo_v4_coco.ts.pth --trt_filename yolo_v4_coco.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/yolo_v4_coco.pth --trt_fp16 $FORCE_EXPORT +python3 export_yolo.py --dataset_type coco --model_type yolo_v4_tiny --output_dir $SCRIPT_PATH/../models --torch_script_filename yolo_v4_tiny_coco.ts.pth --trt_filename yolo_v4_tiny_coco.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/yolo_v4_tiny_coco.pth --trt_fp16 $FORCE_EXPORT +python3 export_yolo.py --dataset_type coco --model_type yolo_v7 --output_dir $SCRIPT_PATH/../models --torch_script_filename yolo_v7_coco.ts.pth --trt_filename yolo_v7_coco.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/yolo_v7_coco.pth --trt_fp16 $FORCE_EXPORT +python3 export_yolo.py --dataset_type coco --model_type yolo_v7_tiny --output_dir $SCRIPT_PATH/../models --torch_script_filename yolo_v7_tiny_coco.ts.pth --trt_filename yolo_v7_tiny_coco.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/yolo_v7_tiny_coco.pth --trt_fp16 $FORCE_EXPORT + +python3 export_yolo.py --dataset_type objects365 --model_type yolo_v7 --output_dir $SCRIPT_PATH/../models --torch_script_filename yolo_v7_objects365.ts.pth --trt_filename yolo_v7_objects365.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/yolo_v7_objects365.pth --trt_fp16 $FORCE_EXPORT + +python3 export_pose_estimator.py --backbone_type efficientnet_b0 --output_dir $SCRIPT_PATH/../models --torch_script_filename pose_estimator_efficientnet_b0.ts.pth --trt_filename pose_estimator_efficientnet_b0.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/pose_estimator_efficientnet_b0.pth --trt_fp16 $FORCE_EXPORT + +python3 export_face_descriptor_extractor.py --backbone_type open_face --embedding_size 256 --output_dir $SCRIPT_PATH/../models --torch_script_filename face_descriptor_open_face_e256.ts.pth --trt_filename face_descriptor_open_face_e256.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/face_descriptor_open_face_e256.pth --trt_fp16 $FORCE_EXPORT python3 export_semantic_segmentation_network.py --dataset_type coco --backbone_type stdc1 --channel_scale 1 --output_dir $SCRIPT_PATH/../models --torch_script_filename semantic_segmentation_network_coco.ts.pth --trt_filename semantic_segmentation_network_coco.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/semantic_segmentation_network_coco_stdc1_s1.pth --trt_fp16 $FORCE_EXPORT python3 export_semantic_segmentation_network.py --dataset_type kitchen_open_images --backbone_type stdc1 --channel_scale 1 --output_dir $SCRIPT_PATH/../models --torch_script_filename semantic_segmentation_network_kitchen_open_images.ts.pth --trt_filename semantic_segmentation_network_kitchen_open_images.trt.pth --model_checkpoint $SCRIPT_PATH/../weights/semantic_segmentation_network_kitchen_open_images_stdc1_s1.pth --trt_fp16 $FORCE_EXPORT diff --git a/ros/utils/dnn_utils/scripts/test.py b/ros/utils/dnn_utils/scripts/test.py index 49603016..5ebad9f2 100755 --- a/ros/utils/dnn_utils/scripts/test.py +++ b/ros/utils/dnn_utils/scripts/test.py @@ -8,7 +8,7 @@ import rospy -from dnn_utils import DescriptorYoloV4, YoloV4, PoseEstimator, FaceDescriptorExtractor +from dnn_utils import DescriptorYolo, Yolo, PoseEstimator, FaceDescriptorExtractor from dnn_utils import MulticlassAudioDescriptorExtractor, VoiceDescriptorExtractor, TTopKeywordSpotter from dnn_utils import SemanticSegmentationNetwork @@ -31,12 +31,12 @@ def launch_test(function, *args): print() -def test_descriptor_yolo_v4(): - print('----------test_descriptor_yolo_v4----------') +def test_descriptor_yolo(model_name): + print(f'----------test_descriptor_{model_name}----------') - cpu_model = DescriptorYoloV4(inference_type='cpu') - torch_gpu_model = DescriptorYoloV4(inference_type='torch_gpu') - trt_gpu_model = DescriptorYoloV4(inference_type='trt_gpu') + cpu_model = DescriptorYolo(model_name, inference_type='cpu') + torch_gpu_model = DescriptorYolo(model_name, inference_type='torch_gpu') + trt_gpu_model = DescriptorYolo(model_name, inference_type='trt_gpu') IMAGE_SIZE = cpu_model.get_supported_image_size() x = torch.rand(3, IMAGE_SIZE[0], IMAGE_SIZE[1]) @@ -52,19 +52,19 @@ def test_descriptor_yolo_v4(): mean_abs_diff(cpu_predictions[i], trt_gpu_predictions[i])) -def test_yolo_v4(): - print('----------test_yolo_v4----------') +def test_yolo(model_name): + print(f'----------test_{model_name}----------') - cpu_model = YoloV4(inference_type='cpu') - torch_gpu_model = YoloV4(inference_type='torch_gpu') - trt_gpu_model = YoloV4(inference_type='trt_gpu') + cpu_model = Yolo(model_name, inference_type='cpu') + torch_gpu_model = Yolo(model_name, inference_type='torch_gpu') + trt_gpu_model = Yolo(model_name, inference_type='trt_gpu') IMAGE_SIZE = cpu_model.get_supported_image_size() x = torch.rand(3, IMAGE_SIZE[0], IMAGE_SIZE[1]) - _, cpu_predictions = cpu_model.forward_raw(x) - _, torch_gpu_predictions = torch_gpu_model.forward_raw(x) - _, trt_gpu_predictions = trt_gpu_model.forward_raw(x) + _, _, _, cpu_predictions = cpu_model.forward_raw(x) + _, _, _, torch_gpu_predictions = torch_gpu_model.forward_raw(x) + _, _, _, trt_gpu_predictions = trt_gpu_model.forward_raw(x) for i in range(len(cpu_predictions)): print('mean(abs(cpu_predictions[{}] - torch_gpu_predictions[{}])) ='.format(i, i), @@ -98,7 +98,7 @@ def test_pose_estimator(): def test_face_descriptor_extractor(): - print('----------test_face_descriptor_extractor----------') + print(f'----------test_face_descriptor_extractor----------') cpu_model = FaceDescriptorExtractor(inference_type='cpu') torch_gpu_model = FaceDescriptorExtractor(inference_type='torch_gpu') @@ -111,10 +111,11 @@ def test_face_descriptor_extractor(): [0.75 * IMAGE_SIZE[1], 0.25 * IMAGE_SIZE[0]], [0.25 * IMAGE_SIZE[1], 0.25 * IMAGE_SIZE[0]]]) pose_presence = np.array([1.0, 1.0, 1.0, 0.0, 0.0]) + pose_confidence_threshold = 0.4 - cpu_descriptor = cpu_model(x, pose_coordinates, pose_presence)[0] - torch_gpu_descriptor = torch_gpu_model(x, pose_coordinates, pose_presence)[0] - trt_gpu_descriptor = trt_gpu_model(x, pose_coordinates, pose_presence)[0] + cpu_descriptor = cpu_model(x, pose_coordinates, pose_presence, pose_confidence_threshold)[0] + torch_gpu_descriptor = torch_gpu_model(x, pose_coordinates, pose_presence, pose_confidence_threshold)[0] + trt_gpu_descriptor = trt_gpu_model(x, pose_coordinates, pose_presence, pose_confidence_threshold)[0] print('mean(abs(cpu_descriptor - torch_gpu_descriptor)) =', mean_abs_diff(cpu_descriptor, torch_gpu_descriptor)) @@ -202,8 +203,13 @@ def test_semantic_segmentation_network(dataset): def main(): rospy.init_node('dnn_utils_test', disable_signals=True) - launch_test(test_descriptor_yolo_v4) - launch_test(test_yolo_v4) + launch_test(test_descriptor_yolo, 'yolo_v4_tiny_coco') + launch_test(test_descriptor_yolo, 'yolo_v7_coco') + launch_test(test_yolo, 'yolo_v4_coco') + launch_test(test_yolo, 'yolo_v4_tiny_coco') + launch_test(test_yolo, 'yolo_v7_coco') + launch_test(test_yolo, 'yolo_v7_tiny_coco') + launch_test(test_yolo, 'yolo_v7_objects365') launch_test(test_pose_estimator) launch_test(test_face_descriptor_extractor) launch_test(test_multiclass_audio_descriptor_extractor) diff --git a/ros/utils/dnn_utils/src/dnn_utils/__init__.py b/ros/utils/dnn_utils/src/dnn_utils/__init__.py index b8a16bd1..b217209f 100644 --- a/ros/utils/dnn_utils/src/dnn_utils/__init__.py +++ b/ros/utils/dnn_utils/src/dnn_utils/__init__.py @@ -1,10 +1,10 @@ from dnn_utils.dnn_model import DnnModel -from dnn_utils.descriptor_yolo_v4 import DescriptorYoloV4, DescriptorYoloV4Prediction +from dnn_utils.descriptor_yolo import DescriptorYolo, DescriptorYoloPrediction from dnn_utils.face_descriptor_extractor import FaceDescriptorExtractor from dnn_utils.pose_estimator import PoseEstimator from dnn_utils.ttop_keyword_spotter import TTopKeywordSpotter from dnn_utils.voice_descriptor_extractor import VoiceDescriptorExtractor from dnn_utils.multiclass_audio_descriptor_extractor import MulticlassAudioDescriptorExtractor -from dnn_utils.yolo_v4 import YoloV4, YoloV4Prediction +from dnn_utils.yolo import Yolo, YoloPrediction from dnn_utils.semantic_segmentation_network import SemanticSegmentationNetwork diff --git a/ros/utils/dnn_utils/src/dnn_utils/yolo_v4.py b/ros/utils/dnn_utils/src/dnn_utils/descriptor_yolo.py similarity index 56% rename from ros/utils/dnn_utils/src/dnn_utils/yolo_v4.py rename to ros/utils/dnn_utils/src/dnn_utils/descriptor_yolo.py index 30006a65..93bcf6f8 100644 --- a/ros/utils/dnn_utils/src/dnn_utils/yolo_v4.py +++ b/ros/utils/dnn_utils/src/dnn_utils/descriptor_yolo.py @@ -6,58 +6,67 @@ import torchaudio from dnn_utils.dnn_model import PACKAGE_PATH, DnnModel +from dnn_utils.yolo import COCO_CLASS_NAMES, OBJECTS365_CLASS_NAMES sys.path.append(os.path.join(PACKAGE_PATH, '..', '..', '..', 'tools', 'dnn_training')) from object_detection.modules.yolo_layer import X_INDEX, Y_INDEX, W_INDEX, H_INDEX, CONFIDENCE_INDEX, CLASSES_INDEX from object_detection.filter_yolo_predictions import group_predictions, filter_yolo_predictions -IMAGE_SIZE = (416, 416) +IMAGE_SIZE_BY_MODEL_NAME = { + 'yolo_v4_tiny_coco' : (416, 416), + 'yolo_v7_coco' : (640, 640), + 'yolo_v7_objects365' : (640, 640), +} + IN_CHANNELS = 3 -CLASS_COUNT = 80 + +CLASS_NAMES_BY_MODEL_NAME = { + 'yolo_v4_tiny_coco' : COCO_CLASS_NAMES, + 'yolo_v7_coco' : COCO_CLASS_NAMES, + 'yolo_v7_objects365' : OBJECTS365_CLASS_NAMES, +} -class YoloV4Prediction: - def __init__(self, prediction_tensor, scale): +class DescriptorYoloPrediction: + def __init__(self, prediction_tensor, scale, class_count): self.center_x = (prediction_tensor[X_INDEX] / scale).item() self.center_y = (prediction_tensor[Y_INDEX] / scale).item() self.width = (prediction_tensor[W_INDEX] / scale).item() self.height = (prediction_tensor[H_INDEX] / scale).item() self.confidence = prediction_tensor[CONFIDENCE_INDEX].item() - class_probabilities = F.softmax(prediction_tensor[CLASSES_INDEX:CLASSES_INDEX + CLASS_COUNT], dim=0) + class_probabilities = F.softmax(prediction_tensor[CLASSES_INDEX:CLASSES_INDEX + class_count], dim=0) self.class_index = torch.argmax(class_probabilities, dim=0).item() self.class_probabilities = class_probabilities.tolist() - self.descriptor = [] + self.descriptor = prediction_tensor[CLASSES_INDEX + class_count:].tolist() -class YoloV4(DnnModel): - def __init__(self, confidence_threshold=0.99, nms_threshold=0.5, inference_type=None): +class DescriptorYolo(DnnModel): + def __init__(self, model_name, confidence_threshold=0.99, nms_threshold=0.5, inference_type=None): + if model_name not in IMAGE_SIZE_BY_MODEL_NAME: + raise ValueError(f'Invalid model name ({model_name})') + + self._model_name = model_name self._confidence_threshold = confidence_threshold self._nms_threshold = nms_threshold - torch_script_model_path = os.path.join(PACKAGE_PATH, 'models', 'yolo_v4.ts.pth') - tensor_rt_model_path = os.path.join(PACKAGE_PATH, 'models', 'yolo_v4.trt.pth') - sample_input = torch.ones(1, 3, IMAGE_SIZE[0], IMAGE_SIZE[1]) + self._image_size = IMAGE_SIZE_BY_MODEL_NAME[model_name] + + torch_script_model_path = os.path.join(PACKAGE_PATH, 'models', f'descriptor_{model_name}.ts.pth') + tensor_rt_model_path = os.path.join(PACKAGE_PATH, 'models', f'descriptor_{model_name}.trt.pth') + sample_input = torch.ones(1, 3, self._image_size[0], self._image_size[1]) - super(YoloV4, self).__init__(torch_script_model_path, tensor_rt_model_path, sample_input, - inference_type=inference_type) - self._padded_image = torch.ones(1, 3, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(self._device) + super(DescriptorYolo, self).__init__(torch_script_model_path, tensor_rt_model_path, sample_input, + inference_type=inference_type) + self._padded_image = torch.ones(1, 3, self._image_size[0], self._image_size[1]).to(self._device) def get_supported_image_size(self): - return IMAGE_SIZE + return IMAGE_SIZE_BY_MODEL_NAME[self._model_name] def get_class_names(self): - return ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 'traffic light', - 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', - 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', - 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', - 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', - 'chair', 'sofa', 'pottedplant', 'bed', 'diningtable', 'toilet', 'tvmonitor', 'laptop', 'mouse', - 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', - 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] + return CLASS_NAMES_BY_MODEL_NAME[self._model_name] def __call__(self, image_tensor): with torch.no_grad(): @@ -67,11 +76,12 @@ def __call__(self, image_tensor): confidence_threshold=self._confidence_threshold, nms_threshold=self._nms_threshold) - return [YoloV4Prediction(prediction.cpu(), scale) for prediction in predictions] + class_count = len(CLASS_NAMES_BY_MODEL_NAME[self._model_name]) + return [DescriptorYoloPrediction(prediction.cpu(), scale, class_count) for prediction in predictions] def forward_raw(self, image_tensor): scale = self._set_image(image_tensor.to(self._device).unsqueeze(0)) - predictions = super(YoloV4, self).__call__(self._padded_image) + predictions = super(DescriptorYolo, self).__call__(self._padded_image) return scale, predictions def _set_image(self, image_tensor): @@ -80,5 +90,5 @@ def _set_image(self, image_tensor): image_tensor = F.interpolate(image_tensor, size=output_size, mode='bilinear') self._padded_image[:, :, :image_tensor.size()[2], :image_tensor.size()[3]] = image_tensor - self._padded_image[:, :, image_tensor.size()[2]:, image_tensor.size()[3]:] = 128 + self._padded_image[:, :, image_tensor.size()[2]:, image_tensor.size()[3]:] = 0.44705882352 return scale diff --git a/ros/utils/dnn_utils/src/dnn_utils/descriptor_yolo_v4.py b/ros/utils/dnn_utils/src/dnn_utils/descriptor_yolo_v4.py deleted file mode 100644 index 3f4eff87..00000000 --- a/ros/utils/dnn_utils/src/dnn_utils/descriptor_yolo_v4.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -import sys - -import torch -import torch.nn.functional as F -import torchaudio - -from dnn_utils.dnn_model import PACKAGE_PATH, DnnModel - -sys.path.append(os.path.join(PACKAGE_PATH, '..', '..', '..', 'tools', 'dnn_training')) -from object_detection.modules.yolo_layer import X_INDEX, Y_INDEX, W_INDEX, H_INDEX, CONFIDENCE_INDEX, CLASSES_INDEX -from object_detection.filter_yolo_predictions import group_predictions, filter_yolo_predictions - - -IMAGE_SIZE = (416, 416) -IN_CHANNELS = 3 -CLASS_COUNT = 80 - - -class DescriptorYoloV4Prediction: - def __init__(self, prediction_tensor, scale): - self.center_x = (prediction_tensor[X_INDEX] / scale).item() - self.center_y = (prediction_tensor[Y_INDEX] / scale).item() - self.width = (prediction_tensor[W_INDEX] / scale).item() - self.height = (prediction_tensor[H_INDEX] / scale).item() - self.confidence = prediction_tensor[CONFIDENCE_INDEX].item() - - class_probabilities = F.softmax(prediction_tensor[CLASSES_INDEX:CLASSES_INDEX + CLASS_COUNT], dim=0) - self.class_index = torch.argmax(class_probabilities, dim=0).item() - self.class_probabilities = class_probabilities.tolist() - - self.descriptor = prediction_tensor[CLASSES_INDEX + CLASS_COUNT:].tolist() - - -class DescriptorYoloV4(DnnModel): - def __init__(self, confidence_threshold=0.99, nms_threshold=0.5, inference_type=None): - self._confidence_threshold = confidence_threshold - self._nms_threshold = nms_threshold - - torch_script_model_path = os.path.join(PACKAGE_PATH, 'models', 'descriptor_yolo_v4.ts.pth') - tensor_rt_model_path = os.path.join(PACKAGE_PATH, 'models', 'descriptor_yolo_v4.trt.pth') - sample_input = torch.ones(1, 3, IMAGE_SIZE[0], IMAGE_SIZE[1]) - - super(DescriptorYoloV4, self).__init__(torch_script_model_path, tensor_rt_model_path, sample_input, - inference_type=inference_type) - self._padded_image = torch.ones(1, 3, IMAGE_SIZE[0], IMAGE_SIZE[1]).to(self._device) - - def get_supported_image_size(self): - return IMAGE_SIZE - - def get_class_names(self): - return ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 'traffic light', - 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', - 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', - 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', - 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', - 'chair', 'sofa', 'pottedplant', 'bed', 'diningtable', 'toilet', 'tvmonitor', 'laptop', 'mouse', - 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', - 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] - - def __call__(self, image_tensor): - with torch.no_grad(): - scale, predictions = self.forward_raw(image_tensor.to(self._device)) - predictions = group_predictions(predictions)[0] - predictions = filter_yolo_predictions(predictions, - confidence_threshold=self._confidence_threshold, - nms_threshold=self._nms_threshold) - - return [DescriptorYoloV4Prediction(prediction.cpu(), scale) for prediction in predictions] - - def forward_raw(self, image_tensor): - scale = self._set_image(image_tensor.to(self._device).unsqueeze(0)) - predictions = super(DescriptorYoloV4, self).__call__(self._padded_image) - return scale, predictions - - def _set_image(self, image_tensor): - scale = min(self._padded_image.size()[2] / image_tensor.size()[2], self._padded_image.size()[3] / image_tensor.size()[3]) - output_size = ((int(image_tensor.size()[2] * scale), int(image_tensor.size()[3] * scale))) - - image_tensor = F.interpolate(image_tensor, size=output_size, mode='bilinear') - self._padded_image[:, :, :image_tensor.size()[2], :image_tensor.size()[3]] = image_tensor - self._padded_image[:, :, image_tensor.size()[2]:, image_tensor.size()[3]:] = 128 - return scale diff --git a/ros/utils/dnn_utils/src/dnn_utils/face_descriptor_extractor.py b/ros/utils/dnn_utils/src/dnn_utils/face_descriptor_extractor.py index c2aee481..6593daa7 100644 --- a/ros/utils/dnn_utils/src/dnn_utils/face_descriptor_extractor.py +++ b/ros/utils/dnn_utils/src/dnn_utils/face_descriptor_extractor.py @@ -6,6 +6,7 @@ import cv2 import torch +import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms @@ -16,23 +17,29 @@ IMAGE_SIZE = (128, 96) +SHARPNESS_SCORE_SCALE = 2.0 class FaceDescriptorExtractor(DnnModel): def __init__(self, inference_type=None): - torch_script_model_path = os.path.join(PACKAGE_PATH, 'models', 'face_descriptor_extractor.ts.pth') - tensor_rt_model_path = os.path.join(PACKAGE_PATH, 'models', 'face_descriptor_extractor.trt.pth') + torch_script_model_path = os.path.join(PACKAGE_PATH, 'models', f'face_descriptor_open_face_e256.ts.pth') + tensor_rt_model_path = os.path.join(PACKAGE_PATH, 'models', f'face_descriptor_open_face_e256.trt.pth') sample_input = torch.ones((1, 3, IMAGE_SIZE[0], IMAGE_SIZE[1])) super(FaceDescriptorExtractor, self).__init__(torch_script_model_path, tensor_rt_model_path, sample_input, inference_type=inference_type) self._normalization = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self._sharpness_score_kernel = torch.tensor([[-1.0, -1.0, -1.0], + [-1.0, 8.0, -1.0], + [-1.0, -1.0, -1.0]], device=self._device) + self._sharpness_score_kernel = self._sharpness_score_kernel.repeat(3, 1, 1).unsqueeze(0) + def get_supported_image_size(self): return IMAGE_SIZE - def __call__(self, image_tensor, pose_coordinates, pose_presence): - landmarks, theoretical_landmark = get_landmarks_from_pose(pose_coordinates, pose_presence) + def __call__(self, image_tensor, pose_coordinates, pose_presence, pose_confidence_threshold): + landmarks, theoretical_landmark, alignment_keypoint_count = get_landmarks_from_pose(pose_coordinates, pose_presence, pose_confidence_threshold) transform = cv2.getAffineTransform(landmarks.astype(np.float32), (theoretical_landmark * np.array((IMAGE_SIZE[1], IMAGE_SIZE[0]))).astype(np.float32)) try: @@ -42,8 +49,10 @@ def __call__(self, image_tensor, pose_coordinates, pose_presence): with torch.no_grad(): grid = F.affine_grid(theta, torch.Size((1, 3, IMAGE_SIZE[0], IMAGE_SIZE[1]))).to(self._device) - aligned_image = F.grid_sample(image_tensor.unsqueeze(0).to(self._device), grid, mode='nearest').squeeze(0) + aligned_image = F.grid_sample(image_tensor.unsqueeze(0).to(self._device), grid, mode='bilinear').squeeze(0) + sharpness_score = SHARPNESS_SCORE_SCALE * torch.std(F.conv2d(aligned_image, self._sharpness_score_kernel)).item() cv2_aligned_image = (255 * aligned_image.permute(1, 2, 0)).to(torch.uint8).cpu().numpy() normalized_aligned_image = self._normalization(aligned_image) - return super(FaceDescriptorExtractor, self).__call__(normalized_aligned_image.unsqueeze(0))[0].cpu(), cv2_aligned_image + descriptor = super(FaceDescriptorExtractor, self).__call__(normalized_aligned_image.unsqueeze(0))[0].cpu() + return descriptor, cv2_aligned_image, alignment_keypoint_count, sharpness_score diff --git a/ros/utils/dnn_utils/src/dnn_utils/pose_estimator.py b/ros/utils/dnn_utils/src/dnn_utils/pose_estimator.py index 87ad51f2..691b690a 100644 --- a/ros/utils/dnn_utils/src/dnn_utils/pose_estimator.py +++ b/ros/utils/dnn_utils/src/dnn_utils/pose_estimator.py @@ -12,12 +12,13 @@ IMAGE_SIZE = (256, 192) +PRESENCE_SCALE = 4.0 class PoseEstimator(DnnModel): def __init__(self, inference_type=None): - torch_script_model_path = os.path.join(PACKAGE_PATH, 'models', 'pose_estimator.ts.pth') - tensor_rt_model_path = os.path.join(PACKAGE_PATH, 'models', 'pose_estimator.trt.pth') + torch_script_model_path = os.path.join(PACKAGE_PATH, 'models', 'pose_estimator_efficientnet_b0.ts.pth') + tensor_rt_model_path = os.path.join(PACKAGE_PATH, 'models', 'pose_estimator_efficientnet_b0.trt.pth') sample_input = torch.ones((1, 3, IMAGE_SIZE[0], IMAGE_SIZE[1])) super(PoseEstimator, self).__init__(torch_script_model_path, tensor_rt_model_path, sample_input, @@ -81,4 +82,4 @@ def __call__(self, image_tensor): scaled_coordinates[:, 0] = heatmap_coordinates[0, :, 0] / pose_heatmaps.size()[3] * width scaled_coordinates[:, 1] = heatmap_coordinates[0, :, 1] / pose_heatmaps.size()[2] * height - return scaled_coordinates.cpu().numpy(), presence.cpu().numpy()[0] + return scaled_coordinates.cpu().numpy(), presence.cpu().numpy()[0] * PRESENCE_SCALE diff --git a/ros/utils/dnn_utils/src/dnn_utils/yolo.py b/ros/utils/dnn_utils/src/dnn_utils/yolo.py new file mode 100644 index 00000000..15260302 --- /dev/null +++ b/ros/utils/dnn_utils/src/dnn_utils/yolo.py @@ -0,0 +1,159 @@ +import os +import sys + +import torch +import torch.nn.functional as F + +from dnn_utils.dnn_model import PACKAGE_PATH, DnnModel + +sys.path.append(os.path.join(PACKAGE_PATH, '..', '..', '..', 'tools', 'dnn_training')) +from object_detection.modules.yolo_layer import X_INDEX, Y_INDEX, W_INDEX, H_INDEX, CONFIDENCE_INDEX, CLASSES_INDEX +from object_detection.filter_yolo_predictions import group_predictions, filter_yolo_predictions_by_classes + + +IMAGE_SIZE_BY_MODEL_NAME = { + 'yolo_v4_coco' : (608, 608), + 'yolo_v4_tiny_coco' : (416, 416), + 'yolo_v7_coco' : (640, 640), + 'yolo_v7_tiny_coco' : (640, 640), + 'yolo_v7_objects365' : (640, 640), +} +IN_CHANNELS = 3 + +COCO_CLASS_NAMES = ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', + 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', + 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'sofa', 'pottedplant', 'bed', 'dining table', + 'toilet', 'tvmonitor', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', + 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair dryer', + 'toothbrush'] + +OBJECTS365_CLASS_NAMES = ['person', 'sneakers', 'chair', 'other shoes', 'hat', 'car', 'lamp', 'glasses', 'bottle', + 'desk', 'cup', 'street lights', 'cabinet/shelf', 'handbag', 'bracelet', 'plate', + 'picture/frame', 'helmet', 'book', 'gloves', 'storage box', 'boat', 'leather shoes', + 'flower', 'bench', 'pottedplant', 'bowl', 'flag', 'pillow', 'boots', 'vase', + 'microphone', 'necklace', 'ring', 'suv', 'wine glass', 'belt', 'tvmonitor', 'backpack', + 'umbrella', 'traffic light', 'speaker', 'watch', 'tie', 'trash bin can', 'slippers', + 'bicycle', 'stool', 'barrel/bucket', 'van', 'couch', 'sandals', 'basket', 'drum', + 'pen/pencil', 'bus', 'bird', 'high heels', 'motorbike', 'guitar', 'carpet', + 'cell phone', 'bread', 'camera', 'canned', 'truck', 'traffic cone', 'cymbal', 'lifesaver', + 'towel', 'stuffed toy', 'candle', 'sailboat', 'laptop', 'awning', 'bed', 'faucet', 'tent', + 'horse', 'mirror', 'power outlet', 'sink', 'apple', 'air conditioner', 'knife', + 'hockey stick', 'paddle', 'pickup truck', 'fork', 'traffic sign', 'balloon', 'tripod', 'dog', + 'spoon', 'clock', 'pot', 'cow', 'cake', 'dining table', 'sheep', 'hanger', + 'blackboard/whiteboard', 'napkin', 'other fish', 'orange', 'toiletry', 'keyboard', + 'tomato', 'lantern', 'machinery vehicle', 'fan', 'green vegetables', 'banana', + 'baseball glove', 'aeroplane', 'mouse', 'train', 'pumpkin', 'soccer', 'skis', 'luggage', + 'nightstand', 'tea pot', 'telephone', 'trolley', 'head phone', 'sports car', 'stop sign', + 'dessert', 'scooter', 'stroller', 'crane', 'remote', 'refrigerator', 'oven', 'lemon', 'duck', + 'baseball bat', 'surveillance camera', 'cat', 'jug', 'broccoli', 'piano', 'pizza', + 'elephant', 'skateboard', 'surfboard', 'gun', 'skating and skiing shoes', 'gas stove', + 'donut', 'bow tie', 'carrot', 'toilet', 'kite', 'strawberry', 'other balls', 'shovel', + 'pepper', 'computer box', 'toilet paper', 'cleaning products', 'chopsticks', 'microwave', + 'pigeon', 'baseball', 'cutting/chopping board', 'coffee table', 'side table', 'scissors', + 'marker', 'pie', 'ladder', 'snowboard', 'cookies', 'radiator', 'fire hydrant', 'basketball', + 'zebra', 'grape', 'giraffe', 'potato', 'sausage', 'tricycle', 'violin', 'egg', + 'fire extinguisher', 'candy', 'fire truck', 'billiards', 'converter', 'bathtub', + 'wheelchair', 'golf club', 'suitcase', 'cucumber', 'cigar/cigarette', 'paint brush', 'pear', + 'heavy truck', 'hamburger', 'extractor', 'extension cord', 'tong', 'tennis racket', + 'folder', 'american football', 'earphone', 'mask', 'kettle', 'tennis', 'ship', 'swing', + 'coffee machine', 'slide', 'carriage', 'onion', 'green beans', 'projector', 'frisbee', + 'washing machine/drying machine', 'chicken', 'printer', 'watermelon', 'saxophone', 'tissue', + 'toothbrush', 'ice cream', 'hot-air balloon', 'cello', 'french fries', 'scale', 'trophy', + 'cabbage', 'hot dog', 'blender', 'peach', 'rice', 'wallet/purse', 'volleyball', 'deer', + 'goose', 'tape', 'tablet', 'cosmetics', 'trumpet', 'pineapple', 'golf ball', 'ambulance', + 'parking meter', 'mango', 'key', 'hurdle', 'fishing rod', 'medal', 'flute', 'brush', + 'penguin', 'megaphone', 'corn', 'lettuce', 'garlic', 'swan', 'helicopter', 'green onion', + 'sandwich', 'nuts', 'speed limit sign', 'induction cooker', 'broom', 'trombone', 'plum', + 'rickshaw', 'goldfish', 'kiwi fruit', 'router/modem', 'poker card', 'toaster', 'shrimp', + 'sushi', 'cheese', 'notepaper', 'cherry', 'pliers', 'cd', 'pasta', 'hammer', 'cue', + 'avocado', 'hamimelon', 'flask', 'mushroom', 'screwdriver', 'soap', 'recorder', 'bear', + 'eggplant', 'board eraser', 'coconut', 'tape measure/ruler', 'pig', 'showerhead', 'globe', + 'chips', 'steak', 'crosswalk sign', 'stapler', 'camel', 'formula 1', 'pomegranate', + 'dishwasher', 'crab', 'hoverboard', 'meat ball', 'rice cooker', 'tuba', 'calculator', + 'papaya', 'antelope', 'parrot', 'seal', 'butterfly', 'dumbbell', 'donkey', 'lion', 'urinal', + 'dolphin', 'electric drill', 'hair dryer', 'egg tart', 'jellyfish', 'treadmill', 'lighter', + 'grapefruit', 'game board', 'mop', 'radish', 'baozi', 'target', 'french', 'spring rolls', + 'monkey', 'rabbit', 'pencil case', 'yak', 'red cabbage', 'binoculars', 'asparagus', 'barbell', + 'scallop', 'noddles', 'comb', 'dumpling', 'oyster', 'table tennis paddle', + 'cosmetics brush/eyeliner pencil', 'chainsaw', 'eraser', 'lobster', 'durian', 'okra', + 'lipstick', 'cosmetics mirror', 'curling', 'table tennis'] + +CLASS_NAMES_BY_MODEL_NAME = { + 'yolo_v4_coco' : COCO_CLASS_NAMES, + 'yolo_v4_tiny_coco' : COCO_CLASS_NAMES, + 'yolo_v7_coco' : COCO_CLASS_NAMES, + 'yolo_v7_tiny_coco' : COCO_CLASS_NAMES, + 'yolo_v7_objects365' : OBJECTS365_CLASS_NAMES, +} + + +class YoloPrediction: + def __init__(self, prediction_tensor, scale, offset_x, offset_y, class_count): + self.center_x = ((prediction_tensor[X_INDEX] - offset_x) / scale).item() + self.center_y = ((prediction_tensor[Y_INDEX] - offset_y) / scale).item() + self.width = (prediction_tensor[W_INDEX] / scale).item() + self.height = (prediction_tensor[H_INDEX] / scale).item() + self.confidence = prediction_tensor[CONFIDENCE_INDEX].item() + + class_probabilities = prediction_tensor[CLASSES_INDEX:CLASSES_INDEX + class_count] + self.class_index = torch.argmax(class_probabilities, dim=0).item() + self.class_probabilities = class_probabilities.tolist() + + self.descriptor = [] + + +class Yolo(DnnModel): + def __init__(self, model_name, confidence_threshold=0.99, nms_threshold=0.5, inference_type=None): + if model_name not in IMAGE_SIZE_BY_MODEL_NAME: + raise ValueError(f'Invalid model name ({model_name})') + + self._model_name = model_name + self._confidence_threshold = confidence_threshold + self._nms_threshold = nms_threshold + + self._image_size = IMAGE_SIZE_BY_MODEL_NAME[model_name] + + torch_script_model_path = os.path.join(PACKAGE_PATH, 'models', f'{model_name}.ts.pth') + tensor_rt_model_path = os.path.join(PACKAGE_PATH, 'models', f'{model_name}.trt.pth') + sample_input = torch.ones(1, 3, self._image_size[0], self._image_size[1]) + + super(Yolo, self).__init__(torch_script_model_path, tensor_rt_model_path, sample_input, + inference_type=inference_type) + self._padded_image = torch.ones(1, 3, self._image_size[0], self._image_size[1]).to(self._device) + + def get_supported_image_size(self): + return IMAGE_SIZE_BY_MODEL_NAME[self._model_name] + + def get_class_names(self): + return CLASS_NAMES_BY_MODEL_NAME[self._model_name] + + def __call__(self, image_tensor): + with torch.no_grad(): + scale, offset_x, offset_y, predictions = self.forward_raw(image_tensor.to(self._device)) + predictions = group_predictions(predictions)[0] + predictions = filter_yolo_predictions_by_classes(predictions, + confidence_threshold=self._confidence_threshold, + nms_threshold=self._nms_threshold) + + class_count = len(CLASS_NAMES_BY_MODEL_NAME[self._model_name]) + return [YoloPrediction(prediction.cpu(), scale, offset_x, offset_y, class_count) for prediction in predictions] + + def forward_raw(self, image_tensor): + scale, offset_x, offset_y = self._set_image(image_tensor.to(self._device).unsqueeze(0)) + predictions = super(Yolo, self).__call__(self._padded_image) + return scale, offset_x, offset_y, predictions + + def _set_image(self, image_tensor): + scale = min(self._padded_image.size()[2] / image_tensor.size()[2], self._padded_image.size()[3] / image_tensor.size()[3]) + output_size = ((int(image_tensor.size()[2] * scale), int(image_tensor.size()[3] * scale))) + offset_y = int((self._padded_image.size()[2] - output_size[0]) / 2) + offset_x = int((self._padded_image.size()[3] - output_size[1]) / 2) + + image_tensor = F.interpolate(image_tensor, size=output_size, mode='bilinear') + self._padded_image[:] = 0.44705882352 + self._padded_image[:, :, offset_y:offset_y + image_tensor.size()[2], offset_x:offset_x + image_tensor.size()[3]] = image_tensor + return scale, offset_x, offset_y diff --git a/ros/utils/recorders/CMakeLists.txt b/ros/utils/recorders/CMakeLists.txt index 2b690675..a7f35eb2 100644 --- a/ros/utils/recorders/CMakeLists.txt +++ b/ros/utils/recorders/CMakeLists.txt @@ -139,11 +139,13 @@ add_library(perception_logger src/perception_logger/VideoAnalysisLogger.cpp src/perception_logger/BinarySerialization.cpp src/perception_logger/SpeechLogger.cpp + src/perception_logger/HbbaStrategyStateLogger.cpp src/perception_logger/sqlite/SQLiteMigration.cpp src/perception_logger/sqlite/SQLitePerceptionLogger.cpp src/perception_logger/sqlite/SQLiteAudioAnalysisLogger.cpp src/perception_logger/sqlite/SQLiteVideoAnalysisLogger.cpp src/perception_logger/sqlite/SQLiteSpeechLogger.cpp + src/perception_logger/sqlite/SQLiteHbbaStrategyStateLogger.cpp ) ## Add cmake target dependencies of the library @@ -232,11 +234,13 @@ catkin_add_gtest(perception_logger-test test/perception_logger/AudioAnalysisLoggerTests.cpp test/perception_logger/VideoAnalysisLoggerTests.cpp test/perception_logger/SpeechLoggerTests.cpp + test/perception_logger/HbbaStrategyStateLoggerTests.cpp test/perception_logger/sqlite/SQLiteMigrationTests.cpp test/perception_logger/sqlite/SQLitePerceptionLoggerTests.cpp test/perception_logger/sqlite/SQLiteAudioAnalysisLoggerTests.cpp test/perception_logger/sqlite/SQLiteVideoAnalysisLoggerTests.cpp test/perception_logger/sqlite/SQLiteSpeechLoggerTests.cpp + test/perception_logger/sqlite/SQLiteHbbaStrategyStateLoggerTests.cpp ) if(TARGET perception_logger-test) target_link_libraries(perception_logger-test perception_logger ${catkin_LIBRARIES}) diff --git a/ros/utils/recorders/include/perception_logger/AudioAnalysisLogger.h b/ros/utils/recorders/include/perception_logger/AudioAnalysisLogger.h index d05ac7bf..973cb4b4 100644 --- a/ros/utils/recorders/include/perception_logger/AudioAnalysisLogger.h +++ b/ros/utils/recorders/include/perception_logger/AudioAnalysisLogger.h @@ -11,11 +11,17 @@ struct AudioAnalysis { Timestamp timestamp; Direction direction; + int64_t trackingId; std::string classes; std::optional> voiceDescriptor; - AudioAnalysis(Timestamp timestamp, Direction direction, std::string classes); - AudioAnalysis(Timestamp timestamp, Direction direction, std::string classes, std::vector voiceDescriptor); + AudioAnalysis(Timestamp timestamp, Direction direction, int64_t trackingId, std::string classes); + AudioAnalysis( + Timestamp timestamp, + Direction direction, + int64_t trackingId, + std::string classes, + std::vector voiceDescriptor); }; class AudioAnalysisLogger diff --git a/ros/utils/recorders/include/perception_logger/HbbaStrategyStateLogger.h b/ros/utils/recorders/include/perception_logger/HbbaStrategyStateLogger.h new file mode 100644 index 00000000..60865f96 --- /dev/null +++ b/ros/utils/recorders/include/perception_logger/HbbaStrategyStateLogger.h @@ -0,0 +1,27 @@ +#ifndef RECORDERS_PERCEPTION_LOGGER_HBBA_STRATEGY_STATE_LOGGER_H +#define RECORDERS_PERCEPTION_LOGGER_HBBA_STRATEGY_STATE_LOGGER_H + +#include + +#include + +struct HbbaStrategyState +{ + Timestamp timestamp; + std::string desireTypeName; + std::string strategyTypeName; + bool enabled; + + HbbaStrategyState(Timestamp timestamp, std::string desireTypeName, std::string strategyTypeName, bool enabled); +}; + +class HbbaStrategyStateLogger +{ +public: + HbbaStrategyStateLogger(); + virtual ~HbbaStrategyStateLogger(); + + virtual int64_t log(const HbbaStrategyState& state) = 0; +}; + +#endif diff --git a/ros/utils/recorders/include/perception_logger/VideoAnalysisLogger.h b/ros/utils/recorders/include/perception_logger/VideoAnalysisLogger.h index 812ace17..ed71f975 100644 --- a/ros/utils/recorders/include/perception_logger/VideoAnalysisLogger.h +++ b/ros/utils/recorders/include/perception_logger/VideoAnalysisLogger.h @@ -14,6 +14,8 @@ struct VideoAnalysis Direction direction; std::string objectClass; + float objectConfidence; + float objectClassProbability; BoundingBox boundingBox; std::optional> personPoseImage; @@ -21,18 +23,24 @@ struct VideoAnalysis std::optional> personPoseConfidence; std::optional> faceDescriptor; + std::optional faceAlignmentKeypointCount; + std::optional faceSharpnessScore; VideoAnalysis( Timestamp timestamp, Position position, Direction direction, std::string objectClass, + float objectConfidence, + float objectClassProbability, BoundingBox boundingBox); VideoAnalysis( Timestamp timestamp, Position position, Direction direction, std::string objectClass, + float objectConfidence, + float objectClassProbability, BoundingBox boundingBox, std::vector personPoseImage, std::vector personPose, @@ -42,11 +50,15 @@ struct VideoAnalysis Position position, Direction direction, std::string objectClass, + float objectConfidence, + float objectClassProbability, BoundingBox boundingBox, std::vector personPoseImage, std::vector personPose, std::vector personPoseConfidence, - std::vector faceDescriptor); + std::vector faceDescriptor, + int32_t faceAlignmentKeypointCount, + float faceSharpnessScore); }; class VideoAnalysisLogger diff --git a/ros/utils/recorders/include/perception_logger/sqlite/SQLiteHbbaStrategyStateLogger.h b/ros/utils/recorders/include/perception_logger/sqlite/SQLiteHbbaStrategyStateLogger.h new file mode 100644 index 00000000..1b5742b5 --- /dev/null +++ b/ros/utils/recorders/include/perception_logger/sqlite/SQLiteHbbaStrategyStateLogger.h @@ -0,0 +1,20 @@ +#ifndef RECORDERS_PERCEPTION_LOGGER_SQLITE_SQLITE_HBBA_STRATEGY_STATE_LOGGER_H +#define RECORDERS_PERCEPTION_LOGGER_SQLITE_SQLITE_HBBA_STRATEGY_STATE_LOGGER_H + +#include + +#include + +class SQLiteHbbaStrategyStateLogger : public HbbaStrategyStateLogger +{ + SQLite::Database& m_database; + +public: + SQLiteHbbaStrategyStateLogger(SQLite::Database& database); + ~SQLiteHbbaStrategyStateLogger() override; + + int64_t log(const HbbaStrategyState& state) override; +}; + + +#endif diff --git a/ros/utils/recorders/src/perception_logger/AudioAnalysisLogger.cpp b/ros/utils/recorders/src/perception_logger/AudioAnalysisLogger.cpp index 0d1f1870..36f47d55 100644 --- a/ros/utils/recorders/src/perception_logger/AudioAnalysisLogger.cpp +++ b/ros/utils/recorders/src/perception_logger/AudioAnalysisLogger.cpp @@ -2,16 +2,23 @@ using namespace std; -AudioAnalysis::AudioAnalysis(Timestamp timestamp, Direction direction, string classes) +AudioAnalysis::AudioAnalysis(Timestamp timestamp, Direction direction, int64_t trackingId, string classes) : timestamp(timestamp), direction(direction), + trackingId(trackingId), classes(move(classes)) { } -AudioAnalysis::AudioAnalysis(Timestamp timestamp, Direction direction, string classes, vector voiceDescriptor) +AudioAnalysis::AudioAnalysis( + Timestamp timestamp, + Direction direction, + int64_t trackingId, + string classes, + vector voiceDescriptor) : timestamp(timestamp), direction(direction), + trackingId(trackingId), classes(move(classes)), voiceDescriptor(move(voiceDescriptor)) { diff --git a/ros/utils/recorders/src/perception_logger/HbbaStrategyStateLogger.cpp b/ros/utils/recorders/src/perception_logger/HbbaStrategyStateLogger.cpp new file mode 100644 index 00000000..2aa2a31d --- /dev/null +++ b/ros/utils/recorders/src/perception_logger/HbbaStrategyStateLogger.cpp @@ -0,0 +1,15 @@ +#include + +using namespace std; + +HbbaStrategyState::HbbaStrategyState(Timestamp timestamp, string desireTypeName, string strategyTypeName, bool enabled) + : timestamp(timestamp), + desireTypeName(move(desireTypeName)), + strategyTypeName(move(strategyTypeName)), + enabled(enabled) +{ +} + +HbbaStrategyStateLogger::HbbaStrategyStateLogger() {} + +HbbaStrategyStateLogger::~HbbaStrategyStateLogger() {} diff --git a/ros/utils/recorders/src/perception_logger/VideoAnalysisLogger.cpp b/ros/utils/recorders/src/perception_logger/VideoAnalysisLogger.cpp index 16f7dcaf..91877113 100644 --- a/ros/utils/recorders/src/perception_logger/VideoAnalysisLogger.cpp +++ b/ros/utils/recorders/src/perception_logger/VideoAnalysisLogger.cpp @@ -7,11 +7,15 @@ VideoAnalysis::VideoAnalysis( Position position, Direction direction, string objectClass, + float objectConfidence, + float objectClassProbability, BoundingBox boundingBox) : timestamp{timestamp}, position{position}, direction{direction}, objectClass{move(objectClass)}, + objectConfidence{objectConfidence}, + objectClassProbability{objectClassProbability}, boundingBox{boundingBox}, personPoseImage{std::nullopt}, personPose{std::nullopt}, @@ -24,6 +28,8 @@ VideoAnalysis::VideoAnalysis( Position position, Direction direction, string objectClass, + float objectConfidence, + float objectClassProbability, BoundingBox boundingBox, vector personPoseImage, vector personPose, @@ -32,6 +38,8 @@ VideoAnalysis::VideoAnalysis( position{position}, direction{direction}, objectClass{move(objectClass)}, + objectConfidence(objectConfidence), + objectClassProbability(objectClassProbability), boundingBox{boundingBox}, personPoseImage{move(personPoseImage)}, personPose{move(personPose)}, @@ -39,25 +47,34 @@ VideoAnalysis::VideoAnalysis( faceDescriptor{std::nullopt} { } + VideoAnalysis::VideoAnalysis( Timestamp timestamp, Position position, Direction direction, string objectClass, + float objectConfidence, + float objectClassProbability, BoundingBox boundingBox, vector personPoseImage, vector personPose, vector personPoseConfidence, - vector faceDescriptor) + vector faceDescriptor, + int32_t faceAlignmentKeypointCount, + float faceSharpnessScore) : timestamp{timestamp}, position{position}, direction{direction}, objectClass{move(objectClass)}, + objectConfidence(objectConfidence), + objectClassProbability(objectClassProbability), boundingBox{boundingBox}, personPoseImage{move(personPoseImage)}, personPose{move(personPose)}, personPoseConfidence{move(personPoseConfidence)}, - faceDescriptor{move(faceDescriptor)} + faceDescriptor{move(faceDescriptor)}, + faceAlignmentKeypointCount{faceAlignmentKeypointCount}, + faceSharpnessScore{faceSharpnessScore} { } diff --git a/ros/utils/recorders/src/perception_logger/sqlite/SQLiteAudioAnalysisLogger.cpp b/ros/utils/recorders/src/perception_logger/sqlite/SQLiteAudioAnalysisLogger.cpp index 0630ddc2..086d5337 100644 --- a/ros/utils/recorders/src/perception_logger/sqlite/SQLiteAudioAnalysisLogger.cpp +++ b/ros/utils/recorders/src/perception_logger/sqlite/SQLiteAudioAnalysisLogger.cpp @@ -7,13 +7,17 @@ using namespace std; SQLiteAudioAnalysisLogger::SQLiteAudioAnalysisLogger(SQLite::Database& database) : SQLitePerceptionLogger(database) { - vector migrations{SQLiteMigration("BEGIN;" - "CREATE TABLE audio_analysis(" - " perception_id INTEGER PRIMARY KEY," - " classes TEXT," - " voice_descriptor BLOB" - ");" - "COMMIT;")}; + vector migrations{ + SQLiteMigration("BEGIN;" + "CREATE TABLE audio_analysis(" + " perception_id INTEGER PRIMARY KEY," + " classes TEXT," + " voice_descriptor BLOB" + ");" + "COMMIT;"), + SQLiteMigration("BEGIN;" + "ALTER TABLE audio_analysis ADD tracking_id INTEGER;" + "COMMIT;")}; applyMigrations(database, "audio_analysis", migrations); } @@ -25,16 +29,17 @@ int64_t SQLiteAudioAnalysisLogger::log(const AudioAnalysis& analysis) int64_t id = insertPerception(analysis.timestamp, nullopt, analysis.direction); SQLite::Statement insert( m_database, - "INSERT INTO audio_analysis(perception_id, classes, voice_descriptor) VALUES(?, ?, ?)"); + "INSERT INTO audio_analysis(perception_id, tracking_id, classes, voice_descriptor) VALUES(?, ?, ?, ?)"); insert.clearBindings(); insert.bind(1, id); - insert.bindNoCopy(2, analysis.classes); + insert.bind(2, analysis.trackingId); + insert.bindNoCopy(3, analysis.classes); optional voiceDescriptorBytes; if (analysis.voiceDescriptor.has_value()) { voiceDescriptorBytes = serializeToBytesNoCopy(analysis.voiceDescriptor.value()); - insert.bindNoCopy(3, voiceDescriptorBytes->data(), voiceDescriptorBytes->size()); + insert.bindNoCopy(4, voiceDescriptorBytes->data(), voiceDescriptorBytes->size()); } insert.exec(); diff --git a/ros/utils/recorders/src/perception_logger/sqlite/SQLiteHbbaStrategyStateLogger.cpp b/ros/utils/recorders/src/perception_logger/sqlite/SQLiteHbbaStrategyStateLogger.cpp new file mode 100644 index 00000000..a91bd1a4 --- /dev/null +++ b/ros/utils/recorders/src/perception_logger/sqlite/SQLiteHbbaStrategyStateLogger.cpp @@ -0,0 +1,38 @@ +#include + +#include + +using namespace std; + +SQLiteHbbaStrategyStateLogger::SQLiteHbbaStrategyStateLogger(SQLite::Database& database) : m_database(database) +{ + vector migrations{SQLiteMigration("BEGIN;" + "CREATE TABLE hbba_strategy_state(" + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " timestamp_ms INTEGER," + " desire_type_name TEXT," + " strategy_type_name TEXT," + " enabled INTEGER" + ");" + "COMMIT;")}; + + applyMigrations(database, "hbba_strategy_state", migrations); +} + +SQLiteHbbaStrategyStateLogger::~SQLiteHbbaStrategyStateLogger() {} + +int64_t SQLiteHbbaStrategyStateLogger::log(const HbbaStrategyState& state) +{ + SQLite::Statement insert( + m_database, + "INSERT INTO hbba_strategy_state(timestamp_ms, desire_type_name, strategy_type_name, enabled)" + " VALUES(?, ?, ?, ?)"); + insert.clearBindings(); + insert.bind(1, state.timestamp.unixEpochMs); + insert.bindNoCopy(2, state.desireTypeName); + insert.bindNoCopy(3, state.strategyTypeName); + insert.bind(4, static_cast(state.enabled)); + + insert.exec(); + return m_database.getLastInsertRowid(); +} diff --git a/ros/utils/recorders/src/perception_logger/sqlite/SQLiteVideoAnalysisLogger.cpp b/ros/utils/recorders/src/perception_logger/sqlite/SQLiteVideoAnalysisLogger.cpp index 5674467c..4bd94a06 100644 --- a/ros/utils/recorders/src/perception_logger/sqlite/SQLiteVideoAnalysisLogger.cpp +++ b/ros/utils/recorders/src/perception_logger/sqlite/SQLiteVideoAnalysisLogger.cpp @@ -7,20 +7,28 @@ using namespace std; SQLiteVideoAnalysisLogger::SQLiteVideoAnalysisLogger(SQLite::Database& database) : SQLitePerceptionLogger(database) { - vector migrations{SQLiteMigration("BEGIN;" - "CREATE TABLE video_analysis(" - " perception_id INTEGER PRIMARY KEY," - " object_class TEXT," - " bounding_box_centre_x REAL," - " bounding_box_centre_y REAL," - " bounding_box_width REAL," - " bounding_box_height REAL," - " person_pose_image BLOB," - " person_pose BLOB," - " person_pose_confidence BLOB," - " face_descriptor BLOB" - ");" - "COMMIT;")}; + vector migrations{ + SQLiteMigration("BEGIN;" + "CREATE TABLE video_analysis(" + " perception_id INTEGER PRIMARY KEY," + " object_class TEXT," + " bounding_box_centre_x REAL," + " bounding_box_centre_y REAL," + " bounding_box_width REAL," + " bounding_box_height REAL," + " person_pose_image BLOB," + " person_pose BLOB," + " person_pose_confidence BLOB," + " face_descriptor BLOB" + ");" + "COMMIT;"), + SQLiteMigration("BEGIN;" + "ALTER TABLE video_analysis ADD object_confidence REAL;" + "ALTER TABLE video_analysis ADD object_class_probability REAL;" + "ALTER TABLE video_analysis ADD face_alignment_keypoint_count INTEGER;" + "ALTER TABLE video_analysis ADD face_sharpness_score REAL;" + "COMMIT;"), + }; applyMigrations(database, "video_analysis", migrations); } @@ -32,45 +40,57 @@ int64_t SQLiteVideoAnalysisLogger::log(const VideoAnalysis& analysis) int64_t id = insertPerception(analysis.timestamp, analysis.position, analysis.direction); SQLite::Statement insert( m_database, - "INSERT INTO video_analysis(perception_id, object_class, bounding_box_centre_x, bounding_box_centre_y, " - "bounding_box_width, bounding_box_height, person_pose_image, person_pose, " - "person_pose_confidence, face_descriptor)" - " VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"); + "INSERT INTO video_analysis(perception_id, object_class, object_confidence, object_class_probability, " + "bounding_box_centre_x, bounding_box_centre_y, bounding_box_width, bounding_box_height, " + "person_pose_image, person_pose, person_pose_confidence, " + "face_descriptor, face_alignment_keypoint_count, face_sharpness_score)" + " VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"); insert.clearBindings(); insert.bind(1, id); insert.bind(2, analysis.objectClass); - insert.bind(3, analysis.boundingBox.center.x); - insert.bind(4, analysis.boundingBox.center.y); - insert.bind(5, analysis.boundingBox.width); - insert.bind(6, analysis.boundingBox.height); + insert.bind(3, analysis.objectConfidence); + insert.bind(4, analysis.objectClassProbability); + insert.bind(5, analysis.boundingBox.center.x); + insert.bind(6, analysis.boundingBox.center.y); + insert.bind(7, analysis.boundingBox.width); + insert.bind(8, analysis.boundingBox.height); optional personPoseImageBytes; if (analysis.personPoseImage.has_value()) { personPoseImageBytes = serializeToBytesNoCopy(analysis.personPoseImage.value()); - insert.bindNoCopy(7, personPoseImageBytes->data(), personPoseImageBytes->size()); + insert.bindNoCopy(9, personPoseImageBytes->data(), personPoseImageBytes->size()); } optional personPoseBytes; if (analysis.personPose.has_value()) { personPoseBytes = serializeToBytesNoCopy(analysis.personPose.value()); - insert.bindNoCopy(8, personPoseBytes->data(), personPoseBytes->size()); + insert.bindNoCopy(10, personPoseBytes->data(), personPoseBytes->size()); } optional personPoseConfidenceBytes; if (analysis.personPoseConfidence.has_value()) { personPoseConfidenceBytes = serializeToBytesNoCopy(analysis.personPoseConfidence.value()); - insert.bindNoCopy(9, personPoseConfidenceBytes->data(), personPoseConfidenceBytes->size()); + insert.bindNoCopy(11, personPoseConfidenceBytes->data(), personPoseConfidenceBytes->size()); } optional faceDescriptorBytes; if (analysis.faceDescriptor.has_value()) { faceDescriptorBytes = serializeToBytesNoCopy(analysis.faceDescriptor.value()); - insert.bindNoCopy(10, faceDescriptorBytes->data(), faceDescriptorBytes->size()); + insert.bindNoCopy(12, faceDescriptorBytes->data(), faceDescriptorBytes->size()); + } + + if (analysis.faceAlignmentKeypointCount.has_value()) + { + insert.bind(13, analysis.faceAlignmentKeypointCount.value()); + } + if (analysis.faceSharpnessScore.has_value()) + { + insert.bind(14, analysis.faceSharpnessScore.value()); } insert.exec(); diff --git a/ros/utils/recorders/src/perception_logger_node.cpp b/ros/utils/recorders/src/perception_logger_node.cpp index 868664ef..8688b363 100644 --- a/ros/utils/recorders/src/perception_logger_node.cpp +++ b/ros/utils/recorders/src/perception_logger_node.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -9,6 +10,7 @@ #include #include #include +#include #include #include @@ -48,6 +50,7 @@ class PerceptionLoggerNode unique_ptr m_videoAnalysisLogger; unique_ptr m_audioAnalysisLogger; unique_ptr m_speechLogger; + unique_ptr m_hbbaStrategyStateLogger; tf::TransformListener m_listener; @@ -55,6 +58,7 @@ class PerceptionLoggerNode ros::Subscriber m_audioAnalysisSubscriber; ros::Subscriber m_talkTextSubscriber; ros::Subscriber m_speechToTextTranscriptSubscriber; + ros::Subscriber m_hbbaStrategyStateSubscriber; public: PerceptionLoggerNode(ros::NodeHandle& nodeHandle, PerceptionLoggerNodeConfiguration configuration) @@ -65,6 +69,7 @@ class PerceptionLoggerNode m_videoAnalysisLogger = make_unique(m_database); m_audioAnalysisLogger = make_unique(m_database); m_speechLogger = make_unique(m_database); + m_hbbaStrategyStateLogger = make_unique(m_database); m_videoAnalysis3dSubscriber = m_nodeHandle.subscribe("video_analysis", 10, &PerceptionLoggerNode::videoAnalysisSubscriberCallback, this); @@ -77,6 +82,11 @@ class PerceptionLoggerNode 10, &PerceptionLoggerNode::speechToTextTranscriptSubscriberCallback, this); + m_hbbaStrategyStateSubscriber = m_nodeHandle.subscribe( + "hbba_strategy_state_log", + 10, + &PerceptionLoggerNode::hbbaStrategyStateSubscriberCallback, + this); } virtual ~PerceptionLoggerNode() {} @@ -135,6 +145,12 @@ class PerceptionLoggerNode } } + void hbbaStrategyStateSubscriberCallback(const hbba_lite::StrategyState::ConstPtr& msg) + { + m_hbbaStrategyStateLogger->log( + HbbaStrategyState(ros::Time::now(), msg->desire_type_name, msg->strategy_type_name, msg->enabled)); + } + void run() { ros::spin(); } private: @@ -171,6 +187,8 @@ class PerceptionLoggerNode position, positionToDirection(position), msg.object_class, + msg.object_confidence, + msg.object_class_probability, centreWidthHeightToBoundingBox(msg.center_2d, msg.width_2d, msg.height_2d)}; if (!msg.person_pose_2d.empty()) @@ -199,6 +217,8 @@ class PerceptionLoggerNode } if (!msg.face_descriptor.empty()) { + videoAnalysis.faceAlignmentKeypointCount = msg.face_alignment_keypoint_count; + videoAnalysis.faceSharpnessScore = msg.face_sharpness_score; videoAnalysis.faceDescriptor = msg.face_descriptor; } @@ -210,6 +230,7 @@ class PerceptionLoggerNode AudioAnalysis audioAnalysis( msg->header.stamp, Direction{msg->direction_x, msg->direction_y, msg->direction_z}, + msg->tracking_id, mergeClasses(msg->audio_classes)); if (!msg->voice_descriptor.empty()) diff --git a/ros/utils/recorders/test/perception_logger/AudioAnalysisLoggerTests.cpp b/ros/utils/recorders/test/perception_logger/AudioAnalysisLoggerTests.cpp index 0090ab19..ecf12fba 100644 --- a/ros/utils/recorders/test/perception_logger/AudioAnalysisLoggerTests.cpp +++ b/ros/utils/recorders/test/perception_logger/AudioAnalysisLoggerTests.cpp @@ -6,19 +6,21 @@ using namespace std; TEST(AudioAnalysisLoggerTests, audioAnalysis_constructor_shouldSetAttributes) { - AudioAnalysis analysis0(Timestamp(1), Direction{2.0, 3.0, 4.0}, "a"); + AudioAnalysis analysis0(Timestamp(1), Direction{2.0, 3.0, 4.0}, 5, "a"); EXPECT_EQ(analysis0.timestamp.unixEpochMs, 1); EXPECT_EQ(analysis0.direction.x, 2.0); EXPECT_EQ(analysis0.direction.y, 3.0); EXPECT_EQ(analysis0.direction.z, 4.0); + EXPECT_EQ(analysis0.trackingId, 5); EXPECT_EQ(analysis0.classes, "a"); EXPECT_EQ(analysis0.voiceDescriptor, std::nullopt); - AudioAnalysis analysis1(Timestamp(5), Direction{6.0, 7.0, 8.0}, "b", {9.f}); + AudioAnalysis analysis1(Timestamp(5), Direction{6.0, 7.0, 8.0}, 9, "b", {10.f}); EXPECT_EQ(analysis1.timestamp.unixEpochMs, 5.0); EXPECT_EQ(analysis1.direction.x, 6.0); EXPECT_EQ(analysis1.direction.y, 7.0); EXPECT_EQ(analysis1.direction.z, 8.0); + EXPECT_EQ(analysis1.trackingId, 9); EXPECT_EQ(analysis1.classes, "b"); - EXPECT_EQ(analysis1.voiceDescriptor, vector({9.f})); + EXPECT_EQ(analysis1.voiceDescriptor, vector({10.f})); } diff --git a/ros/utils/recorders/test/perception_logger/HbbaStrategyStateLoggerTests.cpp b/ros/utils/recorders/test/perception_logger/HbbaStrategyStateLoggerTests.cpp new file mode 100644 index 00000000..a021c79d --- /dev/null +++ b/ros/utils/recorders/test/perception_logger/HbbaStrategyStateLoggerTests.cpp @@ -0,0 +1,14 @@ +#include + +#include + +using namespace std; + +TEST(HbbaStrategyStateLoggerTests, hbbaStrategyState_constructor_shouldSetAttributes) +{ + HbbaStrategyState state(Timestamp(1), "d", "s", true); + EXPECT_EQ(state.timestamp.unixEpochMs, 1); + EXPECT_EQ(state.desireTypeName, "d"); + EXPECT_EQ(state.strategyTypeName, "s"); + EXPECT_TRUE(state.enabled); +} diff --git a/ros/utils/recorders/test/perception_logger/VideoAnalysisLoggerTests.cpp b/ros/utils/recorders/test/perception_logger/VideoAnalysisLoggerTests.cpp index b0c92e4d..9101894f 100644 --- a/ros/utils/recorders/test/perception_logger/VideoAnalysisLoggerTests.cpp +++ b/ros/utils/recorders/test/perception_logger/VideoAnalysisLoggerTests.cpp @@ -13,6 +13,8 @@ TEST(VideoAnalysisLoggerTests, videoAnalysis_constructor_shouldSetAttributes) Position{2.0, 3.0, 4.0}, Direction{5.0, 6.0, 7.0}, "a", + 0.75f, + 0.5f, BoundingBox{{8.0, 9.0}, 10.0, 11.0}}; EXPECT_EQ(analysis0.timestamp.unixEpochMs, 1); EXPECT_EQ(analysis0.position.x, 2.0); @@ -22,20 +24,26 @@ TEST(VideoAnalysisLoggerTests, videoAnalysis_constructor_shouldSetAttributes) EXPECT_EQ(analysis0.direction.y, 6.0); EXPECT_EQ(analysis0.direction.z, 7.0); EXPECT_EQ(analysis0.objectClass, "a"); + EXPECT_EQ(analysis0.objectConfidence, 0.75); + EXPECT_EQ(analysis0.objectClassProbability, 0.5); EXPECT_EQ(analysis0.boundingBox.center.x, 8.0); EXPECT_EQ(analysis0.boundingBox.center.y, 9.0); EXPECT_EQ(analysis0.boundingBox.width, 10.0); EXPECT_EQ(analysis0.boundingBox.height, 11.0); - EXPECT_EQ(analysis0.personPoseImage, std::nullopt); - EXPECT_EQ(analysis0.personPose, std::nullopt); - EXPECT_EQ(analysis0.personPoseConfidence, std::nullopt); - EXPECT_EQ(analysis0.faceDescriptor, std::nullopt); + EXPECT_EQ(analysis0.personPoseImage, nullopt); + EXPECT_EQ(analysis0.personPose, nullopt); + EXPECT_EQ(analysis0.personPoseConfidence, nullopt); + EXPECT_EQ(analysis0.faceDescriptor, nullopt); + EXPECT_EQ(analysis0.faceAlignmentKeypointCount, nullopt); + EXPECT_EQ(analysis0.faceSharpnessScore, nullopt); VideoAnalysis analysis1{ Timestamp(10), Position{20.0, 30.0, 40.0}, Direction{50.0, 60.0, 70.0}, "b", + 0.5f, + 0.25f, BoundingBox{{80.0, 90.0}, 100.0, 110.0}, {ImagePosition{120.0, 130.0}}, {Position{140.0, 150.0, 160.0}}, @@ -48,6 +56,8 @@ TEST(VideoAnalysisLoggerTests, videoAnalysis_constructor_shouldSetAttributes) EXPECT_EQ(analysis1.direction.y, 60.0); EXPECT_EQ(analysis1.direction.z, 70.0); EXPECT_EQ(analysis1.objectClass, "b"); + EXPECT_EQ(analysis1.objectConfidence, 0.5f); + EXPECT_EQ(analysis1.objectClassProbability, 0.25f); EXPECT_EQ(analysis1.boundingBox.center.x, 80.0); EXPECT_EQ(analysis1.boundingBox.center.y, 90.0); EXPECT_EQ(analysis1.boundingBox.width, 100.0); @@ -55,18 +65,24 @@ TEST(VideoAnalysisLoggerTests, videoAnalysis_constructor_shouldSetAttributes) EXPECT_EQ(analysis1.personPoseImage, vector({ImagePosition{120.0, 130.0}})); EXPECT_EQ(analysis1.personPose, vector({Position{140.0, 150.0, 160.0}})); EXPECT_EQ(analysis1.personPoseConfidence, vector({0.5})); - EXPECT_EQ(analysis1.faceDescriptor, std::nullopt); + EXPECT_EQ(analysis1.faceDescriptor, nullopt); + EXPECT_EQ(analysis1.faceAlignmentKeypointCount, nullopt); + EXPECT_EQ(analysis1.faceSharpnessScore, nullopt); VideoAnalysis analysis2{ Timestamp(100), Position{200.0, 300.0, 400.0}, Direction{500.0, 600.0, 700.0}, "c", + 0.25f, + 0.125f, BoundingBox{{800.0, 900.0}, 1000.0, 1100.0}, {ImagePosition{1200.0, 1300.0}}, {Position{1400.0, 1500.0, 1600.0}}, {0.5f}, - {200.f}}; + {200.f}, + 5, + 0.625f}; EXPECT_EQ(analysis2.timestamp.unixEpochMs, 100); EXPECT_EQ(analysis2.position.x, 200.0); EXPECT_EQ(analysis2.position.y, 300.0); @@ -75,6 +91,8 @@ TEST(VideoAnalysisLoggerTests, videoAnalysis_constructor_shouldSetAttributes) EXPECT_EQ(analysis2.direction.y, 600.0); EXPECT_EQ(analysis2.direction.z, 700.0); EXPECT_EQ(analysis2.objectClass, "c"); + EXPECT_EQ(analysis2.objectConfidence, 0.25); + EXPECT_EQ(analysis2.objectClassProbability, 0.125); EXPECT_EQ(analysis2.boundingBox.center.x, 800.0); EXPECT_EQ(analysis2.boundingBox.center.y, 900.0); EXPECT_EQ(analysis2.boundingBox.width, 1000.0); @@ -83,4 +101,6 @@ TEST(VideoAnalysisLoggerTests, videoAnalysis_constructor_shouldSetAttributes) EXPECT_EQ(analysis2.personPose, vector({Position{1400.0, 1500.0, 1600.0}})); EXPECT_EQ(analysis2.personPoseConfidence, vector({0.5f})); EXPECT_EQ(analysis2.faceDescriptor, vector({200.f})); + EXPECT_EQ(analysis2.faceAlignmentKeypointCount, 5); + EXPECT_EQ(analysis2.faceSharpnessScore, 0.625); } diff --git a/ros/utils/recorders/test/perception_logger/sqlite/SQLiteAudioAnalysisLoggerTests.cpp b/ros/utils/recorders/test/perception_logger/sqlite/SQLiteAudioAnalysisLoggerTests.cpp index c8bc3fac..6507a5d1 100644 --- a/ros/utils/recorders/test/perception_logger/sqlite/SQLiteAudioAnalysisLoggerTests.cpp +++ b/ros/utils/recorders/test/perception_logger/sqlite/SQLiteAudioAnalysisLoggerTests.cpp @@ -6,9 +6,16 @@ using namespace std; -void readAudioAnalysis(SQLite::Database& database, int64_t id, string& classes, vector& voiceDescriptor) +void readAudioAnalysis( + SQLite::Database& database, + int64_t id, + int64_t& trackingId, + string& classes, + vector& voiceDescriptor) { - SQLite::Statement query(database, "SELECT classes, voice_descriptor FROM audio_analysis WHERE perception_id=?"); + SQLite::Statement query( + database, + "SELECT tracking_id, classes, voice_descriptor FROM audio_analysis WHERE perception_id=?"); query.bind(1, id); if (!query.executeStep()) @@ -18,8 +25,9 @@ void readAudioAnalysis(SQLite::Database& database, int64_t id, string& classes, return; } - classes = query.getColumn(0).getString(); - columnToVector(query.getColumn(1), voiceDescriptor); + trackingId = query.getColumn(0).getInt64(); + classes = query.getColumn(1).getString(); + columnToVector(query.getColumn(2), voiceDescriptor); } TEST(SQLiteAudioAnalysisLoggerTests, log_shouldInsertAndReturnId) @@ -28,20 +36,23 @@ TEST(SQLiteAudioAnalysisLoggerTests, log_shouldInsertAndReturnId) SQLiteAudioAnalysisLogger logger(database); - int64_t id0 = logger.log(AudioAnalysis(Timestamp(101), Direction{1, 2, 3}, "music,water")); - int64_t id1 = logger.log(AudioAnalysis(Timestamp(102), Direction{4, 5, 6}, "voice", {7.f, 8.f})); + int64_t id0 = logger.log(AudioAnalysis(Timestamp(101), Direction{1, 2, 3}, 4, "music,water")); + int64_t id1 = logger.log(AudioAnalysis(Timestamp(102), Direction{4, 5, 6}, 7, "voice", {8.f, 9.f})); EXPECT_TRUE(perceptionExists(database, id0, Timestamp(101), Direction{1, 2, 3})); EXPECT_TRUE(perceptionExists(database, id1, Timestamp(102), Direction{4, 5, 6})); + int64_t trackingId; string classes; vector voiceDescriptor; - readAudioAnalysis(database, id0, classes, voiceDescriptor); + readAudioAnalysis(database, id0, trackingId, classes, voiceDescriptor); + EXPECT_EQ(trackingId, 4); EXPECT_EQ(classes, "music,water"); EXPECT_EQ(voiceDescriptor, vector({})); - readAudioAnalysis(database, id1, classes, voiceDescriptor); + readAudioAnalysis(database, id1, trackingId, classes, voiceDescriptor); + EXPECT_EQ(trackingId, 7); EXPECT_EQ(classes, "voice"); - EXPECT_EQ(voiceDescriptor, vector({7.f, 8.f})); + EXPECT_EQ(voiceDescriptor, vector({8.f, 9.f})); } diff --git a/ros/utils/recorders/test/perception_logger/sqlite/SQLiteHbbaStrategyStateLoggerTests.cpp b/ros/utils/recorders/test/perception_logger/sqlite/SQLiteHbbaStrategyStateLoggerTests.cpp new file mode 100644 index 00000000..5f48908f --- /dev/null +++ b/ros/utils/recorders/test/perception_logger/sqlite/SQLiteHbbaStrategyStateLoggerTests.cpp @@ -0,0 +1,61 @@ +#include + +#include + +using namespace std; + +void readHbbaStrategyState( + SQLite::Database& database, + int64_t id, + int64_t& timestampMs, + string& desireTypeName, + string& strategyTypeName, + bool& enabled) +{ + SQLite::Statement query( + database, + "SELECT timestamp_ms, desire_type_name, strategy_type_name, enabled" + " FROM hbba_strategy_state WHERE id=?"); + + query.bind(1, id); + if (!query.executeStep()) + { + timestampMs = -1; + desireTypeName = ""; + strategyTypeName = ""; + enabled = false; + return; + } + + timestampMs = query.getColumn(0).getInt64(); + desireTypeName = query.getColumn(1).getString(); + strategyTypeName = query.getColumn(2).getString(); + enabled = static_cast(query.getColumn(3).getInt()); +} + +TEST(SQLiteHbbaStrategyStateLoggerTests, log_shouldInsertAndReturnId) +{ + SQLite::Database database(":memory:", SQLite::OPEN_READWRITE); + + SQLiteHbbaStrategyStateLogger logger(database); + + int64_t id0 = logger.log(HbbaStrategyState(Timestamp(101), "d1", "s1", true)); + int64_t id1 = logger.log(HbbaStrategyState(Timestamp(102), "d2", "s2", false)); + + int64_t timestampMs; + string desireTypeName; + string strategyTypeName; + bool enabled; + + readHbbaStrategyState(database, id0, timestampMs, desireTypeName, strategyTypeName, enabled); + EXPECT_EQ(timestampMs, 101); + EXPECT_EQ(desireTypeName, "d1"); + EXPECT_EQ(strategyTypeName, "s1"); + EXPECT_TRUE(enabled); + + readHbbaStrategyState(database, id1, timestampMs, desireTypeName, strategyTypeName, enabled); + EXPECT_EQ(timestampMs, 102); + EXPECT_EQ(desireTypeName, "d2"); + EXPECT_EQ(strategyTypeName, "s2"); + EXPECT_FALSE(enabled); +} diff --git a/ros/utils/recorders/test/perception_logger/sqlite/SQLiteVideoAnalysisLoggerTests.cpp b/ros/utils/recorders/test/perception_logger/sqlite/SQLiteVideoAnalysisLoggerTests.cpp index 59cd7ad0..b28273e2 100644 --- a/ros/utils/recorders/test/perception_logger/sqlite/SQLiteVideoAnalysisLoggerTests.cpp +++ b/ros/utils/recorders/test/perception_logger/sqlite/SQLiteVideoAnalysisLoggerTests.cpp @@ -12,16 +12,22 @@ void readVideoAnalysis( SQLite::Database& database, int64_t id, string& objectClass, + float& objectConfidence, + float& objectClassProbability, BoundingBox& boundingBox, vector& personPoseImage, vector& personPose, vector& personPoseConfidence, - vector& faceDescriptor) + vector& faceDescriptor, + optional& faceAlignmentKeypointCount, + optional& faceSharpnessScore) { SQLite::Statement query( database, - "SELECT object_class, bounding_box_centre_x, bounding_box_centre_y, bounding_box_width, bounding_box_height, " - "person_pose_image, person_pose, person_pose_confidence, face_descriptor FROM video_analysis WHERE " + "SELECT object_class, object_confidence, object_class_probability, " + "bounding_box_centre_x, bounding_box_centre_y, bounding_box_width, bounding_box_height, " + "person_pose_image, person_pose, person_pose_confidence, " + "face_descriptor, face_alignment_keypoint_count, face_sharpness_score FROM video_analysis WHERE " "perception_id=?"); query.bind(1, id); @@ -29,22 +35,37 @@ void readVideoAnalysis( { objectClass = ""; boundingBox = {}; + objectConfidence = -1.f; + objectClassProbability = -1.f; personPoseImage.clear(); personPose.clear(); personPoseConfidence.clear(); faceDescriptor.clear(); + faceAlignmentKeypointCount = nullopt; + faceSharpnessScore = nullopt; return; } objectClass = query.getColumn(0).getString(); - boundingBox.center.x = query.getColumn(1).getDouble(); - boundingBox.center.y = query.getColumn(2).getDouble(); - boundingBox.width = query.getColumn(3).getDouble(); - boundingBox.height = query.getColumn(4).getDouble(); - columnToVector(query.getColumn(5), personPoseImage); - columnToVector(query.getColumn(6), personPose); - columnToVector(query.getColumn(7), personPoseConfidence); - columnToVector(query.getColumn(8), faceDescriptor); + objectConfidence = query.getColumn(1).getDouble(); + objectClassProbability = query.getColumn(2).getDouble(); + boundingBox.center.x = query.getColumn(3).getDouble(); + boundingBox.center.y = query.getColumn(4).getDouble(); + boundingBox.width = query.getColumn(5).getDouble(); + boundingBox.height = query.getColumn(6).getDouble(); + columnToVector(query.getColumn(7), personPoseImage); + columnToVector(query.getColumn(8), personPose); + columnToVector(query.getColumn(9), personPoseConfidence); + columnToVector(query.getColumn(10), faceDescriptor); + + if (!query.getColumn(11).isNull()) + { + faceAlignmentKeypointCount = query.getColumn(11).getInt(); + } + if (!query.getColumn(12).isNull()) + { + faceSharpnessScore = query.getColumn(12).getDouble(); + } } TEST(SQLiteVideoAnalysisLoggerTests, log_shouldInsertAndReturnId) @@ -53,84 +74,124 @@ TEST(SQLiteVideoAnalysisLoggerTests, log_shouldInsertAndReturnId) SQLiteVideoAnalysisLogger logger(database); - int64_t id0 = logger.log( - VideoAnalysis(Timestamp(101), Position{1, 2, 3}, Direction{4, 5, 6}, "banana", BoundingBox{{7, 8}, 9, 10})); + int64_t id0 = logger.log(VideoAnalysis( + Timestamp(101), + Position{1, 2, 3}, + Direction{4, 5, 6}, + "banana", + 1.f, + 0.5f, + BoundingBox{{7, 8}, 9, 10})); int64_t id1 = logger.log(VideoAnalysis( Timestamp(102), Position{11, 12, 13}, Direction{14, 15, 16}, "person", + 0.75f, + 0.25f, BoundingBox{{17, 18}, 19, 20}, {{21, 22}}, {{23, 24, 25}}, - {0.5})); + {0.5f})); int64_t id2 = logger.log(VideoAnalysis( Timestamp(103), Position{26, 27, 28}, Direction{29, 30, 31}, "person", + 0.5f, + 0.75f, BoundingBox{{32, 33}, 34, 35}, {{36, 37}}, {{38, 39, 40}}, - {0.75}, - {41, 42})); + {0.75f}, + {41, 42}, + 5, + 0.125f)); EXPECT_TRUE(perceptionExists(database, id0, Timestamp(101), Position{1, 2, 3}, Direction{4, 5, 6})); EXPECT_TRUE(perceptionExists(database, id1, Timestamp(102), Position{11, 12, 13}, Direction{14, 15, 16})); EXPECT_TRUE(perceptionExists(database, id2, Timestamp(103), Position{26, 27, 28}, Direction{29, 30, 31})); string objectClass; + float objectConfidence; + float objectClassProbability; BoundingBox boundingBox; vector personPoseImage; vector personPose; vector personPoseConfidence; vector faceDescriptor; + optional faceAlignmentKeypointCount; + optional faceSharpnessScore; readVideoAnalysis( database, id0, objectClass, + objectConfidence, + objectClassProbability, boundingBox, personPoseImage, personPose, personPoseConfidence, - faceDescriptor); + faceDescriptor, + faceAlignmentKeypointCount, + faceSharpnessScore); EXPECT_EQ(objectClass, "banana"); + EXPECT_EQ(objectConfidence, 1.f); + EXPECT_EQ(objectClassProbability, 0.5f); EXPECT_EQ(boundingBox, (BoundingBox{{7, 8}, 9, 10})); EXPECT_EQ(personPoseImage, vector({})); EXPECT_EQ(personPose, vector({})); EXPECT_EQ(personPoseConfidence, vector({})); EXPECT_EQ(faceDescriptor, vector({})); + EXPECT_EQ(faceAlignmentKeypointCount, nullopt); + EXPECT_EQ(faceSharpnessScore, nullopt); readVideoAnalysis( database, id1, objectClass, + objectConfidence, + objectClassProbability, boundingBox, personPoseImage, personPose, personPoseConfidence, - faceDescriptor); + faceDescriptor, + faceAlignmentKeypointCount, + faceSharpnessScore); EXPECT_EQ(objectClass, "person"); + EXPECT_EQ(objectConfidence, 0.75f); + EXPECT_EQ(objectClassProbability, 0.25f); EXPECT_EQ(boundingBox, (BoundingBox{{17, 18}, 19, 20})); ASSERT_EQ(personPoseImage, vector({ImagePosition{21, 22}})); ASSERT_EQ(personPose, vector({Position{23, 24, 25}})); EXPECT_EQ(personPoseConfidence, vector({0.5f})); EXPECT_EQ(faceDescriptor, vector({})); + EXPECT_EQ(faceAlignmentKeypointCount, nullopt); + EXPECT_EQ(faceSharpnessScore, nullopt); readVideoAnalysis( database, id2, objectClass, + objectConfidence, + objectClassProbability, boundingBox, personPoseImage, personPose, personPoseConfidence, - faceDescriptor); + faceDescriptor, + faceAlignmentKeypointCount, + faceSharpnessScore); EXPECT_EQ(objectClass, "person"); + EXPECT_EQ(objectConfidence, 0.5f); + EXPECT_EQ(objectClassProbability, 0.75f); EXPECT_EQ(boundingBox, (BoundingBox{{32, 33}, 34, 35})); ASSERT_EQ(personPoseImage, vector({ImagePosition{36, 37}})); ASSERT_EQ(personPose, vector({Position{38, 39, 40}})); EXPECT_EQ(personPoseConfidence, vector({0.75f})); EXPECT_EQ(faceDescriptor, vector({41.f, 42.f})); + EXPECT_EQ(faceAlignmentKeypointCount, 5); + EXPECT_EQ(faceSharpnessScore, 0.125f); } diff --git a/ros/utils/t_top_hbba_lite/include/t_top_hbba_lite/Strategies.h b/ros/utils/t_top_hbba_lite/include/t_top_hbba_lite/Strategies.h index c794b5bb..8c210b36 100644 --- a/ros/utils/t_top_hbba_lite/include/t_top_hbba_lite/Strategies.h +++ b/ros/utils/t_top_hbba_lite/include/t_top_hbba_lite/Strategies.h @@ -24,6 +24,8 @@ class FaceAnimationStrategy : public Strategy DECLARE_NOT_COPYABLE(FaceAnimationStrategy); DECLARE_NOT_MOVABLE(FaceAnimationStrategy); + StrategyType strategyType() override; + protected: void onEnabling(const FaceAnimationDesire& desire) override; void onDisabling() override; @@ -40,6 +42,8 @@ class LedEmotionStrategy : public Strategy DECLARE_NOT_COPYABLE(LedEmotionStrategy); DECLARE_NOT_MOVABLE(LedEmotionStrategy); + StrategyType strategyType() override; + protected: void onEnabling(const LedEmotionDesire& desire) override; }; @@ -58,6 +62,8 @@ class SpecificFaceFollowingStrategy : public Strategy DECLARE_NOT_COPYABLE(TalkStrategy); DECLARE_NOT_MOVABLE(TalkStrategy); + StrategyType strategyType() override; + protected: void onEnabling(const TalkDesire& desire) override; @@ -103,6 +111,8 @@ class GestureStrategy : public Strategy DECLARE_NOT_COPYABLE(GestureStrategy); DECLARE_NOT_MOVABLE(GestureStrategy); + StrategyType strategyType() override; + protected: void onEnabling(const GestureDesire& desire) override; @@ -127,6 +137,8 @@ class PlaySoundStrategy : public Strategy DECLARE_NOT_COPYABLE(PlaySoundStrategy); DECLARE_NOT_MOVABLE(PlaySoundStrategy); + StrategyType strategyType() override; + protected: void onEnabling(const PlaySoundDesire& desire) override; diff --git a/ros/utils/t_top_hbba_lite/src/Strategies.cpp b/ros/utils/t_top_hbba_lite/src/Strategies.cpp index 41ad9c6f..757b85e9 100644 --- a/ros/utils/t_top_hbba_lite/src/Strategies.cpp +++ b/ros/utils/t_top_hbba_lite/src/Strategies.cpp @@ -17,6 +17,11 @@ FaceAnimationStrategy::FaceAnimationStrategy( m_animationPublisher = nodeHandle.advertise("face/animation", 1); } +StrategyType FaceAnimationStrategy::strategyType() +{ + return StrategyType::get(); +} + void FaceAnimationStrategy::onEnabling(const FaceAnimationDesire& desire) { Strategy::onEnabling(desire); @@ -46,6 +51,11 @@ LedEmotionStrategy::LedEmotionStrategy(uint16_t utility, shared_ptr m_emotionPublisher = nodeHandle.advertise("led_emotions/name", 1); } +StrategyType LedEmotionStrategy::strategyType() +{ + return StrategyType::get(); +} + void LedEmotionStrategy::onEnabling(const LedEmotionDesire& desire) { Strategy::onEnabling(desire); @@ -70,6 +80,11 @@ SpecificFaceFollowingStrategy::SpecificFaceFollowingStrategy( m_targetNamePublisher = nodeHandle.advertise("face_following/target_name", 1); } +StrategyType SpecificFaceFollowingStrategy::strategyType() +{ + return StrategyType::get(); +} + void SpecificFaceFollowingStrategy::onEnabling(const SpecificFaceFollowingDesire& desire) { Strategy::onEnabling(desire); @@ -96,6 +111,11 @@ TalkStrategy::TalkStrategy( m_talkDoneSubscriber = nodeHandle.subscribe("talk/done", 10, &TalkStrategy::talkDoneSubscriberCallback, this); } +StrategyType TalkStrategy::strategyType() +{ + return StrategyType::get(); +} + void TalkStrategy::onEnabling(const TalkDesire& desire) { Strategy::onEnabling(desire); @@ -132,6 +152,11 @@ GestureStrategy::GestureStrategy( nodeHandle.subscribe("gesture/done", 1, &GestureStrategy::gestureDoneSubscriberCallback, this); } +StrategyType GestureStrategy::strategyType() +{ + return StrategyType::get(); +} + void GestureStrategy::onEnabling(const GestureDesire& desire) { Strategy::onEnabling(desire); @@ -168,6 +193,11 @@ PlaySoundStrategy::PlaySoundStrategy( nodeHandle.subscribe("sound_player/done", 1, &PlaySoundStrategy::soundDoneSubscriberCallback, this); } +StrategyType PlaySoundStrategy::strategyType() +{ + return StrategyType::get(); +} + void PlaySoundStrategy::onEnabling(const PlaySoundDesire& desire) { Strategy::onEnabling(desire); diff --git a/tools/dnn_training/README.md b/tools/dnn_training/README.md index f04f3bff..681393a8 100644 --- a/tools/dnn_training/README.md +++ b/tools/dnn_training/README.md @@ -2,16 +2,21 @@ This folder contains the tools to train the neural networks used by T-Top. +## Notes +The `descriptor_yolo_v7` network is training with [https://github.com/mamaheux/descriptor-yolov7](https://github.com/mamaheux/descriptor-yolov7). + + ## Folder Structure - Training scripts - The [train_audio_descriptor_extractor.py](train_audio_descriptor_extractor.py) script trains a neural network that classifies a sound to one class and extracts an embedding. The network name is `audio_descriptor_extractor`. - The [train_backbone.py](train_backbone.py) script trains a neural network that classifies images. - - The [train_backbone_distillation.py](train_backbone_distillation.py) script trains a student neural network that + - The [train_backbone_distillation.py](train_backbone_distillation.py) script trains a student neural network that classifies images from a teacher one. - The [train_descriptor_yolo_v4.py](train_descriptor_yolo_v4.py) script trains a neural network that detects objects, classifies them and extracts embeddings. The network name is `descriptor_yolo_v4`. + The training is not working properly. - The [train_face_descriptor_extractor.py](train_face_descriptor_extractor.py) script trains a neural network that extracts an embedding for a face. The network name is `face_descriptor_extractor`. - The [train_keyword_spotter.py](train_keyword_spotter.py) script trains a neural network that detects a wake-up @@ -21,12 +26,12 @@ This folder contains the tools to train the neural networks used by T-Top. is `audio_descriptor_extractor`. - The [train_pose_estimator.py](train_pose_estimator.py) script trains a neural network that estimates the pose of a person. The network name is `pose_estimator`. - - The [train_semantic_segmentation_network.py](train_semantic_segmentation_network.py) script trains a neural + - The [train_semantic_segmentation_network.py](train_semantic_segmentation_network.py) script trains a neural network that performs semantic segmentation. The network name is `semantic_segmentation_network`. - Export scripts - The [export_audio_descriptor_extractor.py](export_audio_descriptor_extractor.py) script exports the `audio_descriptor_extractor` network to a TorchScript file and a TensorRT file. - - The [export_descriptor_yolo_v4.py](export_descriptor_yolo_v4.py) script exports the `descriptor_yolo_v4` network + - The [export_descriptor_yolo.py](export_descriptor_yolo.py) script exports the `descriptor_yolo` network to a TorchScript file and a TensorRT file. - The [export_face_descriptor_extractor.py](export_face_descriptor_extractor.py) script exports the `face_descriptor_extractor` network to a TorchScript file and a TensorRT file. @@ -36,7 +41,7 @@ This folder contains the tools to train the neural networks used by T-Top. TorchScript file and a TensorRT file. - The [export_yolo_v4.py](export_yolo_v4.py) script exports the `yolo_v4` network to a TorchScript file and a TensorRT file. - - The [export_semantic_segmentation_network.py](export_semantic_segmentation_network.py) script exports the + - The [export_semantic_segmentation_network.py](export_semantic_segmentation_network.py) script exports the `semantic_segmentation_network` network to a TorchScript file and a TensorRT file. - Test scripts - The [test_exported_audio_descriptor_extractor.py](test_exported_audio_descriptor_extractor.py) script tests the @@ -53,7 +58,7 @@ This folder contains the tools to train the neural networks used by T-Top. the `audio_descriptor_extractor` network and the `face_descriptor_extractor` network. - The [test_pose_estimator_with_yolo_v4.py](test_pose_estimator_with_yolo_v4.py) script tests the `pose_estimator` network with the `yolo_v4` network on the COCO dataset. - - The [test_exported_semantic_segmentation_network.py](test_exported_semantic_segmentation_network.py) script tests + - The [test_exported_semantic_segmentation_network.py](test_exported_semantic_segmentation_network.py) script tests the exported `semantic_segmentation_network` network. ## Setup diff --git a/tools/dnn_training/audio_descriptor/audio_descriptor_extractor.py b/tools/dnn_training/audio_descriptor/audio_descriptor_extractor.py index 07732729..16a8bcc0 100644 --- a/tools/dnn_training/audio_descriptor/audio_descriptor_extractor.py +++ b/tools/dnn_training/audio_descriptor/audio_descriptor_extractor.py @@ -1,11 +1,11 @@ import torch.nn as nn -from common.modules import L2Normalization, GlobalAvgPool2d, GlobalHeightAvgPool2d, AmSoftmaxLinear, NetVLAD -from audio_descriptor.modules import SAP +from common.modules import L2Normalization, GlobalAvgPool2d, GlobalHeightAvgPool2d, NormalizedLinear, NetVLAD +from audio_descriptor.modules import SAP, PSLAAttention class AudioDescriptorExtractor(nn.Module): - def __init__(self, backbone, embedding_size=128, class_count=None, am_softmax_linear=False): + def __init__(self, backbone, embedding_size=128, class_count=None, normalized_linear=False): super(AudioDescriptorExtractor, self).__init__() self._backbone = backbone @@ -16,7 +16,7 @@ def __init__(self, backbone, embedding_size=128, class_count=None, am_softmax_li L2Normalization() ) - self._classifier = _create_classifier(embedding_size, class_count, am_softmax_linear) + self._classifier = _create_classifier(embedding_size, class_count, normalized_linear) self._class_count = class_count def class_count(self): @@ -35,7 +35,7 @@ def forward(self, x): class AudioDescriptorExtractorVLAD(nn.Module): - def __init__(self, backbone, embedding_size, class_count=None, am_softmax_linear=False): + def __init__(self, backbone, embedding_size, class_count=None, normalized_linear=False): super(AudioDescriptorExtractorVLAD, self).__init__() self._backbone = backbone @@ -49,7 +49,7 @@ def __init__(self, backbone, embedding_size, class_count=None, am_softmax_linear L2Normalization() ) - self._classifier = _create_classifier(embedding_size, class_count, am_softmax_linear) + self._classifier = _create_classifier(embedding_size, class_count, normalized_linear) self._class_count = class_count def class_count(self): @@ -67,7 +67,7 @@ def forward(self, x): class AudioDescriptorExtractorSAP(nn.Module): - def __init__(self, backbone, embedding_size=128, class_count=None, am_softmax_linear=False): + def __init__(self, backbone, embedding_size=128, class_count=None, normalized_linear=False): super(AudioDescriptorExtractorSAP, self).__init__() self._backbone = backbone @@ -79,7 +79,7 @@ def __init__(self, backbone, embedding_size=128, class_count=None, am_softmax_li L2Normalization() ) - self._classifier = _create_classifier(embedding_size, class_count, am_softmax_linear) + self._classifier = _create_classifier(embedding_size, class_count, normalized_linear) self._class_count = class_count def class_count(self): @@ -98,10 +98,42 @@ def forward(self, x): return descriptor -def _create_classifier(embedding_size, class_count, am_softmax_linear): +class AudioDescriptorExtractorPSLA(nn.Module): + def __init__(self, backbone, embedding_size=128, class_count=None, normalized_linear=False): + super(AudioDescriptorExtractorPSLA, self).__init__() + + self._backbone = backbone + self._frequency_pooling = GlobalHeightAvgPool2d() + self._psla_attention = PSLAAttention(backbone.last_channel_count(), backbone.last_channel_count()) + + self._descriptor_layers = nn.Sequential( + nn.Linear(backbone.last_channel_count(), embedding_size), + L2Normalization() + ) + + self._classifier = _create_classifier(embedding_size, class_count, normalized_linear) + self._class_count = class_count + + def class_count(self): + return self._class_count + + def forward(self, x): + features = self._backbone(x) + features = self._frequency_pooling(features) + features = self._psla_attention(features) + + descriptor = self._descriptor_layers(features) + if self._classifier is not None: + class_scores = self._classifier(descriptor) + return descriptor, class_scores + else: + return descriptor + + +def _create_classifier(embedding_size, class_count, normalized_linear): if class_count is not None: - if am_softmax_linear: - return AmSoftmaxLinear(embedding_size, class_count) + if normalized_linear: + return NormalizedLinear(embedding_size, class_count) else: return nn.Linear(embedding_size, class_count) else: diff --git a/tools/dnn_training/audio_descriptor/backbones/__init__.py b/tools/dnn_training/audio_descriptor/backbones/__init__.py index e36788a1..078455b3 100644 --- a/tools/dnn_training/audio_descriptor/backbones/__init__.py +++ b/tools/dnn_training/audio_descriptor/backbones/__init__.py @@ -1,6 +1,6 @@ from audio_descriptor.backbones.ecapa_tdnn import EcapaTdnn, SmallEcapaTdnn from audio_descriptor.backbones.mnasnet import Mnasnet0_5, Mnasnet1_0 -from audio_descriptor.backbones.resnet import Resnet18, Resnet34, Resnet50 +from audio_descriptor.backbones.resnet import Resnet18, Resnet34, Resnet50, Resnet101 from audio_descriptor.backbones.open_face_inception import OpenFaceInception from audio_descriptor.backbones.thin_resnet_34 import ThinResnet34 from audio_descriptor.backbones.tiny_cnn import TinyCnn diff --git a/tools/dnn_training/audio_descriptor/backbones/resnet.py b/tools/dnn_training/audio_descriptor/backbones/resnet.py index 8c56c72f..ec358855 100644 --- a/tools/dnn_training/audio_descriptor/backbones/resnet.py +++ b/tools/dnn_training/audio_descriptor/backbones/resnet.py @@ -48,3 +48,11 @@ def __init__(self, pretrained=False): def last_channel_count(self): return 2048 + + +class Resnet101(_Resnet): + def __init__(self, pretrained=False): + super(Resnet101, self).__init__(models.resnet101(pretrained=pretrained)) + + def last_channel_count(self): + return 2048 diff --git a/tools/dnn_training/audio_descriptor/datasets/__init__.py b/tools/dnn_training/audio_descriptor/datasets/__init__.py index 36b0ff8a..3da85d23 100644 --- a/tools/dnn_training/audio_descriptor/datasets/__init__.py +++ b/tools/dnn_training/audio_descriptor/datasets/__init__.py @@ -1,4 +1,7 @@ from audio_descriptor.datasets.audio_descriptor_transforms import AudioDescriptorTrainingTransforms, \ AudioDescriptorValidationTransforms, AudioDescriptorTestTransforms from audio_descriptor.datasets.audio_descriptor_dataset import AudioDescriptorDataset -from audio_descriptor.datasets.fsd50k_dataset import Fsd50kDataset, FSDK50k_POS_WEIGHT +from audio_descriptor.datasets.audio_set_dataset import AudioSetDataset +from audio_descriptor.datasets.fsd50k_dataset import Fsd50kDataset +from audio_descriptor.datasets.imbalanced_multiclass_audio_descriptor_dataset_sampler import \ + ImbalancedMulticlassAudioDescriptorDatasetSampler diff --git a/tools/dnn_training/audio_descriptor/datasets/audio_descriptor_transforms.py b/tools/dnn_training/audio_descriptor/datasets/audio_descriptor_transforms.py index 70821df4..65f42aa3 100644 --- a/tools/dnn_training/audio_descriptor/datasets/audio_descriptor_transforms.py +++ b/tools/dnn_training/audio_descriptor/datasets/audio_descriptor_transforms.py @@ -2,11 +2,17 @@ import random import torch +import torch.nn as nn import torchaudio import torchaudio.transforms as transforms from common.datasets.audio_transform_utils import to_mono, resample, resize_waveform, resize_waveform_random, \ - normalize, standardize_every_frame, RandomPitchShift, RandomTimeStretch + normalize, RandomPitchShift, RandomTimeStretch + + +class LogModule(nn.Module): + def forward(self, x, eps=1e-6): + return torch.log10(x + eps) class _AudioDescriptorTransforms: @@ -26,6 +32,13 @@ def __init__(self, sample_rate=16000, waveform_size=64000, n_features=128, n_fft self._audio_transform = transforms.MelSpectrogram(sample_rate=self._sample_rate, n_fft=n_fft, n_mels=n_features) + elif audio_transform_type == 'log_mel_spectrogram': + self._audio_transform = nn.Sequential( + transforms.MelSpectrogram(sample_rate=self._sample_rate, + n_fft=n_fft, + n_mels=n_features), + LogModule() + ) elif audio_transform_type == 'spectrogram': if n_features != (n_fft // 2 + 1): raise ValueError('n_features must be equal to (n_fft // 2 + 1) ' @@ -115,7 +128,6 @@ def __call__(self, waveform, target, metadata): if self._enable_frequency_masking and random.random() < self._frequency_masking_p: spectrogram = self._frequency_masking(spectrogram) - spectrogram = standardize_every_frame(spectrogram) return spectrogram, target, metadata def _add_noise(self, waveform): @@ -135,7 +147,6 @@ def __call__(self, waveform, target, metadata): waveform = normalize(waveform) spectrogram = self._audio_transform(waveform) - spectrogram = standardize_every_frame(spectrogram) return spectrogram, target, metadata @@ -146,5 +157,4 @@ def __call__(self, waveform, target, metadata): waveform = normalize(waveform) spectrogram = self._audio_transform(waveform) - spectrogram = standardize_every_frame(spectrogram) return spectrogram, target, metadata diff --git a/tools/dnn_training/audio_descriptor/datasets/audio_set_dataset.py b/tools/dnn_training/audio_descriptor/datasets/audio_set_dataset.py new file mode 100644 index 00000000..7ebdc43e --- /dev/null +++ b/tools/dnn_training/audio_descriptor/datasets/audio_set_dataset.py @@ -0,0 +1,44 @@ +import os +import json + +from audio_descriptor.datasets.multiclass_audio_descriptor_dataset import MulticlassAudioDescriptorDataset + + +class AudioSetDataset(MulticlassAudioDescriptorDataset): + def _list_classes(self, root): + with open(os.path.join(root, 'ontology.json')) as ontology_file: + ontology_data = json.load(ontology_file) + + class_names = [d['id'] for d in ontology_data] + class_names.sort() + return {i: c for i, c in enumerate(class_names)} + + def _list_sounds(self, root, split, enhanced_targets): + folder, filename = self._get_folder_and_sound_file(split, enhanced_targets) + + sounds = [] + with open(os.path.join(root, filename), 'r') as sound_file: + for line in sound_file: + values = line.split(' ') + filename = values[0] + class_names = (n.strip() for n in values[1:]) + sounds.append({ + 'path': os.path.join(folder, filename[:2], filename), + 'target': self._create_target(class_names) + }) + + return sounds + + def _get_folder_and_sound_file(self, split, enhanced_targets): + if split == 'training' and enhanced_targets: + return 'train', 'train_enhanced.txt' + elif split == 'training' and not enhanced_targets: + return 'train', 'train.txt' + elif split == 'validation' and enhanced_targets: + return 'balanced_train', 'validation_enhanced.txt' + elif split == 'validation' and not enhanced_targets: + return 'balanced_train', 'validation.txt' + elif split == 'testing': + return 'eval', 'test.txt' + else: + raise ValueError('Invalid split') diff --git a/tools/dnn_training/audio_descriptor/datasets/fsd50k_dataset.py b/tools/dnn_training/audio_descriptor/datasets/fsd50k_dataset.py index 2f8769ed..004ec4f1 100644 --- a/tools/dnn_training/audio_descriptor/datasets/fsd50k_dataset.py +++ b/tools/dnn_training/audio_descriptor/datasets/fsd50k_dataset.py @@ -1,52 +1,10 @@ import os import csv -import random -import torch -from torch.utils.data import Dataset -import torchaudio +from audio_descriptor.datasets.multiclass_audio_descriptor_dataset import MulticlassAudioDescriptorDataset -# 1 - AP of each class during the training -FSDK50k_POS_WEIGHT = torch.tensor([0.5651, 0.1267, 0.1715, 0.4201, 0.1626, 0.1172, 0.1135, 0.1636, 0.2194, - 0.1674, 0.6011, 0.1615, 0.4142, 0.1635, 0.1679, 0.3407, 0.6103, 0.5347, - 0.4462, 0.0848, 0.0880, 0.2101, 0.1174, 0.3452, 0.3135, 0.4181, 0.3304, - 0.5456, 0.1882, 0.7136, 0.2944, 0.5132, 0.1999, 0.2401, 0.1726, 0.4264, - 0.3671, 0.5323, 0.3620, 0.2148, 0.2273, 0.3999, 0.5439, 0.5621, 0.2584, - 0.0998, 0.8355, 0.5168, 0.1512, 0.2777, 0.3545, 0.3391, 0.5969, 0.7260, - 0.3474, 0.3558, 0.4137, 0.0947, 0.3535, 0.1371, 0.1311, 0.1961, 0.2600, - 0.1426, 0.6236, 0.2040, 0.4149, 0.1309, 0.1180, 0.1205, 0.4038, 0.5866, - 0.2246, 0.1319, 0.1774, 0.1733, 0.5300, 0.1873, 0.6034, 0.3428, 0.3794, - 0.1900, 0.3687, 0.6971, 0.3248, 0.3333, 0.2516, 0.1687, 0.1891, 0.4374, - 0.0784, 0.4582, 0.3191, 0.4168, 0.5475, 0.2403, 0.1014, 0.2052, 0.1083, - 0.5914, 0.1457, 0.0759, 0.5624, 0.2385, 0.1468, 0.4627, 0.2696, 0.1478, - 0.2991, 0.2345, 0.2632, 0.2202, 0.1617, 0.1774, 0.5143, 0.3080, 0.1851, - 0.4282, 0.2364, 0.4489, 0.0200, 0.0198, 0.3145, 0.2032, 0.7746, 0.0797, - 0.1643, 0.0784, 0.4913, 0.1963, 0.4238, 0.2741, 0.4078, 0.2325, 0.4009, - 0.3452, 0.3277, 0.7384, 0.2016, 0.1768, 0.1613, 0.3781, 0.3899, 0.7663, - 0.1736, 0.1545, 0.4978, 0.2567, 0.2145, 0.2204, 0.1840, 0.4632, 0.3326, - 0.2809, 0.4492, 0.5211, 0.1237, 0.3485, 0.1280, 0.2009, 0.3581, 0.4747, - 0.3072, 0.1372, 0.2562, 0.3254, 0.1270, 0.6841, 0.4878, 0.1764, 0.5468, - 0.2546, 0.2394, 0.9123, 0.3800, 0.2726, 0.3231, 0.5599, 0.4232, 0.4885, - 0.7342, 0.0858, 0.5061, 0.4035, 0.1759, 0.1579, 0.5611, 0.2854, 0.4690, - 0.3196, 0.2173, 0.3066, 0.1498, 0.5192, 0.2322, 0.0570, 0.7181, 0.5055, - 0.2353, 0.3516], dtype=torch.float64) - - -class Fsd50kDataset(Dataset): - def __init__(self, root, split=None, transforms=None, enable_mixup=True): - self._class_indexes_by_name = self._list_classes(root) - - if split == 'training': - self._sounds = self._list_sounds(root, 'dev') - elif split == 'validation': - self._sounds = self._list_sounds(root, 'eval') - else: - raise ValueError('Invalid split') - - self._transforms = transforms - self._enable_mixup = enable_mixup - +class Fsd50kDataset(MulticlassAudioDescriptorDataset): def _list_classes(self, root): class_indexes_by_name = {} with open(os.path.join(root, 'FSD50K.ground_truth', 'vocabulary.csv'), newline='') as vocabulary_file: @@ -56,57 +14,31 @@ def _list_classes(self, root): return class_indexes_by_name - def _list_sounds(self, root, fsd50k_split): + def _list_sounds(self, root, split, enhanced_targets): + folder, filename = self._get_folder_and_sound_file(split, enhanced_targets) + sounds = [] - with open(os.path.join(root, 'FSD50K.ground_truth', '{}.csv'.format(fsd50k_split)), newline='') as sound_file: - sound_reader = csv.reader(sound_file, delimiter=',', quotechar='"') - next(sound_reader) - for row in sound_reader: - id = row[0] - class_names = row[1] - class_names = class_names.split(',') + with open(os.path.join(root, filename), 'r') as sound_file: + for line in sound_file: + values = line.split(' ') + class_names = (n.strip() for n in values[1:]) sounds.append({ - 'path': os.path.join(root, 'FSD50K.{}_audio'.format(fsd50k_split), '{}.wav'.format(id)), + 'path': os.path.join(folder, values[0]), 'target': self._create_target(class_names) }) return sounds - def _create_target(self, class_names): - target = torch.zeros(len(self._class_indexes_by_name), dtype=torch.float) - for class_name in class_names: - target[self._class_indexes_by_name[class_name]] = 1.0 - - return target - - def __len__(self): - return len(self._sounds) - - def __getitem__(self, index): - waveform, target, metadata = self._get_item_without_mixup(index) - - if self._enable_mixup: - mixup_index = random.randrange(len(self._sounds)) - alpha = random.random() - mixup_waveform, mixup_target, _ = self._get_item_without_mixup(mixup_index) - - waveform = alpha * waveform + (1 - alpha) * mixup_waveform - target = alpha * target + (1 - alpha) * mixup_target - - return waveform, target, metadata - - def _get_item_without_mixup(self, index): - waveform, sample_rate = torchaudio.load(self._sounds[index]['path']) - target = self._sounds[index]['target'].clone() - - metadata = { - 'original_sample_rate': sample_rate - } - - if self._transforms is not None: - waveform, target, metadata = self._transforms(waveform, target, metadata) - - return waveform, target, metadata - - def transforms(self): - return self._transforms + def _get_folder_and_sound_file(self, split, enhanced_targets): + if split == 'training' and enhanced_targets: + return 'FSD50K.dev_audio', 'train_enhanced.txt' + elif split == 'training' and not enhanced_targets: + return 'FSD50K.dev_audio', 'train.txt' + elif split == 'validation' and enhanced_targets: + return 'FSD50K.dev_audio', 'validation_enhanced.txt' + elif split == 'validation' and not enhanced_targets: + return 'FSD50K.dev_audio', 'validation.txt' + elif split == 'testing': + return 'FSD50K.eval_audio', 'test.txt' + else: + raise ValueError('Invalid split') diff --git a/tools/dnn_training/audio_descriptor/datasets/imbalanced_multiclass_audio_descriptor_dataset_sampler.py b/tools/dnn_training/audio_descriptor/datasets/imbalanced_multiclass_audio_descriptor_dataset_sampler.py new file mode 100644 index 00000000..585ffb19 --- /dev/null +++ b/tools/dnn_training/audio_descriptor/datasets/imbalanced_multiclass_audio_descriptor_dataset_sampler.py @@ -0,0 +1,23 @@ +import torch + + +class ImbalancedMulticlassAudioDescriptorDatasetSampler(torch.utils.data.sampler.Sampler): + def __init__(self, multiclass_audio_descriptor_dataset): + self._sound_count = len(multiclass_audio_descriptor_dataset) + + class_counts = torch.ones(multiclass_audio_descriptor_dataset.class_count()) * 1e-6 + for i in range(self._sound_count): + class_counts += multiclass_audio_descriptor_dataset.get_target(i) + + class_weights = 1.0 / class_counts + self._image_weights = [] + for i in range(self._sound_count): + self._image_weights.append((class_weights * multiclass_audio_descriptor_dataset.get_target(i)).sum()) + self._image_weights = torch.tensor(self._image_weights) + + def __iter__(self): + indexes = torch.multinomial(self._image_weights, self._sound_count, replacement=True) + return iter(indexes.tolist()) + + def __len__(self): + return self._sound_count diff --git a/tools/dnn_training/audio_descriptor/datasets/multiclass_audio_descriptor_dataset.py b/tools/dnn_training/audio_descriptor/datasets/multiclass_audio_descriptor_dataset.py new file mode 100644 index 00000000..126b2763 --- /dev/null +++ b/tools/dnn_training/audio_descriptor/datasets/multiclass_audio_descriptor_dataset.py @@ -0,0 +1,74 @@ +import os +import random + +import torch +from torch.utils.data import Dataset +import torchaudio + + +class MulticlassAudioDescriptorDataset(Dataset): + def __init__(self, root, split=None, transforms=None, mixup_rate=0.5, mixup_alpha=10.0, enhanced_targets=True): + self._root = root + self._class_indexes_by_name = self._list_classes(root) + + self._sounds = self._list_sounds(root, split, enhanced_targets) + self._transforms = transforms + + self._mixup_rate = mixup_rate if split == 'training' else -1.0 + self._mixup_alpha = mixup_alpha + + def _list_classes(self, root): + raise NotImplementedError() + + def _list_sounds(self, root, split, enhanced_targets): + raise NotImplementedError() + + def _create_target(self, class_names): + target = torch.zeros(len(self._class_indexes_by_name), dtype=torch.float) + for class_name in class_names: + target[self._class_indexes_by_name[class_name]] = 1.0 + + return target + + def class_count(self): + if len(self._sounds) > 0: + return self._sounds[0]['target'].size(0) + else: + return 0 + + def __len__(self): + return len(self._sounds) + + def __getitem__(self, index): + waveform, target, metadata = self._get_item_without_mixup(index) + + if random.random() < self._mixup_rate: + mixup_index = random.randrange(len(self._sounds)) + mixup_waveform, mixup_target, _ = self._get_item_without_mixup(mixup_index) + l = random.betavariate(self._mixup_alpha, self._mixup_alpha) + + waveform = l * waveform + (1.0 - l) * mixup_waveform + target = l * target + (1.0 - l) * mixup_target + + return waveform, target, metadata + + def get_target(self, index): + return self._sounds[index]['target'].clone() + + def _get_item_without_mixup(self, index): + waveform, sample_rate = torchaudio.load(os.path.join(self._root, self._sounds[index]['path'])) + target = self._sounds[index]['target'].clone() + + metadata = { + 'original_sample_rate': sample_rate + } + + if self._transforms is not None: + waveform, target, metadata = self._transforms(waveform, target, metadata) + + return waveform, target, metadata + + def transforms(self): + return self._transforms + + diff --git a/tools/dnn_training/audio_descriptor/modules/__init__.py b/tools/dnn_training/audio_descriptor/modules/__init__.py index 7346e09c..b6b2c171 100644 --- a/tools/dnn_training/audio_descriptor/modules/__init__.py +++ b/tools/dnn_training/audio_descriptor/modules/__init__.py @@ -1 +1,2 @@ from audio_descriptor.modules.sap import SAP +from audio_descriptor.modules.psla_attention import PSLAAttention diff --git a/tools/dnn_training/audio_descriptor/modules/psla_attention.py b/tools/dnn_training/audio_descriptor/modules/psla_attention.py new file mode 100644 index 00000000..9e46e09f --- /dev/null +++ b/tools/dnn_training/audio_descriptor/modules/psla_attention.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn + + +class PSLAAttention(nn.Module): + def __init__(self, in_channels, out_channels, head_count=4): + super(PSLAAttention, self).__init__() + + self._attention_convs = nn.ModuleList() + self._feature_convs = nn.ModuleList() + for i in range(head_count): + self._attention_convs.append(nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.Sigmoid() + )) + self._feature_convs.append(nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.Sigmoid() + )) + + self._head_weights = nn.Parameter(torch.tensor([1.0 / head_count] * head_count)) + + def forward(self, x, attention_eps=1e-7): + y = [] + for attention_conv, feature_conv, head_weight in zip(self._attention_convs, self._feature_convs, self._head_weights): + attention = torch.clamp(attention_conv(x), attention_eps, 1.0 - attention_eps) + normalized_attention = attention / torch.sum(attention, dim=3, keepdim=True) + + feature = feature_conv(x) + y.append(torch.sum(normalized_attention * feature, dim=3) * head_weight) + + return torch.stack(y, dim=0).sum(dim=0).squeeze(2) diff --git a/tools/dnn_training/audio_descriptor/trainers/audio_descriptor_extractor_trainer.py b/tools/dnn_training/audio_descriptor/trainers/audio_descriptor_extractor_trainer.py index fda73edc..1b2e0a3a 100644 --- a/tools/dnn_training/audio_descriptor/trainers/audio_descriptor_extractor_trainer.py +++ b/tools/dnn_training/audio_descriptor/trainers/audio_descriptor_extractor_trainer.py @@ -125,11 +125,6 @@ def _measure_training_metrics(self, loss, model_output, target): if self._criterion_type != 'triplet_loss': self._training_accuracy_metric.add(model_output[1], target) - def _validate(self): - super(AudioDescriptorExtractorTrainer, self)._validate() - if self._criterion_type == 'am_softmax_loss': - self._criterion.next_epoch() - def _clear_between_validation_epoch(self): self._validation_loss_metric.clear() if self._criterion_type != 'triplet_loss': diff --git a/tools/dnn_training/audio_descriptor/trainers/multiclass_audio_descriptor_extractor_trainer.py b/tools/dnn_training/audio_descriptor/trainers/multiclass_audio_descriptor_extractor_trainer.py index 330fed2f..659387c3 100644 --- a/tools/dnn_training/audio_descriptor/trainers/multiclass_audio_descriptor_extractor_trainer.py +++ b/tools/dnn_training/audio_descriptor/trainers/multiclass_audio_descriptor_extractor_trainer.py @@ -11,18 +11,20 @@ from common.metrics import MulticlassClassificationAccuracyMetric, MulticlassClassificationPrecisionRecallMetric, \ LossMetric, LossAccuracyMeanAveragePrecisionLearningCurves, MulticlassClassificationMeanAveragePrecisionMetric -from audio_descriptor.datasets import Fsd50kDataset, FSDK50k_POS_WEIGHT, AudioDescriptorTrainingTransforms, \ - AudioDescriptorValidationTransforms, AudioDescriptorTestTransforms +from audio_descriptor.datasets import Fsd50kDataset, AudioSetDataset, AudioDescriptorTrainingTransforms, \ + ImbalancedMulticlassAudioDescriptorDatasetSampler, AudioDescriptorValidationTransforms, \ + AudioDescriptorTestTransforms from audio_descriptor.metrics import AudioDescriptorEvaluation class MulticlassAudioDescriptorExtractorTrainer(Trainer): - def __init__(self, device, model, dataset_root='', output_path='', + def __init__(self, device, model, dataset_root='', dataset_type='fsd50k', output_path='', epoch_count=10, learning_rate=0.01, weight_decay=0, batch_size=128, criterion_type='bce_loss', waveform_size=64000, n_features=128, n_fft=400, audio_transform_type='mel_spectrogram', enable_pitch_shifting=False, enable_time_stretching=False, enable_time_masking=False, enable_frequency_masking=False, - enable_pos_weight=False, enable_mixup=True, model_checkpoint=None): + enhanced_targets=False, model_checkpoint=None): + self._dataset_type = dataset_type self._criterion_type = criterion_type self._waveform_size = waveform_size self._n_features = n_features @@ -32,8 +34,7 @@ def __init__(self, device, model, dataset_root='', output_path='', self._enable_time_stretching = enable_time_stretching self._enable_time_masking = enable_time_masking self._enable_frequency_masking = enable_frequency_masking - self._enable_pos_weight = enable_pos_weight - self._enable_mixup = enable_mixup + self._enhanced_targets = enhanced_targets self._class_count = model.class_count() super(MulticlassAudioDescriptorExtractorTrainer, self).__init__(device, model, dataset_root=dataset_root, @@ -59,13 +60,11 @@ def __init__(self, device, model, dataset_root='', output_path='', self._validation_map_metric = MulticlassClassificationMeanAveragePrecisionMetric(self._class_count) def _create_criterion(self, model): - pos_weight = FSDK50k_POS_WEIGHT.to(self._device) if self._enable_pos_weight else None - if self._criterion_type == 'bce_loss': - criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) + criterion = nn.BCEWithLogitsLoss() return lambda model_output, target: criterion(model_output[1], target) elif self._criterion_type == 'sigmoid_focal_loss': - criterion = SigmoidFocalLossWithLogits(pos_weight=pos_weight) + criterion = SigmoidFocalLossWithLogits() return lambda model_output, target: criterion(model_output[1], target) else: raise ValueError('Invalid criterion type') @@ -82,27 +81,33 @@ def _create_training_dataset_loader(self, dataset_root, batch_size, batch_size_d enable_time_stretching=self._enable_time_stretching, enable_time_masking=self._enable_time_masking, enable_frequency_masking=self._enable_frequency_masking) - return self._create_dataset_loader(dataset_root, batch_size, batch_size_division, 'training', transforms, - enable_mixup=self._enable_mixup) + return self._create_dataset_loader(dataset_root, batch_size, batch_size_division, 'training', transforms) def _create_validation_dataset_loader(self, dataset_root, batch_size, batch_size_division): transforms = AudioDescriptorValidationTransforms(waveform_size=self._waveform_size, n_features=self._n_features, n_fft=self._n_fft, audio_transform_type=self._audio_transform_type) - return self._create_dataset_loader(dataset_root, batch_size, batch_size_division, 'validation', transforms, - enable_mixup=False) + return self._create_dataset_loader(dataset_root, batch_size, batch_size_division, 'validation', transforms) def _create_testing_dataset_loader(self, dataset_root, batch_size, batch_size_division): transforms = AudioDescriptorTestTransforms(waveform_size=self._waveform_size, n_features=self._n_features, n_fft=self._n_fft, audio_transform_type=self._audio_transform_type) - return self._create_dataset_loader(dataset_root, 1, 1, 'validation', transforms, enable_mixup=False) - - def _create_dataset_loader(self, dataset_root, batch_size, batch_size_division, split, transforms, enable_mixup): - dataset = Fsd50kDataset(dataset_root, split=split, transforms=transforms, enable_mixup=enable_mixup) - return torch.utils.data.DataLoader(dataset, batch_size=batch_size // batch_size_division, shuffle=True, + return self._create_dataset_loader(dataset_root, 1, 1, 'testing', transforms) + + def _create_dataset_loader(self, dataset_root, batch_size, batch_size_division, split, transforms): + if self._dataset_type == 'audio_set': + dataset = AudioSetDataset(dataset_root, split=split, transforms=transforms, + enhanced_targets=self._enhanced_targets) + elif self._dataset_type == 'fsd50k': + dataset = Fsd50kDataset(dataset_root, split=split, transforms=transforms, + enhanced_targets=self._enhanced_targets) + else: + raise ValueError('Invalid dataset type') + sampler = ImbalancedMulticlassAudioDescriptorDatasetSampler(dataset) if split == 'training' else None + return torch.utils.data.DataLoader(dataset, batch_size=batch_size // batch_size_division, sampler=sampler, num_workers=4) def _clear_between_training(self): diff --git a/tools/dnn_training/backbone/datasets/__init__.py b/tools/dnn_training/backbone/datasets/__init__.py index 398bc793..b4b61a76 100644 --- a/tools/dnn_training/backbone/datasets/__init__.py +++ b/tools/dnn_training/backbone/datasets/__init__.py @@ -1,2 +1,3 @@ from backbone.datasets.classification_image_net import ClassificationImageNet from backbone.datasets.classification_open_images import ClassificationOpenImages +from backbone.datasets.mixup_classification_dataset import MixupClassificationDataset diff --git a/tools/dnn_training/backbone/datasets/classification_image_net.py b/tools/dnn_training/backbone/datasets/classification_image_net.py index fd4e29f1..47ae6125 100644 --- a/tools/dnn_training/backbone/datasets/classification_image_net.py +++ b/tools/dnn_training/backbone/datasets/classification_image_net.py @@ -62,6 +62,9 @@ def __getitem__(self, index): return image, self._images[index]['class_index'] + def class_count(self): + return CLASS_COUNT + def _is_jpeg(filename): filename = filename.upper() diff --git a/tools/dnn_training/backbone/datasets/classification_open_images.py b/tools/dnn_training/backbone/datasets/classification_open_images.py index efe58d95..3e658e1c 100644 --- a/tools/dnn_training/backbone/datasets/classification_open_images.py +++ b/tools/dnn_training/backbone/datasets/classification_open_images.py @@ -50,3 +50,6 @@ def _list_images(self): def __getitem__(self, index): image, target, _ = super(ClassificationOpenImages, self).__getitem__(index) return image, target + + def class_count(self): + return CLASS_COUNT diff --git a/tools/dnn_training/backbone/datasets/mixup_classification_dataset.py b/tools/dnn_training/backbone/datasets/mixup_classification_dataset.py new file mode 100644 index 00000000..481ca4eb --- /dev/null +++ b/tools/dnn_training/backbone/datasets/mixup_classification_dataset.py @@ -0,0 +1,32 @@ +import random + +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F + + +class MixupClassificationDataset(Dataset): + def __init__(self, dataset, mixup_rate=0.5, mixup_alpha=10.0): + self._dataset = dataset + self._mixup_rate = mixup_rate + self._mixup_alpha = mixup_alpha + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, index): + image, class_index = self._dataset[index] + target = F.one_hot(torch.tensor(class_index), self._dataset.class_count()).float() + + if random.random() < self._mixup_rate: + mixup_index = random.randrange(len(self)) + mixup_image, mixup_class_index = self._dataset[mixup_index] + mixup_target = F.one_hot(torch.tensor(mixup_class_index), self._dataset.class_count()).float() + + l = random.betavariate(self._mixup_alpha, self._mixup_alpha) + image = l * image + (1.0 - l) * mixup_image + target = l * target + (1.0 - l) * mixup_target + + return image, target + + diff --git a/tools/dnn_training/backbone/trainers/__init__.py b/tools/dnn_training/backbone/trainers/__init__.py index d33ee0d0..7c9db3cc 100644 --- a/tools/dnn_training/backbone/trainers/__init__.py +++ b/tools/dnn_training/backbone/trainers/__init__.py @@ -1,2 +1,2 @@ -from backbone.trainers.backbone_trainer import BackboneTrainer +from backbone.trainers.backbone_trainer import BackboneTrainer, IMAGE_SIZE from backbone.trainers.backbone_distillation_trainer import BackboneDistillationTrainer diff --git a/tools/dnn_training/backbone/trainers/backbone_distillation_trainer.py b/tools/dnn_training/backbone/trainers/backbone_distillation_trainer.py index b2e24c83..37d69381 100644 --- a/tools/dnn_training/backbone/trainers/backbone_distillation_trainer.py +++ b/tools/dnn_training/backbone/trainers/backbone_distillation_trainer.py @@ -11,7 +11,7 @@ from common.metrics import ClassificationAccuracyMetric, LossMetric, LossAccuracyLearningCurves, \ TopNClassificationAccuracyMetric -from backbone.datasets import ClassificationOpenImages, ClassificationImageNet +from backbone.datasets import ClassificationOpenImages, ClassificationImageNet, MixupClassificationDataset from backbone.trainers.backbone_trainer import create_training_image_transform, create_validation_image_transform @@ -41,7 +41,7 @@ def __init__(self, device, student_model, teacher_model, image_net_root='', open self._teacher_model = teacher_model.to(device) self._teacher_model.eval() - self._optimizer = torch.optim.Adam(self._student_model.parameters(), lr=learning_rate, + self._optimizer = torch.optim.AdamW(self._student_model.parameters(), lr=learning_rate, weight_decay=weight_decay) self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer, epoch_count) diff --git a/tools/dnn_training/backbone/trainers/backbone_trainer.py b/tools/dnn_training/backbone/trainers/backbone_trainer.py index 546c3f5e..ac6d5938 100644 --- a/tools/dnn_training/backbone/trainers/backbone_trainer.py +++ b/tools/dnn_training/backbone/trainers/backbone_trainer.py @@ -14,7 +14,7 @@ from common.metrics import ClassificationAccuracyMetric, LossMetric, LossAccuracyLearningCurves, \ TopNClassificationAccuracyMetric -from backbone.datasets import ClassificationOpenImages, ClassificationImageNet +from backbone.datasets import ClassificationOpenImages, ClassificationImageNet, MixupClassificationDataset IMAGE_SIZE = (224, 224) @@ -80,7 +80,9 @@ def _create_dataset_loader(self, dataset_root, batch_size, batch_size_division, else: raise ValueError('Invalid dataset type') - return torch.utils.data.DataLoader(dataset, batch_size=batch_size // batch_size_division, shuffle=shuffle, + return torch.utils.data.DataLoader(MixupClassificationDataset(dataset), + batch_size=batch_size // batch_size_division, + shuffle=shuffle, num_workers=4) def _clear_between_training(self): diff --git a/tools/dnn_training/backbone/vit.py b/tools/dnn_training/backbone/vit.py new file mode 100644 index 00000000..fb4defc7 --- /dev/null +++ b/tools/dnn_training/backbone/vit.py @@ -0,0 +1,181 @@ +import math + +import torch +import torch.nn as nn +from torch.nn.modules.utils import _pair +import torch.nn.functional as F + + +class ImagePatchEmbedding(nn.Module): + def __init__(self, in_channels, patch_size, stride, embedding_size): + super(ImagePatchEmbedding, self).__init__() + self._projection = nn.Conv2d(in_channels, out_channels=embedding_size, kernel_size=patch_size, stride=stride) + + def forward(self, x): + return self._projection(x) + + +class SelfAttention(nn.Module): + def __init__(self, dim, head_count, attention_dropout_rate=0.0, projection_dropout_rate=0.0): + super(SelfAttention, self).__init__() + + self._head_count = head_count + head_dim = dim // head_count + self._scale = 1.0 / math.sqrt(head_dim) + + self._qkv = nn.Linear(in_features=dim, out_features=3 * dim) + self._attention_dropout = nn.Dropout(attention_dropout_rate) + self._projection = nn.Linear(in_features=dim, out_features=dim) + self._projection_dropout = nn.Dropout(projection_dropout_rate) + + def forward(self, x): + B, N, C = x.shape + qkv = self._qkv(x).reshape(B, N, 3, self._head_count, C // self._head_count).permute(2, 0, 3, 1, 4) + + attention = (torch.matmul(qkv[0], qkv[1].transpose(-2, -1))) * self._scale + normalized_attention = F.softmax(attention, dim=-1) + normalized_attention = self._attention_dropout(normalized_attention) + + y = torch.matmul(normalized_attention, qkv[2]).transpose(1, 2).reshape(B, N, C) + return self._projection_dropout(self._projection(y)) + + +class EncoderBlock(nn.Module): + def __init__(self, dim, head_count, dropout_rate=0.0, attention_dropout_rate=0.0): + super(EncoderBlock, self).__init__() + self._attention = SelfAttention(dim, + head_count, + projection_dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate) + self._attention_norm = nn.LayerNorm(dim) + self._mlp = nn.Sequential( + nn.Linear(in_features=dim, out_features=4 * dim), + nn.GELU(), + nn.Dropout(dropout_rate), + nn.Linear(in_features=4 * dim, out_features=dim), + nn.Dropout(dropout_rate) + ) + self._mlp_norm = nn.LayerNorm(dim) + + def forward(self, x): + y0 = x + self._attention(self._attention_norm(x)) + y1 = y0 + self._mlp(self._mlp_norm(y0)) + return y1 + + +class Vit(nn.Module): + def __init__(self, image_size, in_channels=3, patch_size=16, stride=16, embedding_size=768, head_count=12, depth=7, + class_count=1000, distilled=False, patchout_time=5, patchout_freq=5, + dropout_rate=0.1, attention_dropout_rate=0.1, output_embeddings=False): + super(Vit, self).__init__() + + self._patchout_time = patchout_time + self._patchout_freq = patchout_freq + self._class_count = class_count + self._output_embeddings = output_embeddings + + image_size = _pair(image_size) + grid_size = (image_size[0] // stride, image_size[1] // stride) + self._image_patch_embedding = ImagePatchEmbedding(in_channels, patch_size, stride, embedding_size) + + self._class_token = nn.Parameter(torch.zeros(1, 1, embedding_size)) + self._distillation_token = nn.Parameter(torch.zeros(1, 1, embedding_size)) if distilled else None + + self._class_positional_embedding = nn.Parameter(torch.zeros(1, 1, embedding_size)) + self._distillation_positional_embedding = nn.Parameter(torch.zeros(1, 1, embedding_size)) if distilled else None + self._freq_positional_embedding = nn.Parameter(torch.zeros(1, embedding_size, grid_size[0], 1)) + self._time_positional_embedding = nn.Parameter(torch.zeros(1, embedding_size, 1, grid_size[1])) + + self._encoders = nn.Sequential(*[EncoderBlock(embedding_size, head_count, dropout_rate, attention_dropout_rate) + for _ in range(depth)], + nn.LayerNorm(embedding_size)) + self._head = nn.Linear(in_features=embedding_size, out_features=class_count, bias=False) + + if distilled: + self._head_distillation = nn.Linear(in_features=embedding_size, out_features=class_count, bias=False) + + self._embedding_dropout = nn.Dropout(dropout_rate) + + self._init_embeddings() + for name, module in self.named_modules(): + self._init_module_weights(module, name) + + def class_count(self): + return self._class_count + + def no_weight_decay_parameters(self): + return {'_class_token', '_distillation_token', '_class_positional_embedding', + '_distillation_positional_embedding', '_freq_positional_embedding', '_time_positional_embedding'} + + def _init_embeddings(self): + nn.init.trunc_normal_(self._class_token, std=0.02) + if self._distillation_token is not None: + nn.init.trunc_normal_(self._distillation_token, std=0.02) + nn.init.trunc_normal_(self._distillation_positional_embedding, std=0.02) + nn.init.trunc_normal_(self._class_positional_embedding, std=0.02) + nn.init.trunc_normal_(self._freq_positional_embedding, std=0.02) + nn.init.trunc_normal_(self._time_positional_embedding, std=0.02) + + def _init_module_weights(self, module, name): + if isinstance(module, nn.Linear): + if 'head' in name: + nn.init.zeros_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + else: + nn.init.trunc_normal_(module.weight, std=.02) + nn.init.zeros_(module.bias) + + def forward(self, x): + embedding = self._forward_embeddings(x) + all_features = self._encoders(embedding) + class_token_embedding = F.normalize(all_features[:, 0], p=2.0, dim=1) + y = self._head(class_token_embedding) + + if self._distillation_token is not None: + y_distillation = self._head_distillation(F.normalize(all_features[:, 1], p=2.0, dim=1)) + if self.training: + if self._output_embeddings: + return class_token_embedding, y, y_distillation + else: + return y, y_distillation + else: + if self._output_embeddings: + return class_token_embedding, (x + y_distillation) / 2, + else: + return (x + y_distillation) / 2 + else: + if self._output_embeddings: + return class_token_embedding, y + else: + return y + + def _forward_embeddings(self, x): + embedding = self._image_patch_embedding(x) + + if embedding.size(-1) < self._time_positional_embedding.size(-1): + time_positional_embedding = self._time_positional_embedding[:, :, :, :embedding.size(-1)] + elif embedding.size(-1) > self._time_positional_embedding.size(-1): + raise ValueError('The image is bigger than the positional embedding') + else: + time_positional_embedding = self._time_positional_embedding + + embedding = embedding + time_positional_embedding + self._freq_positional_embedding + + # Structured patchout + _, _, F, T = embedding.size() + if self.training and self._patchout_time > 0: + random_indices = torch.randperm(T)[:T - self._patchout_time].sort().values + embedding = embedding[:, :, :, random_indices] + if self.training and self._patchout_freq > 0: + random_indices = torch.randperm(F)[:F - self._patchout_freq].sort().values + embedding = embedding[:, :, random_indices, :] + + embedding = embedding.flatten(2).transpose(1, 2) + B, N, C = embedding.size() + class_token = self._class_token.expand(B, -1, -1) + self._class_positional_embedding# + if self._distillation_token is None: + return torch.cat([class_token, embedding], dim=1) + else: + distillation_token = self._distillation_token.expand(B, -1, -1) + self._distillation_positional_embedding + return torch.cat([class_token, distillation_token, embedding], dim=1) diff --git a/tools/dnn_training/common/criterions/__init__.py b/tools/dnn_training/common/criterions/__init__.py index debde36b..e788d974 100644 --- a/tools/dnn_training/common/criterions/__init__.py +++ b/tools/dnn_training/common/criterions/__init__.py @@ -1,4 +1,5 @@ from common.criterions.am_softmax_loss import AmSoftmaxLoss +from common.criterions.arc_face import ArcFaceLoss from common.criterions.jensen_shannon_divergence import JensenShannonDivergence from common.criterions.ohem_cross_entropy_loss import OhemCrossEntropyLoss from common.criterions.sigmoid_focal_loss import SigmoidFocalLossWithLogits, SigmoidFocalLoss diff --git a/tools/dnn_training/common/criterions/am_softmax_loss.py b/tools/dnn_training/common/criterions/am_softmax_loss.py index 6fb65978..031431fc 100644 --- a/tools/dnn_training/common/criterions/am_softmax_loss.py +++ b/tools/dnn_training/common/criterions/am_softmax_loss.py @@ -18,7 +18,7 @@ def forward(self, scores, target): scores = scores.clone() numerator = self._s * (scores[range(scores.size(0)), target] - self._m) - scores[range(scores.size(0)), target] = 0.0 + scores[range(scores.size(0)), target] = -float('inf') denominator = torch.exp(numerator) + torch.sum(torch.exp(self._s * scores), dim=1) loss = numerator - torch.log(denominator) return -loss.mean() diff --git a/tools/dnn_training/common/criterions/arc_face.py b/tools/dnn_training/common/criterions/arc_face.py new file mode 100644 index 00000000..c5960899 --- /dev/null +++ b/tools/dnn_training/common/criterions/arc_face.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + + +class ArcFaceLoss(nn.Module): + def __init__(self, s=64.0, m=0.5, start_annealing_epoch=0, end_annealing_epoch=0): + super(ArcFaceLoss, self).__init__() + self._s = s + self._m = 0.0 + self._target_m = m + + self._epoch = -1 + self._start_annealing_epoch = start_annealing_epoch + self._end_annealing_epoch = end_annealing_epoch + self.next_epoch() + + def forward(self, scores, target): + angles = torch.arccos(scores) + scores = scores.clone() + + numerator = self._s * torch.cos(angles[range(scores.size(0)), target] + self._m) + + scores[range(scores.size(0)), target] = -float('inf') + denominator = torch.exp(numerator) + torch.sum(torch.exp(self._s * scores), dim=1) + loss = numerator - torch.log(denominator) + return -loss.mean() + + def next_epoch(self): + self._epoch += 1 + if self._epoch >= self._end_annealing_epoch: + self._m = self._target_m + elif self._epoch < self._start_annealing_epoch: + self._m = 0.0 + else: + diff = (self._end_annealing_epoch - self._start_annealing_epoch) + self._m = self._target_m * (self._epoch - self._start_annealing_epoch) / diff diff --git a/tools/dnn_training/common/datasets/audio_transform_utils.py b/tools/dnn_training/common/datasets/audio_transform_utils.py index f5ce316f..3b09be35 100644 --- a/tools/dnn_training/common/datasets/audio_transform_utils.py +++ b/tools/dnn_training/common/datasets/audio_transform_utils.py @@ -12,6 +12,7 @@ def to_mono(waveform): def normalize(waveform): + waveform = waveform - waveform.mean() s = max(waveform.min().abs(), waveform.max().abs()) + 1e-6 return waveform / s @@ -73,7 +74,7 @@ def forward(self, x): if random.random() < self._p: n_steps = random.randint(self._min_steps, self._max_steps) - return torch.from_numpy(pitch_shift(x[0].numpy(), self._sample_rate, n_steps)).unsqueeze(0) + return torch.from_numpy(pitch_shift(x[0].numpy(), sr=self._sample_rate, n_steps=n_steps)).unsqueeze(0) else: return x @@ -91,6 +92,6 @@ def forward(self, x): if random.random() < self._p: rate = random.uniform(self._min_rate, self._max_rate) - return torch.from_numpy(time_stretch(x[0].numpy(), rate)).unsqueeze(0) + return torch.from_numpy(time_stretch(x[0].numpy(), rate=rate)).unsqueeze(0) else: return x diff --git a/tools/dnn_training/common/metrics/classification_accuracy_metric.py b/tools/dnn_training/common/metrics/classification_accuracy_metric.py index 819f70bc..915981ab 100644 --- a/tools/dnn_training/common/metrics/classification_accuracy_metric.py +++ b/tools/dnn_training/common/metrics/classification_accuracy_metric.py @@ -10,6 +10,11 @@ def clear(self): def add(self, predicted_class_scores, target_classes): predicted_classes = predicted_class_scores.argmax(dim=1) + if len(target_classes.size()) == 2: + target_classes = target_classes.argmax(dim=1) + elif len(target_classes.size()) > 2: + raise ValueError('Invalid target_classes size') + self._good += (predicted_classes == target_classes).sum().item() self._total += target_classes.size()[0] @@ -32,6 +37,11 @@ def clear(self): def add(self, predicted_class_scores, target_classes): top_n_predicted_classes = predicted_class_scores.argsort(dim=1, descending=True)[:, :self._n] + if len(target_classes.size()) == 2: + target_classes = target_classes.argmax(dim=1) + elif len(target_classes.size()) > 2: + raise ValueError('Invalid target_classes size') + for i in range(self._n): self._good += (top_n_predicted_classes[:, i] == target_classes).sum().item() self._total += target_classes.size()[0] diff --git a/tools/dnn_training/common/metrics/multiclass_classification_mean_average_precision.py b/tools/dnn_training/common/metrics/multiclass_classification_mean_average_precision.py index 809fb0fe..e913c797 100644 --- a/tools/dnn_training/common/metrics/multiclass_classification_mean_average_precision.py +++ b/tools/dnn_training/common/metrics/multiclass_classification_mean_average_precision.py @@ -5,8 +5,10 @@ class MulticlassClassificationMeanAveragePrecisionMetric: - def __init__(self, class_count): + def __init__(self, class_count, apply_sigmoid=True): self._class_count = class_count + self._apply_sigmoid = apply_sigmoid + self._predictions = [] self._targets = [] @@ -15,7 +17,10 @@ def clear(self): self._targets = [] def add(self, prediction, target): - prediction = torch.sigmoid(prediction).cpu().detach().numpy() + if self._apply_sigmoid: + prediction = torch.sigmoid(prediction) + + prediction = prediction.cpu().detach().numpy() target = (target > 0.0).float().cpu().detach().numpy() self._predictions.append(prediction) diff --git a/tools/dnn_training/common/modules/__init__.py b/tools/dnn_training/common/modules/__init__.py index 7a8da63d..099be463 100644 --- a/tools/dnn_training/common/modules/__init__.py +++ b/tools/dnn_training/common/modules/__init__.py @@ -1,4 +1,4 @@ -from common.modules.am_softmax_linear import AmSoftmaxLinear +from common.modules.normalized_linear import NormalizedLinear from common.modules.depth_wise_separable_conv2d import DepthWiseSeparableConv2d from common.modules.global_avg_pool_1d import global_avg_pool_1d, GlobalAvgPool1d from common.modules.global_avg_pool_2d import global_avg_pool_2d, GlobalAvgPool2d, GlobalHeightAvgPool2d diff --git a/tools/dnn_training/common/modules/am_softmax_linear.py b/tools/dnn_training/common/modules/normalized_linear.py similarity index 80% rename from tools/dnn_training/common/modules/am_softmax_linear.py rename to tools/dnn_training/common/modules/normalized_linear.py index f7e5dcf2..4e64506b 100644 --- a/tools/dnn_training/common/modules/am_softmax_linear.py +++ b/tools/dnn_training/common/modules/normalized_linear.py @@ -2,9 +2,9 @@ import torch.nn.functional as F -class AmSoftmaxLinear(nn.Module): +class NormalizedLinear(nn.Module): def __init__(self, in_features, out_features): - super(AmSoftmaxLinear, self).__init__() + super(NormalizedLinear, self).__init__() self._weight = nn.Linear(in_features, out_features, bias=False).weight def forward(self, x): diff --git a/tools/dnn_training/common/trainers/__init__.py b/tools/dnn_training/common/trainers/__init__.py index 233955c0..0dedffe3 100644 --- a/tools/dnn_training/common/trainers/__init__.py +++ b/tools/dnn_training/common/trainers/__init__.py @@ -1 +1,2 @@ from common.trainers.trainer import Trainer +from common.trainers.distillation_trainer import DistillationTrainer diff --git a/tools/dnn_training/common/trainers/distillation_trainer.py b/tools/dnn_training/common/trainers/distillation_trainer.py new file mode 100644 index 00000000..cf475a34 --- /dev/null +++ b/tools/dnn_training/common/trainers/distillation_trainer.py @@ -0,0 +1,179 @@ +import os + +import torch +import torch.nn as nn + +from tqdm import tqdm + +from common.modules import load_checkpoint + + +class DistillationTrainer: + def __init__(self, device, student_model, teacher_model, dataset_root='', output_path='', + epoch_count=10, learning_rate=0.01, weight_decay=0.0, batch_size=128, batch_size_division=4, + student_model_checkpoint=None, teacher_model_checkpoint=None): + self._device = device + self._output_path = output_path + os.makedirs(self._output_path, exist_ok=True) + + self._epoch_count = epoch_count + self._batch_size = batch_size + self._batch_size_division = batch_size_division + + self._criterion = self._create_criterion(student_model, teacher_model) + + if teacher_model_checkpoint is None: + raise ValueError('teacher_model_checkpoint should be set.') + + load_checkpoint(teacher_model, teacher_model_checkpoint, strict=True) + if student_model_checkpoint is not None: + load_checkpoint(student_model, student_model_checkpoint, strict=False) + if device.type == 'cuda' and torch.cuda.device_count() > 1: + print('DataParallel - GPU count:', torch.cuda.device_count()) + student_model = nn.DataParallel(student_model) + teacher_model = nn.DataParallel(teacher_model) + + self._student_model = student_model.to(device) + self._teacher_model = teacher_model.to(device) + + no_weight_decay_parameters = getattr(self._criterion, 'no_weight_decay_parameters', None) + if no_weight_decay_parameters is None: + no_weight_decay_parameters = {} + else: + no_weight_decay_parameters = no_weight_decay_parameters() + no_weight_decay_parameters = [parameter for name, parameter in student_model.named_parameters() + if name.endswith('.bias') or name in no_weight_decay_parameters] + other_parameters = [parameter for name, parameter in student_model.named_parameters() + if not name.endswith('.bias') and name not in no_weight_decay_parameters] + parameter_groups = [ + {'params': other_parameters}, + {'params': no_weight_decay_parameters, 'weight_decay': 0.0} + ] + self._optimizer = torch.optim.AdamW(parameter_groups, lr=learning_rate, weight_decay=weight_decay) + self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer, epoch_count) + + self._training_dataset_loader = self._create_training_dataset_loader(dataset_root, + batch_size, + batch_size_division) + self._validation_dataset_loader = self._create_validation_dataset_loader(dataset_root, + batch_size, + batch_size_division) + self._testing_dataset_loader = self._create_testing_dataset_loader(dataset_root, + batch_size, + batch_size_division) + if self._testing_dataset_loader is None: + self._testing_dataset_loader = self._validation_dataset_loader + + def _create_criterion(self, student_model, teacher_model): + raise NotImplementedError() + + def _create_training_dataset_loader(self, dataset_root, batch_size, batch_size_division): + raise NotImplementedError() + + def _create_validation_dataset_loader(self, dataset_root, batch_size, batch_size_division): + raise NotImplementedError() + + def _create_testing_dataset_loader(self, dataset_root, batch_size, batch_size_division): + return None + + def train(self): + self._teacher_model.eval() + + self._clear_between_training() + for epoch in range(self._epoch_count): + print('Training - Epoch [{}/{}]'.format(epoch + 1, self._epoch_count), flush=True) + self._train_one_epoch() + + print('\nValidation - Epoch [{}/{}]'.format(epoch + 1, self._epoch_count), flush=True) + self._validate() + self._scheduler.step() + next_epoch_method = getattr(self._criterion, 'next_epoch', None) + if next_epoch_method is not None: + next_epoch_method() + + self._print_performances() + self._save_learning_curves() + self._save_states(epoch + 1) + + with torch.no_grad(): + self._student_model.eval() + self._evaluate(self._student_model, self._device, self._testing_dataset_loader, self._output_path) + + def _clear_between_training(self): + raise NotImplementedError() + + def _train_one_epoch(self): + self._clear_between_training_epoch() + + self._student_model.train() + self._optimizer.zero_grad() + + division = 0 + for data in tqdm(self._training_dataset_loader): + input_data = data[0].to(self._device) + with torch.no_grad(): + teacher_model_output = self._teacher_model(input_data) + student_model_output = self._student_model(input_data) + target = self._move_target_to_device(data[1], self._device) + loss = self._criterion(student_model_output, target, teacher_model_output) + + if torch.all(torch.isfinite(loss)): + loss.backward() + self._measure_training_metrics(loss, student_model_output, target) + else: + print('Warning the loss is not finite.') + + division += 1 + if division == self._batch_size_division: + division = 0 + torch.nn.utils.clip_grad_value_(self._student_model.parameters(), 1) + self._optimizer.step() + self._optimizer.zero_grad() + + if division != 0: + torch.nn.utils.clip_grad_value_(self._student_model.parameters(), 1) + self._optimizer.step() + self._optimizer.zero_grad() + + def _clear_between_training_epoch(self): + raise NotImplementedError() + + def _move_target_to_device(self, target, device): + raise NotImplementedError() + + def _measure_training_metrics(self, loss, model_output, target): + raise NotImplementedError() + + def _validate(self): + with torch.no_grad(): + self._clear_between_validation_epoch() + + self._student_model.eval() + + for data in tqdm(self._validation_dataset_loader): + input_data = data[0].to(self._device) + teacher_model_output = self._teacher_model(input_data) + student_model_output = self._student_model(input_data) + target = self._move_target_to_device(data[1], self._device) + loss = self._criterion(student_model_output, target, teacher_model_output) + + self._measure_validation_metrics(loss, student_model_output, target) + + def _clear_between_validation_epoch(self): + raise NotImplementedError() + + def _measure_validation_metrics(self, loss, model_output, target): + raise NotImplementedError() + + def _print_performances(self): + raise NotImplementedError() + + def _save_learning_curves(self): + raise NotImplementedError() + + def _save_states(self, epoch): + torch.save(self._student_model.state_dict(), + os.path.join(self._output_path, 'model_checkpoint_epoch_{}.pth'.format(epoch))) + + def _evaluate(self, model, device, dataset_loader, output_path): + raise NotImplementedError() diff --git a/tools/dnn_training/common/trainers/trainer.py b/tools/dnn_training/common/trainers/trainer.py index 6cc9196c..198353fd 100644 --- a/tools/dnn_training/common/trainers/trainer.py +++ b/tools/dnn_training/common/trainers/trainer.py @@ -1,5 +1,4 @@ import os -import time import torch import torch.nn as nn @@ -30,13 +29,20 @@ def __init__(self, device, model, dataset_root='', output_path='', model = nn.DataParallel(model) self._model = model.to(device) - bias_parameters = [parameter for name, parameter in model.named_parameters() if name.endswith('.bias')] - other_parameters = [parameter for name, parameter in model.named_parameters() if not name.endswith('.bias')] + no_weight_decay_parameters = getattr(self._criterion, 'no_weight_decay_parameters', None) + if no_weight_decay_parameters is None: + no_weight_decay_parameters = {} + else: + no_weight_decay_parameters = no_weight_decay_parameters() + no_weight_decay_parameters = [parameter for name, parameter in model.named_parameters() + if name.endswith('.bias') or name in no_weight_decay_parameters] + other_parameters = [parameter for name, parameter in model.named_parameters() + if not name.endswith('.bias') and name not in no_weight_decay_parameters] parameter_groups = [ {'params': other_parameters}, - {'params': bias_parameters, 'weight_decay': 0.0} + {'params': no_weight_decay_parameters, 'weight_decay': 0.0} ] - self._optimizer = torch.optim.Adam(parameter_groups, lr=learning_rate, weight_decay=weight_decay) + self._optimizer = torch.optim.AdamW(parameter_groups, lr=learning_rate, weight_decay=weight_decay) self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer, epoch_count) self._training_dataset_loader = self._create_training_dataset_loader(dataset_root, @@ -73,11 +79,15 @@ def train(self): print('\nValidation - Epoch [{}/{}]'.format(epoch + 1, self._epoch_count), flush=True) self._validate() self._scheduler.step() + next_epoch_method = getattr(self._criterion, 'next_epoch', None) + if next_epoch_method is not None: + next_epoch_method() self._print_performances() self._save_learning_curves() self._save_states(epoch + 1) + with torch.no_grad(): self._model.eval() self._evaluate(self._model, self._device, self._testing_dataset_loader, self._output_path) diff --git a/tools/dnn_training/export_audio_descriptor_extractor.py b/tools/dnn_training/export_audio_descriptor_extractor.py index eb7412d5..0e7ea659 100644 --- a/tools/dnn_training/export_audio_descriptor_extractor.py +++ b/tools/dnn_training/export_audio_descriptor_extractor.py @@ -6,6 +6,14 @@ from common.file_presence_checker import terminate_if_already_exported +from backbone.vit import Vit + +from audio_descriptor.backbones import Mnasnet0_5, Mnasnet1_0, Resnet18, Resnet34, Resnet50, Resnet101 +from audio_descriptor.backbones import OpenFaceInception, ThinResnet34, EcapaTdnn, SmallEcapaTdnn +from audio_descriptor.audio_descriptor_extractor import AudioDescriptorExtractor, AudioDescriptorExtractorVLAD +from audio_descriptor.audio_descriptor_extractor import AudioDescriptorExtractorSAP, AudioDescriptorExtractorPSLA + + def main(): parser = argparse.ArgumentParser(description='Export audio descriptor extractor') parser.add_argument('--backbone_type', choices=['mnasnet0.5', 'mnasnet1.0', @@ -25,7 +33,8 @@ def main(): help='Choose the dataset class count when criterion_type is "cross_entropy_loss" or ' '"am_softmax_loss"', default=None) - parser.add_argument('--am_softmax_linear', action='store_true', help='Use "AmSoftmaxLinear" instead of "nn.Linear"') + parser.add_argument('--normalized_linear', action='store_true', + help='Use "NormalizedLinear" instead of "nn.Linear"') parser.add_argument('--conv_bias', action='store_true', help='Set bias=True to Conv2d (open_face_inception only)') parser.add_argument('--output_dir', type=str, help='Choose the output directory', required=True) @@ -45,16 +54,74 @@ def main(): from common.model_exporter import export_model - from train_audio_descriptor_extractor import create_model - image_size = (args.n_features, args.waveform_size // (args.n_fft // 2) + 1) model = create_model(args.backbone_type, args.n_features, args.embedding_size, args.dataset_class_count, - args.am_softmax_linear, args.pooling_layer, conv_bias=args.conv_bias) + args.normalized_linear, args.pooling_layer, conv_bias=args.conv_bias) x = torch.ones((1, 1, image_size[0], image_size[1])) keys_to_remove = ['_classifier._weight'] if args.dataset_class_count is None else [] export_model(model, args.model_checkpoint, x, args.output_dir, args.torch_script_filename, args.trt_filename, trt_fp16=args.trt_fp16, keys_to_remove=keys_to_remove) +def create_model(backbone_type, n_features, embedding_size, + class_count=None, normalized_linear=False, pooling_layer='avg', conv_bias=False): + pretrained = True + if backbone_type == 'passt_s_n': + return Vit((n_features, 1000), embedding_size=embedding_size, class_count=class_count, + in_channels=1, depth=12, dropout_rate=0.0, attention_dropout_rate=0.0, output_embeddings=True) + elif backbone_type == 'passt_s_n_l': + return Vit((n_features, 1000), embedding_size=embedding_size, class_count=class_count, + in_channels=1, depth=7, dropout_rate=0.0, attention_dropout_rate=0.0, output_embeddings=True) + + backbone = create_backbone(backbone_type, n_features, pretrained, conv_bias) + if pooling_layer == 'avg': + return AudioDescriptorExtractor(backbone, embedding_size=embedding_size, + class_count=class_count, normalized_linear=normalized_linear) + elif pooling_layer == 'vlad': + return AudioDescriptorExtractorVLAD(backbone, embedding_size=embedding_size, + class_count=class_count, normalized_linear=normalized_linear) + elif pooling_layer == 'sap': + return AudioDescriptorExtractorSAP(backbone, embedding_size=embedding_size, + class_count=class_count, normalized_linear=normalized_linear) + elif pooling_layer == 'psla': + return AudioDescriptorExtractorPSLA(backbone, embedding_size=embedding_size, + class_count=class_count, normalized_linear=normalized_linear) + else: + raise ValueError('Invalid pooling layer') + + +def create_backbone(backbone_type, n_features, pretrained, conv_bias=False): + if backbone_type == 'mnasnet0.5': + return Mnasnet0_5(pretrained=pretrained) + elif backbone_type == 'mnasnet1.0': + return Mnasnet1_0(pretrained=pretrained) + elif backbone_type == 'resnet18': + return Resnet18(pretrained=pretrained) + elif backbone_type == 'resnet34': + return Resnet34(pretrained=pretrained) + elif backbone_type == 'resnet50': + return Resnet50(pretrained=pretrained) + elif backbone_type == 'resnet101': + return Resnet101(pretrained=pretrained) + elif backbone_type == 'open_face_inception': + return OpenFaceInception(conv_bias) + elif backbone_type == 'thin_resnet_34': + return ThinResnet34() + elif backbone_type == 'ecapa_tdnn_512': + return EcapaTdnn(n_features, channels=512) + elif backbone_type == 'ecapa_tdnn_1024': + return EcapaTdnn(n_features, channels=1024) + elif backbone_type == 'small_ecapa_tdnn_128': + return SmallEcapaTdnn(n_features, channels=128) + elif backbone_type == 'small_ecapa_tdnn_256': + return SmallEcapaTdnn(n_features, channels=256) + elif backbone_type == 'small_ecapa_tdnn_512': + return SmallEcapaTdnn(n_features, channels=512) + elif backbone_type == 'small_ecapa_tdnn_1024': + return SmallEcapaTdnn(n_features, channels=1024) + else: + raise ValueError('Invalid backbone type') + + if __name__ == '__main__': main() diff --git a/tools/dnn_training/export_descriptor_yolo_v4.py b/tools/dnn_training/export_descriptor_yolo.py similarity index 90% rename from tools/dnn_training/export_descriptor_yolo_v4.py rename to tools/dnn_training/export_descriptor_yolo.py index a7613fe1..89db37bb 100644 --- a/tools/dnn_training/export_descriptor_yolo_v4.py +++ b/tools/dnn_training/export_descriptor_yolo.py @@ -8,10 +8,10 @@ def main(): - parser = argparse.ArgumentParser(description='Export descriptor yolo v4') - parser.add_argument('--dataset_type', choices=['coco', 'open_images'], help='Choose the database type', - required=True) - parser.add_argument('--model_type', choices=['yolo_v4', 'yolo_v4_tiny'], + parser = argparse.ArgumentParser(description='Export descriptor yolo') + parser.add_argument('--dataset_type', choices=['coco', 'open_images', 'objects365'], + help='Choose the database type', required=True) + parser.add_argument('--model_type', choices=['yolo_v4', 'yolo_v4_tiny', 'yolo_v7'], help='Choose the model type', required=True) parser.add_argument('--descriptor_size', type=int, help='Choose the descriptor size', required=True) @@ -32,7 +32,7 @@ def main(): from common.model_exporter import export_model - from train_descriptor_yolo_v4 import create_model + from train_descriptor_yolo import create_model model = create_model(args.model_type, args.descriptor_size, args.dataset_type) x = torch.ones((1, 3, model.get_image_size()[0], model.get_image_size()[1])) diff --git a/tools/dnn_training/export_face_descriptor_extractor.py b/tools/dnn_training/export_face_descriptor_extractor.py index 15e55944..5e0af12f 100644 --- a/tools/dnn_training/export_face_descriptor_extractor.py +++ b/tools/dnn_training/export_face_descriptor_extractor.py @@ -6,9 +6,13 @@ from common.file_presence_checker import terminate_if_already_exported +BACKBONE_TYPES = ['open_face', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', + 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7'] + def main(): parser = argparse.ArgumentParser(description='Export face descriptor') + parser.add_argument('--backbone_type', choices=BACKBONE_TYPES, help='Choose the backbone type', required=True) parser.add_argument('--embedding_size', type=int, help='Set the embedding size', required=True) parser.add_argument('--output_dir', type=str, help='Choose the output directory', required=True) @@ -29,9 +33,9 @@ def main(): from common.model_exporter import export_model from face_recognition.datasets import IMAGE_SIZE - from face_recognition.face_descriptor_extractor import FaceDescriptorExtractor + from train_face_descriptor_extractor import create_model - model = FaceDescriptorExtractor(embedding_size=args.embedding_size) + model = create_model(args.backbone_type, args.embedding_size) x = torch.ones((1, 3, IMAGE_SIZE[0], IMAGE_SIZE[1])) export_model(model, args.model_checkpoint, x, args.output_dir, args.torch_script_filename, args.trt_filename, trt_fp16=args.trt_fp16, keys_to_remove=['_classifier._weight']) diff --git a/tools/dnn_training/export_keyword_spotter.py b/tools/dnn_training/export_keyword_spotter.py index 4dc3d1b9..691bc202 100644 --- a/tools/dnn_training/export_keyword_spotter.py +++ b/tools/dnn_training/export_keyword_spotter.py @@ -6,6 +6,8 @@ from common.file_presence_checker import terminate_if_already_exported +from keyword_spotting.keyword_spotter import KeywordSpotter + def main(): parser = argparse.ArgumentParser(description='Export keyword spotter') @@ -30,8 +32,6 @@ def main(): from common.model_exporter import export_model - from train_keyword_spotter import create_model - model = create_model(args.dataset_type) x = torch.ones((1, 1, args.mfcc_feature_count, 51)) @@ -39,5 +39,14 @@ def main(): trt_fp16=args.trt_fp16) +def create_model(dataset_type): + if dataset_type == 'google_speech_commands': + return KeywordSpotter(class_count=36, use_softmax=False) + elif dataset_type == 'ttop_keyword': + return KeywordSpotter(class_count=2, use_softmax=False) + else: + raise ValueError('Invalid database type') + + if __name__ == '__main__': main() diff --git a/tools/dnn_training/export_pose_estimator.py b/tools/dnn_training/export_pose_estimator.py index c01a9b31..af46b3ee 100644 --- a/tools/dnn_training/export_pose_estimator.py +++ b/tools/dnn_training/export_pose_estimator.py @@ -6,11 +6,13 @@ from common.file_presence_checker import terminate_if_already_exported +BACKBONE_TYPES = ['efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', + 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7'] + + def main(): parser = argparse.ArgumentParser(description='Export pose estimator') - parser.add_argument('--backbone_type', choices=['mnasnet0.5', 'mnasnet1.0', 'resnet18', 'resnet34', 'resnet50'], - help='Choose the backbone type', required=True) - parser.add_argument('--upsampling_count', type=int, help='Set the upsamping layer count', required=True) + parser.add_argument('--backbone_type', choices=BACKBONE_TYPES, help='Choose the backbone type', required=True) parser.add_argument('--output_dir', type=str, help='Choose the output directory', required=True) parser.add_argument('--torch_script_filename', type=str, help='Choose the TorchScript filename', required=True) @@ -32,7 +34,7 @@ def main(): from pose_estimation.trainers.pose_estimator_trainer import IMAGE_SIZE from train_pose_estimator import create_model - model = create_model(args.backbone_type, args.upsampling_count) + model = create_model(args.backbone_type) x = torch.ones((1, 3, IMAGE_SIZE[0], IMAGE_SIZE[1])) export_model(model, args.model_checkpoint, x, args.output_dir, args.torch_script_filename, args.trt_filename, trt_fp16=args.trt_fp16) diff --git a/tools/dnn_training/export_yolo_v4.py b/tools/dnn_training/export_yolo.py similarity index 74% rename from tools/dnn_training/export_yolo_v4.py rename to tools/dnn_training/export_yolo.py index 0a0da3ac..9a6e908c 100644 --- a/tools/dnn_training/export_yolo_v4.py +++ b/tools/dnn_training/export_yolo.py @@ -6,9 +6,11 @@ from common.file_presence_checker import terminate_if_already_exported + def main(): - parser = argparse.ArgumentParser(description='Export yolo v4') - parser.add_argument('--model_type', choices=['yolo_v4', 'yolo_v4_tiny'], + parser = argparse.ArgumentParser(description='Export yolo') + parser.add_argument('--dataset_type', choices=['coco', 'objects365'], help='Choose the dataset type', required=True) + parser.add_argument('--model_type', choices=['yolo_v4', 'yolo_v4_tiny', 'yolo_v7', 'yolo_v7_tiny'], help='Choose the model type', required=True) parser.add_argument('--output_dir', type=str, help='Choose the output directory', required=True) @@ -28,25 +30,13 @@ def main(): from common.model_exporter import export_model - from object_detection.modules.yolo_v4_tiny import YoloV4Tiny - from object_detection.modules.yolo_v4 import YoloV4 - - def create_model(model_type): - if model_type == 'yolo_v4': - model = YoloV4() - elif model_type == 'yolo_v4_tiny': - model = YoloV4Tiny() - else: - raise ValueError('Invalid model type') + from object_detection.modules.test_converted_yolo import create_model - return model - - model = create_model(args.model_type) + model = create_model(args.model_type, args.dataset_type, class_probs=True) x = torch.ones((1, 3, model.get_image_size()[0], model.get_image_size()[1])) export_model(model, args.model_checkpoint, x, args.output_dir, args.torch_script_filename, args.trt_filename, trt_fp16=args.trt_fp16) - if __name__ == '__main__': main() diff --git a/tools/dnn_training/face_recognition/criterions/__init__.py b/tools/dnn_training/face_recognition/criterions/__init__.py index 6d0e696d..73b1fc6a 100644 --- a/tools/dnn_training/face_recognition/criterions/__init__.py +++ b/tools/dnn_training/face_recognition/criterions/__init__.py @@ -1 +1,2 @@ -from face_recognition.criterions.face_descriptor_am_softmax_loss import FaceDescriptorAmSoftmaxLoss +from face_recognition.criterions.face_descriptor_loss import FaceDescriptorAmSoftmaxLoss, FaceDescriptorArcFaceLoss, \ + FaceDescriptorCrossEntropyLoss, FaceDescriptorDistillationLoss diff --git a/tools/dnn_training/face_recognition/criterions/face_descriptor_am_softmax_loss.py b/tools/dnn_training/face_recognition/criterions/face_descriptor_am_softmax_loss.py deleted file mode 100644 index e192cfa6..00000000 --- a/tools/dnn_training/face_recognition/criterions/face_descriptor_am_softmax_loss.py +++ /dev/null @@ -1,6 +0,0 @@ -from common.criterions import AmSoftmaxLoss - - -class FaceDescriptorAmSoftmaxLoss(AmSoftmaxLoss): - def forward(self, model_output, target): - return super(FaceDescriptorAmSoftmaxLoss, self).forward(model_output[1], target) diff --git a/tools/dnn_training/face_recognition/criterions/face_descriptor_loss.py b/tools/dnn_training/face_recognition/criterions/face_descriptor_loss.py new file mode 100644 index 00000000..7a957231 --- /dev/null +++ b/tools/dnn_training/face_recognition/criterions/face_descriptor_loss.py @@ -0,0 +1,38 @@ +import torch.nn as nn +import torch.nn.functional as F + +from common.criterions import AmSoftmaxLoss, ArcFaceLoss + + +class FaceDescriptorCrossEntropyLoss(nn.CrossEntropyLoss): + def forward(self, model_output, target): + return super(FaceDescriptorCrossEntropyLoss, self).forward(model_output[1], target) + + +class FaceDescriptorAmSoftmaxLoss(AmSoftmaxLoss): + def forward(self, model_output, target): + return super(FaceDescriptorAmSoftmaxLoss, self).forward(model_output[1], target) + + +class FaceDescriptorArcFaceLoss(ArcFaceLoss): + def forward(self, model_output, target): + return super(FaceDescriptorArcFaceLoss, self).forward(model_output[1], target) + + +class FaceDescriptorDistillationLoss(nn.Module): + def __init__(self, target_criterion, alpha=0.25): + super(FaceDescriptorDistillationLoss, self).__init__() + + self._target_criterion = target_criterion + self._alpha = alpha + + def forward(self, student_model_output, target, teacher_model_output): + student_embedding = student_model_output[0] if isinstance(student_model_output, tuple) else student_model_output + teacher_embedding = teacher_model_output[0] if isinstance(teacher_model_output, tuple) else teacher_model_output + + target_loss = self._target_criterion(student_model_output, target) + teacher_loss = F.mse_loss(student_embedding, teacher_embedding) + return self._alpha * target_loss + (1 - self._alpha) * teacher_loss + + + diff --git a/tools/dnn_training/face_recognition/datasets/__init__.py b/tools/dnn_training/face_recognition/datasets/__init__.py index 08a15990..9c076f94 100644 --- a/tools/dnn_training/face_recognition/datasets/__init__.py +++ b/tools/dnn_training/face_recognition/datasets/__init__.py @@ -1,2 +1,4 @@ from face_recognition.datasets.align_faces import ALIGNED_IMAGE_SIZE as IMAGE_SIZE -from face_recognition.datasets.vggface2_dataset import Vggface2Dataset, LFW_OVERLAPPED_VGGFACE2_CLASS_NAMES +from face_recognition.datasets.face_dataset import FaceDataset, FaceConcatDataset, LFW_OVERLAPPED_VGGFACE2_CLASS_NAMES +from face_recognition.datasets.face_dataset import LFW_OVERLAPPED_MSCELEB1M_CLASS_NAMES +from face_recognition.datasets.imbalanced_face_dataset_sampler import ImbalancedFaceDatasetSampler diff --git a/tools/dnn_training/face_recognition/datasets/align_faces.py b/tools/dnn_training/face_recognition/datasets/align_faces.py index 1b085b54..3ce65b54 100644 --- a/tools/dnn_training/face_recognition/datasets/align_faces.py +++ b/tools/dnn_training/face_recognition/datasets/align_faces.py @@ -1,6 +1,7 @@ import os import argparse import math +import shutil import numpy as np from PIL import Image @@ -10,30 +11,33 @@ import torch.nn.functional as F import torchvision.transforms as transforms +from tqdm import tqdm + from common.modules import load_checkpoint from pose_estimation.pose_estimator import get_coordinates from pose_estimation.trainers.pose_estimator_trainer import IMAGE_SIZE as POSE_ESTIMATOR_IMAGE_SIZE -from train_pose_estimator import create_model +from train_pose_estimator import create_model, BACKBONE_TYPES ALIGNED_IMAGE_SIZE = (128, 96) -PRESENCE_THRESHOLD = 0.4 class FolderFaceAligner: - def __init__(self, pose_estimator_model): - self._pose_estimator_model = pose_estimator_model + def __init__(self, pose_estimator_model, device, presence_threshold, ignore_presence_threshold_for_nose_eyes): + self._device = device + self._pose_estimator_model = pose_estimator_model.to(device) self._pose_estimator_image_transform = transforms.Compose([ transforms.Resize(POSE_ESTIMATOR_IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) + self._presence_threshold = presence_threshold + self._ignore_presence_threshold_for_nose_eyes = ignore_presence_threshold_for_nose_eyes - def align_lfw(self, input_path, output_path): + def align(self, input_path, output_path): person_names = [o for o in os.listdir(input_path) if os.path.isdir(os.path.join(input_path, o))] - for person_name in person_names: - print('Processing {} images'.format(person_name)) + for person_name in tqdm(person_names): self._align_person_images(input_path, output_path, person_name) def _align_person_images(self, input_path, output_path, person_name): @@ -45,10 +49,13 @@ def _align_person_images(self, input_path, output_path, person_name): self._align_person_image(input_path, output_path, person_name, image_filename) except (ValueError, np.linalg.LinAlgError): print('Warning: the alignment is impossible ({})'.format(image_filename)) + if self._ignore_presence_threshold_for_nose_eyes: + shutil.copyfile(os.path.join(input_path, person_name, image_filename), + os.path.join(output_path, person_name, image_filename)) def _align_person_image(self, input_path, output_path, person_name, image_filename): output_size = (ALIGNED_IMAGE_SIZE[1], ALIGNED_IMAGE_SIZE[0]) - landmarks, theoretical_landmark = self._get_landmarks(input_path, person_name, image_filename) + landmarks, theoretical_landmark, _ = self._get_landmarks(input_path, person_name, image_filename) cv2_image = cv2.imread(os.path.join(input_path, person_name, image_filename)) cv2_transform = cv2.getAffineTransform(landmarks.astype(np.float32), @@ -59,71 +66,82 @@ def _align_person_image(self, input_path, output_path, person_name, image_filena grid = F.affine_grid(theta, torch.Size((1, 3, ALIGNED_IMAGE_SIZE[0], ALIGNED_IMAGE_SIZE[1]))) torch_image = torch.from_numpy(cv2_image).permute(2, 0, 1).unsqueeze(0).float() - torch_aligned_image = F.grid_sample(torch_image, grid, mode='nearest').squeeze(0) + torch_aligned_image = F.grid_sample(torch_image, grid, mode='bilinear').squeeze(0) cv2_aligned_image = torch_aligned_image.permute(1, 2, 0).numpy() cv2.imwrite(os.path.join(output_path, person_name, image_filename), cv2_aligned_image) def _get_landmarks(self, input_path, person_name, image_filename): - image = Image.open(os.path.join(input_path, person_name, image_filename)).convert('RGB') - pose_estimator_image = self._pose_estimator_image_transform(image) - pose_heatmaps = self._pose_estimator_model(pose_estimator_image.unsqueeze(0)) - heatmap_coordinates, presence = get_coordinates(pose_heatmaps) + with torch.no_grad(): + image = Image.open(os.path.join(input_path, person_name, image_filename)).convert('RGB') + pose_estimator_image = self._pose_estimator_image_transform(image) + pose_heatmaps = self._pose_estimator_model(pose_estimator_image.unsqueeze(0).to(self._device)) + heatmap_coordinates, presence = get_coordinates(pose_heatmaps) - scaled_coordinates = np.zeros((heatmap_coordinates.size()[1], 2)) + scaled_coordinates = np.zeros((heatmap_coordinates.size()[1], 2)) - for i in range(heatmap_coordinates.size()[1]): - scaled_coordinates[i, 0] = heatmap_coordinates[0, i, 0] / pose_heatmaps.size()[3] * image.width - scaled_coordinates[i, 1] = heatmap_coordinates[0, i, 1] / pose_heatmaps.size()[2] * image.height + for i in range(heatmap_coordinates.size()[1]): + scaled_coordinates[i, 0] = heatmap_coordinates[0, i, 0] / pose_heatmaps.size()[3] * image.width + scaled_coordinates[i, 1] = heatmap_coordinates[0, i, 1] / pose_heatmaps.size()[2] * image.height - return get_landmarks_from_pose(scaled_coordinates, presence[0]) + return get_landmarks_from_pose(scaled_coordinates, presence[0].cpu().numpy(), + self._presence_threshold, self._ignore_presence_threshold_for_nose_eyes) -def get_landmarks_from_pose(pose, presence): - if (presence[0:5] > PRESENCE_THRESHOLD).all(): +def get_landmarks_from_pose(pose, presence, presence_threshold, ignore_presence_threshold_for_nose_eyes=False): + if (presence[0:5] > presence_threshold).all(): eyes_center = (pose[1] + pose[2]) / 2 - hears_center = (pose[3] + pose[4]) / 2 + ears_center = (pose[3] + pose[4]) / 2 landmarks = np.zeros((3, 2)) - landmarks[0] = 2 * pose[0] - eyes_center + hears_center - eyes_center - landmarks[1] = pose[3] - np.array([0, hears_center[1] - eyes_center[1]]) - landmarks[2] = pose[4] - np.array([0, hears_center[1] - eyes_center[1]]) + landmarks[0] = 2 * pose[0] - eyes_center + ears_center - eyes_center + landmarks[1] = pose[3] - np.array([0, ears_center[1] - eyes_center[1]]) + landmarks[2] = pose[4] - np.array([0, ears_center[1] - eyes_center[1]]) - theoretical_landmark = np.array([[0.5, 0.75], + theoretical_landmarks = np.array([[0.5, 0.75], [0.9, 0.25], [0.1, 0.25]]) - return landmarks, theoretical_landmark - elif (presence[0:3] > PRESENCE_THRESHOLD).all() and presence[3] > PRESENCE_THRESHOLD: + return landmarks, theoretical_landmarks, 5 + elif (presence[0:3] > presence_threshold).all() and presence[3] > presence_threshold: landmarks = np.zeros((3, 2)) landmarks[0] = pose[0] landmarks[1] = np.array([pose[3, 0], pose[1, 1]]) landmarks[2] = pose[2] - theoretical_landmark = np.array([[0.25, 0.5], - [0.9, 0.25], + eye_ear_x_diff = landmarks[1, 0] - landmarks[2, 0] + eye_nose_x_diff = landmarks[0, 0] - landmarks[2, 0] + + theoretical_landmarks = np.array([[0.25 + 0.6 * eye_nose_x_diff / eye_ear_x_diff, 0.45], + [0.85, 0.25], [0.25, 0.25]]) - return landmarks, theoretical_landmark + return landmarks, theoretical_landmarks, 4 - elif (presence[0:3] > PRESENCE_THRESHOLD).all() and presence[4] > PRESENCE_THRESHOLD: + elif (presence[0:3] > presence_threshold).all() and presence[4] > presence_threshold: landmarks = np.zeros((3, 2)) landmarks[0] = pose[0] landmarks[1] = pose[1] landmarks[2] = np.array([pose[4, 0], pose[1, 1]]) - theoretical_landmark = np.array([[0.75, 0.5], + eye_ear_x_diff = landmarks[1, 0] - landmarks[2, 0] + eye_nose_x_diff = landmarks[0, 0] - landmarks[2, 0] + + theoretical_landmarks = np.array([[0.15 + 0.6 * eye_nose_x_diff / eye_ear_x_diff, 0.45], [0.75, 0.25], - [0.1, 0.25]]) + [0.15, 0.25]]) - return landmarks, theoretical_landmark + return landmarks, theoretical_landmarks, 4 - elif (presence[0:3] > PRESENCE_THRESHOLD).all(): - theoretical_landmark = np.array([[0.5, 0.5144414], - [0.75, 0.25], - [0.25, 0.25]]) + elif (presence[0:3] > presence_threshold).all() or ignore_presence_threshold_for_nose_eyes: + eyes_x_diff = pose[1, 0] - pose[2, 0] + eye_nose_x_diff = pose[0, 0] - pose[2, 0] + + theoretical_landmarks = np.array([[0.35 + 0.3 * eye_nose_x_diff / eyes_x_diff, 0.5], + [0.7, 0.35], + [0.3, 0.35]]) - return pose[0:3], theoretical_landmark + return pose[0:3], theoretical_landmarks, 3 else: - raise ValueError('The aligment is not possible') + raise ValueError('The alignment is not possible') def cv2_transform_to_theta(transform, source_height, source_width, destination_height, destination_width): @@ -174,26 +192,29 @@ def det_3x3(x): def main(): - parser = argparse.ArgumentParser(description='Align LFW faces') - parser.add_argument('--pose_estimator_backbone_type', - choices=['mnasnet0.5', 'mnasnet1.0', 'resnet18', 'resnet34', 'resnet50'], + parser = argparse.ArgumentParser(description='Align faces') + parser.add_argument('--use_gpu', action='store_true', help='Use the GPU') + parser.add_argument('--pose_estimator_backbone_type', choices=BACKBONE_TYPES, help='Choose the pose estimator backbone type', required=True) - parser.add_argument('--pose_estimator_upsampling_count', type=int, - help='Set the pose estimator upsamping layer count', required=True) parser.add_argument('--pose_estimator_model_checkpoint', type=str, help='Choose the pose estimator model checkpoint file', required=True) + parser.add_argument('--presence_threshold', type=float, help='Choose the presence threshold', required=True) + parser.add_argument('--ignore_presence_threshold_for_nose_eyes', action='store_true', + help='Ignore the presence threshold for nose and eyes keypoint') parser.add_argument('--input', type=str, help='Choose the input path', required=True) parser.add_argument('--output', type=str, help='Choose the output path', required=True) args = parser.parse_args() - pose_estimator_model = create_model(args.pose_estimator_backbone_type, args.pose_estimator_upsampling_count) + device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') + pose_estimator_model = create_model(args.pose_estimator_backbone_type) load_checkpoint(pose_estimator_model, args.pose_estimator_model_checkpoint) pose_estimator_model.eval() - aligner = FolderFaceAligner(pose_estimator_model) - aligner.align_lfw(args.input, args.output) + aligner = FolderFaceAligner(pose_estimator_model, device, + args.presence_threshold, args.ignore_presence_threshold_for_nose_eyes) + aligner.align(args.input, args.output) if __name__ == '__main__': diff --git a/tools/dnn_training/face_recognition/datasets/face_dataset.py b/tools/dnn_training/face_recognition/datasets/face_dataset.py new file mode 100644 index 00000000..24d242cc --- /dev/null +++ b/tools/dnn_training/face_recognition/datasets/face_dataset.py @@ -0,0 +1,150 @@ +import os + +import torch +from PIL import Image + +from torch.utils.data import Dataset + + +LFW_OVERLAPPED_VGGFACE2_CLASS_NAMES = ['n000021', 'n000137', 'n000172', 'n000184', 'n000195', 'n000199', 'n000220', 'n000242', 'n000255', 'n000272', 'n000281', 'n000297', 'n000310', 'n000359', 'n000373', 'n000379', 'n000402', 'n000420', 'n000427', 'n000429', 'n000483', 'n000560', 'n000562', 'n000568', 'n000645', 'n000667', 'n000692', 'n000709', 'n000712', 'n000755', 'n000780', 'n000810', 'n000816', 'n000817', 'n000835', 'n000887', 'n000888', 'n000902', 'n000906', 'n000912', 'n000946', 'n000947', 'n001042', 'n001048', 'n001051', 'n001054', 'n001057', 'n001060', 'n001091', 'n001095', 'n001096', 'n001105', 'n001106', 'n001110', 'n001111', 'n001113', 'n001142', 'n001143', 'n001155', 'n001156', 'n001159', 'n001171', 'n001200', 'n001228', 'n001235', 'n001243', 'n001257', 'n001288', 'n001309', 'n001329', 'n001360', 'n001367', 'n001379', 'n001387', 'n001419', 'n001434', 'n001437', 'n001473', 'n001478', 'n001538', 'n001540', 'n001547', 'n001548', 'n001551', 'n001567', 'n001568', 'n001595', 'n001607', 'n001642', 'n001656', 'n001663', 'n001695', 'n001698', 'n001700', 'n001727', 'n001775', 'n001779', 'n001781', 'n001795', 'n001805', 'n001813', 'n001825', 'n001839', 'n001873', 'n001874', 'n001934', 'n001962', 'n001968', 'n001971', 'n001986', 'n002017', 'n002018', 'n002019', 'n002025', 'n002028', 'n002038', 'n002056', 'n002057', 'n002067', 'n002081', 'n002094', 'n002120', 'n002133', 'n002135', 'n002141', 'n002173', 'n002178', 'n002227', 'n002244', 'n002251', 'n002259', 'n002269', 'n002272', 'n002274', 'n002278', 'n002298', 'n002310', 'n002316', 'n002342', 'n002356', 'n002360', 'n002388', 'n002391', 'n002413', 'n002452', 'n002454', 'n002460', 'n002465', 'n002471', 'n002478', 'n002494', 'n002497', 'n002537', 'n002547', 'n002572', 'n002574', 'n002595', 'n002630', 'n002654', 'n002666', 'n002673', 'n002684', 'n002725', 'n002759', 'n002782', 'n002814', 'n002844', 'n002854', 'n002868', 'n002947', 'n002969', 'n002982', 'n003013', 'n003020', 'n003022', 'n003039', 'n003098', 'n003101', 'n003147', 'n003186', 'n003198', 'n003205', 'n003206', 'n003212', 'n003228', 'n003257', 'n003266', 'n003303', 'n003311', 'n003353', 'n003359', 'n003360', 'n003363', 'n003383', 'n003385', 'n003391', 'n003392', 'n003422', 'n003426', 'n003474', 'n003476', 'n003523', 'n003619', 'n003623', 'n003628', 'n003648', 'n003671', 'n003685', 'n003689', 'n003709', 'n003712', 'n003724', 'n003728', 'n003749', 'n003756', 'n003767', 'n003785', 'n003803', 'n003806', 'n003816', 'n003817', 'n003820', 'n003866', 'n003876', 'n003879', 'n003880', 'n003885', 'n003904', 'n003908', 'n003912', 'n003937', 'n003946', 'n003952', 'n003969', 'n003972', 'n003973', 'n003974', 'n003979', 'n003982', 'n004014', 'n004020', 'n004030', 'n004043', 'n004081', 'n004083', 'n004084', 'n004106', 'n004132', 'n004144', 'n004150', 'n004151', 'n004154', 'n004181', 'n004197', 'n004198', 'n004204', 'n004216', 'n004218', 'n004220', 'n004231', 'n004236', 'n004263', 'n004327', 'n004328', 'n004375', 'n004407', 'n004408', 'n004429', 'n004439', 'n004451', 'n004457', 'n004480', 'n004481', 'n004488', 'n004580', 'n004588', 'n004617', 'n004644', 'n004646', 'n004651', 'n004652', 'n004661', 'n004674', 'n004703', 'n004715', 'n004720', 'n004737', 'n004756', 'n004781', 'n004782', 'n004841', 'n004844', 'n004861', 'n004893', 'n004902', 'n004905', 'n004949', 'n004990', 'n005009', 'n005026', 'n005035', 'n005053', 'n005054', 'n005056', 'n005057', 'n005060', 'n005098', 'n005104', 'n005106', 'n005110', 'n005116', 'n005123', 'n005136', 'n005140', 'n005159', 'n005196', 'n005203', 'n005204', 'n005206', 'n005208', 'n005212', 'n005219', 'n005226', 'n005227', 'n005248', 'n005292', 'n005400', 'n005423', 'n005491', 'n005531', 'n005588', 'n005642', 'n005659', 'n005662', 'n005666', 'n005690', 'n005726', 'n005737', 'n005745', 'n005761', 'n005762', 'n005770', 'n005815', 'n005831', 'n005836', 'n005846', 'n005874', 'n005933', 'n005946', 'n005948', 'n005949', 'n005950', 'n005983', 'n005985', 'n005988', 'n006003', 'n006004', 'n006008', 'n006013', 'n006015', 'n006047', 'n006048', 'n006055', 'n006061', 'n006073', 'n006083', 'n006103', 'n006110', 'n006111', 'n006120', 'n006155', 'n006158', 'n006160', 'n006221', 'n006234', 'n006236', 'n006246', 'n006260', 'n006348', 'n006350', 'n006359', 'n006363', 'n006374', 'n006380', 'n006384', 'n006386', 'n006395', 'n006408', 'n006414', 'n006512', 'n006592', 'n006604', 'n006662', 'n006691', 'n006695', 'n006696', 'n006699', 'n006716', 'n006720', 'n006784', 'n006793', 'n006827', 'n006840', 'n006846', 'n006860', 'n006868', 'n006910', 'n006913', 'n006922', 'n006928', 'n006933', 'n006935', 'n006966', 'n006969', 'n006970', 'n007009', 'n007024', 'n007075', 'n007122', 'n007123', 'n007157', 'n007166', 'n007188', 'n007190', 'n007213', 'n007226', 'n007242', 'n007277', 'n007300', 'n007320', 'n007326', 'n007331', 'n007352', 'n007356', 'n007357', 'n007366', 'n007383', 'n007389', 'n007399', 'n007402', 'n007412', 'n007428', 'n007461', 'n007476', 'n007493', 'n007501', 'n007510', 'n007565', 'n007574', 'n007581', 'n007584', 'n007595', 'n007596', 'n007613', 'n007616', 'n007620', 'n007623', 'n007645', 'n007679', 'n007704', 'n007712', 'n007787', 'n007794', 'n007801', 'n007847', 'n007901', 'n007974', 'n007980', 'n007981', 'n008018', 'n008065', 'n008079', 'n008127', 'n008162', 'n008166', 'n008180', 'n008205', 'n008215', 'n008239', 'n008242', 'n008272', 'n008277', 'n008316', 'n008333', 'n008347', 'n008428', 'n008431', 'n008446', 'n008466', 'n008496', 'n008524', 'n008545', 'n008551', 'n008556', 'n008557', 'n008568', 'n008575', 'n008592', 'n008614', 'n008638', 'n008646', 'n008658', 'n008659', 'n008679', 'n008690', 'n008707', 'n008708', 'n008725', 'n008782', 'n008791', 'n008814', 'n008839', 'n008859', 'n008865', 'n008893', 'n008902', 'n008917', 'n008927', 'n008999', 'n009008', 'n009050', 'n009139', 'n009179', 'n009283'] +LFW_OVERLAPPED_MSCELEB1M_CLASS_NAMES = ['m.03t4cz', 'm.03p4zn', 'm.01341q', 'm.023066', 'm.01yynm', 'm.07s7wk', 'm.05y82p', 'm.01d_7n', 'm.0224hd', 'm.0245tw', 'm.03mkck', 'm.05_70r', 'm.023phr', 'm.02vsp', 'm.0586vd', 'm.018ygt', 'm.05j0zz', 'm.013b7z', 'm.02tpys', 'm.014kb2', 'm.084dwc', 'm.01bnr9', 'm.02qzsv', 'm.05gttk', 'm.03qsqt', 'm.06cjq3', 'm.01y_rh', 'm.0gskh7', 'm.01fkqs', 'm.0261k6z', 'm.0c01c', 'm.03q1tm', 'm.03jd7z', 'm.068ggb', 'm.01cp_j', 'm.0347xl', 'm.03ds83', 'm.031h9y', 'm.03fqjd', 'm.01p9wn', 'm.0j28l1t', 'm.01tr5l', 'm.03lhhq', 'm.035m7s', 'm.09l7gt', 'm.0264h51', 'm.01kzy3', 'm.0bxr_8', 'm.048g68', 'm.023v5v', 'm.01t2h2', 'm.01lmd6', 'm.05n8_wc', 'm.0f4vbz', 'm.0pdcdhz', 'm.0ftfrr', 'm.02x5q_k', 'm.01vhb0', 'm.02cgqp', 'm.05p1y0r', 'm.036wtl', 'm.01p_kk', 'm.01yym5', 'm.012bk', 'm.0204ym', 'm.0ckpyv', 'm.051qh7', 'm.029985', 'm.082hs2', 'm.04fbt1', 'm.0fmpmx', 'm.0pstz', 'm.011p3', 'm.02l3b1', 'm.022wp3', 'm.025wkbs', 'm.021m32', 'm.02pr_2h', 'm.04rb3d', 'm.044khm', 'm.0b742yn', 'm.09y_q5', 'm.098sq7', 'm.02n670', 'm.013tcv', 'm.01nbjr', 'm.01bhq8', 'm.058rsb', 'm.06_1s3', 'm.04l118', 'm.02t244', 'm.02843j', 'm.06plcb', 'm.023lqp', 'm.01qkdh', 'm.09_9nl', 'm.01mpq7s', 'm.05214n', 'm.01jjdg7', 'm.0b_j2', 'm.02_fs7', 'm.05nbg1', 'm.07zk65', 'm.0154r3', 'm.028t43', 'm.017dzj', 'm.01_pb_', 'm.01sd9r', 'm.01s7zw', 'm.04lj06', 'm.0gfgzcp', 'm.01ym3s', 'm.03bnrg', 'm.01qwpt', 'm.01kyns', 'm.06f091', 'm.0hyrg', 'm.05dhdz', 'm.024k6h', 'm.01l87db', 'm.05xc43', 'm.05chnj', 'm.012gq6', 'm.08rt4p', 'm.0518nv', 'm.02c6bt', 'm.031pbz', 'm.01krs', 'm.01dzl3', 'm.0hr3h8d', 'm.01t_mn', 'm.0fvszq', 'm.03qfp6', 'm.02vy8cg', 'm.05p8_y2', 'm.0g5sm_y', 'm.0hhxv91', 'm.03vttl', 'm.053nn8', 'm.03fv56', 'm.0cn6gm', 'm.0bx_q', 'm.01kk0s', 'm.0801_6', 'm.06njqy', 'm.023s8', 'm.0bksh', 'm.04y3tl', 'm.04v16_', 'm.0d7_4', 'm.031dns', 'm.03cvx8g', 'm.01bzhh', 'm.06k68k', 'm.01xh6j', 'm.0pnf3', 'm.049dy4y', 'm.0d06qs', 'm.01l1ls', 'm.019k2w', 'm.01kj0p', 'm.019n7x', 'm.03tt0_', 'm.013knm', 'm.01k7lw', 'm.0fswr7', 'm.01cwhp', 'm.03yjtb', 'm.031513', 'm.020ymy', 'm.02z0ck', 'm.0yc8l', 'm.053f8h', 'm.0411ytm', 'm.01l9p', 'm.04fkdr', 'm.01chwz', 'm.05tcsz2', 'm.0c4b30', 'm.0hwmx', 'm.07nj7w', 'm.07vyft', 'm.063f7w', 'm.01ybxf', 'm.019qv6', 'm.01cp71', 'm.0dpqlq', 'm.05m53h', 'm.06ml5k', 'm.04cf09', 'm.016_mj', 'm.01900g', 'm.0g8mk1', 'm.0260ct', 'm.0bfwxz', 'm.08z9xg', 'm.05vxct', 'm.07dn76', 'm.027m7j3', 'm.0253pn', 'm.05spws', 'm.01l9pz', 'm.027nx2m', 'm.0cc58j_', 'm.0211tw', 'm.05x2hd', 'm.0f41k8', 'm.027d12t', 'm.0421x9', 'm.017tn4', 'm.03mfqm', 'm.02p1q9', 'm.0bqs56', 'm.03c30n', 'm.01pybj', 'm.011zfz', 'm.069_ks', 'm.05myt24', 'm.083pq2', 'm.01r3ct', 'm.015c2f', 'm.09wjvn', 'm.06zkr3z', 'm.04d9yz', 'm.04t33m', 'm.0751tg', 'm.04_x29', 'm.01dhpj', 'm.06qv4y', 'm.0jl26', 'm.02fhw', 'm.01ctnp', 'm.03wyyk', 'm.02bh9', 'm.0205dx', 'm.0htcn', 'm.04pn7c', 'm.051z8k', 'm.01r7wc', 'm.01s4kt', 'm.02v0hy', 'm.02ghgl', 'm.040jzf', 'm.05p36x1', 'm.02hh_y', 'm.025ycc', 'm.01mh7tb', 'm.0fcmr1', 'm.034bb7', 'm.025r7k', 'm.03sqfv', 'm.046xmg', 'm.0138zs', 'm.02ct_k', 'm.01my4f', 'm.03gbkk', 'm.02cgqp', 'm.02dlfh', 'm.0dql67', 'm.043ftz', 'm.05wgvhp', 'm.08w95x', 'm.0cjng5', 'm.06j0cg', 'm.04jw71', 'm.02_75_', 'm.07_ttx', 'm.019sj8', 'm.0pmhf', 'm.01t3s_', 'm.0822l3', 'm.037bzt', 'm.01gqws', 'm.04mxvcc', 'm.0d0vj4', 'm.07_v3p', 'm.01wsj06', 'm.073y_j', 'm.06p39_', 'm.01vt9p3', 'm.08_nmx', 'm.0260zf', 'm.02f1c', 'm.01kwld', 'm.04_vq_', 'm.015yb0', 'm.0cp14lp', 'm.01v3vb', 'm.01mbwlb', 'm.06j02v', 'm.078q1l', 'm.03gbqk', 'm.01_j71', 'm.0ks3w3', 'm.069ggg', 'm.01r4ft', 'm.03mswq', 'm.04fhqh', 'm.06v1ms', 'm.09gc_wk', 'm.0bl2g', 'm.02tlx1', 'm.06pl0j', 'm.06nmdb', 'm.0235mg', 'm.027nt8', 'm.04dr0x', 'm.02r7jn', 'm.01bzhh', 'm.01zkk1', 'm.01515w', 'm.02502p', 'm.01qrf2', 'm.0411lx', 'm.0267d_', 'm.0ccqd7', 'm.03wgt48', 'm.01jb26', 'm.04sj1_', 'm.01gbbz', 'm.05np4c', 'm.03jq8c', 'm.01w3fc0', 'm.01ksss', 'm.012sk_', 'm.022wf_', 'm.085qrf', 'm.05g7t5', 'm.018pj3', 'm.03d7ykg', 'm.0159h6', 'm.0cjbl1s', 'm.03m83s', 'm.0jbv8', 'm.02vn3vm', 'm.027kn0', 'm.01jhw2', 'm.09jgcv', 'm.07f3t6', 'm.04lvys', 'm.04smkr', 'm.03hjrr', 'm.02nxk', 'm.01r9lx', 'm.026fb9', 'm.0cg81d', 'm.03z5hz', 'm.0263mk', 'm.02c_q_', 'm.01zzy_', 'm.08mm_5', 'm.08ckfm', 'm.03c44y', 'm.03m83s', 'm.05241j0', 'm.03v68d', 'm.01pt2r', 'm.0k269', 'm.033rq', 'm.07k4g7g', 'm.02pk0t2', 'm.03rlj6', 'm.03c8zx', 'm.022vdk', 'm.04ykxz', 'm.02wy12', 'm.03q5dr', 'm.0fqm7g0', 'm.0dvv9s', 'm.05d0sm', 'm.06n_nn', 'm.02vyw', 'm.08sb57', 'm.03dhy8', 'm.0412t6', 'm.03m7rkt', 'm.05khsy', 'm.01gtsh', 'm.036g2b', 'm.04k1kk', 'm.01q7cb_', 'm.0ks957', 'm.02xbw2', 'm.02pt11', 'm.07tmq9', 'm.037w1', 'm.03kn29', 'm.0d_hr', 'm.04bgs7', 'm.0344jy', 'm.07s_31', 'm.038p7s', 'm.0166w_', 'm.0c6vjs', 'm.020sr8', 'm.0343h', 'm.01c_xx', 'm.0191bn', 'm.034ls', 'm.037w7r', 'm.014jy6', 'm.027c96m', 'm.0c9hm', 'm.03mt9', 'm.0d_9f1', 'm.04l9wz', 'm.051ztf', 'm.01l1hr', 'm.06v748', 'm.03b0n5', 'm.05k13g', 'm.0502rv', 'm.03mfjm', 'm.038zc', 'm.0cwtm', 'm.05tbhj', 'm.01lykw', 'm.074pxw', 'm.016k38', 'm.080p_h', 'm.0jwwgnw', 'm.09z0lc', 'm.03pt18', 'm.0p81w', 'm.05lqps', 'm.0bnyhq', 'm.01qwly', 'm.01q4s5', 'm.03g4bf', 'm.02l6dy', 'm.03g0p', 'm.03h1tqw', 'm.02xjlj', 'm.02qppnm', 'm.01g1lp', 'm.016fnb', 'm.0p921', 'm.067tsz', 'm.02569p', 'm.0kxrb', 'm.036df9', 'm.072jm3', 'm.0402rg', 'm.024wt7', 'm.07flkk', 'm.0288vd8', 'm.02l1g0', 'm.05hj_k', 'm.09qvwf', 'm.02wwj9', 'm.05bm10', 'm.047m111', 'm.01vysy8', 'm.04n1l2', 'm.0flspx', 'm.03hjhb4', 'm.08l6td', 'm.09hnb', 'm.02bcgg', 'm.03pg42', 'm.02gjzj', 'm.0261rs', 'm.02w_xk', 'm.01b1gs', 'm.04d1cq', 'm.0d0gzz', 'm.036m8s', 'm.0cmmy9', 'm.05czcv', 'm.03935p', 'm.0f5zj6', 'm.04hrhm', 'm.04jg9ff', 'm.02r5h2t', 'm.0230wx', 'm.0bpr4y', 'm.0dgfsk', 'm.03y23jt', 'm.0jb54', 'm.0182q2', 'm.019d8_', 'm.019xyd', 'm.014jb8', 'm.01pwrjt', 'm.0cv34x', 'm.01xdlx', 'm.05l97n', 'm.06_69f', 'm.03qspn', 'm.01fq0x', 'm.0k2hrth', 'm.0bgjc8', 'm.02348n', 'm.01k5tg', 'm.01770r', 'm.09xg8', 'm.03bzz8t', 'm.066xc8', 'm.03c5bz', 'm.03v3j_', 'm.08jj7t', 'm.03d_03', 'm.06ln1j', 'm.0ksv3d', 'm.01z7_f', 'm.07fcvr', 'm.02751h2', 'm.04_vgv', 'm.034bg5', 'm.03crtr', 'm.039gzc', 'm.0chhjz', 'm.01rn_x', 'm.01qr1_', 'm.0ckb3y', 'm.06tc98', 'm.0245wb', 'm.04flrx', 'm.02g87m', 'm.079nvf', 'm.03syvx', 'm.0gprt0', 'm.0477_', 'm.0hvby', 'm.0231bb', 'm.081vxj', 'm.0m8_v', 'm.015ybz', 'm.03gbf7', 'm.01g7hs', 'm.033hb0', 'm.0459k', 'm.03kth6', 'm.02rq9n', 'm.04n0lkh', 'm.03bmvc', 'm.0320cg', 'm.01gvv5', 'm.0dst4x', 'm.03jjzf', 'm.023pzh', 'm.046l2', 'm.041xfx', 'm.02cnq1', 'm.0240vt', 'm.056svc', 'm.02h73f', 'm.0_5w6', 'm.05ty1w', 'm.0hd1l', 'm.0s8tynj', 'm.09zvcg', 'm.01m54p', 'm.04rxwx', 'm.0270jd', 'm.02238b', 'm.0684tm', 'm.03qhcnp', 'm.03ldpc', 'm.01kvrj', 'm.0fc34h', 'm.03h979', 'm.014ptb', 'm.07r_sc', 'm.027tcmx', 'm.01_8rq', 'm.056mxl', 'm.013vtr', 'm.05db50', 'm.02w5_k', 'm.07p160', 'm.08cfkh', 'm.050p2t', 'm.02n45z', 'm.0408r5', 'm.026l37', 'm.05cnkm', 'm.06x328', 'm.01wgsvv', 'm.01cbjz', 'm.046c6', 'm.036z2s', 'm.044h4', 'm.01sthx', 'm.02fp95', 'm.0150p7', 'm.01s4ss', 'm.036_x3', 'm.04yw4x', 'm.04yj5z', 'm.079m5z', 'm.01gkbj', 'm.017r13', 'm.0bymv', 'm.0pyqh', 'm.0208bk', 'm.02wvnnx', 'm.05yhv', 'm.0336gg', 'm.0gg6xn1', 'm.03ms0t', 'm.0l65n', 'm.0h653gg', 'm.02vxmw3', 'm.01xcly', 'm.03sfbh', 'm.0fyf5g', 'm.03ph6k', 'm.098hm1', 'm.0h7f2f', 'm.0jgvf', 'm.01z1ws', 'm.04vkvw', 'm.03whg42', 'm.0kc54', 'm.0bbwlp5', 'm.021v2z', 'm.0b16s0', 'm.01r93l', 'm.0lpjn', 'm.01frrf', 'm.02f91x', 'm.03npb_', 'm.0509bl', 'm.082r6n', 'm.02g0mx', 'm.01q9m5', 'm.0d5cy', 'm.02b6km', 'm.03ghnx', 'm.0gx02bb', 'm.0182qx', 'm.0b6m063', 'm.04n0yqn', 'm.0fq2760', 'm.02f95t', 'm.01f8ld', 'm.065h1p', 'm.019ncn', 'm.02050j', 'm.03bdbl', 'm.06bxgv', 'm.037721', 'm.02_bdt', 'm.01pfh3w', 'm.05zrbnd', 'm.01yk06', 'm.049l7', 'm.0127m7', 'm.04s6ts', 'm.05qrgk', 'm.0273c_', 'm.0jt9z', 'm.048lv', 'm.0235t5', 'm.03bzdvy', 'm.06cpj9', 'm.069nbk', 'm.01p85y', 'm.013yvd', 'm.047sth0', 'm.015tp7', 'm.01cy1c', 'm.01kwlwp', 'm.04fzk', 'm.02pl46g', 'm.0498f', 'm.03cmk_', 'm.0155zp', 'm.03p4qd', 'm.04crpl', 'm.02pj01j', 'm.069zyg', 'm.04znp2', 'm.015bw2', 'm.0337vz', 'm.04fhqh', 'm.04my5mx', 'm.08047z', 'm.04r6kn', 'm.045qln', 'm.0dpt2x', 'm.09j_2f', 'm.03qc9cc', 'm.027ypcs', 'm.04g8d', 'm.020_95', 'm.0krz6v', 'm.047n7c', 'm.0dfphs', 'm.014gf8', 'm.012tmz', 'm.03nzr_', 'm.01qx13', 'm.078dv3', 'm.01fkxr', 'm.09l7gt', 'm.04q4l', 'm.01pxrx', 'm.043ls76', 'm.05yqrl', 'm.03h_x0', 'm.04sh_1', 'm.0dvmd', 'm.021yhj', 'm.04svb2', 'm.04gsvmw', 'm.088_bk', 'm.02nwxc', 'm.05zv32', 'm.011_3s', 'm.0g9y1p7', 'm.0347ls', 'm.0byhnl', 'm.08qvjx', 'm.0gkxg7b', 'm.064r8yv', 'm.0qlry', 'm.026qnkm', 'm.03d70n0', 'm.0g476', 'm.01n048', 'm.01pfdg', 'm.0qs96j4', 'm.05dd_l', 'm.01rsdq', 'm.01270s', 'm.0p_2r', 'm.02lf70', 'm.09gcs', 'm.05q2n2', 'm.03fqzp', 'm.01zm1l', 'm.05_xn9', 'm.05dztx', 'm.01fhst', 'm.01fdc0', 'm.04_0sw4', 'm.01vs_v8', 'm.0139q5', 'm.01p8qv', 'm.043jp1_', 'm.02fzd3', 'm.026tmp', 'm.02wxbt', 'm.01bqmg', 'm.08d65g', 'm.03s4j_', 'm.01wv9p', 'm.03th34', 'm.044gjv', 'm.04cp9x', 'm.06jrvl', 'm.02kt6r', 'm.06l8jw', 'm.047dbx9', 'm.0mgb9', 'm.0kb3n', 'm.05zsgz', 'm.0c06sr', 'm.0375zc', 'm.0b3cs1', 'm.033h2v', 'm.0jwyq2x', 'm.09g82r', 'm.018phr', 'm.05ly3r', 'm.01mqnr', 'm.05jc7m', 'm.0b6z2c', 'm.0gmjf8', 'm.0jfzc', 'm.01v1ys3', 'm.024xmf', 'm.01msrs', 'm.019tyn', 'm.057hz', 'm.01bffy', 'm.09dh7n', 'm.03vjl6', 'm.05_tcb', 'm.026r8q', 'm.01rrd4', 'm.082vtn', 'm.028wtc', 'm.0dzyp6', 'm.022q61', 'm.05n_d_', 'm.0227p5', 'm.01jglh', 'm.04rdhn', 'm.0lgsq', 'm.080sy1', 'm.052hl', 'm.04g09qn', 'm.05fj6r', 'm.04bktz', 'm.0gvr2mx', 'm.043mw84', 'm.01j5ws', 'm.01r_drv', 'm.01fdpj', 'm.070j61', 'm.02r1qhb', 'm.0h3n5yz', 'm.054bt3', 'm.053n8s', 'm.03p5vy', 'm.04fw8yd', 'm.06qy4p', 'm.027f6w', 'm.04jfgvr', 'm.02_0bm', 'm.0263y0', 'm.0ck7wj', 'm.02722_', 'm.05c8f_', 'm.03qp6s', 'm.03lr3z', 'm.01dtxd', 'm.038m0d', 'm.0265bkr', 'm.01zgyp', 'm.0294fd', 'm.01phtd', 'm.07nb54', 'm.05s74y', 'm.04ffv_', 'm.09gh9sl', 'm.01728m', 'm.09r_jb', 'm.0k3mtc8', 'm.081xj6', 'm.02vmmz', 'm.02dwq5', 'm.03dpm5', 'm.07cktb', 'm.0641q5', 'm.03kml6', 'm.01_x4x', 'm.0296q2', 'm.03nw4q', 'm.02y_4xw', 'm.0346l4', 'm.032bfz', 'm.02dvwl', 'm.0d9mt_', 'm.01nxzv', 'm.03cymlv', 'm.043q6n_', 'm.01hhd7', 'm.018fzs', 'm.02s5m3', 'm.0lkr7', 'm.024zc8', 'm.0f9dpr', 'm.032kyt', 'm.02ps9k', 'm.07bp3c', 'm.08f2m4', 'm.05zwl4', 'm.0251xd', 'm.01t6xz', 'm.01rjfj', 'm.069lpq', 'm.05dgpl', 'm.02jxnz', 'm.01jhtj', 'm.0krnw', 'm.0ft68', 'm.016mbz', 'm.09rcjg', 'm.05kfs', 'm.059zx3', 'm.01zz8t', 'm.034g1z', 'm.06l0c3', 'm.016mj4', 'm.07s8p52', 'm.06w7jl9', 'm.01q_ph', 'm.03ghnx', 'm.05r5w', 'm.01cygd', 'm.051j_m', 'm.0615j_', 'm.0c3yd_j', 'm.0gt1yv', 'm.0b0mpg', 'm.05sy8', 'm.026qhcv', 'm.059g24', 'm.01f9yn', 'm.0czcd7c', 'm.01h75w', 'm.02l2s2', 'm.09s3gm', 'm.01gzfn', 'm.012nry', 'm.01kb6l', 'm.0drhp3', 'm.0206p0', 'm.037d35', 'm.01nzww', 'm.03q35x', 'm.0g58864', 'm.010ngb', 'm.038786', 'm.03bghb', 'm.067g_', 'm.01nw3d', 'm.061s_', 'm.0p3q3', 'm.028b5vn', 'm.0d8kxt', 'm.01lghn', 'm.07s6zxp', 'm.0dxg6', 'm.0hgst', 'm.053kd63', 'm.056z6_', 'm.02qgqt', 'm.02655s', 'm.01llhq', 'm.026t71b', 'm.01rjtl', 'm.026db3c', 'm.06b3bx', 'm.02xcc35', 'm.01pcz9', 'm.02s8vf', 'm.06v054', 'm.01csb_', 'm.0xnc3', 'm.0pj4n', 'm.0xm_0', 'm.01ndxw', 'm.09v7n0f', 'm.0ddh63s', 'm.018n52', 'm.0626jk', 'm.0b6hgm8', 'm.04l5px', 'm.069wr', 'm.01c1ww', 'm.01_587', 'm.0235fz', 'm.03h047', 'm.031fq_', 'm.0bbbky', 'm.04kyzg', 'm.0d_sbv', 'm.029dwg', 'm.02z1r_0', 'm.016srn', 'm.01t3w_', 'm.0270lg_', 'm.0ctnsv', 'm.02ppxfg', 'm.02px9j4', 'm.01ywbz', 'm.043gpt', 'm.06_74n', 'm.02j490', 'm.01h910', 'm.0kssdg', 'm.03yj5dv', 'm.07f_x', 'm.0n6f8', 'm.065k5_', 'm.027s38x', 'm.02ykkk', 'm.01chdy', 'm.016z51', 'm.0b0s3', 'm.0163l9', 'm.06h93l', 'm.0hn9fbl', 'm.03nf04', 'm.01d_bx', 'm.07xb48', 'm.020yj1', 'm.01sp75k', 'm.03hfvl', 'm.02cdqb', 'm.0bq8p7', 'm.0h3tl82', 'm.04l20l', 'm.03wqs65', 'm.07l1h3', 'm.0g56339', 'm.01x209s', 'm.044zvm', 'm.0552zz', 'm.0d9q7w', 'm.026qc96', 'm.065zxc', 'm.016z2j', 'm.0269l3g', 'm.027_zq', 'm.0bkjs4', 'm.02l5km', 'm.019f8r', 'm.01tgq_', 'm.02djmh', 'm.0bsfy', 'm.05vslll', 'm.04kc9m', 'm.0985c0', 'm.015b67', 'm.06yks9', 'm.02_z4g', 'm.05b6sg5', 'm.04ghydx', 'm.08xh0f', 'm.05pdct', 'm.01my95', 'm.04d49y', 'm.02h8sh', 'm.0gr21', 'm.08ydvp', 'm.05k185', 'm.01zl71', 'm.074tyf', 'm.06b_0', 'm.02wxcr', 'm.053ksp', 'm.05ckg_', 'm.06c0j', 'm.0f8d6c', 'm.0lf9j', 'm.0b_lfh', 'm.0j8g6', 'm.034fn4', 'm.0264f6', 'm.02776bg', 'm.0473p_', 'm.03dlc9', 'm.02xxbs', 'm.04grr30', 'm.02pb1n', 'm.079dy', 'm.03mkf8', 'm.01jw4r', 'm.019d0l', 'm.026y49b', 'm.0bq7m', 'm.0c7wm8', 'm.05y7nf', 'm.0q9zs', 'm.023rpc', 'm.0m66w', 'm.06w6_', 'm.03wdk0', 'm.04bs3j', 'm.03z1cn', 'm.0289v04', 'm.01ls77', 'm.049kjv', 'm.0f98y', 'm.0418vd', 'm.031k24', 'm.02jj7b', 'm.051mg4', 'm.031h8r', 'm.0bjtyg', 'm.03ndwy', 'm.01qp6qt', 'm.02643n_', 'm.01rs0h', 'm.0163t3', 'm.039tcz', 'm.0lq90wy', 'm.01k5d4', 'm.077874', 'm.06dn4y', 'm.03f5mf', 'm.05579g', 'm.08m_0v', 'm.05hsn2', 'm.08c941', 'm.0cgzj', 'm.01csyt', 'm.0h96g', 'm.01ghyw', 'm.07_ty0', 'm.06zqs8n', 'm.0418ft', 'm.02m0nf', 'm.06pk8', 'm.05yvhj', 'm.05557v4', 'm.0659sj', 'm.0fpx_7', 'm.0h7h5p', 'm.07tmq9', 'm.0h_cvzm', 'm.0f3wr9', 'm.025k5p', 'm.01_j1t', 'm.01jtkg', 'm.03nk3t', 'm.04hp8s', 'm.0hqly', 'm.055m_v', 'm.03y7hp', 'm.07bkv', 'm.01p__8', 'm.06t_zq', 'm.094bfl', 'm.01xyq6', 'm.092g11', 'm.08rr56', 'm.03yx01', 'm.061wxp', 'm.013qwl', 'm.02mqc4', 'm.0crg7tp', 'm.02dc3_', 'm.033bkd', 'm.02z0ck8', 'm.01h434', 'm.01tgny', 'm.064kks2', 'm.0fml1k', 'm.05x1gn', 'm.01p1cn', 'm.04gd84', 'm.0f98y', 'm.0m3rl', 'm.0b6k9d', 'm.06608wm', 'm.04v62b', 'm.03cq1yy', 'm.07fsdg', 'm.020_qy', 'm.01cvp9', 'm.084f8h', 'm.07wbzj', 'm.01g4bk', 'm.026lr8', 'm.04p_tk', 'm.054xvs', 'm.02f_sx', 'm.05df3p', 'm.08q7kb', 'm.04z84', 'm.0gfg8x0', 'm.0fn_fh', 'm.0206mj', 'm.02yjf8', 'm.033f6x', 'm.03x2sj', 'm.0290v1', 'm.02mhfy', 'm.07h5d', 'm.02dgtb', 'm.0dk5zn', 'm.01s9ym', 'm.07vsx3', 'm.0643nx', 'm.02jxq1', 'm.01k53f', 'm.03m6v_', 'm.05z5y_', 'm.04zp_j', 'm.06dkf7', 'm.07rp8', 'm.0jdhp', 'm.01gct2', 'm.026s3n', 'm.01qm9d', 'm.01jf6n', 'm.01nr36', 'm.043p624', 'm.04y6_th', 'm.03lhmg', 'm.079vh1', 'm.04z598', 'm.07lmp', 'm.02zb92', 'm.04tmr4', 'm.0h79sh', 'm.01fnd0', 'm.02646db', 'm.024sld', 'm.0c837g', 'm.01d_c9', 'm.0q5fw', 'm.027d5g5', 'm.0g7kkb', 'm.02pqxl3', 'm.02661h', 'm.02j8fb', 'm.01l_r6', 'm.0gwyvzh', 'm.0dr5g9', 'm.06zttt', 'm.04m0nc', 'm.06t87d', 'm.0mbs_', 'm.0h1dy9t', 'm.05jyjn', 'm.01fq2k', 'm.04y0yc', 'm.03v1rk', 'm.0chdxy', 'm.04gq97', 'm.01b90h', 'm.01zbf1', 'm.03ndwy', 'm.0kn91', 'm.0pyg6', 'm.07ymn5', 'm.09nj72', 'm.07vm2b', 'm.02w09gx', 'm.022p28', 'm.01k0z1', 'm.01fxck', 'm.02fn5r', 'm.025_wg9', 'm.02ts3h', 'm.029tx1', 'm.0p720', 'm.0cm4xj', 'm.05_5txt', 'm.0gyx4', 'm.02y2nr', 'm.049dzvg', 'm.05r73g', 'm.04k5fq', 'm.02ppy7', 'm.0260x42', 'm.01l64q', 'm.013zyw', 'm.0h2ftf', 'm.01qn6k', 'm.0bgy72', 'm.030lw5', 'm.023w_z', 'm.023kzp', 'm.0166z2', 'm.05nhh1', 'm.0668g4', 'm.03lp0g', 'm.0lkgc', 'm.05zx3p', 'm.045qln', 'm.0fjgy', 'm.0dr046', 'm.047cl49', 'm.02lqby', 'm.08849', 'm.02pmfg3', 'm.07hsw6', 'm.04cxpks', 'm.07x_rh', 'm.028lkr', 'm.08887m', 'm.01bjnp', 'm.014hdb', 'm.0kcv4', 'm.02qtk76', 'm.01br1k', 'm.0gtj4r'] + + +class FaceDataset(Dataset): + def __init__(self, root, split, transforms=None, ignored_classes=None): + self._root = os.path.join(root, 'images') + if ignored_classes is None: + ignored_classes = [] + + self._class_names = [o for o in os.listdir(self._root) if os.path.isdir(os.path.join(self._root, o))] + self._class_names = list(set(self._class_names) - set(ignored_classes)) + self._class_names.sort() + + if split == 'training': + self._all_images, self._images_by_class = self._list_images(root, 'train.txt') + elif split == 'validation': + self._all_images, self._images_by_class = self._list_images(root, 'validation.txt') + else: + raise ValueError('Invalid split') + + self._transforms = transforms + + def _list_images(self, root, filename): + class_indexes_by_class_name = {self._class_names[i]: i for i in range(len(self._class_names))} + + with open(os.path.join(root, filename), 'r') as image_file: + images_lines = [line.strip() for line in image_file.readlines()] + + images = [] + images_by_class = [[] for _ in self._class_names] + for images_line in images_lines: + class_name, filename = images_line.split(' ') + if class_name not in class_indexes_by_class_name: + continue + + class_index = class_indexes_by_class_name[class_name] + sound_index = len(images) + images.append({ + 'path': os.path.join(class_name, filename), + 'class_index': class_index + }) + images_by_class[class_index].append({'index': sound_index}) + + return images, images_by_class + + def __len__(self): + return len(self._all_images) + + def __getitem__(self, index): + image = Image.open(os.path.join(self._root, self._all_images[index]['path'])).convert('RGB') + if self._transforms is not None: + image = self._transforms(image) + + return image, self._all_images[index]['class_index'] + + def class_count(self): + return len(self._images_by_class) + + def class_indexes(self): + return [d['class_index'] for d in self._all_images] + + def lens_by_class(self): + return [len(x) for x in self._images_by_class] + + def get_all_indexes(self, class_, index): + return self._images_by_class[class_][index]['index'] + + def transforms(self): + return self._transforms + + +class FaceConcatDataset(Dataset): + def __init__(self, face_datasets, transforms=None): + if any((d.transforms() is not None for d in face_datasets)): + raise ValueError('All face dataset ') + + self._face_datasets = face_datasets + self._class_offsets = self._compute_class_offsets(face_datasets) + + self._transforms = transforms + + def _compute_class_offsets(self, face_datasets): + offsets = [] + offset = 0 + for d in face_datasets: + offsets.append(offset) + offset += d.class_count() + return offsets + + def __len__(self): + return sum((len(d) for d in self._face_datasets)) + + def __getitem__(self, index): + dataset_index, index = self._transform_image_index(index) + + image, class_index = self._face_datasets[dataset_index][index] + if self._transforms is not None: + image = self._transforms(image) + + return image, class_index + self._class_offsets[dataset_index] + + def class_count(self): + return sum((d.class_count() for d in self._face_datasets)) + + def class_indexes(self): + class_indexes_list = [] + for dataset_index, dataset in enumerate(self._face_datasets): + for class_index in dataset.class_indexes(): + class_indexes_list.append(class_index + self._class_offsets[dataset_index]) + return class_indexes_list + + def lens_by_class(self): + lens_by_class_list = [] + for d in self._face_datasets: + lens_by_class_list += d.lens_by_class() + return lens_by_class_list + + def get_all_indexes(self, class_index, index): + dataset_index, class_index = self._transform_class_index(class_index) + return self._face_datasets[dataset_index].get_all_indexes(class_index, index) + + def transforms(self): + return self._transforms + + def _transform_image_index(self, image_index): + for i, d in enumerate(self._face_datasets): + if image_index < len(d): + return i, image_index + else: + image_index -= len(d) + + raise IndexError(f'Image index out of range ({image_index})') + + def _transform_class_index(self, class_index): + for i, d in enumerate(self._face_datasets): + if class_index < d.class_count(): + return i, class_index + else: + class_index -= d.class_count() + + raise IndexError(f'Class index out of range ({class_index})') diff --git a/tools/dnn_training/face_recognition/datasets/imbalanced_face_dataset_sampler.py b/tools/dnn_training/face_recognition/datasets/imbalanced_face_dataset_sampler.py new file mode 100644 index 00000000..9acb6fc1 --- /dev/null +++ b/tools/dnn_training/face_recognition/datasets/imbalanced_face_dataset_sampler.py @@ -0,0 +1,17 @@ +import torch + + +class ImbalancedFaceDatasetSampler(torch.utils.data.sampler.Sampler): + def __init__(self, face_dataset): + self._image_count = len(face_dataset) + + class_weights = [1.0 / (c + 1e-6) for c in face_dataset.lens_by_class()] + self._image_weights = [class_weights[class_index] for class_index in face_dataset.class_indexes()] + self._image_weights = torch.tensor(self._image_weights) + + def __iter__(self): + indexes = torch.multinomial(self._image_weights, self._image_count, replacement=True) + return iter(indexes.tolist()) + + def __len__(self): + return self._image_count diff --git a/tools/dnn_training/face_recognition/datasets/vggface2_dataset.py b/tools/dnn_training/face_recognition/datasets/vggface2_dataset.py deleted file mode 100644 index 673fe082..00000000 --- a/tools/dnn_training/face_recognition/datasets/vggface2_dataset.py +++ /dev/null @@ -1,70 +0,0 @@ -import os - -from PIL import Image - -from torch.utils.data import Dataset - - -LFW_OVERLAPPED_VGGFACE2_CLASS_NAMES = ['n000021', 'n000137', 'n000172', 'n000184', 'n000195', 'n000199', 'n000220', 'n000242', 'n000255', 'n000272', 'n000281', 'n000297', 'n000310', 'n000359', 'n000373', 'n000379', 'n000402', 'n000420', 'n000427', 'n000429', 'n000483', 'n000560', 'n000562', 'n000568', 'n000645', 'n000667', 'n000692', 'n000709', 'n000712', 'n000755', 'n000780', 'n000810', 'n000816', 'n000817', 'n000835', 'n000887', 'n000888', 'n000902', 'n000906', 'n000912', 'n000946', 'n000947', 'n001042', 'n001048', 'n001051', 'n001054', 'n001057', 'n001060', 'n001091', 'n001095', 'n001096', 'n001105', 'n001106', 'n001110', 'n001111', 'n001113', 'n001142', 'n001143', 'n001155', 'n001156', 'n001159', 'n001171', 'n001200', 'n001228', 'n001235', 'n001243', 'n001257', 'n001288', 'n001309', 'n001329', 'n001360', 'n001367', 'n001379', 'n001387', 'n001419', 'n001434', 'n001437', 'n001473', 'n001478', 'n001538', 'n001540', 'n001547', 'n001548', 'n001551', 'n001567', 'n001568', 'n001595', 'n001607', 'n001642', 'n001656', 'n001663', 'n001695', 'n001698', 'n001700', 'n001727', 'n001775', 'n001779', 'n001781', 'n001795', 'n001805', 'n001813', 'n001825', 'n001839', 'n001873', 'n001874', 'n001934', 'n001962', 'n001968', 'n001971', 'n001986', 'n002017', 'n002018', 'n002019', 'n002025', 'n002028', 'n002038', 'n002056', 'n002057', 'n002067', 'n002081', 'n002094', 'n002120', 'n002133', 'n002135', 'n002141', 'n002173', 'n002178', 'n002227', 'n002244', 'n002251', 'n002259', 'n002269', 'n002272', 'n002274', 'n002278', 'n002298', 'n002310', 'n002316', 'n002342', 'n002356', 'n002360', 'n002388', 'n002391', 'n002413', 'n002452', 'n002454', 'n002460', 'n002465', 'n002471', 'n002478', 'n002494', 'n002497', 'n002537', 'n002547', 'n002572', 'n002574', 'n002595', 'n002630', 'n002654', 'n002666', 'n002673', 'n002684', 'n002725', 'n002759', 'n002782', 'n002814', 'n002844', 'n002854', 'n002868', 'n002947', 'n002969', 'n002982', 'n003013', 'n003020', 'n003022', 'n003039', 'n003098', 'n003101', 'n003147', 'n003186', 'n003198', 'n003205', 'n003206', 'n003212', 'n003228', 'n003257', 'n003266', 'n003303', 'n003311', 'n003353', 'n003359', 'n003360', 'n003363', 'n003383', 'n003385', 'n003391', 'n003392', 'n003422', 'n003426', 'n003474', 'n003476', 'n003523', 'n003619', 'n003623', 'n003628', 'n003648', 'n003671', 'n003685', 'n003689', 'n003709', 'n003712', 'n003724', 'n003728', 'n003749', 'n003756', 'n003767', 'n003785', 'n003803', 'n003806', 'n003816', 'n003817', 'n003820', 'n003866', 'n003876', 'n003879', 'n003880', 'n003885', 'n003904', 'n003908', 'n003912', 'n003937', 'n003946', 'n003952', 'n003969', 'n003972', 'n003973', 'n003974', 'n003979', 'n003982', 'n004014', 'n004020', 'n004030', 'n004043', 'n004081', 'n004083', 'n004084', 'n004106', 'n004132', 'n004144', 'n004150', 'n004151', 'n004154', 'n004181', 'n004197', 'n004198', 'n004204', 'n004216', 'n004218', 'n004220', 'n004231', 'n004236', 'n004263', 'n004327', 'n004328', 'n004375', 'n004407', 'n004408', 'n004429', 'n004439', 'n004451', 'n004457', 'n004480', 'n004481', 'n004488', 'n004580', 'n004588', 'n004617', 'n004644', 'n004646', 'n004651', 'n004652', 'n004661', 'n004674', 'n004703', 'n004715', 'n004720', 'n004737', 'n004756', 'n004781', 'n004782', 'n004841', 'n004844', 'n004861', 'n004893', 'n004902', 'n004905', 'n004949', 'n004990', 'n005009', 'n005026', 'n005035', 'n005053', 'n005054', 'n005056', 'n005057', 'n005060', 'n005098', 'n005104', 'n005106', 'n005110', 'n005116', 'n005123', 'n005136', 'n005140', 'n005159', 'n005196', 'n005203', 'n005204', 'n005206', 'n005208', 'n005212', 'n005219', 'n005226', 'n005227', 'n005248', 'n005292', 'n005400', 'n005423', 'n005491', 'n005531', 'n005588', 'n005642', 'n005659', 'n005662', 'n005666', 'n005690', 'n005726', 'n005737', 'n005745', 'n005761', 'n005762', 'n005770', 'n005815', 'n005831', 'n005836', 'n005846', 'n005874', 'n005933', 'n005946', 'n005948', 'n005949', 'n005950', 'n005983', 'n005985', 'n005988', 'n006003', 'n006004', 'n006008', 'n006013', 'n006015', 'n006047', 'n006048', 'n006055', 'n006061', 'n006073', 'n006083', 'n006103', 'n006110', 'n006111', 'n006120', 'n006155', 'n006158', 'n006160', 'n006221', 'n006234', 'n006236', 'n006246', 'n006260', 'n006348', 'n006350', 'n006359', 'n006363', 'n006374', 'n006380', 'n006384', 'n006386', 'n006395', 'n006408', 'n006414', 'n006512', 'n006592', 'n006604', 'n006662', 'n006691', 'n006695', 'n006696', 'n006699', 'n006716', 'n006720', 'n006784', 'n006793', 'n006827', 'n006840', 'n006846', 'n006860', 'n006868', 'n006910', 'n006913', 'n006922', 'n006928', 'n006933', 'n006935', 'n006966', 'n006969', 'n006970', 'n007009', 'n007024', 'n007075', 'n007122', 'n007123', 'n007157', 'n007166', 'n007188', 'n007190', 'n007213', 'n007226', 'n007242', 'n007277', 'n007300', 'n007320', 'n007326', 'n007331', 'n007352', 'n007356', 'n007357', 'n007366', 'n007383', 'n007389', 'n007399', 'n007402', 'n007412', 'n007428', 'n007461', 'n007476', 'n007493', 'n007501', 'n007510', 'n007565', 'n007574', 'n007581', 'n007584', 'n007595', 'n007596', 'n007613', 'n007616', 'n007620', 'n007623', 'n007645', 'n007679', 'n007704', 'n007712', 'n007787', 'n007794', 'n007801', 'n007847', 'n007901', 'n007974', 'n007980', 'n007981', 'n008018', 'n008065', 'n008079', 'n008127', 'n008162', 'n008166', 'n008180', 'n008205', 'n008215', 'n008239', 'n008242', 'n008272', 'n008277', 'n008316', 'n008333', 'n008347', 'n008428', 'n008431', 'n008446', 'n008466', 'n008496', 'n008524', 'n008545', 'n008551', 'n008556', 'n008557', 'n008568', 'n008575', 'n008592', 'n008614', 'n008638', 'n008646', 'n008658', 'n008659', 'n008679', 'n008690', 'n008707', 'n008708', 'n008725', 'n008782', 'n008791', 'n008814', 'n008839', 'n008859', 'n008865', 'n008893', 'n008902', 'n008917', 'n008927', 'n008999', 'n009008', 'n009050', 'n009139', 'n009179', 'n009283'] - - -class Vggface2Dataset(Dataset): - def __init__(self, root, split, transforms=None, ignored_classes=None): - images_path = os.path.join(root, 'images') - if ignored_classes is None: - ignored_classes = [] - - self._class_names = [o for o in os.listdir(images_path) if os.path.isdir(os.path.join(images_path, o))] - self._class_names = list(set(self._class_names) - set(ignored_classes)) - self._class_names.sort() - - if split == 'training': - self._all_images, self._images_by_class = self._list_images(root, 'train.txt') - elif split == 'validation': - self._all_images, self._images_by_class = self._list_images(root, 'validation.txt') - else: - raise ValueError('Invalid split') - - self._transforms = transforms - - def _list_images(self, root, filename): - class_indexes_by_class_name = {self._class_names[i]: i for i in range(len(self._class_names))} - - with open(os.path.join(root, filename), 'r') as image_file: - images_lines = [line.strip() for line in image_file.readlines()] - - images = [] - images_by_class = [[] for _ in self._class_names] - for images_line in images_lines: - class_name, filename = images_line.split(' ') - if class_name not in class_indexes_by_class_name: - continue - - class_index = class_indexes_by_class_name[class_name] - sound_index = len(images) - images.append({ - 'path': os.path.join(root, 'images', class_name, filename), - 'class_index': class_index - }) - images_by_class[class_index].append({'index': sound_index}) - - return images, images_by_class - - def __len__(self): - return len(self._all_images) - - def __getitem__(self, index): - image = Image.open(self._all_images[index]['path']).convert('RGB') - if self._transforms is not None: - image = self._transforms(image) - - return image, self._all_images[index]['class_index'] - - def lens_by_class(self): - return [len(x) for x in self._images_by_class] - - def get_all_indexes(self, class_, index): - return self._images_by_class[class_][index]['index'] - - def transforms(self): - return self._transforms diff --git a/tools/dnn_training/face_recognition/face_descriptor_extractor.py b/tools/dnn_training/face_recognition/face_descriptor_extractor.py index 8d412d0b..e8ddcaf0 100644 --- a/tools/dnn_training/face_recognition/face_descriptor_extractor.py +++ b/tools/dnn_training/face_recognition/face_descriptor_extractor.py @@ -1,14 +1,15 @@ import torch.nn as nn +import torchvision.models as models + from common.modules import L2Normalization -from common.modules import InceptionModule, PaddedLPPool2d, Lrn2d, AmSoftmaxLinear +from common.modules import InceptionModule, PaddedLPPool2d, Lrn2d, NormalizedLinear, GlobalAvgPool2d -# Based on OpenFace (https://cmusatyalab.github.io/openface/) -class FaceDescriptorExtractor(nn.Module): - def __init__(self, embedding_size=128, class_count=None, am_softmax_linear=False): - super(FaceDescriptorExtractor, self).__init__() +class OpenFaceBackbone(nn.Module): + def __init__(self): + super(OpenFaceBackbone, self).__init__() self._features_layers = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False), @@ -76,14 +77,60 @@ def __init__(self, embedding_size=128, class_count=None, am_softmax_linear=False nn.AvgPool2d(kernel_size=3, stride=(2, 1)) ) + def forward(self, x): + return self._features_layers(x) + + def last_channel_count(self): + return 736 + + +class EfficientNetBackbone(nn.Module): + SUPPORTED_TYPES = ['efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', + 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7'] + LAST_CHANNEL_COUNT_BY_TYPE = {'efficientnet_b0': 1280, + 'efficientnet_b1': 1280, + 'efficientnet_b2': 1408, + 'efficientnet_b3': 1536, + 'efficientnet_b4': 1792, + 'efficientnet_b5': 2048, + 'efficientnet_b6': 2304, + 'efficientnet_b7': 2560} + def __init__(self, type, pretrained_backbone=True): + super(EfficientNetBackbone, self).__init__() + + if pretrained_backbone: + backbone_weights = 'DEFAULT' + else: + backbone_weights = None + + if (type not in self.SUPPORTED_TYPES or type not in self.LAST_CHANNEL_COUNT_BY_TYPE): + raise ValueError('Invalid backbone type') + + self._features_layers = models.__dict__[type](weights=backbone_weights).features + self._last_channel_count = self.LAST_CHANNEL_COUNT_BY_TYPE[type] + + def forward(self, x): + return self._features_layers(x) + + def last_channel_count(self): + return self._last_channel_count + + +# Based on OpenFace (https://cmusatyalab.github.io/openface/) +class FaceDescriptorExtractor(nn.Module): + def __init__(self, backbone, embedding_size=128, class_count=None, normalized_linear=False): + super(FaceDescriptorExtractor, self).__init__() + + self._backbone = backbone + self._global_avg_pool = GlobalAvgPool2d() self._descriptor_layers = nn.Sequential( - nn.Linear(in_features=736, out_features=embedding_size), + nn.Linear(in_features=self._backbone.last_channel_count(), out_features=embedding_size), L2Normalization() ) self._class_count = class_count - if class_count is not None and am_softmax_linear: - self._classifier = AmSoftmaxLinear(embedding_size, class_count) + if class_count is not None and normalized_linear: + self._classifier = NormalizedLinear(embedding_size, class_count) elif class_count is not None: self._classifier = nn.Linear(embedding_size, class_count) else: @@ -93,7 +140,7 @@ def class_count(self): return self._class_count def forward(self, x): - features = self._features_layers(x) + features = self._global_avg_pool(self._backbone(x)) descriptor = self._descriptor_layers(features.view(x.size()[0], -1)) if self._classifier is not None: diff --git a/tools/dnn_training/face_recognition/metrics/lfw_evaluation.py b/tools/dnn_training/face_recognition/metrics/lfw_evaluation.py index 0951cb96..bd01fc24 100644 --- a/tools/dnn_training/face_recognition/metrics/lfw_evaluation.py +++ b/tools/dnn_training/face_recognition/metrics/lfw_evaluation.py @@ -57,6 +57,7 @@ def _read_image_pairs(self): fold_size = int(p[1]) lines = lines[1:] + not_available_pairs = 0 for line in lines: p = line.strip().split() if len(p) == 3: @@ -72,7 +73,10 @@ def _read_image_pairs(self): if os.path.exists(image_path1) and os.path.exists(image_path2): image_pairs.append((image_path1, image_path2, is_same_person)) + else: + not_available_pairs += 1 + print('Not available pairs:', not_available_pairs) return fold_count, fold_size, image_pairs def _calculate_distances(self): diff --git a/tools/dnn_training/face_recognition/trainers/__init__.py b/tools/dnn_training/face_recognition/trainers/__init__.py index da214f4c..1aa17069 100644 --- a/tools/dnn_training/face_recognition/trainers/__init__.py +++ b/tools/dnn_training/face_recognition/trainers/__init__.py @@ -1 +1,3 @@ from face_recognition.trainers.face_descriptor_extractor_trainer import FaceDescriptorExtractorTrainer +from face_recognition.trainers.face_descriptor_extractor_distillation_trainer import \ + FaceDescriptorExtractorDistillationTrainer diff --git a/tools/dnn_training/face_recognition/trainers/face_descriptor_extractor_distillation_trainer.py b/tools/dnn_training/face_recognition/trainers/face_descriptor_extractor_distillation_trainer.py new file mode 100644 index 00000000..a2650b59 --- /dev/null +++ b/tools/dnn_training/face_recognition/trainers/face_descriptor_extractor_distillation_trainer.py @@ -0,0 +1,125 @@ +import os + +from common.datasets import TripletLossBatchSampler +from common.trainers import DistillationTrainer +from common.metrics import LossMetric, ClassificationAccuracyMetric, LossLearningCurves, LossAccuracyLearningCurves + +from face_recognition.criterions import FaceDescriptorDistillationLoss +from face_recognition.datasets import ImbalancedFaceDatasetSampler +from face_recognition.metrics import LfwEvaluation +from face_recognition.trainers.face_descriptor_extractor_trainer import _create_criterion, _create_dataset, \ + _evaluate_classification_accuracy, create_training_image_transform, create_validation_image_transform + +import torch +import torch.utils.data + + +class FaceDescriptorExtractorDistillationTrainer(DistillationTrainer): + def __init__(self, device, model, teacher_model, dataset_roots='', lfw_dataset_root='', output_path='', + epoch_count=10, learning_rate=0.01, weight_decay=0.0, criterion_type='triplet_loss', + batch_size=128, margin=0.2, + student_model_checkpoint=None, teacher_model_checkpoint=None): + self._lfw_dataset_root = lfw_dataset_root + self._criterion_type = criterion_type + self._margin = margin + self._class_count = model.class_count() + + super(FaceDescriptorExtractorDistillationTrainer, self).__init__( + device, model, teacher_model, + dataset_root=dataset_roots, + output_path=output_path, + epoch_count=epoch_count, + learning_rate=learning_rate, + weight_decay=weight_decay, + batch_size=batch_size, + batch_size_division=1, + student_model_checkpoint=student_model_checkpoint, + teacher_model_checkpoint=teacher_model_checkpoint) + + self._training_loss_metric = LossMetric() + self._validation_loss_metric = LossMetric() + + if self._criterion_type == 'triplet_loss': + self._learning_curves = LossLearningCurves() + else: + self._learning_curves = LossAccuracyLearningCurves() + self._training_accuracy_metric = ClassificationAccuracyMetric() + self._validation_accuracy_metric = ClassificationAccuracyMetric() + + def _create_criterion(self, student_model, teacher_model): + return FaceDescriptorDistillationLoss(_create_criterion(self._criterion_type, self._margin, self._epoch_count)) + + def _create_training_dataset_loader(self, dataset_roots, batch_size, batch_size_division): + dataset = _create_dataset(dataset_roots, 'training', create_training_image_transform()) + return self._create_dataset_loader(dataset, batch_size, batch_size_division, + use_imbalanced_face_dataset_sampler=True) + + def _create_validation_dataset_loader(self, dataset_roots, batch_size, batch_size_division): + dataset = _create_dataset(dataset_roots, 'validation', create_validation_image_transform()) + return self._create_dataset_loader(dataset, batch_size, batch_size_division) + + def _create_dataset_loader(self, dataset, batch_size, batch_size_division, + use_imbalanced_face_dataset_sampler=False): + if self._criterion_type == 'triplet_loss': + batch_sampler = TripletLossBatchSampler(dataset, batch_size=batch_size // batch_size_division) + return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=8) + else: + sampler = ImbalancedFaceDatasetSampler(dataset) if use_imbalanced_face_dataset_sampler else None + return torch.utils.data.DataLoader(dataset, batch_size=batch_size // batch_size_division, sampler=sampler, + num_workers=8) + + def _clear_between_training(self): + self._learning_curves.clear() + + def _clear_between_training_epoch(self): + self._training_loss_metric.clear() + if self._criterion_type != 'triplet_loss': + self._training_accuracy_metric.clear() + + def _move_target_to_device(self, target, device): + return target.to(device) + + def _measure_training_metrics(self, loss, model_output, target): + self._training_loss_metric.add(loss.item()) + if self._criterion_type != 'triplet_loss': + self._training_accuracy_metric.add(model_output[1], target) + + def _clear_between_validation_epoch(self): + self._validation_loss_metric.clear() + if self._criterion_type != 'triplet_loss': + self._validation_accuracy_metric.clear() + + def _measure_validation_metrics(self, loss, model_output, target): + self._validation_loss_metric.add(loss.item()) + if self._criterion_type != 'triplet_loss': + self._validation_accuracy_metric.add(model_output[1], target) + + def _print_performances(self): + if self._criterion_type != 'triplet_loss': + print('\nTraining : Loss={}, Accuracy={}'.format(self._training_loss_metric.get_loss(), + self._training_accuracy_metric.get_accuracy())) + print('Validation : Loss={}, Accuracy={}\n'.format(self._validation_loss_metric.get_loss(), + self._validation_accuracy_metric.get_accuracy())) + else: + print('\nTraining : Loss={}'.format(self._training_loss_metric.get_loss())) + print('Validation : Loss={}\n'.format(self._validation_loss_metric.get_loss())) + + def _save_learning_curves(self): + self._learning_curves.add_training_loss_value(self._training_loss_metric.get_loss()) + self._learning_curves.add_validation_loss_value(self._validation_loss_metric.get_loss()) + if self._criterion_type != 'triplet_loss': + self._learning_curves.add_training_accuracy_value(self._training_accuracy_metric.get_accuracy()) + self._learning_curves.add_validation_accuracy_value(self._validation_accuracy_metric.get_accuracy()) + + self._learning_curves.save(os.path.join(self._output_path, 'learning_curves.png'), + os.path.join(self._output_path, 'learning_curves.json')) + + def _evaluate(self, model, device, dataset_loader, output_path): + print('Evaluation', flush=True) + + lfw_evaluation = LfwEvaluation(model, device, dataset_loader.dataset.transforms(), + self._lfw_dataset_root, output_path) + lfw_evaluation.evaluate() + + if self._criterion_type != 'triplet_loss': + _evaluate_classification_accuracy(model, device, dataset_loader, self._class_count) diff --git a/tools/dnn_training/face_recognition/trainers/face_descriptor_extractor_trainer.py b/tools/dnn_training/face_recognition/trainers/face_descriptor_extractor_trainer.py index cefef622..cb27f713 100644 --- a/tools/dnn_training/face_recognition/trainers/face_descriptor_extractor_trainer.py +++ b/tools/dnn_training/face_recognition/trainers/face_descriptor_extractor_trainer.py @@ -1,6 +1,5 @@ import os -import torch.nn as nn import torchvision.transforms as transforms from tqdm import tqdm @@ -12,8 +11,11 @@ from common.metrics import LossMetric, ClassificationAccuracyMetric, TopNClassificationAccuracyMetric, \ ClassificationMeanAveragePrecisionMetric, LossLearningCurves, LossAccuracyLearningCurves -from face_recognition.criterions import FaceDescriptorAmSoftmaxLoss -from face_recognition.datasets import IMAGE_SIZE, Vggface2Dataset, LFW_OVERLAPPED_VGGFACE2_CLASS_NAMES +from face_recognition.criterions import FaceDescriptorAmSoftmaxLoss, FaceDescriptorArcFaceLoss, \ + FaceDescriptorCrossEntropyLoss +from face_recognition.datasets import IMAGE_SIZE, FaceDataset, FaceConcatDataset, LFW_OVERLAPPED_VGGFACE2_CLASS_NAMES, \ + LFW_OVERLAPPED_MSCELEB1M_CLASS_NAMES +from face_recognition.datasets import ImbalancedFaceDatasetSampler from face_recognition.metrics import LfwEvaluation import torch @@ -21,7 +23,7 @@ class FaceDescriptorExtractorTrainer(Trainer): - def __init__(self, device, model, vvgface2_dataset_root='', lfw_dataset_root='', output_path='', + def __init__(self, device, model, dataset_roots='', lfw_dataset_root='', output_path='', epoch_count=10, learning_rate=0.01, weight_decay=0.0, criterion_type='triplet_loss', batch_size=128, margin=0.2, model_checkpoint=None): @@ -31,7 +33,7 @@ def __init__(self, device, model, vvgface2_dataset_root='', lfw_dataset_root='', self._class_count = model.class_count() super(FaceDescriptorExtractorTrainer, self).__init__(device, model, - dataset_root=vvgface2_dataset_root, + dataset_root=dataset_roots, output_path=output_path, epoch_count=epoch_count, learning_rate=learning_rate, @@ -51,37 +53,27 @@ def __init__(self, device, model, vvgface2_dataset_root='', lfw_dataset_root='', self._validation_accuracy_metric = ClassificationAccuracyMetric() def _create_criterion(self, model): - if self._criterion_type == 'triplet_loss': - return TripletLoss(margin=self._margin) - elif self._criterion_type == 'cross_entropy_loss': - criterion = nn.CrossEntropyLoss() - return lambda model_output, target: criterion(model_output[1], target) - elif self._criterion_type == 'am_softmax_loss': - return FaceDescriptorAmSoftmaxLoss(s=30.0, m=self._margin, - start_annealing_epoch=0, - end_annealing_epoch=self._epoch_count // 4) - else: - raise ValueError('Invalid criterion type') + return _create_criterion(self._criterion_type, self._margin, self._epoch_count) - def _create_training_dataset_loader(self, dataset_root, batch_size, batch_size_division): - dataset = Vggface2Dataset(dataset_root, split='training', - transforms=create_training_image_transform(), - ignored_classes=LFW_OVERLAPPED_VGGFACE2_CLASS_NAMES) - return self._create_dataset_loader(dataset, batch_size, batch_size_division) + def _create_training_dataset_loader(self, dataset_roots, batch_size, batch_size_division): + dataset = _create_dataset(dataset_roots, 'training', create_training_image_transform()) + return self._create_dataset_loader(dataset, batch_size, batch_size_division, + use_imbalanced_face_dataset_sampler=True) - def _create_validation_dataset_loader(self, dataset_root, batch_size, batch_size_division): - dataset = Vggface2Dataset(dataset_root, split='validation', - transforms=create_validation_image_transform(), - ignored_classes=LFW_OVERLAPPED_VGGFACE2_CLASS_NAMES) + def _create_validation_dataset_loader(self, dataset_roots, batch_size, batch_size_division): + dataset = _create_dataset(dataset_roots, 'validation', create_validation_image_transform()) return self._create_dataset_loader(dataset, batch_size, batch_size_division) - def _create_dataset_loader(self, dataset, batch_size, batch_size_division): + def _create_dataset_loader(self, dataset, batch_size, batch_size_division, + use_imbalanced_face_dataset_sampler=False): if self._criterion_type == 'triplet_loss': batch_sampler = TripletLossBatchSampler(dataset, batch_size=batch_size // batch_size_division) - return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0) + return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=8) else: - return torch.utils.data.DataLoader(dataset, batch_size=batch_size // batch_size_division, shuffle=True, - num_workers=2) + sampler = ImbalancedFaceDatasetSampler(dataset) if use_imbalanced_face_dataset_sampler else None + return torch.utils.data.DataLoader(dataset, batch_size=batch_size // batch_size_division, + sampler=sampler, + num_workers=8) def _clear_between_training(self): self._learning_curves.clear() @@ -99,11 +91,6 @@ def _measure_training_metrics(self, loss, model_output, target): if self._criterion_type != 'triplet_loss': self._training_accuracy_metric.add(model_output[1], target) - def _validate(self): - super(FaceDescriptorExtractorTrainer, self)._validate() - if self._criterion_type == 'am_softmax_loss': - self._criterion.next_epoch() - def _clear_between_validation_epoch(self): self._validation_loss_metric.clear() if self._criterion_type != 'triplet_loss': @@ -142,24 +129,7 @@ def _evaluate(self, model, device, dataset_loader, output_path): lfw_evaluation.evaluate() if self._criterion_type != 'triplet_loss': - self._evaluate_classification_accuracy(model, device, dataset_loader) - - def _evaluate_classification_accuracy(self, model, device, dataset_loader): - print('Evaluation - Classification') - top1_accuracy_metric = ClassificationAccuracyMetric() - top5_accuracy_metric = TopNClassificationAccuracyMetric(5) - map_metric = ClassificationMeanAveragePrecisionMetric(self._class_count) - - for data in tqdm(dataset_loader): - model_output = model(data[0].to(device)) - target = self._move_target_to_device(data[1], device) - top1_accuracy_metric.add(model_output[1], target) - top5_accuracy_metric.add(model_output[1], target) - map_metric.add(model_output[1], target) - - print('\nTest : Top 1 Accuracy={}, Top 5 Accuracy={}, mAP={}'.format(top1_accuracy_metric.get_accuracy(), - top5_accuracy_metric.get_accuracy(), - map_metric.get_value())) + _evaluate_classification_accuracy(model, device, dataset_loader, self._class_count) def create_training_image_transform(): @@ -183,3 +153,53 @@ def create_validation_image_transform(): transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) + + +def _create_criterion(criterion_type, margin, epoch_count): + if criterion_type == 'triplet_loss': + return TripletLoss(margin=margin) + elif criterion_type == 'cross_entropy_loss': + return FaceDescriptorCrossEntropyLoss() + elif criterion_type == 'am_softmax_loss': + return FaceDescriptorAmSoftmaxLoss(s=30.0, m=margin, + start_annealing_epoch=0, + end_annealing_epoch=epoch_count // 4) + elif criterion_type == 'arc_face_loss': + return FaceDescriptorArcFaceLoss(s=30.0, m=margin, + start_annealing_epoch=0, + end_annealing_epoch=epoch_count // 4) + else: + raise ValueError('Invalid criterion type') + + +def _create_dataset(dataset_roots, split, transforms): + datasets = [] + for dataset_root in dataset_roots: + ignored_classes = [] + if 'vgg' in dataset_root.lower(): + ignored_classes = LFW_OVERLAPPED_VGGFACE2_CLASS_NAMES + elif 'ms' in dataset_root.lower(): + ignored_classes = LFW_OVERLAPPED_MSCELEB1M_CLASS_NAMES + + dataset = FaceDataset(dataset_root, split=split, ignored_classes=ignored_classes) + datasets.append(dataset) + + return FaceConcatDataset(datasets, transforms=transforms) + + +def _evaluate_classification_accuracy(model, device, dataset_loader, class_count): + print('Evaluation - Classification') + top1_accuracy_metric = ClassificationAccuracyMetric() + top5_accuracy_metric = TopNClassificationAccuracyMetric(5) + map_metric = ClassificationMeanAveragePrecisionMetric(class_count) + + for data in tqdm(dataset_loader): + model_output = model(data[0].to(device)) + target = data[1].to(device) + top1_accuracy_metric.add(model_output[1], target) + top5_accuracy_metric.add(model_output[1], target) + map_metric.add(model_output[1], target) + + print('\nTest : Top 1 Accuracy={}, Top 5 Accuracy={}, mAP={}'.format(top1_accuracy_metric.get_accuracy(), + top5_accuracy_metric.get_accuracy(), + map_metric.get_value())) diff --git a/tools/dnn_training/object_detection/datasets/coco_detection_transforms.py b/tools/dnn_training/object_detection/datasets/coco_detection_transforms.py index 9b809b15..0dce10ac 100644 --- a/tools/dnn_training/object_detection/datasets/coco_detection_transforms.py +++ b/tools/dnn_training/object_detection/datasets/coco_detection_transforms.py @@ -63,10 +63,12 @@ def _resize_image(image, size): image = image.resize((int(w * scale), int(h * scale)), Image.BILINEAR) - padded_image = Image.new('RGB', (size[0], size[1]), (128, 128, 128)) - padded_image.paste(image) + offset_x = int((size[0] - image.width) / 2) + offset_y = int((size[1] - image.height) / 2) + padded_image = Image.new('RGB', (size[0], size[1]), (114, 114, 114)) + padded_image.paste(image, (offset_x, offset_y)) - return padded_image, scale + return padded_image, scale, offset_x, offset_y def _hflip_bbox(target, image_size): @@ -122,7 +124,7 @@ def __call__(self, image, target): image, target = _random_crop(image, target) - resized_image, scale = _resize_image(image, self._image_size) + resized_image, scale, offset_x, offset_y = _resize_image(image, self._image_size) target = _convert_bbox_to_yolo(target, scale, self._image_size, self._one_hot_class) if random.random() < self._horizontal_flip_p: @@ -132,7 +134,9 @@ def __call__(self, image, target): resized_image_tensor = F.to_tensor(resized_image) metadata = { - 'scale': scale + 'scale': scale, + 'offset_x': offset_x, + 'offset_y': offset_y } return resized_image_tensor, target, metadata @@ -143,13 +147,15 @@ def __init__(self, image_size, one_hot_class): self._one_hot_class = one_hot_class def __call__(self, image, target): - resized_image, scale = _resize_image(image, self._image_size) + resized_image, scale, offset_x, offset_y = _resize_image(image, self._image_size) resized_image_tensor = F.to_tensor(resized_image) if target is not None: target = _convert_bbox_to_yolo(target, scale, self._image_size, self._one_hot_class) metadata = { - 'scale': scale + 'scale': scale, + 'offset_x': offset_x, + 'offset_y': offset_y } return resized_image_tensor, target, metadata diff --git a/tools/dnn_training/object_detection/datasets/object_detection_coco.py b/tools/dnn_training/object_detection/datasets/object_detection_coco.py index c343c48a..0fa32044 100644 --- a/tools/dnn_training/object_detection/datasets/object_detection_coco.py +++ b/tools/dnn_training/object_detection/datasets/object_detection_coco.py @@ -29,7 +29,9 @@ def __getitem__(self, index): 'image_id': image_id, 'initial_width': initial_width, 'initial_height': initial_height, - 'scale': transforms_metadata['scale'] + 'scale': transforms_metadata['scale'], + 'offset_x': transforms_metadata['offset_x'], + 'offset_y': transforms_metadata['offset_y'] } return image, target, metadata diff --git a/tools/dnn_training/object_detection/datasets/open_images_detection_transforms.py b/tools/dnn_training/object_detection/datasets/open_images_detection_transforms.py index 43002147..94cc88c8 100644 --- a/tools/dnn_training/object_detection/datasets/open_images_detection_transforms.py +++ b/tools/dnn_training/object_detection/datasets/open_images_detection_transforms.py @@ -59,13 +59,15 @@ def __init__(self, image_size, one_hot_class, class_count): def __call__(self, image, target): image = self._image_only_transform(image) - resized_image, scale = _resize_image(image, self._image_size) + resized_image, scale, offset_x, offset_y = _resize_image(image, self._image_size) target = _convert_bbox_to_yolo(target, scale, self._image_size, self._one_hot_class, self._class_count) resized_image_tensor = F.to_tensor(resized_image) metadata = { - 'scale': scale + 'scale': scale, + 'offset_x': offset_x, + 'offset_y': offset_y } return resized_image_tensor, target, metadata @@ -77,13 +79,15 @@ def __init__(self, image_size, one_hot_class, class_count): self._class_count = class_count def __call__(self, image, target): - resized_image, scale = _resize_image(image, self._image_size) + resized_image, scale, offset_x, offset_y = _resize_image(image, self._image_size) resized_image_tensor = F.to_tensor(resized_image) if target is not None: target = _convert_bbox_to_yolo(target, scale, self._image_size, self._one_hot_class, self._class_count) metadata = { - 'scale': scale + 'scale': scale, + 'offset_x': offset_x, + 'offset_y': offset_y } return resized_image_tensor, target, metadata diff --git a/tools/dnn_training/object_detection/descriptor_yolo_v4.py b/tools/dnn_training/object_detection/descriptor_yolo_v4.py index e70e0aa4..ad5c3d2d 100644 --- a/tools/dnn_training/object_detection/descriptor_yolo_v4.py +++ b/tools/dnn_training/object_detection/descriptor_yolo_v4.py @@ -5,7 +5,7 @@ from common.modules import Mish -from object_detection.modules.descriptor_yolo_layer import DescriptorYoloLayer +from object_detection.modules.descriptor_yolo_layer import DescriptorYoloV4Layer IMAGE_SIZE = (608, 608) IN_CHANNELS = 3 @@ -535,8 +535,8 @@ def __init__(self, class_count, descriptor_size): ) self._anchors.append(np.array([(12, 16), (19, 36), (40, 28)])) self._output_strides.append(8) - self._yolo139 = DescriptorYoloLayer(IMAGE_SIZE, 8, self._anchors[-1].tolist(), class_count, descriptor_size, - 1.2) + self._yolo139 = DescriptorYoloV4Layer(IMAGE_SIZE, 8, self._anchors[-1].tolist(), class_count, descriptor_size, + 1.2) self._conv141 = nn.Sequential( nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False), @@ -579,8 +579,8 @@ def __init__(self, class_count, descriptor_size): ) self._anchors.append(np.array([(36, 75), (76, 55), (72, 146)])) self._output_strides.append(16) - self._yolo150 = DescriptorYoloLayer(IMAGE_SIZE, 16, self._anchors[-1].tolist(), class_count, descriptor_size, - 1.1) + self._yolo150 = DescriptorYoloV4Layer(IMAGE_SIZE, 16, self._anchors[-1].tolist(), class_count, descriptor_size, + 1.1) self._conv152 = nn.Sequential( nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False), @@ -623,8 +623,8 @@ def __init__(self, class_count, descriptor_size): ) self._anchors.append(np.array([(142, 110), (192, 243), (459, 401)])) self._output_strides.append(32) - self._yolo161 = DescriptorYoloLayer(IMAGE_SIZE, 32, self._anchors[-1].tolist(), class_count, descriptor_size, - 1.05) + self._yolo161 = DescriptorYoloV4Layer(IMAGE_SIZE, 32, self._anchors[-1].tolist(), class_count, descriptor_size, + 1.05) def get_image_size(self): return IMAGE_SIZE diff --git a/tools/dnn_training/object_detection/descriptor_yolo_v4_tiny.py b/tools/dnn_training/object_detection/descriptor_yolo_v4_tiny.py index 6dcead3d..98bb9e52 100644 --- a/tools/dnn_training/object_detection/descriptor_yolo_v4_tiny.py +++ b/tools/dnn_training/object_detection/descriptor_yolo_v4_tiny.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from object_detection.modules.descriptor_yolo_layer import DescriptorYoloLayer +from object_detection.modules.descriptor_yolo_layer import DescriptorYoloV4Layer IMAGE_SIZE = (416, 416) IN_CHANNELS = 3 @@ -120,8 +120,8 @@ def __init__(self, class_count, descriptor_size): ) self._anchors.append(np.array([(81, 82), (135, 169), (344, 319)])) self._output_strides.append(32) - self._yolo30 = DescriptorYoloLayer(IMAGE_SIZE, 32, self._anchors[-1].tolist(), class_count, descriptor_size, - 1.05) + self._yolo30 = DescriptorYoloV4Layer(IMAGE_SIZE, 32, self._anchors[-1].tolist(), class_count, descriptor_size, + 1.05) self._conv32 = nn.Sequential( nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False), @@ -140,8 +140,8 @@ def __init__(self, class_count, descriptor_size): ) self._anchors.append(np.array([(23, 27), (37, 58), (81, 82)])) self._output_strides.append(16) - self._yolo37 = DescriptorYoloLayer(IMAGE_SIZE, 16, self._anchors[-1].tolist(), class_count, descriptor_size, - 1.05) + self._yolo37 = DescriptorYoloV4Layer(IMAGE_SIZE, 16, self._anchors[-1].tolist(), class_count, descriptor_size, + 1.05) def get_image_size(self): return IMAGE_SIZE diff --git a/tools/dnn_training/object_detection/descriptor_yolo_v7.py b/tools/dnn_training/object_detection/descriptor_yolo_v7.py new file mode 100644 index 00000000..555435c2 --- /dev/null +++ b/tools/dnn_training/object_detection/descriptor_yolo_v7.py @@ -0,0 +1,625 @@ +from collections import OrderedDict + +import numpy as np + +import torch +import torch.nn as nn + +from object_detection.modules.descriptor_yolo_layer import DescriptorYoloV7Layer +from object_detection.modules.yolo_v7_modules import YoloV7SPPCSPC, RepConv + + +IMAGE_SIZE = (640, 640) +IN_CHANNELS = 3 + + +# Generated from: yolov7.yaml: +class DescriptorYoloV7(nn.Module): + def __init__(self, class_count=80, embedding_size=128, class_probs=False): + super(DescriptorYoloV7, self).__init__() + + self._anchors = [] + self._output_strides = [8, 16, 32] + self._anchors.append(np.array([(12, 16), (19, 36), (40, 28)])) + self._anchors.append(np.array([(36, 75), (76, 55), (72, 146)])) + self._anchors.append(np.array([(142, 110), (192, 243), (459, 401)])) + + self._conv0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(32, eps=0.001), + nn.SiLU(), + ) + self._conv1 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv2 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv3 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv4 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv5 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv6 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv7 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv8 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv9 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + + self._conv11 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._max_pool12 = nn.MaxPool2d(kernel_size=2, stride=2) + self._conv13 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv14 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv15 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + + self._conv17 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv18 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv19 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv20 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv21 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv22 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + + self._conv24 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + self._max_pool25 = nn.MaxPool2d(kernel_size=2, stride=2) + self._conv26 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv27 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv28 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + + self._conv30 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv31 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv32 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv33 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv34 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv35 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + + self._conv37 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(1024, eps=0.001), + nn.SiLU(), + ) + self._max_pool38 = nn.MaxPool2d(kernel_size=2, stride=2) + self._conv39 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + self._conv40 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + self._conv41 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + + self._conv43 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv44 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv45 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv46 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv47 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv48 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + + self._conv50 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(1024, eps=0.001), + nn.SiLU(), + ) + self._sppcspc51 = YoloV7SPPCSPC(1024, 512) + self._conv52 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._upsample53 = nn.Upsample(scale_factor=2, mode='nearest') + self._conv54 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + + self._conv56 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv57 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv58 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv59 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv60 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv61 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + + self._conv63 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv64 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._upsample65 = nn.Upsample(scale_factor=2, mode='nearest') + self._conv66 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + + self._conv68 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv69 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv70 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv71 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv72 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv73 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + + self._conv75 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._max_pool76 = nn.MaxPool2d(kernel_size=2, stride=2) + self._conv77 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv78 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv79 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + + self._conv81 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv82 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv83 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv84 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv85 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv86 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + + self._conv88 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._max_pool89 = nn.MaxPool2d(kernel_size=2, stride=2) + self._conv90 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv91 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv92 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + + self._conv94 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + self._conv95 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + self._conv96 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv97 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv98 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv99 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + + self._conv101 = nn.Sequential( + nn.Conv2d(in_channels=2048, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + self._rep_conv102 = RepConv(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, activation=nn.SiLU()) + self._rep_conv103 = RepConv(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, groups=1, activation=nn.SiLU()) + self._rep_conv104 = RepConv(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, groups=1, activation=nn.SiLU()) + + self._yolo0 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=self._anchors[0].shape[0] * (embedding_size + 5), kernel_size=1), + DescriptorYoloV7Layer(IMAGE_SIZE, 8, self._anchors[0], embedding_size) + ) + self._yolo1 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=self._anchors[1].shape[0] * (embedding_size + 5), kernel_size=1), + DescriptorYoloV7Layer(IMAGE_SIZE, 16, self._anchors[1], embedding_size) + ) + self._yolo2 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=self._anchors[2].shape[0] * (embedding_size + 5), kernel_size=1), + DescriptorYoloV7Layer(IMAGE_SIZE, 32, self._anchors[2], embedding_size) + ) + + self._classifier = nn.Linear(embedding_size, class_count, bias=False) + self._class_probs = class_probs + + def get_image_size(self): + return IMAGE_SIZE + + def get_anchors(self): + return self._anchors + + def get_output_strides(self): + return self._output_strides + + def forward(self, x): + y0 = self._conv0(x) + y1 = self._conv1(y0) + y2 = self._conv2(y1) + y3 = self._conv3(y2) + y4 = self._conv4(y3) + y5 = self._conv5(y3) + y6 = self._conv6(y5) + y7 = self._conv7(y6) + y8 = self._conv8(y7) + y9 = self._conv9(y8) + y10 = torch.cat([y9, y7, y5, y4], dim=1) + + y11 = self._conv11(y10) + y12 = self._max_pool12(y11) + y13 = self._conv13(y12) + y14 = self._conv14(y11) + y15 = self._conv15(y14) + y16 = torch.cat([y15, y13], dim=1) + + y17 = self._conv17(y16) + y18 = self._conv18(y16) + y19 = self._conv19(y18) + y20 = self._conv20(y19) + y21 = self._conv21(y20) + y22 = self._conv22(y21) + y23 = torch.cat([y22, y20, y18, y17], dim=1) + + y24 = self._conv24(y23) + y25 = self._max_pool25(y24) + y26 = self._conv26(y25) + y27 = self._conv27(y24) + y28 = self._conv28(y27) + y29 = torch.cat([y28, y26], dim=1) + + y30 = self._conv30(y29) + y31 = self._conv31(y29) + y32 = self._conv32(y31) + y33 = self._conv33(y32) + y34 = self._conv34(y33) + y35 = self._conv35(y34) + y36 = torch.cat([y35, y33, y31, y30], dim=1) + + y37 = self._conv37(y36) + y38 = self._max_pool38(y37) + y39 = self._conv39(y38) + y40 = self._conv40(y37) + y41 = self._conv41(y40) + y42 = torch.cat([y41, y39], dim=1) + + y43 = self._conv43(y42) + y44 = self._conv44(y42) + y45 = self._conv45(y44) + y46 = self._conv46(y45) + y47 = self._conv47(y46) + y48 = self._conv48(y47) + y49 = torch.cat([y48, y46, y44, y43], dim=1) + + y50 = self._conv50(y49) + y51 = self._sppcspc51(y50) + y52 = self._conv52(y51) + y53 = self._upsample53(y52) + y54 = self._conv54(y37) + y55 = torch.cat([y54, y53], dim=1) + + y56 = self._conv56(y55) + y57 = self._conv57(y55) + y58 = self._conv58(y57) + y59 = self._conv59(y58) + y60 = self._conv60(y59) + y61 = self._conv61(y60) + y62 = torch.cat([y61, y60, y59, y58, y57, y56], dim=1) + + y63 = self._conv63(y62) + y64 = self._conv64(y63) + y65 = self._upsample65(y64) + y66 = self._conv66(y24) + y67 = torch.cat([y66, y65], dim=1) + + y68 = self._conv68(y67) + y69 = self._conv69(y67) + y70 = self._conv70(y69) + y71 = self._conv71(y70) + y72 = self._conv72(y71) + y73 = self._conv73(y72) + y74 = torch.cat([y73, y72, y71, y70, y69, y68], dim=1) + + y75 = self._conv75(y74) + y76 = self._max_pool76(y75) + y77 = self._conv77(y76) + y78 = self._conv78(y75) + y79 = self._conv79(y78) + y80 = torch.cat([y79, y77, y63], dim=1) + + y81 = self._conv81(y80) + y82 = self._conv82(y80) + y83 = self._conv83(y82) + y84 = self._conv84(y83) + y85 = self._conv85(y84) + y86 = self._conv86(y85) + y87 = torch.cat([y86, y85, y84, y83, y82, y81], dim=1) + + y88 = self._conv88(y87) + y89 = self._max_pool89(y88) + y90 = self._conv90(y89) + y91 = self._conv91(y88) + y92 = self._conv92(y91) + y93 = torch.cat([y92, y90, y51], dim=1) + + y94 = self._conv94(y93) + y95 = self._conv95(y93) + y96 = self._conv96(y95) + y97 = self._conv97(y96) + y98 = self._conv98(y97) + y99 = self._conv99(y98) + y100 = torch.cat([y99, y98, y97, y96, y95, y94], dim=1) + + y101 = self._conv101(y100) + y102 = self._rep_conv102(y75) + y103 = self._rep_conv103(y88) + y104 = self._rep_conv104(y101) + + box0, embedding0 = self._yolo0(y102) + box1, embedding1 = self._yolo1(y103) + box2, embedding2 = self._yolo2(y104) + + d0 = self._classify_embeddings(box0, embedding0) + d1 = self._classify_embeddings(box1, embedding1) + d2 = self._classify_embeddings(box2, embedding2) + + return [d0, d1, d2] + + def _classify_embeddings(self, box, embedding): + classes = self._classifier(embedding) + if self._class_probs: + classes = torch.softmax(classes, dim=4) + + return torch.cat([box, classes, embedding], dim=4) + + def load_weights(self, weights_path): + loaded_state_dict = self._filter_static_dict(torch.load(weights_path), 'anchor') + current_state_dict = self._filter_static_dict(self.state_dict(), 'offset') + + for i, (kl, kc) in enumerate(zip(loaded_state_dict.keys(), current_state_dict.keys())): + if current_state_dict[kc].size() != loaded_state_dict[kl].size(): + raise ValueError('Mismatching size.') + current_state_dict[kc] = loaded_state_dict[kl] + + self.load_state_dict(current_state_dict, strict=False) + + def _filter_static_dict(self, state_dict, x): + return OrderedDict([(k, v) for k, v in state_dict.items() if x not in k]) diff --git a/tools/dnn_training/object_detection/filter_yolo_predictions.py b/tools/dnn_training/object_detection/filter_yolo_predictions.py index 06fac7b4..379ae83a 100644 --- a/tools/dnn_training/object_detection/filter_yolo_predictions.py +++ b/tools/dnn_training/object_detection/filter_yolo_predictions.py @@ -1,7 +1,7 @@ import torch from object_detection.criterions.yolo_v4_loss import calculate_iou -from object_detection.modules.yolo_layer import CONFIDENCE_INDEX +from object_detection.modules.yolo_layer import CONFIDENCE_INDEX, CLASSES_INDEX def group_predictions(predictions): @@ -19,7 +19,36 @@ def group_predictions(predictions): def filter_yolo_predictions(predictions, confidence_threshold=0.7, nms_threshold=0.6): predictions = predictions[predictions[:, CONFIDENCE_INDEX] > confidence_threshold] + return _nms(predictions, nms_threshold) + +def filter_yolo_predictions_by_classes(predictions, confidence_threshold=0.7, nms_threshold=0.6): + predictions = predictions[predictions[:, CONFIDENCE_INDEX] > confidence_threshold] + predictions_by_class_index = _group_predictions_by_class_index(predictions) + + valid_predictions = [] + for c, p in predictions_by_class_index.items(): + valid_predictions += _nms(p, nms_threshold) + return valid_predictions + + +def _group_predictions_by_class_index(predictions): + class_count = predictions[:, CLASSES_INDEX:].size(1) + class_indexes = torch.argmax(predictions[:, CLASSES_INDEX:], dim=1).tolist() + + predictions_by_class_index = {c: [] for c in range(class_count)} + for i, c in enumerate(class_indexes): + predictions_by_class_index[c].append(predictions[i:i + 1]) + + tensor_predictions_by_class_index = {} + for c in predictions_by_class_index.keys(): + if len(predictions_by_class_index[c]) > 0: + tensor_predictions_by_class_index[c] = torch.cat(predictions_by_class_index[c], dim=0) + + return tensor_predictions_by_class_index + + +def _nms(predictions, nms_threshold): sorted_index = torch.argsort(predictions[:, CONFIDENCE_INDEX], descending=True) predictions = predictions[sorted_index] diff --git a/tools/dnn_training/object_detection/metrics/coco_object_evaluation.py b/tools/dnn_training/object_detection/metrics/coco_object_evaluation.py index c672c414..438a6c08 100644 --- a/tools/dnn_training/object_detection/metrics/coco_object_evaluation.py +++ b/tools/dnn_training/object_detection/metrics/coco_object_evaluation.py @@ -7,11 +7,11 @@ from object_detection.datasets.coco_detection_transforms import CLASS_INDEX_TO_CATEGORY_ID_MAPPING from object_detection.modules.yolo_layer import X_INDEX, Y_INDEX, W_INDEX, H_INDEX, CONFIDENCE_INDEX, CLASSES_INDEX -from object_detection.filter_yolo_predictions import group_predictions, filter_yolo_predictions +from object_detection.filter_yolo_predictions import group_predictions, filter_yolo_predictions_by_classes class CocoObjectEvaluation: - def __init__(self, model, device, dataset_loader, output_path, confidence_threshold=0.005, nms_threshold=0.45): + def __init__(self, model, device, dataset_loader, output_path, confidence_threshold=0.001, nms_threshold=0.65): self._model = model self._device = device self._dataset_loader = dataset_loader @@ -29,47 +29,46 @@ def evaluate(self): return self._dataset.evaluate(self._results_file_path) def _get_results(self): - results = [] - for image, target, metadata in tqdm(self._dataset_loader): - predictions = self._model(image.to(self._device)) - predictions = group_predictions(predictions) - - for n in range(image.size()[0]): - image_id = metadata['image_id'][n].item() - scale = metadata['scale'][n].item() - results.extend(self._get_result(image_id, scale, predictions[n])) - - return results - - def _get_result(self, image_id, scale, predictions): - predictions = filter_yolo_predictions(predictions, self._confidence_threshold, self._nms_threshold) + with torch.no_grad(): + results = [] + for image, target, metadata in tqdm(self._dataset_loader): + predictions = self._model(image.to(self._device)) + predictions = group_predictions(predictions) + + for n in range(image.size()[0]): + image_id = metadata['image_id'][n].item() + scale = metadata['scale'][n].item() + offset_x = metadata['offset_x'][n].item() + offset_y = metadata['offset_y'][n].item() + results.extend(self._get_result(image_id, scale, offset_x, offset_y, predictions[n])) + + return results + + def _get_result(self, image_id, scale, offset_x, offset_y, predictions): + predictions = filter_yolo_predictions_by_classes(predictions, self._confidence_threshold, self._nms_threshold) if len(predictions) == 0: return [] - predictions = torch.stack(predictions) - - sorted_index = torch.argsort(predictions[:, CONFIDENCE_INDEX], descending=True) - sorted_predictions = predictions[sorted_index] - result = [] - for i in range(len(sorted_predictions)): - center_x = sorted_predictions[i][X_INDEX].item() / scale - center_y = sorted_predictions[i][Y_INDEX].item() / scale - w = sorted_predictions[i][W_INDEX].item() / scale - h = sorted_predictions[i][H_INDEX].item() / scale + for i in range(len(predictions)): + center_x = (predictions[i][X_INDEX].item() - offset_x) / scale + center_y = (predictions[i][Y_INDEX].item() - offset_y) / scale + w = predictions[i][W_INDEX].item() / scale + h = predictions[i][H_INDEX].item() / scale x = center_x - w / 2 y = center_y - h / 2 - class_index = torch.argmax(sorted_predictions[i][CLASSES_INDEX:CLASSES_INDEX + self._class_count]).item() + class_probs = torch.sigmoid(predictions[i][CLASSES_INDEX:CLASSES_INDEX + self._class_count]) + class_index = torch.argmax(class_probs).item() category_id = CLASS_INDEX_TO_CATEGORY_ID_MAPPING[class_index] result.append({ 'image_id': image_id, 'category_id': category_id, - 'bbox': [round(x), round(y), round(w), round(h)], - 'score': sorted_predictions[i][CONFIDENCE_INDEX].item(), + 'bbox': [round(x, 3), round(y, 3), round(w, 3), round(h, 3)], + 'score': predictions[i][CONFIDENCE_INDEX].item() * class_probs[class_index].item(), 'segmentation': [] }) diff --git a/tools/dnn_training/object_detection/modules/__init__.py b/tools/dnn_training/object_detection/modules/__init__.py index 4c02214a..e08e2721 100644 --- a/tools/dnn_training/object_detection/modules/__init__.py +++ b/tools/dnn_training/object_detection/modules/__init__.py @@ -1,4 +1,4 @@ -from object_detection.modules.yolo_layer import YoloLayer +from object_detection.modules.yolo_layer import YoloV4Layer, YoloV7Layer from object_detection.modules.yolo_v4 import YoloV4 from object_detection.modules.yolo_v4_tiny import YoloV4Tiny diff --git a/tools/dnn_training/object_detection/modules/convert_darknet_cfg_to_pytorch_module.py b/tools/dnn_training/object_detection/modules/convert_darknet_cfg_to_pytorch_module.py index 852219e9..ad4b99bb 100644 --- a/tools/dnn_training/object_detection/modules/convert_darknet_cfg_to_pytorch_module.py +++ b/tools/dnn_training/object_detection/modules/convert_darknet_cfg_to_pytorch_module.py @@ -14,370 +14,384 @@ } -class DarknetCfgToPytorchModuleConverter: - def convert(self, cfg_path, python_output_path, class_name): - layers = self._read_layers(cfg_path) - with open(python_output_path, 'w') as python_file: - in_channels = self._write_header(python_file, layers[0], class_name, cfg_path) - self._write_init(python_file, layers[1:], in_channels, class_name) - self._write_getters(python_file) - self._write_forward(python_file, layers[1:]) - self._write_load_weights(python_file, layers[1:]) - self._write_load_batch_norm_conv_weights(python_file) - self._write_load_conv_weights(python_file) - - def _read_layers(self, cfg_path): - with open(cfg_path, 'r') as cfg_file: - lines = cfg_file.readlines() - - layer = {} - layers = [] - - for i, line in enumerate(lines): - line = line.strip() - if len(line) == 0 or line.startswith('#'): - continue - - if line.startswith('[') and line.endswith(']'): - if len(layer) != 0: - layers.append(layer) - layer = {} - type = line[1:-1] - if type in SUPPORTED_LAYER_TYPES: - layer['type'] = type - layer['line_no'] = i + 1 - else: - raise ValueError('Invalid layer type (line={})'.format(i + 1)) - elif line.startswith('[') or line.endswith(']'): - raise ValueError('Invalid cfg file (line={})'.format(i + 1)) - elif line.count('=') == 1: - key, value = line.split('=') - layer[key.strip()] = value.strip() +def convert(cfg_path, python_output_path, class_name): + layers = _read_layers(cfg_path) + with open(python_output_path, 'w') as python_file: + in_channels = _write_header(python_file, layers[0], class_name, cfg_path) + _write_init(python_file, layers[1:], in_channels, class_name) + _write_getters(python_file) + _write_forward(python_file, layers[1:]) + _write_load_weights(python_file, layers[1:]) + _write_load_batch_norm_conv_weights(python_file) + _write_load_conv_weights(python_file) + + +def _read_layers(cfg_path): + with open(cfg_path, 'r') as cfg_file: + lines = cfg_file.readlines() + + layer = {} + layers = [] + + for i, line in enumerate(lines): + line = line.strip() + if len(line) == 0 or line.startswith('#'): + continue + + if line.startswith('[') and line.endswith(']'): + if len(layer) != 0: + layers.append(layer) + layer = {} + type = line[1:-1] + if type in SUPPORTED_LAYER_TYPES: + layer['type'] = type + layer['line_no'] = i + 1 else: - raise ValueError('Invalid cfg file (line={})'.format(i + 1)) - - if len(layer) != 0: - layers.append(layer) - return layers - - def _write_header(self, python_file, layer, class_name, cfg_path): - if layer['type'] != 'net': - raise ValueError('The type of the first cfg block must be "net" (line={}).'.format(layer['line_no'])) - if 'height' not in layer: - raise ValueError('A "net" block must contain the "height" attribute (line={}).'.format(layer['line_no'])) - if 'width' not in layer: - raise ValueError('A "net" block must contain the "width" attribute (line={}).'.format(layer['line_no'])) - if 'channels' not in layer: - raise ValueError('A "net" block must contain the "channels" attribute (line={}).'.format(layer['line_no'])) - - python_file.write('import numpy as np\n') - python_file.write('\n') - python_file.write('import torch\n') - python_file.write('import torch.nn as nn\n') - python_file.write('\n') - python_file.write('from common.modules import Mish, Swish\n') - python_file.write('\n') - python_file.write('from object_detection.modules.yolo_layer import YoloLayer\n') - python_file.write('\n') - python_file.write('\n') - python_file.write('IMAGE_SIZE = ({}, {})\n'.format(layer['height'], layer['width'])) - python_file.write('IN_CHANNELS = {}\n'.format(layer['channels'])) - python_file.write('\n') - python_file.write('\n') - python_file.write('# Genereated from: {}:\n'.format(os.path.basename(cfg_path))) - python_file.write('class {}(nn.Module):\n'.format(class_name)) - - return layer['channels'] - - def _write_init(self, python_file, layers, in_channels, class_name): - python_file.write(' def __init__(self):\n') - python_file.write(' super({}, self).__init__()\n'.format(class_name)) - python_file.write(' self._anchors = []\n') - python_file.write(' self._output_strides = []\n') - - cumulated_strides = [1] - in_channels = [in_channels] - for i, layer in enumerate(layers): - if layer['type'] == 'convolutional': - out_channels, stride = self._write_init_convolutional(python_file, i, layer, in_channels[-1]) - in_channels.append(out_channels) - cumulated_strides.append(cumulated_strides[-1] * stride) - elif layer['type'] == 'upsample': - out_channels, stride = self._write_init_upsample(python_file, i, layer, in_channels[-1]) - in_channels.append(out_channels) - cumulated_strides.append(cumulated_strides[-1] // stride) - elif layer['type'] == 'maxpool': - out_channels, stride = self._write_init_maxpool(python_file, i, layer, in_channels[-1]) - in_channels.append(out_channels) - cumulated_strides.append(cumulated_strides[-1] * stride) - elif layer['type'] == 'yolo': - in_channels.append(self._write_init_yolo(python_file, i, layer, cumulated_strides[-1])) - cumulated_strides.append(0) - elif layer['type'] == 'route': - out_channels, stride = self._write_init_route(layer, in_channels, cumulated_strides) - in_channels.append(out_channels) - cumulated_strides.append(stride) - elif layer['type'] == 'shortcut': - out_channels, stride = self._write_init_shortcut(layer, in_channels, cumulated_strides) - in_channels.append(out_channels) - cumulated_strides.append(stride) - else: - raise ValueError('Not supported layer (type={})'.format(layer['type'])) - python_file.write('\n') + raise ValueError('Invalid layer type (line={})'.format(i + 1)) + elif line.startswith('[') or line.endswith(']'): + raise ValueError('Invalid cfg file (line={})'.format(i + 1)) + elif line.count('=') == 1: + key, value = line.split('=') + layer[key.strip()] = value.strip() + else: + raise ValueError('Invalid cfg file (line={})'.format(i + 1)) + + if len(layer) != 0: + layers.append(layer) + return layers + + +def _write_header(python_file, layer, class_name, cfg_path): + if layer['type'] != 'net': + raise ValueError('The type of the first cfg block must be "net" (line={}).'.format(layer['line_no'])) + if 'height' not in layer: + raise ValueError('A "net" block must contain the "height" attribute (line={}).'.format(layer['line_no'])) + if 'width' not in layer: + raise ValueError('A "net" block must contain the "width" attribute (line={}).'.format(layer['line_no'])) + if 'channels' not in layer: + raise ValueError('A "net" block must contain the "channels" attribute (line={}).'.format(layer['line_no'])) + + python_file.write('import numpy as np\n') + python_file.write('\n') + python_file.write('import torch\n') + python_file.write('import torch.nn as nn\n') + python_file.write('\n') + python_file.write('from common.modules import Mish, Swish\n') + python_file.write('\n') + python_file.write('from object_detection.modules.yolo_layer import YoloV4Layer\n') + python_file.write('\n') + python_file.write('\n') + python_file.write('IMAGE_SIZE = ({}, {})\n'.format(layer['height'], layer['width'])) + python_file.write('IN_CHANNELS = {}\n'.format(layer['channels'])) + python_file.write('\n') + python_file.write('\n') + python_file.write('# Generated from: {}:\n'.format(os.path.basename(cfg_path))) + python_file.write('class {}(nn.Module):\n'.format(class_name)) + + return layer['channels'] + + +def _write_init(python_file, layers, in_channels, class_name): + python_file.write(' def __init__(self, class_probs=False):\n') + python_file.write(' super({}, self).__init__()\n'.format(class_name)) + python_file.write(' self._anchors = []\n') + python_file.write(' self._output_strides = []\n') + + cumulated_strides = [1] + in_channels = [in_channels] + for i, layer in enumerate(layers): + if layer['type'] == 'convolutional': + out_channels, stride = _write_init_convolutional(python_file, i, layer, in_channels[-1]) + in_channels.append(out_channels) + cumulated_strides.append(cumulated_strides[-1] * stride) + elif layer['type'] == 'upsample': + out_channels, stride = _write_init_upsample(python_file, i, layer, in_channels[-1]) + in_channels.append(out_channels) + cumulated_strides.append(cumulated_strides[-1] // stride) + elif layer['type'] == 'maxpool': + out_channels, stride = _write_init_maxpool(python_file, i, layer, in_channels[-1]) + in_channels.append(out_channels) + cumulated_strides.append(cumulated_strides[-1] * stride) + elif layer['type'] == 'yolo': + in_channels.append(_write_init_yolo(python_file, i, layer, cumulated_strides[-1])) + cumulated_strides.append(0) + elif layer['type'] == 'route': + out_channels, stride = _write_init_route(layer, in_channels, cumulated_strides) + in_channels.append(out_channels) + cumulated_strides.append(stride) + elif layer['type'] == 'shortcut': + out_channels, stride = _write_init_shortcut(layer, in_channels, cumulated_strides) + in_channels.append(out_channels) + cumulated_strides.append(stride) + else: + raise ValueError('Not supported layer (type={})'.format(layer['type'])) python_file.write('\n') + python_file.write('\n') - def _write_init_convolutional(self, python_file, i, layer, in_channels): - not_supported_options = ['stride_x', 'dilation', 'antialiasing', 'padding', - 'binary', 'xnor', 'bin_output', 'sway', 'rotate', 'stretch', 'stretch_sway', - 'flipped', 'dot', 'angle', 'grad_centr', 'reverse', 'coordconv', - 'stream', 'wait_stream'] - mandatory_options = ['activation', 'filters', 'size', 'stride'] - _check_not_supported_options(layer, not_supported_options) - _check_mandatory_options(layer, mandatory_options) +def _write_init_convolutional(python_file, i, layer, in_channels): + not_supported_options = ['stride_x', 'dilation', 'antialiasing', 'padding', + 'binary', 'xnor', 'bin_output', 'sway', 'rotate', 'stretch', 'stretch_sway', + 'flipped', 'dot', 'angle', 'grad_centr', 'reverse', 'coordconv', + 'stream', 'wait_stream'] + mandatory_options = ['activation', 'filters', 'size', 'stride'] - activation = layer['activation'] - out_channels = int(layer['filters']) - kernel_size = int(layer['size']) - stride = int(layer['stride']) - padding = 0 - if 'pad' in layer and int(layer['pad']) == 1: - padding = kernel_size // 2 - has_batch_norm = 'batch_normalize' in layer and int(layer['batch_normalize']) == 1 - - groups = 1 if 'groups' not in layer else int(layer['groups']) - - python_file.write(' self._conv{} = nn.Sequential(\n'.format(i)) - python_file.write(' nn.Conv2d({}, {}, {}, stride={}, padding={}, bias={}, groups={}),\n' - .format(in_channels, out_channels, kernel_size, stride, padding, not has_batch_norm, groups)) - if has_batch_norm: - python_file.write(' nn.BatchNorm2d({}),\n'.format(out_channels)) - if activation != 'linear': - python_file.write(' {}\n'.format(ACTIVATION_MODULES_BY_NAME[activation])) - python_file.write(' )') - - return out_channels, stride - - def _write_init_upsample(self, python_file, i, layer, in_channels): - not_supported_options = ['scale'] - mandatory_options = ['stride'] - - _check_not_supported_options(layer, not_supported_options) - _check_mandatory_options(layer, mandatory_options) - - stride = int(layer['stride']) - python_file.write(' self._upsample{} = nn.Upsample(scale_factor={}, mode=\'nearest\')'.format(i, stride)) - return in_channels, stride - - def _write_init_maxpool(self, python_file, i, layer, in_channels): - not_supported_options = ['stride_x', 'stride_y', 'padding', 'maxpool_depth', 'out_channels', 'antialiasing'] - mandatory_options = ['size', 'stride'] - - _check_not_supported_options(layer, not_supported_options) - _check_mandatory_options(layer, mandatory_options) - - kernel_size = int(layer['size']) - stride = int(layer['stride']) - - padding = 0 - if stride == 1: - padding = kernel_size // 2 - python_file.write(' self._max_pool{} = nn.MaxPool2d({}, stride={}, padding={})' - .format(i, kernel_size, stride, padding)) - - return in_channels, stride - - def _write_init_route(self, layer, in_channels, cumulated_strides): - not_supported_options = ['stream', 'wait_stream'] - mandatory_options = ['layers'] - - _check_not_supported_options(layer, not_supported_options) - _check_mandatory_options(layer, mandatory_options) - - if 'layers' not in layer: - raise ValueError('A "route" block must contain the "layers" attribute (line={}).'.format(layer['line_no'])) - layers = [int(l.strip()) for l in layer['layers'].split(',')] - out_channels = 0 - for l in layers: - out_channels += in_channels[l] - - if 'groups' in layer and 'group_id' in layer: - out_channels //= int(layer['groups']) - - return out_channels, cumulated_strides[layers[0]] - - def _write_init_shortcut(self, layer, in_channels, cumulated_strides): - not_supported_options = ['weights_type', 'weights_normalization'] - mandatory_options = ['from', 'activation'] - - _check_not_supported_options(layer, not_supported_options) - _check_mandatory_options(layer, mandatory_options) - - if layer['activation'] != 'linear': - raise ValueError('Not supported activation (line={}).'.format(layer['line_no'])) - - from_ = int(layer['from']) - if in_channels[-1] != in_channels[from_]: - raise ValueError('The channel counts must be equal (line={}).'.format(layer['line_no'])) - - return in_channels[-1], cumulated_strides[-1] - - def _write_init_yolo(self, python_file, i, layer, cumulated_stride): - not_supported_options = ['show_details', 'counters_per_class', 'label_smooth_eps', - 'objectness_smooth', 'new_coords', 'focal_loss', - 'track_history_size', 'sim_thresh', 'dets_for_track', 'dets_for_show', - 'track_ciou_norm', 'embedding_layer'] - mandatory_options = ['mask', 'anchors', 'classes'] - - _check_not_supported_options(layer, not_supported_options) - _check_mandatory_options(layer, mandatory_options) - - masks = [int(index.strip()) for index in layer['mask'].split(',')] - anchors = [int(size.strip()) for size in layer['anchors'].split(',')] - anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)] - anchors = [anchors[i] for i in masks] - class_count = int(layer['classes']) - scale_x_y = 1 if 'scale_x_y' not in layer else layer['scale_x_y'] - - python_file.write(' self._anchors.append(np.array({}))\n'.format(anchors)) - python_file.write(' self._output_strides.append({})\n'.format(cumulated_stride)) - python_file.write(' self._yolo{} = YoloLayer({}, {}, {}, {}, {})' - .format(i, 'IMAGE_SIZE', cumulated_stride, 'self._anchors[-1].tolist()', class_count, - scale_x_y)) - return 0 - - def _write_getters(self, python_file): - python_file.write(' def get_image_size(self):\n') - python_file.write(' return IMAGE_SIZE\n') - python_file.write('\n') - python_file.write(' def get_anchors(self):\n') - python_file.write(' return self._anchors\n') - python_file.write('\n') - python_file.write(' def get_output_strides(self):\n') - python_file.write(' return self._output_strides\n\n') - - def _write_forward(self, python_file, layers): - python_file.write(' def forward(self, x):\n') - output_names = [] - yolo_outputs = [] - for i, layer in enumerate(layers): - input_name = 'x' if len(output_names) == 0 else output_names[-1] - output_name = 'y{}'.format(i) - if layer['type'] == 'convolutional': - python_file.write(' {} = self._conv{}({})'.format(output_name, i, input_name)) - elif layer['type'] == 'upsample': - python_file.write(' {} = self._upsample{}({})'.format(output_name, i, input_name)) - elif layer['type'] == 'maxpool': - python_file.write(' {} = self._max_pool{}({})'.format(output_name, i, input_name)) - elif layer['type'] == 'yolo': - python_file.write(' {} = self._yolo{}({})\n'.format(output_name, i, input_name)) - yolo_outputs.append(output_name) - elif layer['type'] == 'route': - output_name = self._write_forward_route(python_file, i, layer, output_names, output_name) - elif layer['type'] == 'shortcut': - from_ = int(layer['from']) - python_file.write(' {} = {} + {}\n'.format(output_name, input_name, output_names[from_])) - else: - raise ValueError('Not supported layer (type={})'.format(layer['type'])) - output_names.append(output_name) - python_file.write('\n') - - python_file.write(' return {}\n\n'.format(str(yolo_outputs).replace('\'', ''))) - - def _write_forward_route(self, python_file, i, layer, output_names, output_name): - routes = [int(l.strip()) for l in layer['layers'].split(',')] - route_outputs = [output_names[i] for i in routes] - - if len(route_outputs) == 1 and 'groups' in layer and 'group_id' in layer: - groups = int(layer['groups']) - group_id = int(layer['group_id']) - - python_file.write(' C = {}.size()[1]\n'.format(route_outputs[0])) - python_file.write(' {} = {}[:, C // {} * {}:C // {} * ({} + 1), :, :]\n' - .format(output_name, route_outputs[0], groups, group_id, groups, group_id)) - elif len(route_outputs) == 1: - output_name = route_outputs[0] - elif 'groups' not in layer and 'group_id' not in layer: - route_outputs = str(route_outputs).replace('\'', '') - python_file.write(' {} = torch.cat({}, dim=1)\n'.format(output_name, route_outputs)) + _check_not_supported_options(layer, not_supported_options) + _check_mandatory_options(layer, mandatory_options) + + activation = layer['activation'] + out_channels = int(layer['filters']) + kernel_size = int(layer['size']) + stride = int(layer['stride']) + padding = 0 + if 'pad' in layer and int(layer['pad']) == 1: + padding = kernel_size // 2 + has_batch_norm = 'batch_normalize' in layer and int(layer['batch_normalize']) == 1 + + groups = 1 if 'groups' not in layer else int(layer['groups']) + + python_file.write(' self._conv{} = nn.Sequential(\n'.format(i)) + python_file.write(' nn.Conv2d({}, {}, {}, stride={}, padding={}, bias={}, groups={}),\n' + .format(in_channels, out_channels, kernel_size, stride, padding, not has_batch_norm, groups)) + if has_batch_norm: + python_file.write(' nn.BatchNorm2d({}),\n'.format(out_channels)) + if activation != 'linear': + python_file.write(' {}\n'.format(ACTIVATION_MODULES_BY_NAME[activation])) + python_file.write(' )') + + return out_channels, stride + + +def _write_init_upsample(python_file, i, layer, in_channels): + not_supported_options = ['scale'] + mandatory_options = ['stride'] + + _check_not_supported_options(layer, not_supported_options) + _check_mandatory_options(layer, mandatory_options) + + stride = int(layer['stride']) + python_file.write(' self._upsample{} = nn.Upsample(scale_factor={}, mode=\'nearest\')'.format(i, stride)) + return in_channels, stride + + +def _write_init_maxpool(python_file, i, layer, in_channels): + not_supported_options = ['stride_x', 'stride_y', 'padding', 'maxpool_depth', 'out_channels', 'antialiasing'] + mandatory_options = ['size', 'stride'] + + _check_not_supported_options(layer, not_supported_options) + _check_mandatory_options(layer, mandatory_options) + + kernel_size = int(layer['size']) + stride = int(layer['stride']) + + padding = 0 + if stride == 1: + padding = kernel_size // 2 + python_file.write(' self._max_pool{} = nn.MaxPool2d({}, stride={}, padding={})' + .format(i, kernel_size, stride, padding)) + + return in_channels, stride + + +def _write_init_route(layer, in_channels, cumulated_strides): + not_supported_options = ['stream', 'wait_stream'] + mandatory_options = ['layers'] + + _check_not_supported_options(layer, not_supported_options) + _check_mandatory_options(layer, mandatory_options) + + if 'layers' not in layer: + raise ValueError('A "route" block must contain the "layers" attribute (line={}).'.format(layer['line_no'])) + layers = [int(l.strip()) for l in layer['layers'].split(',')] + out_channels = 0 + for l in layers: + out_channels += in_channels[l] + + if 'groups' in layer and 'group_id' in layer: + out_channels //= int(layer['groups']) + + return out_channels, cumulated_strides[layers[0]] + + +def _write_init_shortcut(layer, in_channels, cumulated_strides): + not_supported_options = ['weights_type', 'weights_normalization'] + mandatory_options = ['from', 'activation'] + + _check_not_supported_options(layer, not_supported_options) + _check_mandatory_options(layer, mandatory_options) + + if layer['activation'] != 'linear': + raise ValueError('Not supported activation (line={}).'.format(layer['line_no'])) + + from_ = int(layer['from']) + if in_channels[-1] != in_channels[from_]: + raise ValueError('The channel counts must be equal (line={}).'.format(layer['line_no'])) + + return in_channels[-1], cumulated_strides[-1] + + +def _write_init_yolo(python_file, i, layer, cumulated_stride): + not_supported_options = ['show_details', 'counters_per_class', 'label_smooth_eps', + 'objectness_smooth', 'new_coords', 'focal_loss', + 'track_history_size', 'sim_thresh', 'dets_for_track', 'dets_for_show', + 'track_ciou_norm', 'embedding_layer'] + mandatory_options = ['mask', 'anchors', 'classes'] + + _check_not_supported_options(layer, not_supported_options) + _check_mandatory_options(layer, mandatory_options) + + masks = [int(index.strip()) for index in layer['mask'].split(',')] + anchors = [int(size.strip()) for size in layer['anchors'].split(',')] + anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)] + anchors = [anchors[i] for i in masks] + class_count = int(layer['classes']) + scale_x_y = 1 if 'scale_x_y' not in layer else layer['scale_x_y'] + + python_file.write(' self._anchors.append(np.array({}))\n'.format(anchors)) + python_file.write(' self._output_strides.append({})\n'.format(cumulated_stride)) + python_file.write(' self._yolo{} = YoloV4Layer({}, {}, {}, {}, scale_x_y={}, class_probs=class_probs)' + .format(i, 'IMAGE_SIZE', cumulated_stride, 'self._anchors[-1].tolist()', class_count, + scale_x_y)) + return 0 + + +def _write_getters(python_file): + python_file.write(' def get_image_size(self):\n') + python_file.write(' return IMAGE_SIZE\n') + python_file.write('\n') + python_file.write(' def get_anchors(self):\n') + python_file.write(' return self._anchors\n') + python_file.write('\n') + python_file.write(' def get_output_strides(self):\n') + python_file.write(' return self._output_strides\n\n') + + +def _write_forward(python_file, layers): + python_file.write(' def forward(self, x):\n') + output_names = [] + yolo_outputs = [] + for i, layer in enumerate(layers): + input_name = 'x' if len(output_names) == 0 else output_names[-1] + output_name = 'y{}'.format(i) + if layer['type'] == 'convolutional': + python_file.write(' {} = self._conv{}({})'.format(output_name, i, input_name)) + elif layer['type'] == 'upsample': + python_file.write(' {} = self._upsample{}({})'.format(output_name, i, input_name)) + elif layer['type'] == 'maxpool': + python_file.write(' {} = self._max_pool{}({})'.format(output_name, i, input_name)) + elif layer['type'] == 'yolo': + python_file.write(' {} = self._yolo{}({})\n'.format(output_name, i, input_name)) + yolo_outputs.append(output_name) + elif layer['type'] == 'route': + output_name = _write_forward_route(python_file, i, layer, output_names, output_name) + elif layer['type'] == 'shortcut': + from_ = int(layer['from']) + python_file.write(' {} = {} + {}\n'.format(output_name, input_name, output_names[from_])) else: - raise ValueError('Invalid route (line={})'.format(layer['line_no'])) - return output_name - - def _write_load_weights(self, python_file, layers): - python_file.write(' def load_weights(self, weights_file_path):\n') - python_file.write(' with open(weights_file_path, \'r\') as weights_file:\n') - python_file.write(' header1 = np.fromfile(weights_file, dtype=np.int32, count=3)\n') - python_file.write(' header2 = np.fromfile(weights_file, dtype=np.int64, count=1)\n') - python_file.write(' weights = np.fromfile(weights_file, dtype=np.float32)\n') - python_file.write('\n') - python_file.write(' print(\'load_weights - Major version:\', header1[0])\n') - python_file.write(' print(\'load_weights - Minor version:\', header1[1])\n') - python_file.write(' print(\'load_weights - Subversion:\', header1[2])\n') - python_file.write(' print(\'load_weights - # images:\', header2[0])\n') + raise ValueError('Not supported layer (type={})'.format(layer['type'])) + output_names.append(output_name) python_file.write('\n') - python_file.write(' offset = 0\n') - - for i, layer in enumerate(layers): - if layer['type'] == 'convolutional': - if 'batch_normalize' in layer and int(layer['batch_normalize']) == 1: - python_file.write(' offset = self._load_batch_norm_conv_weights(' - 'self._conv{}, weights, offset)\n'.format(i)) - else: - python_file.write(' offset = self._load_conv_weights(' - 'self._conv{}, weights, offset)\n'.format(i)) - - elif layer['type'] == 'upsample' or layer['type'] == 'maxpool' or layer['type'] == 'yolo' or \ - layer['type'] == 'route' or layer['type'] == 'shortcut': - continue + + python_file.write(' return {}\n\n'.format(str(yolo_outputs).replace('\'', ''))) + + +def _write_forward_route(python_file, i, layer, output_names, output_name): + routes = [int(l.strip()) for l in layer['layers'].split(',')] + route_outputs = [output_names[i] for i in routes] + + if len(route_outputs) == 1 and 'groups' in layer and 'group_id' in layer: + groups = int(layer['groups']) + group_id = int(layer['group_id']) + + python_file.write(' C = {}.size()[1]\n'.format(route_outputs[0])) + python_file.write(' {} = {}[:, C // {} * {}:C // {} * ({} + 1), :, :]\n' + .format(output_name, route_outputs[0], groups, group_id, groups, group_id)) + elif len(route_outputs) == 1: + output_name = route_outputs[0] + elif 'groups' not in layer and 'group_id' not in layer: + route_outputs = str(route_outputs).replace('\'', '') + python_file.write(' {} = torch.cat({}, dim=1)\n'.format(output_name, route_outputs)) + else: + raise ValueError('Invalid route (line={})'.format(layer['line_no'])) + return output_name + + +def _write_load_weights(python_file, layers): + python_file.write(' def load_weights(self, weights_file_path):\n') + python_file.write(' with open(weights_file_path, \'r\') as weights_file:\n') + python_file.write(' header1 = np.fromfile(weights_file, dtype=np.int32, count=3)\n') + python_file.write(' header2 = np.fromfile(weights_file, dtype=np.int64, count=1)\n') + python_file.write(' weights = np.fromfile(weights_file, dtype=np.float32)\n') + python_file.write('\n') + python_file.write(' print(\'load_weights - Major version:\', header1[0])\n') + python_file.write(' print(\'load_weights - Minor version:\', header1[1])\n') + python_file.write(' print(\'load_weights - Subversion:\', header1[2])\n') + python_file.write(' print(\'load_weights - # images:\', header2[0])\n') + python_file.write('\n') + python_file.write(' offset = 0\n') + + for i, layer in enumerate(layers): + if layer['type'] == 'convolutional': + if 'batch_normalize' in layer and int(layer['batch_normalize']) == 1: + python_file.write(' offset = self._load_batch_norm_conv_weights(' + 'self._conv{}, weights, offset)\n'.format(i)) else: - raise ValueError('Not supported layer (type={})'.format(layer['type'])) + python_file.write(' offset = self._load_conv_weights(' + 'self._conv{}, weights, offset)\n'.format(i)) - python_file.write('\n') - python_file.write(' if offset != weights.size:\n') - python_file.write(' raise ValueError(\'Invalid weights file.\')\n\n') - - def _write_load_batch_norm_conv_weights(self, python_file): - python_file.write(' def _load_batch_norm_conv_weights(self, conv, weights, offset):\n') - python_file.write(' n = conv[1].bias.numel()\n') - python_file.write(' bias_data = torch.from_numpy(weights[offset:offset + n])\n') - python_file.write(' conv[1].bias.data.copy_(bias_data.view_as(conv[1].bias.data))\n') - python_file.write(' offset += n\n') - python_file.write('\n') - python_file.write(' weight_data = torch.from_numpy(weights[offset:offset + n])\n') - python_file.write(' conv[1].weight.data.copy_(weight_data.view_as(conv[1].weight.data))\n') - python_file.write(' offset += n\n') - python_file.write('\n') - python_file.write(' running_mean_data = torch.from_numpy(weights[offset:offset + n])\n') - python_file.write(' conv[1].running_mean.data.copy_(running_mean_data.view_as(' - 'conv[1].running_mean.data))\n') - python_file.write(' offset += n\n') - python_file.write('\n') - python_file.write(' running_var_data = torch.from_numpy(weights[offset:offset + n])\n') - python_file.write(' conv[1].running_var.data.copy_(running_var_data.view_as(' - 'conv[1].running_var.data))\n') - python_file.write(' offset += n\n') - python_file.write('\n') - python_file.write(' n = conv[0].weight.numel()\n') - python_file.write(' weight_data = torch.from_numpy(weights[offset:offset + n])\n') - python_file.write(' conv[0].weight.data.copy_(weight_data.view_as(conv[0].weight.data))\n') - python_file.write(' offset += n\n') - python_file.write('\n') - python_file.write(' return offset\n\n') - - def _write_load_conv_weights(self, python_file): - python_file.write(' def _load_conv_weights(self, conv, weights, offset):\n') - python_file.write(' n = conv[0].bias.numel()\n') - python_file.write(' bias_data = torch.from_numpy(weights[offset:offset + n])\n') - python_file.write(' conv[0].bias.data.copy_(bias_data.view_as(conv[0].bias.data))\n') - python_file.write(' offset += n\n') - python_file.write('\n') - python_file.write(' n = conv[0].weight.numel()\n') - python_file.write(' weight_data = torch.from_numpy(weights[offset:offset + n])\n') - python_file.write(' conv[0].weight.data.copy_(weight_data.view_as(conv[0].weight.data))\n') - python_file.write(' offset += n\n') - python_file.write('\n') - python_file.write(' return offset\n') + elif layer['type'] == 'upsample' or layer['type'] == 'maxpool' or layer['type'] == 'yolo' or \ + layer['type'] == 'route' or layer['type'] == 'shortcut': + continue + else: + raise ValueError('Not supported layer (type={})'.format(layer['type'])) + + python_file.write('\n') + python_file.write(' if offset != weights.size:\n') + python_file.write(' raise ValueError(\'Invalid weights file.\')\n\n') + + +def _write_load_batch_norm_conv_weights(python_file): + python_file.write(' def _load_batch_norm_conv_weights(self, conv, weights, offset):\n') + python_file.write(' n = conv[1].bias.numel()\n') + python_file.write(' bias_data = torch.from_numpy(weights[offset:offset + n])\n') + python_file.write(' conv[1].bias.data.copy_(bias_data.view_as(conv[1].bias.data))\n') + python_file.write(' offset += n\n') + python_file.write('\n') + python_file.write(' weight_data = torch.from_numpy(weights[offset:offset + n])\n') + python_file.write(' conv[1].weight.data.copy_(weight_data.view_as(conv[1].weight.data))\n') + python_file.write(' offset += n\n') + python_file.write('\n') + python_file.write(' running_mean_data = torch.from_numpy(weights[offset:offset + n])\n') + python_file.write(' conv[1].running_mean.data.copy_(running_mean_data.view_as(' + 'conv[1].running_mean.data))\n') + python_file.write(' offset += n\n') + python_file.write('\n') + python_file.write(' running_var_data = torch.from_numpy(weights[offset:offset + n])\n') + python_file.write(' conv[1].running_var.data.copy_(running_var_data.view_as(' + 'conv[1].running_var.data))\n') + python_file.write(' offset += n\n') + python_file.write('\n') + python_file.write(' n = conv[0].weight.numel()\n') + python_file.write(' weight_data = torch.from_numpy(weights[offset:offset + n])\n') + python_file.write(' conv[0].weight.data.copy_(weight_data.view_as(conv[0].weight.data))\n') + python_file.write(' offset += n\n') + python_file.write('\n') + python_file.write(' return offset\n\n') + + +def _write_load_conv_weights(python_file): + python_file.write(' def _load_conv_weights(self, conv, weights, offset):\n') + python_file.write(' n = conv[0].bias.numel()\n') + python_file.write(' bias_data = torch.from_numpy(weights[offset:offset + n])\n') + python_file.write(' conv[0].bias.data.copy_(bias_data.view_as(conv[0].bias.data))\n') + python_file.write(' offset += n\n') + python_file.write('\n') + python_file.write(' n = conv[0].weight.numel()\n') + python_file.write(' weight_data = torch.from_numpy(weights[offset:offset + n])\n') + python_file.write(' conv[0].weight.data.copy_(weight_data.view_as(conv[0].weight.data))\n') + python_file.write(' offset += n\n') + python_file.write('\n') + python_file.write(' return offset\n') def _check_not_supported_options(layer, not_supported_options): @@ -401,9 +415,7 @@ def main(): parser.add_argument('--class_name', type=str, help='Choose the class name', required=True) args = parser.parse_args() - - converter = DarknetCfgToPytorchModuleConverter() - converter.convert(args.cfg_path, args.python_output_path, args.class_name) + convert(args.cfg_path, args.python_output_path, args.class_name) if __name__ == '__main__': diff --git a/tools/dnn_training/object_detection/modules/convert_yolov7_yaml_to_pytorch_module.py b/tools/dnn_training/object_detection/modules/convert_yolov7_yaml_to_pytorch_module.py new file mode 100644 index 00000000..2ff0b424 --- /dev/null +++ b/tools/dnn_training/object_detection/modules/convert_yolov7_yaml_to_pytorch_module.py @@ -0,0 +1,354 @@ +import argparse +import os + +import yaml + + +class Layer: + def __init__(self, init_code, forward_code): + self.init_code = init_code + self.forward_code = forward_code + + +def convert(yaml_path, python_output_path, class_name): + with open(yaml_path, 'r') as yaml_file: + yaml_data = yaml.safe_load(yaml_file) + + all_anchor_counts = [len(a) // 2 for a in yaml_data['anchors']] + + layers, output_strides = _convert_yaml_to_layers(yaml_data['backbone'] + yaml_data['head'], yaml_data['nc'], + all_anchor_counts) + + with open(python_output_path, 'w') as python_file: + _write_header(python_file, class_name, yaml_path) + _write_init(python_file, layers, class_name, yaml_data['anchors'], output_strides) + _write_getters(python_file) + _write_forward(python_file, layers) + _write_load_weights(python_file) + + +def _convert_yaml_to_layers(yaml_list, class_count, all_anchor_counts): + if len(yaml_list) == 0: + raise ValueError('At least one layer is required.') + + all_outputs = [] + all_channels = [] + all_strides = [] + layers = [] + + i = 0 + for i, (input_index, c, layer_type, arguments) in enumerate(yaml_list): + if c != 1: + raise ValueError('C must be 1.') + layers.append( + _convert_to_layer(class_count, all_anchor_counts, all_outputs, all_channels, all_strides, input_index, + layer_type, arguments, i)) + + if layer_type == 'Detect': + break + + output_strides = [all_strides[i] for i in yaml_list[i][0]] + return layers, output_strides + + +def _convert_to_layer(class_count, all_anchor_counts, all_outputs, all_channels, all_strides, input_index, layer_type, + arguments, i): + input = _input_index_to_input(input_index, all_outputs) + input_channels = _input_index_to_channels(input_index, all_channels) + input_stride = _input_index_to_stride(input_index, all_strides) + + if layer_type == 'Conv': + layer, output, output_channels, stride = _convert_conv_to_layer(input, input_channels, arguments, i) + elif layer_type == 'MP': + layer, output, output_channels, stride = _convert_mp_to_layer(input, input_channels, arguments, i) + elif layer_type == 'SP': + layer, output, output_channels, stride = _convert_sp_to_layer(input, input_channels, arguments, i) + elif layer_type == 'SPPCSPC': + layer, output, output_channels, stride = _convert_sppcspc_to_layer(input, input_channels, arguments, i) + elif layer_type == 'RepConv': + layer, output, output_channels, stride = _convert_rep_conv_to_layer(input, input_channels, arguments, i) + elif layer_type == 'Concat': + stride = 1 + layer, output, output_channels = _convert_concat_to_layer(input, input_channels, arguments, i) + elif layer_type == 'nn.Upsample': + layer, output, output_channels, stride = _convert_upsample_to_layer(input, input_channels, arguments, i) + elif layer_type == 'Detect': + return _convert_detect_to_layer(class_count, all_anchor_counts, input_index, all_outputs, all_channels, + all_strides) + else: + raise ValueError('Invalid layer type (' + layer_type + ')') + + all_outputs.append(output) + all_channels.append(output_channels) + all_strides.append(int(input_stride * stride)) + + return layer + + +def _input_index_to_input(input_index, all_outputs): + if len(all_outputs) == 0: + return 'x' + elif isinstance(input_index, list): + return '[' + ', '.join((all_outputs[i] for i in input_index)) + ']' + else: + return all_outputs[input_index] + + +def _input_index_to_channels(input_index, all_channels): + if len(all_channels) == 0: + return 3 + elif isinstance(input_index, list): + return sum((all_channels[i] for i in input_index)) + else: + return all_channels[input_index] + + +def _input_index_to_stride(input_index, all_strides): + if len(all_strides) == 0: + return 1 + elif isinstance(input_index, list): + return all_strides[input_index[0]] + else: + return all_strides[input_index] + + +def _convert_conv_to_layer(input, input_channels, arguments, i): + if len(arguments) > 6: + raise ValueError('Too many arguments') + + output_channels = arguments[0] + kernel_size = arguments[1] if len(arguments) > 1 else 1 + stride = arguments[2] if len(arguments) > 2 else 1 + padding = arguments[3] if len(arguments) > 3 and arguments[3] != 'None' else kernel_size // 2 + groups = arguments[4] if len(arguments) > 4 else 1 + activation = arguments[5] if len(arguments) > 5 else 'nn.SiLU()' + + layer_name = f'self._conv{i}' + output = 'y' + str(i) + + init_code = (f' {layer_name} = nn.Sequential(\n' + f' nn.Conv2d(in_channels={input_channels}, out_channels={output_channels}, kernel_size={kernel_size}, stride={stride}, padding={padding}, groups={groups}, bias=False),\n' + f' nn.BatchNorm2d({output_channels}, eps=0.001),\n' + f' {activation},\n' + f' )' + ) + forward_code = f' {output} = {layer_name}({input})' + + return Layer(init_code, forward_code), output, output_channels, stride + + +def _convert_mp_to_layer(input, input_channels, arguments, i): + if len(arguments) > 1: + raise ValueError('Too many arguments') + + kernel_size = arguments[0] if len(arguments) > 0 else 2 + stride = kernel_size + + layer_name = f'self._max_pool{i}' + output = 'y' + str(i) + + init_code = f' {layer_name} = nn.MaxPool2d(kernel_size={kernel_size}, stride={stride})' + forward_code = f' {output} = {layer_name}({input})' + + return Layer(init_code, forward_code), output, input_channels, stride + + +def _convert_sp_to_layer(input, input_channels, arguments, i): + if len(arguments) > 2: + raise ValueError('Too many arguments') + + kernel_size = arguments[0] if len(arguments) > 0 else 3 + stride = arguments[1] if len(arguments) > 1 else 1 + + layer_name = f'self._max_pool{i}' + output = 'y' + str(i) + + init_code = f' {layer_name} = nn.MaxPool2d(kernel_size={kernel_size}, stride={stride}, padding={kernel_size // 2})' + forward_code = f' {output} = {layer_name}({input})' + + return Layer(init_code, forward_code), output, input_channels, stride + + +def _convert_sppcspc_to_layer(input, input_channels, arguments, i): + if len(arguments) > 1: + raise ValueError('Too many arguments') + + output_channels = arguments[0] + stride = 1 + + layer_name = f'self._sppcspc{i}' + output = 'y' + str(i) + + init_code = f' {layer_name} = YoloV7SPPCSPC({input_channels}, {output_channels})' + forward_code = f' {output} = {layer_name}({input})' + + return Layer(init_code, forward_code), output, output_channels, stride + + +def _convert_rep_conv_to_layer(input, input_channels, arguments, i): + if len(arguments) > 6: + raise ValueError('Too many arguments') + + output_channels = arguments[0] + kernel_size = arguments[1] if len(arguments) > 1 else 1 + stride = arguments[2] if len(arguments) > 2 else 1 + padding = arguments[3] if len(arguments) > 3 and arguments[3] != 'None' else kernel_size // 2 + groups = arguments[4] if len(arguments) > 4 else 1 + activation = arguments[5] if len(arguments) > 5 else 'nn.SiLU()' + + layer_name = f'self._rep_conv{i}' + output = 'y' + str(i) + + init_code = f' {layer_name} = RepConv(in_channels={input_channels}, out_channels={output_channels}, kernel_size={kernel_size}, stride={stride}, padding={padding}, groups={groups}, activation={activation})' + forward_code = f' {output} = {layer_name}({input})' + + return Layer(init_code, forward_code), output, output_channels, stride + + +def _convert_concat_to_layer(input, input_channels, arguments, i): + if len(arguments) > 1: + raise ValueError('Too many arguments') + + dim = arguments[0] if len(arguments) > 0 else 1 + + output = 'y' + str(i) + + init_code = '' + forward_code = f' {output} = torch.cat({input}, dim={dim})\n' + + return Layer(init_code, forward_code), output, input_channels + + +def _convert_upsample_to_layer(input, input_channels, arguments, i): + if len(arguments) > 3: + raise ValueError('Too many arguments') + + size = arguments[0] if len(arguments) > 0 else None + scale_factor = arguments[1] if len(arguments) > 1 else None + mode = arguments[2] if len(arguments) > 2 else 'nearest' + + if size != 'None': + raise ValueError(f'Size must be None ({size})') + + layer_name = f'self._upsample{i}' + output = 'y' + str(i) + + init_code = f' {layer_name} = nn.Upsample(scale_factor={scale_factor}, mode=\'{mode}\')' + forward_code = f' {output} = {layer_name}({input})' + + return Layer(init_code, forward_code), output, input_channels, 1.0 / scale_factor + + +def _convert_detect_to_layer(class_count, all_anchor_counts, input_indexes, all_outputs, all_channels, all_strides): + init_code = '\n' + forward_code = '\n' + + layer_names = [] + output_names = [] + + for i, input_index in enumerate(input_indexes): + layer_names.append(f'self._yolo{i}') + output_names.append(f'd{i}') + + init_code += (f' {layer_names[i]} = nn.Sequential(\n' + f' nn.Conv2d(in_channels={all_channels[input_index]}, out_channels={all_anchor_counts[i] * (class_count + 5)}, kernel_size=1),\n' + f' YoloV7Layer(IMAGE_SIZE, {all_strides[input_index]}, self._anchors[{i}], {class_count}, class_probs=class_probs)\n' + f' )\n' + ) + forward_code += f' {output_names[i]} = {layer_names[i]}({all_outputs[input_index]})\n' + + forward_code += f' return [{", ".join(output_names)}]' + + return Layer(init_code, forward_code) + + +def _write_header(python_file, class_name, yaml_path): + python_file.write('from collections import OrderedDict\n') + python_file.write('\n') + python_file.write('import numpy as np\n') + python_file.write('\n') + python_file.write('import torch\n') + python_file.write('import torch.nn as nn\n') + python_file.write('\n') + python_file.write('from object_detection.modules.yolo_layer import YoloV7Layer\n') + python_file.write('from object_detection.modules.yolo_v7_modules import YoloV7SPPCSPC, RepConv\n') + python_file.write('\n') + python_file.write('\n') + python_file.write(f'IMAGE_SIZE = (640, 640)\n') + python_file.write('IN_CHANNELS = 3\n') + python_file.write('\n') + python_file.write('\n') + python_file.write(f'# Generated from: {os.path.basename(yaml_path)}:\n') + python_file.write(f'class {class_name}(nn.Module):\n') + + +def _write_init(python_file, layers, class_name, anchors, output_strides): + python_file.write(' def __init__(self, class_probs=False):\n') + python_file.write(f' super({class_name}, self).__init__()\n\n') + + python_file.write(' self._anchors = []\n') + python_file.write(f' self._output_strides = {output_strides}\n') + + for a in anchors: + python_file.write( + f' self._anchors.append(np.array([({a[0]}, {a[1]}), ({a[2]}, {a[3]}), ({a[4]}, {a[5]})]))\n') + + python_file.write('\n') + for layer in layers: + python_file.write(layer.init_code) + python_file.write('\n') + + python_file.write('\n') + + +def _write_getters(python_file): + python_file.write(' def get_image_size(self):\n') + python_file.write(' return IMAGE_SIZE\n') + python_file.write('\n') + python_file.write(' def get_anchors(self):\n') + python_file.write(' return self._anchors\n') + python_file.write('\n') + python_file.write(' def get_output_strides(self):\n') + python_file.write(' return self._output_strides\n\n') + + +def _write_forward(python_file, layers): + python_file.write(' def forward(self, x):\n') + + for layer in layers: + python_file.write(layer.forward_code) + python_file.write('\n') + + python_file.write('\n') + + +def _write_load_weights(python_file): + python_file.write(' def load_weights(self, weights_path):\n') + python_file.write(' loaded_state_dict = self._filter_static_dict(torch.load(weights_path), \'anchor\')\n') + python_file.write(' current_state_dict = self._filter_static_dict(self.state_dict(), \'offset\')\n') + python_file.write('\n') + python_file.write( + ' for i, (kl, kc) in enumerate(zip(loaded_state_dict.keys(), current_state_dict.keys())):\n') + python_file.write(' if current_state_dict[kc].size() != loaded_state_dict[kl].size():\n') + python_file.write(' raise ValueError(\'Mismatching size.\')\n') + python_file.write(' current_state_dict[kc] = loaded_state_dict[kl]\n') + python_file.write('\n') + python_file.write(' self.load_state_dict(current_state_dict, strict=False)\n') + python_file.write('\n') + + python_file.write(' def _filter_static_dict(self, state_dict, x):\n') + python_file.write(' return OrderedDict([(k, v) for k, v in state_dict.items() if x not in k])\n') + + +def main(): + parser = argparse.ArgumentParser(description='Convert the specified darknet configuration file to PyTorch') + parser.add_argument('--yaml_path', type=str, help='Choose the configuration file', required=True) + parser.add_argument('--python_output_path', type=str, help='Choose the Python output file', required=True) + parser.add_argument('--class_name', type=str, help='Choose the class name', required=True) + + args = parser.parse_args() + convert(args.yaml_path, args.python_output_path, args.class_name) + + +if __name__ == '__main__': + main() diff --git a/tools/dnn_training/object_detection/modules/descriptor_yolo_layer.py b/tools/dnn_training/object_detection/modules/descriptor_yolo_layer.py index d3104f14..e09e32d0 100644 --- a/tools/dnn_training/object_detection/modules/descriptor_yolo_layer.py +++ b/tools/dnn_training/object_detection/modules/descriptor_yolo_layer.py @@ -10,9 +10,9 @@ CLASSES_INDEX = 5 -class DescriptorYoloLayer(nn.Module): +class DescriptorYoloV4Layer(nn.Module): def __init__(self, image_size, stride, anchors, class_count, descriptor_size, scale_x_y): - super(DescriptorYoloLayer, self).__init__() + super(DescriptorYoloV4Layer, self).__init__() self._grid_size = (image_size[1] // stride, image_size[0] // stride) self._stride = stride @@ -98,3 +98,66 @@ def _forward_descriptor(self, x): class_scores = class_scores.permute(0, 3, 2, 1).reshape(N, H, W, N_ANCHORS, self._class_count) return torch.cat([bboxes_and_confidences, class_scores, descriptors], dim=4) + + +class DescriptorYoloV7Layer(nn.Module): + def __init__(self, image_size, stride, anchors, embedding_size): + super(DescriptorYoloV7Layer, self).__init__() + self._grid_size = (image_size[1] // stride, image_size[0] // stride) + self._stride = stride + + self._anchors = [(a[0] / stride, a[1] / stride) for a in anchors] + self._embedding_size = embedding_size + + x = torch.arange(self._grid_size[1]) + y = torch.arange(self._grid_size[0]) + y_offset, x_offset = torch.meshgrid(y, x, indexing='ij') + self.register_buffer('_x_offset', x_offset.float().clone()) + self.register_buffer('_y_offset', y_offset.float().clone()) + + # Fix scripting errors + self._x_index = X_INDEX + self._y_index = Y_INDEX + self._w_index = W_INDEX + self._h_index = H_INDEX + self._confidence_index = CONFIDENCE_INDEX + self._classes_index = CLASSES_INDEX + + def forward(self, t): + N = t.size()[0] + N_ANCHORS = len(self._anchors) + H = self._grid_size[1] + W = self._grid_size[0] + N_PREDICTION = 5 + self._embedding_size + + # Transform x + t = t.view(N, N_ANCHORS, N_PREDICTION, H, W).permute(0, 1, 3, 4, 2).contiguous() + x = torch.sigmoid(t[:, :, :, :, self._x_index]) * 2.0 - 0.5 + x += self._x_offset + x *= self._stride + x = x.unsqueeze(4).permute(0, 2, 3, 1, 4) + + # Transform y + y = torch.sigmoid(t[:, :, :, :, self._y_index]) * 2.0 - 0.5 + y += self._y_offset + y *= self._stride + y = y.unsqueeze(4).permute(0, 2, 3, 1, 4) + + t = t.permute(0, 2, 3, 1, 4) + + # Transform w and h + w = [] + h = [] + for i in range(N_ANCHORS): + w.append(4 * torch.sigmoid(t[:, :, :, i, self._w_index:self._w_index + 1]) ** 2 * self._anchors[i][0]) + h.append(4 * torch.sigmoid(t[:, :, :, i, self._h_index:self._h_index + 1]) ** 2 * self._anchors[i][1]) + + w = torch.cat(w, dim=3).unsqueeze(4) * self._stride + h = torch.cat(h, dim=3).unsqueeze(4) * self._stride + + # Transform confidence + confidence = torch.sigmoid(t[:, :, :, :, self._confidence_index:self._confidence_index + 1]) + + embedding = F.normalize(t[:, :, :, :, self._classes_index:], dim=4, p=2.0) + + return torch.cat([x, y, w, h, confidence], dim=4), embedding diff --git a/tools/dnn_training/object_detection/modules/evaluate_converted_yolo.py b/tools/dnn_training/object_detection/modules/evaluate_converted_yolo.py new file mode 100644 index 00000000..dee01048 --- /dev/null +++ b/tools/dnn_training/object_detection/modules/evaluate_converted_yolo.py @@ -0,0 +1,48 @@ +import argparse +import os + +import torch + +from common.modules import load_checkpoint + +from object_detection.datasets.yolo_collate import yolo_collate +from object_detection.datasets.coco_detection_transforms import CocoDetectionValidationTransforms +from object_detection.datasets.object_detection_coco import ObjectDetectionCoco +from object_detection.metrics import CocoObjectEvaluation +from object_detection.modules.test_converted_yolo import create_model + + +def main(): + parser = argparse.ArgumentParser(description='Test the specified converted model') + parser.add_argument('--use_gpu', action='store_true', help='Use the GPU') + parser.add_argument('--model_type', choices=['yolo_v4', 'yolo_v4_tiny', 'yolo_v7', 'yolo_v7_tiny'], + help='Choose the mode', required=True) + parser.add_argument('--model_checkpoint', type=str, help='Choose the model checkpoint file', required=True) + parser.add_argument('--coco_root', type=str, help='Choose the image file', required=True) + parser.add_argument('--batch_size', type=int, help='Choose the batch size', default=4) + parser.add_argument('--output_path', type=str, help='Choose the output path', required=True) + + args = parser.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') + model = create_model(args.model_type, class_probs=False) + load_checkpoint(model, args.model_checkpoint) + model.eval() + + dataset = ObjectDetectionCoco( + os.path.join(args.coco_root, 'val2017'), + os.path.join(args.coco_root, 'instances_val2017.json'), + transforms=CocoDetectionValidationTransforms(model.get_image_size(), one_hot_class=True)) + data_loader = torch.utils.data.DataLoader(dataset, + batch_size=args.batch_size, + collate_fn=yolo_collate, + shuffle=False, + num_workers=1) + + os.makedirs(args.output_path, exist_ok=True) + evaluation = CocoObjectEvaluation(model.to(device), device, data_loader, args.output_path) + evaluation.evaluate() + + +if __name__ == '__main__': + main() diff --git a/tools/dnn_training/object_detection/modules/test_converted_yolo.py b/tools/dnn_training/object_detection/modules/test_converted_yolo.py index 4765f845..adad2cc0 100644 --- a/tools/dnn_training/object_detection/modules/test_converted_yolo.py +++ b/tools/dnn_training/object_detection/modules/test_converted_yolo.py @@ -5,12 +5,17 @@ import torch +from object_detection.descriptor_yolo_v7 import DescriptorYoloV7 from object_detection.modules.yolo_v4 import YoloV4 from object_detection.modules.yolo_v4_tiny import YoloV4Tiny +from object_detection.modules.yolo_v7 import YoloV7 +from object_detection.modules.yolo_v7_tiny import YoloV7Tiny from object_detection.datasets.coco_detection_transforms import CocoDetectionValidationTransforms from object_detection.filter_yolo_predictions import group_predictions, filter_yolo_predictions from object_detection.modules.yolo_layer import X_INDEX, Y_INDEX, W_INDEX, H_INDEX, CLASSES_INDEX +from train_descriptor_yolo import _get_class_count + COCO_CLASSES = ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', @@ -21,54 +26,123 @@ 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] +OBJECTS365_CLASS_NAMES = ['person', 'sneakers', 'chair', 'other shoes', 'hat', 'car', 'lamp', 'glasses', 'bottle', + 'desk', 'cup', 'street lights', 'cabinet/shelf', 'handbag', 'bracelet', 'plate', + 'picture/frame', 'helmet', 'book', 'gloves', 'storage box', 'boat', 'leather shoes', + 'flower', 'bench', 'pottedplant', 'bowl', 'flag', 'pillow', 'boots', 'vase', + 'microphone', 'necklace', 'ring', 'suv', 'wine glass', 'belt', 'tvmonitor', 'backpack', + 'umbrella', 'traffic light', 'speaker', 'watch', 'tie', 'trash bin can', 'slippers', + 'bicycle', 'stool', 'barrel/bucket', 'van', 'couch', 'sandals', 'basket', 'drum', + 'pen/pencil', 'bus', 'bird', 'high heels', 'motorbike', 'guitar', 'carpet', + 'cell phone', 'bread', 'camera', 'canned', 'truck', 'traffic cone', 'cymbal', 'lifesaver', + 'towel', 'stuffed toy', 'candle', 'sailboat', 'laptop', 'awning', 'bed', 'faucet', 'tent', + 'horse', 'mirror', 'power outlet', 'sink', 'apple', 'air conditioner', 'knife', + 'hockey stick', 'paddle', 'pickup truck', 'fork', 'traffic sign', 'balloon', 'tripod', 'dog', + 'spoon', 'clock', 'pot', 'cow', 'cake', 'dining table', 'sheep', 'hanger', + 'blackboard/whiteboard', 'napkin', 'other fish', 'orange', 'toiletry', 'keyboard', + 'tomato', 'lantern', 'machinery vehicle', 'fan', 'green vegetables', 'banana', + 'baseball glove', 'aeroplane', 'mouse', 'train', 'pumpkin', 'soccer', 'skis', 'luggage', + 'nightstand', 'tea pot', 'telephone', 'trolley', 'head phone', 'sports car', 'stop sign', + 'dessert', 'scooter', 'stroller', 'crane', 'remote', 'refrigerator', 'oven', 'lemon', 'duck', + 'baseball bat', 'surveillance camera', 'cat', 'jug', 'broccoli', 'piano', 'pizza', + 'elephant', 'skateboard', 'surfboard', 'gun', 'skating and skiing shoes', 'gas stove', + 'donut', 'bow tie', 'carrot', 'toilet', 'kite', 'strawberry', 'other balls', 'shovel', + 'pepper', 'computer box', 'toilet paper', 'cleaning products', 'chopsticks', 'microwave', + 'pigeon', 'baseball', 'cutting/chopping board', 'coffee table', 'side table', 'scissors', + 'marker', 'pie', 'ladder', 'snowboard', 'cookies', 'radiator', 'fire hydrant', 'basketball', + 'zebra', 'grape', 'giraffe', 'potato', 'sausage', 'tricycle', 'violin', 'egg', + 'fire extinguisher', 'candy', 'fire truck', 'billiards', 'converter', 'bathtub', + 'wheelchair', 'golf club', 'suitcase', 'cucumber', 'cigar/cigarette', 'paint brush', 'pear', + 'heavy truck', 'hamburger', 'extractor', 'extension cord', 'tong', 'tennis racket', + 'folder', 'american football', 'earphone', 'mask', 'kettle', 'tennis', 'ship', 'swing', + 'coffee machine', 'slide', 'carriage', 'onion', 'green beans', 'projector', 'frisbee', + 'washing machine/drying machine', 'chicken', 'printer', 'watermelon', 'saxophone', 'tissue', + 'toothbrush', 'ice cream', 'hot-air balloon', 'cello', 'french fries', 'scale', 'trophy', + 'cabbage', 'hot dog', 'blender', 'peach', 'rice', 'wallet/purse', 'volleyball', 'deer', + 'goose', 'tape', 'tablet', 'cosmetics', 'trumpet', 'pineapple', 'golf ball', 'ambulance', + 'parking meter', 'mango', 'key', 'hurdle', 'fishing rod', 'medal', 'flute', 'brush', + 'penguin', 'megaphone', 'corn', 'lettuce', 'garlic', 'swan', 'helicopter', 'green onion', + 'sandwich', 'nuts', 'speed limit sign', 'induction cooker', 'broom', 'trombone', 'plum', + 'rickshaw', 'goldfish', 'kiwi fruit', 'router/modem', 'poker card', 'toaster', 'shrimp', + 'sushi', 'cheese', 'notepaper', 'cherry', 'pliers', 'cd', 'pasta', 'hammer', 'cue', + 'avocado', 'hamimelon', 'flask', 'mushroom', 'screwdriver', 'soap', 'recorder', 'bear', + 'eggplant', 'board eraser', 'coconut', 'tape measure/ruler', 'pig', 'showerhead', 'globe', + 'chips', 'steak', 'crosswalk sign', 'stapler', 'camel', 'formula 1', 'pomegranate', + 'dishwasher', 'crab', 'hoverboard', 'meat ball', 'rice cooker', 'tuba', 'calculator', + 'papaya', 'antelope', 'parrot', 'seal', 'butterfly', 'dumbbell', 'donkey', 'lion', 'urinal', + 'dolphin', 'electric drill', 'hair dryer', 'egg tart', 'jellyfish', 'treadmill', 'lighter', + 'grapefruit', 'game board', 'mop', 'radish', 'baozi', 'target', 'french', 'spring rolls', + 'monkey', 'rabbit', 'pencil case', 'yak', 'red cabbage', 'binoculars', 'asparagus', 'barbell', + 'scallop', 'noddles', 'comb', 'dumpling', 'oyster', 'table tennis paddle', + 'cosmetics brush/eyeliner pencil', 'chainsaw', 'eraser', 'lobster', 'durian', 'okra', + 'lipstick', 'cosmetics mirror', 'curling', 'table tennis'] + +CLASSES_BY_DATASET_TYPE = {'coco': COCO_CLASSES, 'objects365': OBJECTS365_CLASS_NAMES} + def main(): parser = argparse.ArgumentParser(description='Test the specified converted model') - parser.add_argument('--model_type', choices=['yolo_v4', 'yolo_v4_tiny'], help='Choose the mode', required=True) + parser.add_argument('--dataset_type', choices=['coco', 'objects365'], help='Choose the dataset type', required=True) + parser.add_argument('--model_type', choices=['yolo_v4', 'yolo_v4_tiny', 'yolo_v7', 'yolo_v7_tiny', + 'descriptor_yolo_v7'], + help='Choose the mode', required=True) + parser.add_argument('--embedding_size', type=int, help='Choose the embedding size for descriptor_yolo_v7') parser.add_argument('--weights_path', type=str, help='Choose the weights file path', required=True) parser.add_argument('--image_path', type=str, help='Choose the image file', required=True) args = parser.parse_args() - model = create_model(args.model_type) + model = create_model(args.model_type, args.dataset_type, embedding_size=args.embedding_size) model.load_weights(args.weights_path) + model.eval() image = Image.open(args.image_path) - predictions, scale = get_predictions(model, image) - display_predictions(predictions, scale, image) + predictions, scale, offset_x, offset_y = get_predictions(model, image) + display_predictions(predictions, scale, offset_x, offset_y, image, CLASSES_BY_DATASET_TYPE[args.dataset_type]) -def create_model(model_type): +def create_model(model_type, dataset_type, embedding_size=None, class_probs=False): + class_count = _get_class_count(dataset_type) if model_type == 'yolo_v4': - return YoloV4() + model = YoloV4(class_count, class_probs=class_probs) elif model_type == 'yolo_v4_tiny': - return YoloV4Tiny() + model = YoloV4Tiny(class_count, class_probs=class_probs) + elif model_type == 'yolo_v7': + model = YoloV7(dataset_type, class_probs=class_probs) + elif model_type == 'yolo_v7_tiny': + model = YoloV7Tiny(dataset_type, class_probs=class_probs) + elif model_type == 'descriptor_yolo_v7': + model = DescriptorYoloV7(class_count, embedding_size=embedding_size, class_probs=class_probs) else: raise ValueError('Invalid model type') + return model + def get_predictions(model, image): with torch.no_grad(): - transforms = CocoDetectionValidationTransforms(model.get_image_size()) + transforms = CocoDetectionValidationTransforms(model.get_image_size(), one_hot_class=True) image_tensor, _, metadata = transforms(image, None) start = time.time() predictions = model(image_tensor.unsqueeze(0)) print('Inference time: ', time.time() - start, 's') + start = time.time() predictions = group_predictions(predictions)[0] predictions = filter_yolo_predictions(predictions, confidence_threshold=0.5, nms_threshold=0.45) + print('Postprocessing time: ', time.time() - start, 's') - return predictions, metadata['scale'] + return predictions, metadata['scale'], metadata['offset_x'], metadata['offset_y'] -def display_predictions(predictions, scale, image): +def display_predictions(predictions, scale, offset_x, offset_y, image, classses): draw = ImageDraw.Draw(image) for prediction in predictions: - center_x = prediction[X_INDEX].item() / scale - center_y = prediction[Y_INDEX].item() / scale + center_x = (prediction[X_INDEX].item() - offset_x) / scale + center_y = (prediction[Y_INDEX].item() - offset_y) / scale w = prediction[W_INDEX].item() / scale h = prediction[H_INDEX].item() / scale class_index = torch.argmax(prediction[CLASSES_INDEX:]).item() @@ -79,7 +153,7 @@ def display_predictions(predictions, scale, image): y1 = center_y + h / 2 draw.rectangle([x0, y0, x1, y1], outline='red') - draw.text((x0, y0), COCO_CLASSES[class_index], fill='red') + draw.text((x0, y0), classses[class_index], fill='red') del draw image.show() diff --git a/tools/dnn_training/object_detection/modules/yolo_layer.py b/tools/dnn_training/object_detection/modules/yolo_layer.py index 8381f7d9..d1f433cc 100644 --- a/tools/dnn_training/object_detection/modules/yolo_layer.py +++ b/tools/dnn_training/object_detection/modules/yolo_layer.py @@ -1,3 +1,7 @@ + + +import pickle + import torch import torch.nn as nn @@ -9,13 +13,15 @@ CLASSES_INDEX = 5 -class YoloLayer(nn.Module): - def __init__(self, image_size, stride, anchors, class_count, scale_x_y): - super(YoloLayer, self).__init__() +class YoloV4Layer(nn.Module): + def __init__(self, image_size, stride, anchors, class_count, scale_x_y=1.0, class_probs=False): + super(YoloV4Layer, self).__init__() + self._class_probs = class_probs + self._grid_size = (image_size[1] // stride, image_size[0] // stride) self._stride = stride - self._anchors = [(a[0] // stride, a[1] // stride) for a in anchors] + self._anchors = [(a[0] / stride, a[1] / stride) for a in anchors] self._class_count = class_count self._scale_x_y = scale_x_y @@ -69,6 +75,74 @@ def forward(self, t): # Transform confidence confidence = torch.sigmoid(t[:, :, :, :, self._confidence_index:self._confidence_index + 1]) - descriptors = t[:, :, :, :, self._classes_index:] + classes = t[:, :, :, :, self._classes_index:] + if self._class_probs: + classes = torch.sigmoid(classes) + + return torch.cat([x, y, w, h, confidence, classes], dim=4) + + +class YoloV7Layer(nn.Module): + def __init__(self, image_size, stride, anchors, class_count, class_probs=False): + super(YoloV7Layer, self).__init__() + self._class_probs = class_probs + + self._grid_size = (image_size[1] // stride, image_size[0] // stride) + self._stride = stride + + self._anchors = [(a[0] / stride, a[1] / stride) for a in anchors] + self._class_count = class_count + + x = torch.arange(self._grid_size[1]) + y = torch.arange(self._grid_size[0]) + y_offset, x_offset = torch.meshgrid(y, x, indexing='ij') + self.register_buffer('_x_offset', x_offset.float().clone()) + self.register_buffer('_y_offset', y_offset.float().clone()) + + # Fix scripting errors + self._x_index = X_INDEX + self._y_index = Y_INDEX + self._w_index = W_INDEX + self._h_index = H_INDEX + self._confidence_index = CONFIDENCE_INDEX + self._classes_index = CLASSES_INDEX + + def forward(self, t): + N = t.size()[0] + N_ANCHORS = len(self._anchors) + H = self._grid_size[1] + W = self._grid_size[0] + N_PREDICTION = 5 + self._class_count + + # Transform x + t = t.view(N, N_ANCHORS, N_PREDICTION, H, W).permute(0, 1, 3, 4, 2).contiguous() + x = torch.sigmoid(t[:, :, :, :, self._x_index]) * 2.0 - 0.5 + x += self._x_offset + x *= self._stride + x = x.unsqueeze(4).permute(0, 2, 3, 1, 4) + + # Transform y + y = torch.sigmoid(t[:, :, :, :, self._y_index]) * 2.0 - 0.5 + y += self._y_offset + y *= self._stride + y = y.unsqueeze(4).permute(0, 2, 3, 1, 4) + + t = t.permute(0, 2, 3, 1, 4) + + # Transform w and h + w = [] + h = [] + for i in range(N_ANCHORS): + w.append(4 * torch.sigmoid(t[:, :, :, i, self._w_index:self._w_index + 1]) ** 2 * self._anchors[i][0]) + h.append(4 * torch.sigmoid(t[:, :, :, i, self._h_index:self._h_index + 1]) ** 2 * self._anchors[i][1]) + + w = torch.cat(w, dim=3).unsqueeze(4) * self._stride + h = torch.cat(h, dim=3).unsqueeze(4) * self._stride + + # Transform confidence + confidence = torch.sigmoid(t[:, :, :, :, self._confidence_index:self._confidence_index + 1]) + classes = t[:, :, :, :, self._classes_index:] + if self._class_probs: + classes = torch.sigmoid(classes) - return torch.cat([x, y, w, h, confidence, descriptors], dim=4) + return torch.cat([x, y, w, h, confidence, classes], dim=4) diff --git a/tools/dnn_training/object_detection/modules/yolo_v4.py b/tools/dnn_training/object_detection/modules/yolo_v4.py index be590bd2..e118375e 100644 --- a/tools/dnn_training/object_detection/modules/yolo_v4.py +++ b/tools/dnn_training/object_detection/modules/yolo_v4.py @@ -5,15 +5,15 @@ from common.modules import Mish -from object_detection.modules.yolo_layer import YoloLayer +from object_detection.modules.yolo_layer import YoloV4Layer IMAGE_SIZE = (608, 608) IN_CHANNELS = 3 -# Genereated from: yolov4.cfg: +# Generated from: yolov4.cfg: class YoloV4(nn.Module): - def __init__(self): + def __init__(self, class_count, class_probs=False): super(YoloV4, self).__init__() self._anchors = [] self._output_strides = [] @@ -528,11 +528,11 @@ def __init__(self): nn.LeakyReLU(0.1, inplace=True) ) self._conv138 = nn.Sequential( - nn.Conv2d(256, 255, 1, stride=1, padding=0, bias=True), + nn.Conv2d(256, 3 * (class_count + 5), 1, stride=1, padding=0, bias=True), ) self._anchors.append(np.array([(12, 16), (19, 36), (40, 28)])) self._output_strides.append(8) - self._yolo139 = YoloLayer(IMAGE_SIZE, 8, self._anchors[-1].tolist(), 80, 1.2) + self._yolo139 = YoloV4Layer(IMAGE_SIZE, 8, self._anchors[-1].tolist(), class_count, scale_x_y=1.2, class_probs=class_probs) self._conv141 = nn.Sequential( nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False), @@ -571,11 +571,11 @@ def __init__(self): nn.LeakyReLU(0.1, inplace=True) ) self._conv149 = nn.Sequential( - nn.Conv2d(512, 255, 1, stride=1, padding=0, bias=True), + nn.Conv2d(512, 3 * (class_count + 5), 1, stride=1, padding=0, bias=True), ) self._anchors.append(np.array([(36, 75), (76, 55), (72, 146)])) self._output_strides.append(16) - self._yolo150 = YoloLayer(IMAGE_SIZE, 16, self._anchors[-1].tolist(), 80, 1.1) + self._yolo150 = YoloV4Layer(IMAGE_SIZE, 16, self._anchors[-1].tolist(), class_count, scale_x_y=1.1, class_probs=class_probs) self._conv152 = nn.Sequential( nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False), @@ -614,11 +614,11 @@ def __init__(self): nn.LeakyReLU(0.1, inplace=True) ) self._conv160 = nn.Sequential( - nn.Conv2d(1024, 255, 1, stride=1, padding=0, bias=True), + nn.Conv2d(1024, 3 * (class_count + 5), 1, stride=1, padding=0, bias=True), ) self._anchors.append(np.array([(142, 110), (192, 243), (459, 401)])) self._output_strides.append(32) - self._yolo161 = YoloLayer(IMAGE_SIZE, 32, self._anchors[-1].tolist(), 80, 1.05) + self._yolo161 = YoloV4Layer(IMAGE_SIZE, 32, self._anchors[-1].tolist(), class_count, scale_x_y=1.05, class_probs=class_probs) def get_class_count(self): return 80 diff --git a/tools/dnn_training/object_detection/modules/yolo_v4_tiny.py b/tools/dnn_training/object_detection/modules/yolo_v4_tiny.py index 0cd15520..39b3aa90 100644 --- a/tools/dnn_training/object_detection/modules/yolo_v4_tiny.py +++ b/tools/dnn_training/object_detection/modules/yolo_v4_tiny.py @@ -3,14 +3,15 @@ import torch import torch.nn as nn -from object_detection.modules.yolo_layer import YoloLayer +from object_detection.modules.yolo_layer import YoloV4Layer IMAGE_SIZE = (416, 416) IN_CHANNELS = 3 -# Genereated from: yolov4-tiny.cfg: + +# Generated from: yolov4-tiny.cfg: class YoloV4Tiny(nn.Module): - def __init__(self): + def __init__(self, class_count, class_probs=False): super(YoloV4Tiny, self).__init__() self._anchors = [] @@ -113,11 +114,11 @@ def __init__(self): nn.LeakyReLU(0.1, inplace=True) ) self._conv29 = nn.Sequential( - nn.Conv2d(512, 255, 1, stride=1, padding=0, bias=True), + nn.Conv2d(512, 3 * (class_count + 5), 1, stride=1, padding=0, bias=True), ) self._anchors.append(np.array([(81, 82), (135, 169), (344, 319)])) self._output_strides.append(32) - self._yolo30 = YoloLayer(IMAGE_SIZE, 32, self._anchors[-1].tolist(), 80, 1.05) + self._yolo30 = YoloV4Layer(IMAGE_SIZE, 32, self._anchors[-1].tolist(), class_count, scale_x_y=1.05, class_probs=class_probs) self._conv32 = nn.Sequential( nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False), @@ -132,11 +133,11 @@ def __init__(self): nn.LeakyReLU(0.1, inplace=True) ) self._conv36 = nn.Sequential( - nn.Conv2d(256, 255, 1, stride=1, padding=0, bias=True), + nn.Conv2d(256, 3 * (class_count + 5), 1, stride=1, padding=0, bias=True), ) self._anchors.append(np.array([(23, 27), (37, 58), (81, 82)])) self._output_strides.append(16) - self._yolo37 = YoloLayer(IMAGE_SIZE, 16, self._anchors[-1].tolist(), 80, 1.05) + self._yolo37 = YoloV4Layer(IMAGE_SIZE, 16, self._anchors[-1].tolist(), class_count, scale_x_y=1.05, class_probs=class_probs) def get_image_size(self): return IMAGE_SIZE diff --git a/tools/dnn_training/object_detection/modules/yolo_v7.py b/tools/dnn_training/object_detection/modules/yolo_v7.py new file mode 100644 index 00000000..5b9a5ba2 --- /dev/null +++ b/tools/dnn_training/object_detection/modules/yolo_v7.py @@ -0,0 +1,621 @@ +from collections import OrderedDict + +import numpy as np + +import torch +import torch.nn as nn + +from object_detection.modules.yolo_layer import YoloV7Layer +from object_detection.modules.yolo_v7_modules import YoloV7SPPCSPC, RepConv + +from object_detection.datasets.object_detection_coco import CLASS_COUNT as COCO_CLASS_COUNT + + +IMAGE_SIZE = (640, 640) +IN_CHANNELS = 3 + + +# Generated from: yolov7.yaml: +class YoloV7(nn.Module): + def __init__(self, dataset_type='coco', class_probs=False): + super(YoloV7, self).__init__() + + self._anchors = [] + self._output_strides = [8, 16, 32] + + if dataset_type == 'coco': + class_count = COCO_CLASS_COUNT + self._anchors.append(np.array([(12, 16), (19, 36), (40, 28)])) + self._anchors.append(np.array([(36, 75), (76, 55), (72, 146)])) + self._anchors.append(np.array([(142, 110), (192, 243), (459, 401)])) + elif dataset_type == 'objects365': + class_count = 365 + self._anchors.append(np.array([(8, 7), (15, 14), (17, 36)])) + self._anchors.append(np.array([(38, 22), (39, 53), (93, 59)])) + self._anchors.append(np.array([(55, 122), (126, 179), (257, 324)])) + + self._conv0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(32, eps=0.001), + nn.SiLU(), + ) + self._conv1 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv2 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv3 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv4 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv5 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv6 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv7 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv8 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv9 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + + self._conv11 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._max_pool12 = nn.MaxPool2d(kernel_size=2, stride=2) + self._conv13 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv14 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv15 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + + self._conv17 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv18 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv19 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv20 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv21 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv22 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + + self._conv24 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + self._max_pool25 = nn.MaxPool2d(kernel_size=2, stride=2) + self._conv26 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv27 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv28 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + + self._conv30 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv31 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv32 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv33 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv34 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv35 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + + self._conv37 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(1024, eps=0.001), + nn.SiLU(), + ) + self._max_pool38 = nn.MaxPool2d(kernel_size=2, stride=2) + self._conv39 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + self._conv40 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + self._conv41 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + + self._conv43 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv44 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv45 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv46 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv47 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv48 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + + self._conv50 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(1024, eps=0.001), + nn.SiLU(), + ) + self._sppcspc51 = YoloV7SPPCSPC(1024, 512) + self._conv52 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._upsample53 = nn.Upsample(scale_factor=2, mode='nearest') + self._conv54 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + + self._conv56 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv57 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv58 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv59 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv60 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv61 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + + self._conv63 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv64 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._upsample65 = nn.Upsample(scale_factor=2, mode='nearest') + self._conv66 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + + self._conv68 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv69 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv70 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv71 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv72 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + self._conv73 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.SiLU(), + ) + + self._conv75 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._max_pool76 = nn.MaxPool2d(kernel_size=2, stride=2) + self._conv77 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv78 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv79 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + + self._conv81 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv82 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv83 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv84 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv85 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + self._conv86 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.SiLU(), + ) + + self._conv88 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._max_pool89 = nn.MaxPool2d(kernel_size=2, stride=2) + self._conv90 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv91 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv92 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + + self._conv94 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + self._conv95 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + self._conv96 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv97 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv98 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + self._conv99 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.SiLU(), + ) + + self._conv101 = nn.Sequential( + nn.Conv2d(in_channels=2048, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.SiLU(), + ) + self._rep_conv102 = RepConv(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, activation=nn.SiLU()) + self._rep_conv103 = RepConv(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, groups=1, activation=nn.SiLU()) + self._rep_conv104 = RepConv(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, groups=1, activation=nn.SiLU()) + + self._yolo0 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=self._anchors[0].shape[0] * (class_count + 5), kernel_size=1), + YoloV7Layer(IMAGE_SIZE, 8, self._anchors[0], class_count, class_probs=class_probs) + ) + self._yolo1 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=self._anchors[1].shape[0] * (class_count + 5), kernel_size=1), + YoloV7Layer(IMAGE_SIZE, 16, self._anchors[1], class_count, class_probs=class_probs) + ) + self._yolo2 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=self._anchors[2].shape[0] * (class_count + 5), kernel_size=1), + YoloV7Layer(IMAGE_SIZE, 32, self._anchors[2], class_count, class_probs=class_probs) + ) + + + def get_image_size(self): + return IMAGE_SIZE + + def get_anchors(self): + return self._anchors + + def get_output_strides(self): + return self._output_strides + + def forward(self, x): + y0 = self._conv0(x) + y1 = self._conv1(y0) + y2 = self._conv2(y1) + y3 = self._conv3(y2) + y4 = self._conv4(y3) + y5 = self._conv5(y3) + y6 = self._conv6(y5) + y7 = self._conv7(y6) + y8 = self._conv8(y7) + y9 = self._conv9(y8) + y10 = torch.cat([y9, y7, y5, y4], dim=1) + + y11 = self._conv11(y10) + y12 = self._max_pool12(y11) + y13 = self._conv13(y12) + y14 = self._conv14(y11) + y15 = self._conv15(y14) + y16 = torch.cat([y15, y13], dim=1) + + y17 = self._conv17(y16) + y18 = self._conv18(y16) + y19 = self._conv19(y18) + y20 = self._conv20(y19) + y21 = self._conv21(y20) + y22 = self._conv22(y21) + y23 = torch.cat([y22, y20, y18, y17], dim=1) + + y24 = self._conv24(y23) + y25 = self._max_pool25(y24) + y26 = self._conv26(y25) + y27 = self._conv27(y24) + y28 = self._conv28(y27) + y29 = torch.cat([y28, y26], dim=1) + + y30 = self._conv30(y29) + y31 = self._conv31(y29) + y32 = self._conv32(y31) + y33 = self._conv33(y32) + y34 = self._conv34(y33) + y35 = self._conv35(y34) + y36 = torch.cat([y35, y33, y31, y30], dim=1) + + y37 = self._conv37(y36) + y38 = self._max_pool38(y37) + y39 = self._conv39(y38) + y40 = self._conv40(y37) + y41 = self._conv41(y40) + y42 = torch.cat([y41, y39], dim=1) + + y43 = self._conv43(y42) + y44 = self._conv44(y42) + y45 = self._conv45(y44) + y46 = self._conv46(y45) + y47 = self._conv47(y46) + y48 = self._conv48(y47) + y49 = torch.cat([y48, y46, y44, y43], dim=1) + + y50 = self._conv50(y49) + y51 = self._sppcspc51(y50) + y52 = self._conv52(y51) + y53 = self._upsample53(y52) + y54 = self._conv54(y37) + y55 = torch.cat([y54, y53], dim=1) + + y56 = self._conv56(y55) + y57 = self._conv57(y55) + y58 = self._conv58(y57) + y59 = self._conv59(y58) + y60 = self._conv60(y59) + y61 = self._conv61(y60) + y62 = torch.cat([y61, y60, y59, y58, y57, y56], dim=1) + + y63 = self._conv63(y62) + y64 = self._conv64(y63) + y65 = self._upsample65(y64) + y66 = self._conv66(y24) + y67 = torch.cat([y66, y65], dim=1) + + y68 = self._conv68(y67) + y69 = self._conv69(y67) + y70 = self._conv70(y69) + y71 = self._conv71(y70) + y72 = self._conv72(y71) + y73 = self._conv73(y72) + y74 = torch.cat([y73, y72, y71, y70, y69, y68], dim=1) + + y75 = self._conv75(y74) + y76 = self._max_pool76(y75) + y77 = self._conv77(y76) + y78 = self._conv78(y75) + y79 = self._conv79(y78) + y80 = torch.cat([y79, y77, y63], dim=1) + + y81 = self._conv81(y80) + y82 = self._conv82(y80) + y83 = self._conv83(y82) + y84 = self._conv84(y83) + y85 = self._conv85(y84) + y86 = self._conv86(y85) + y87 = torch.cat([y86, y85, y84, y83, y82, y81], dim=1) + + y88 = self._conv88(y87) + y89 = self._max_pool89(y88) + y90 = self._conv90(y89) + y91 = self._conv91(y88) + y92 = self._conv92(y91) + y93 = torch.cat([y92, y90, y51], dim=1) + + y94 = self._conv94(y93) + y95 = self._conv95(y93) + y96 = self._conv96(y95) + y97 = self._conv97(y96) + y98 = self._conv98(y97) + y99 = self._conv99(y98) + y100 = torch.cat([y99, y98, y97, y96, y95, y94], dim=1) + + y101 = self._conv101(y100) + y102 = self._rep_conv102(y75) + y103 = self._rep_conv103(y88) + y104 = self._rep_conv104(y101) + + d0 = self._yolo0(y102) + d1 = self._yolo1(y103) + d2 = self._yolo2(y104) + return [d0, d1, d2] + + def load_weights(self, weights_path): + loaded_state_dict = self._filter_static_dict(torch.load(weights_path), 'anchor') + current_state_dict = self._filter_static_dict(self.state_dict(), 'offset') + + for i, (kl, kc) in enumerate(zip(loaded_state_dict.keys(), current_state_dict.keys())): + if current_state_dict[kc].size() != loaded_state_dict[kl].size(): + raise ValueError('Mismatching size.') + current_state_dict[kc] = loaded_state_dict[kl] + + self.load_state_dict(current_state_dict, strict=False) + + def _filter_static_dict(self, state_dict, x): + return OrderedDict([(k, v) for k, v in state_dict.items() if x not in k]) diff --git a/tools/dnn_training/object_detection/modules/yolo_v7_modules.py b/tools/dnn_training/object_detection/modules/yolo_v7_modules.py new file mode 100644 index 00000000..aa94373a --- /dev/null +++ b/tools/dnn_training/object_detection/modules/yolo_v7_modules.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn + + +class YoloV7SPPCSPC(nn.Module): + def __init__(self, in_channels, out_channels): + super(YoloV7SPPCSPC, self).__init__() + + self._conv0 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_channels, eps=0.001), + nn.SiLU() + ) + self._conv1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_channels, eps=0.001), + nn.SiLU() + ) + self._conv2 = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(out_channels, eps=0.001), + nn.SiLU() + ) + self._conv3 = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_channels, eps=0.001), + nn.SiLU() + ) + + self._max_pool0 = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) + self._max_pool1 = nn.MaxPool2d(kernel_size=9, stride=1, padding=4) + self._max_pool2 = nn.MaxPool2d(kernel_size=13, stride=1, padding=6) + + self._conv4 = nn.Sequential( + nn.Conv2d(4 * out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_channels, eps=0.001), + nn.SiLU() + ) + self._conv5 = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(out_channels, eps=0.001), + nn.SiLU() + ) + self._conv6 = nn.Sequential( + nn.Conv2d(2 * out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_channels, eps=0.001), + nn.SiLU() + ) + + def forward(self, x0): + x1 = self._conv3(self._conv2(self._conv0(x0))) + x2 = torch.cat([x1, self._max_pool0(x1), self._max_pool1(x1), self._max_pool2(x1)], dim=1) + y1 = self._conv5(self._conv4(x2)) + y2 = self._conv1(x0) + return self._conv6(torch.cat([y1, y2], dim=1)) + + +# TODO use only one Layer +class RepConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, activation): + super(RepConv, self).__init__() + + if padding == None: + padding = kernel_size // 2 + padding_1x1 = padding - kernel_size // 2 + + self._identity = nn.BatchNorm2d(in_channels) if in_channels == out_channels and stride == 1 else None + self._conv_kxk = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False), + nn.BatchNorm2d(out_channels, eps=0.001) + ) + + self._conv_1x1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=padding_1x1, groups=groups, bias=False), + nn.BatchNorm2d(out_channels, eps=0.001) + ) + + self._activation = activation + + def forward(self, x): + if self._identity is None: + return self._activation(self._conv_kxk(x) + self._conv_1x1(x)) + else: + return self._activation(self._conv_kxk(x) + self._conv_1x1(x) + self._identity(x)) diff --git a/tools/dnn_training/object_detection/modules/yolo_v7_tiny.py b/tools/dnn_training/object_detection/modules/yolo_v7_tiny.py new file mode 100644 index 00000000..c40ebe29 --- /dev/null +++ b/tools/dnn_training/object_detection/modules/yolo_v7_tiny.py @@ -0,0 +1,465 @@ +from collections import OrderedDict + +import numpy as np + +import torch +import torch.nn as nn + +from object_detection.modules.yolo_layer import YoloV7Layer + +from object_detection.datasets.object_detection_coco import CLASS_COUNT as COCO_CLASS_COUNT + +IMAGE_SIZE = (640, 640) +IN_CHANNELS = 3 + + +# Generated from: yolov7-tiny.yaml: +class YoloV7Tiny(nn.Module): + def __init__(self, dataset_type='coco', class_probs=False): + super(YoloV7Tiny, self).__init__() + + self._anchors = [] + self._output_strides = [8, 16, 32] + + if dataset_type == 'coco': + class_count = COCO_CLASS_COUNT + self._anchors.append(np.array([(12, 16), (19, 36), (40, 28)])) + self._anchors.append(np.array([(36, 75), (76, 55), (72, 146)])) + self._anchors.append(np.array([(142, 110), (192, 243), (459, 401)])) + elif dataset_type == 'objects365': + class_count = 365 + self._anchors.append(np.array([(8, 7), (15, 14), (17, 36)])) + self._anchors.append(np.array([(38, 22), (39, 53), (93, 59)])) + self._anchors.append(np.array([(55, 122), (126, 179), (257, 324)])) + + self._conv0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(32, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv1 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv2 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=32, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(32, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv3 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=32, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(32, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv4 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(32, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv5 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(32, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._conv7 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._max_pool8 = nn.MaxPool2d(kernel_size=2, stride=2) + self._conv9 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv10 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv11 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv12 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._conv14 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._max_pool15 = nn.MaxPool2d(kernel_size=2, stride=2) + self._conv16 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv17 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv18 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv19 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._conv21 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._max_pool22 = nn.MaxPool2d(kernel_size=2, stride=2) + self._conv23 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv24 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv25 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv26 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._conv28 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv29 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv30 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._max_pool31 = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) + self._max_pool32 = nn.MaxPool2d(kernel_size=9, stride=1, padding=4) + self._max_pool33 = nn.MaxPool2d(kernel_size=13, stride=1, padding=6) + + self._conv35 = nn.Sequential( + nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._conv37 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv38 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._upsample39 = nn.Upsample(scale_factor=2, mode='nearest') + self._conv40 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._conv42 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv43 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv44 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv45 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._conv47 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv48 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._upsample49 = nn.Upsample(scale_factor=2, mode='nearest') + self._conv50 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._conv52 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=32, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(32, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv53 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=32, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(32, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv54 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(32, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv55 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(32, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._conv57 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv58 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._conv60 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv61 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=64, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(64), + nn.LeakyReLU(0.1), + ) + self._conv62 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv63 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(64, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._conv65 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv66 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._conv68 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv69 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv70 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv71 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._conv73 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv74 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(128, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv75 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(256, eps=0.001), + nn.LeakyReLU(0.1), + ) + self._conv76 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, groups=1, bias=False), + nn.BatchNorm2d(512, eps=0.001), + nn.LeakyReLU(0.1), + ) + + self._yolo0 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=self._anchors[0].shape[0] * (class_count + 5), kernel_size=1), + YoloV7Layer(IMAGE_SIZE, 8, self._anchors[0], class_count, class_probs=class_probs) + ) + self._yolo1 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=self._anchors[1].shape[0] * (class_count + 5), kernel_size=1), + YoloV7Layer(IMAGE_SIZE, 16, self._anchors[1], class_count, class_probs=class_probs) + ) + self._yolo2 = nn.Sequential( + nn.Conv2d(in_channels=512, out_channels=self._anchors[2].shape[0] * (class_count + 5), kernel_size=1), + YoloV7Layer(IMAGE_SIZE, 32, self._anchors[2], class_count, class_probs=class_probs) + ) + + def get_image_size(self): + return IMAGE_SIZE + + def get_anchors(self): + return self._anchors + + def get_output_strides(self): + return self._output_strides + + def forward(self, x): + y0 = self._conv0(x) + y1 = self._conv1(y0) + y2 = self._conv2(y1) + y3 = self._conv3(y1) + y4 = self._conv4(y3) + y5 = self._conv5(y4) + y6 = torch.cat([y5, y4, y3, y2], dim=1) + + y7 = self._conv7(y6) + y8 = self._max_pool8(y7) + y9 = self._conv9(y8) + y10 = self._conv10(y8) + y11 = self._conv11(y10) + y12 = self._conv12(y11) + y13 = torch.cat([y12, y11, y10, y9], dim=1) + + y14 = self._conv14(y13) + y15 = self._max_pool15(y14) + y16 = self._conv16(y15) + y17 = self._conv17(y15) + y18 = self._conv18(y17) + y19 = self._conv19(y18) + y20 = torch.cat([y19, y18, y17, y16], dim=1) + + y21 = self._conv21(y20) + y22 = self._max_pool22(y21) + y23 = self._conv23(y22) + y24 = self._conv24(y22) + y25 = self._conv25(y24) + y26 = self._conv26(y25) + y27 = torch.cat([y26, y25, y24, y23], dim=1) + + y28 = self._conv28(y27) + y29 = self._conv29(y28) + y30 = self._conv30(y28) + y31 = self._max_pool31(y30) + y32 = self._max_pool32(y30) + y33 = self._max_pool33(y30) + y34 = torch.cat([y33, y32, y31, y30], dim=1) + + y35 = self._conv35(y34) + y36 = torch.cat([y35, y29], dim=1) + + y37 = self._conv37(y36) + y38 = self._conv38(y37) + y39 = self._upsample39(y38) + y40 = self._conv40(y21) + y41 = torch.cat([y40, y39], dim=1) + + y42 = self._conv42(y41) + y43 = self._conv43(y41) + y44 = self._conv44(y43) + y45 = self._conv45(y44) + y46 = torch.cat([y45, y44, y43, y42], dim=1) + + y47 = self._conv47(y46) + y48 = self._conv48(y47) + y49 = self._upsample49(y48) + y50 = self._conv50(y14) + y51 = torch.cat([y50, y49], dim=1) + + y52 = self._conv52(y51) + y53 = self._conv53(y51) + y54 = self._conv54(y53) + y55 = self._conv55(y54) + y56 = torch.cat([y55, y54, y53, y52], dim=1) + + y57 = self._conv57(y56) + y58 = self._conv58(y57) + y59 = torch.cat([y58, y47], dim=1) + + y60 = self._conv60(y59) + y61 = self._conv61(y59) + y62 = self._conv62(y61) + y63 = self._conv63(y62) + y64 = torch.cat([y63, y62, y61, y60], dim=1) + + y65 = self._conv65(y64) + y66 = self._conv66(y65) + y67 = torch.cat([y66, y37], dim=1) + + y68 = self._conv68(y67) + y69 = self._conv69(y67) + y70 = self._conv70(y69) + y71 = self._conv71(y70) + y72 = torch.cat([y71, y70, y69, y68], dim=1) + + y73 = self._conv73(y72) + y74 = self._conv74(y57) + y75 = self._conv75(y65) + y76 = self._conv76(y73) + + d0 = self._yolo0(y74) + d1 = self._yolo1(y75) + d2 = self._yolo2(y76) + return [d0, d1, d2] + + def load_weights(self, weights_path): + loaded_state_dict = self._filter_static_dict(torch.load(weights_path), 'anchor') + current_state_dict = self._filter_static_dict(self.state_dict(), 'offset') + + for i, (kl, kc) in enumerate(zip(loaded_state_dict.keys(), current_state_dict.keys())): + if current_state_dict[kc].size() != loaded_state_dict[kl].size(): + raise ValueError('Mismatching size.') + current_state_dict[kc] = loaded_state_dict[kl] + + self.load_state_dict(current_state_dict, strict=False) + + def _filter_static_dict(self, state_dict, x): + return OrderedDict([(k, v) for k, v in state_dict.items() if x not in k]) diff --git a/tools/dnn_training/pose_estimation/backbones/__init__.py b/tools/dnn_training/pose_estimation/backbones/__init__.py deleted file mode 100644 index fed25af4..00000000 --- a/tools/dnn_training/pose_estimation/backbones/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from pose_estimation.backbones.mnasnet import Mnasnet0_5, Mnasnet1_0 -from pose_estimation.backbones.resnet import Resnet18, Resnet34, Resnet50 diff --git a/tools/dnn_training/pose_estimation/backbones/mnasnet.py b/tools/dnn_training/pose_estimation/backbones/mnasnet.py deleted file mode 100644 index ef1b1199..00000000 --- a/tools/dnn_training/pose_estimation/backbones/mnasnet.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch.nn as nn - -import torchvision.models as models - - -class _Mnasnet(nn.Module): - def __init__(self, mnasnet): - super(_Mnasnet, self).__init__() - self._mnasnet_feature_extractor = mnasnet.layers - - def forward(self, x): - return self._mnasnet_feature_extractor(x) - - def last_channel_count(self): - return 1280 - - -class Mnasnet0_5(_Mnasnet): - def __init__(self, pretrained=False): - super(Mnasnet0_5, self).__init__(models.mnasnet0_5(pretrained=pretrained)) - - -class Mnasnet1_0(_Mnasnet): - def __init__(self, pretrained=False): - super(Mnasnet1_0, self).__init__(models.mnasnet1_0(pretrained=pretrained)) diff --git a/tools/dnn_training/pose_estimation/backbones/resnet.py b/tools/dnn_training/pose_estimation/backbones/resnet.py deleted file mode 100644 index ca824b15..00000000 --- a/tools/dnn_training/pose_estimation/backbones/resnet.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch.nn as nn - -import torchvision.models as models - - -class _Resnet(nn.Module): - def __init__(self, resnet): - super(_Resnet, self).__init__() - self._resnet_feature_extractor = nn.Sequential( - resnet.conv1, - resnet.bn1, - resnet.relu, - resnet.maxpool, - - resnet.layer1, - resnet.layer2, - resnet.layer3, - resnet.layer4 - ) - - def forward(self, x): - return self._resnet_feature_extractor(x) - - -class Resnet18(_Resnet): - def __init__(self, pretrained=False): - super(Resnet18, self).__init__(models.resnet18(pretrained=pretrained)) - - def last_channel_count(self): - return 512 - - -class Resnet34(_Resnet): - def __init__(self, pretrained=False): - super(Resnet34, self).__init__(models.resnet34(pretrained=pretrained)) - - def last_channel_count(self): - return 512 - - -class Resnet50(_Resnet): - def __init__(self, pretrained=False): - super(Resnet50, self).__init__(models.resnet50(pretrained=pretrained)) - - def last_channel_count(self): - return 2048 diff --git a/tools/dnn_training/pose_estimation/criterions/__init__.py b/tools/dnn_training/pose_estimation/criterions/__init__.py index d30bc018..b14f55b2 100644 --- a/tools/dnn_training/pose_estimation/criterions/__init__.py +++ b/tools/dnn_training/pose_estimation/criterions/__init__.py @@ -1 +1,2 @@ from pose_estimation.criterions.pose_estimation_loss import PoseEstimationLoss +from pose_estimation.criterions.pose_estimation_loss import PoseEstimationDistillationLoss diff --git a/tools/dnn_training/pose_estimation/criterions/pose_estimation_loss.py b/tools/dnn_training/pose_estimation/criterions/pose_estimation_loss.py index a45fa9b8..b5d7742c 100644 --- a/tools/dnn_training/pose_estimation/criterions/pose_estimation_loss.py +++ b/tools/dnn_training/pose_estimation/criterions/pose_estimation_loss.py @@ -11,6 +11,24 @@ def forward(self, heatmap_prediction, target): heatmap_target = F.interpolate(heatmap_target, (heatmap_prediction.size()[2], heatmap_prediction.size()[3])) - loss = ((heatmap_prediction - heatmap_target) ** 2 / (1 - heatmap_target + 0.005)).mean() + return F.binary_cross_entropy(heatmap_prediction, heatmap_target) - return loss + +class PoseEstimationDistillationLoss(nn.Module): + def __init__(self, alpha=0.25): + super(PoseEstimationDistillationLoss, self).__init__() + self._alpha = alpha + + def forward(self, student_heatmap_prediction, target, teacher_heatmap_prediction): + heatmap_target, presence_target, _ = target + + heatmap_target = F.interpolate(heatmap_target, + (student_heatmap_prediction.size()[2], student_heatmap_prediction.size()[3])) + teacher_heatmap_prediction = F.interpolate( + teacher_heatmap_prediction, + (student_heatmap_prediction.size()[2], student_heatmap_prediction.size()[3]) + ) + + target_loss = F.binary_cross_entropy(student_heatmap_prediction, heatmap_target) + teacher_loss = F.binary_cross_entropy(student_heatmap_prediction, teacher_heatmap_prediction) + return self._alpha * target_loss + (1 - self._alpha) * teacher_loss diff --git a/tools/dnn_training/pose_estimation/datasets/pose_estimation_coco.py b/tools/dnn_training/pose_estimation/datasets/pose_estimation_coco.py index 1fae3003..01d07961 100644 --- a/tools/dnn_training/pose_estimation/datasets/pose_estimation_coco.py +++ b/tools/dnn_training/pose_estimation/datasets/pose_estimation_coco.py @@ -18,13 +18,12 @@ RANDOM_KEYPOINT_MASK_P = 1.0 RANDOM_KEYPOINT_MASK_RATIO = 0.2 -HEATMAP_SIGMA = 10 - class PoseEstimationCoco(Dataset): - def __init__(self, root, train=True, data_augmentation=False, image_transforms=None): + def __init__(self, root, train=True, data_augmentation=False, image_transforms=None, heatmap_sigma=10): self._data_augmentation = data_augmentation self._image_transforms = image_transforms + self._heatmap_sigma = heatmap_sigma if train: self._image_root = os.path.join(root, 'train2017') @@ -156,7 +155,7 @@ def _generate_heatmap(self, keypoint_x, keypoint_y, image_width, image_height, i heatmap_grid_y, heatmap_grid_x = torch.meshgrid(heatmap_y, heatmap_x, indexing='ij') return torch.exp(-(torch.pow(heatmap_grid_x - keypoint_x, 2) + torch.pow(heatmap_grid_y - keypoint_y, 2)) / - (2 * HEATMAP_SIGMA ** 2)) + (2 * self._heatmap_sigma ** 2)) def evaluate(self, result_file): coco_gt = COCO(self._annotation_file_path) diff --git a/tools/dnn_training/pose_estimation/metrics/pose_accuracy_metric.py b/tools/dnn_training/pose_estimation/metrics/pose_accuracy_metric.py index 6d909711..e2d7e554 100644 --- a/tools/dnn_training/pose_estimation/metrics/pose_accuracy_metric.py +++ b/tools/dnn_training/pose_estimation/metrics/pose_accuracy_metric.py @@ -26,7 +26,7 @@ def clear(self): self._false_negative_count = 0 def add(self, heatmap_prediction, heatmap_target, presence_target, oks_scale): - heatmap_target = F.interpolate(heatmap_target, (heatmap_prediction.size()[2], heatmap_prediction.size()[3])) + heatmap_prediction = F.interpolate(heatmap_prediction, (heatmap_target.size()[2], heatmap_target.size()[3])) predicted_coordinates, presence_prediction = get_coordinates(heatmap_prediction) target_coordinates, _ = get_coordinates(heatmap_target) diff --git a/tools/dnn_training/pose_estimation/metrics/pose_map_metric.py b/tools/dnn_training/pose_estimation/metrics/pose_map_metric.py index d09c41c5..320a7746 100644 --- a/tools/dnn_training/pose_estimation/metrics/pose_map_metric.py +++ b/tools/dnn_training/pose_estimation/metrics/pose_map_metric.py @@ -2,7 +2,7 @@ from pose_estimation.metrics.pose_accuracy_metric import PoseAccuracyMetric -THRESHOLDS = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.8, 0.85, 0.9] +THRESHOLDS = np.linspace(0.0, 1.0, num=10).tolist() class PoseMapMetric: @@ -32,22 +32,4 @@ def _calculate_ap(recall, precision): sorted_indexes = np.argsort(recall) recall = recall[sorted_indexes] precision = precision[sorted_indexes] - - recall_points = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] - ap = 0 - for point in recall_points: - ap += _find_interpolated_precision(recall, precision, point) - ap /= len(recall_points) - - return ap - - -def _find_interpolated_precision(recall, precision, point): - s = 0 - while s < recall.shape[0] and recall[s] < point: - s += 1 - - if s == recall.shape[0]: - return 0 - else: - return np.amax(precision[s:]) + return np.trapz(precision, recall) diff --git a/tools/dnn_training/pose_estimation/pose_estimator.py b/tools/dnn_training/pose_estimation/pose_estimator.py index cd334f3f..c4599253 100644 --- a/tools/dnn_training/pose_estimation/pose_estimator.py +++ b/tools/dnn_training/pose_estimation/pose_estimator.py @@ -1,45 +1,135 @@ import torch import torch.nn as nn +import torchvision.models as models + + +class EfficientNetPoseEstimator(nn.Module): + SUPPORTED_BACKBONE_TYPES = ['efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', + 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7'] + HEATMAP_LAYER_CHANNELS_BY_BACKBONE_TYPE = {'efficientnet_b0': [320, 80, 40, 24], + 'efficientnet_b1': [320, 80, 40, 24], + 'efficientnet_b2': [352, 88, 48, 24], + 'efficientnet_b3': [384, 96, 48, 32], + 'efficientnet_b4': [448, 112, 56, 32], + 'efficientnet_b5': [512, 128, 64, 40], + 'efficientnet_b6': [576, 144, 72, 40], + 'efficientnet_b7': [640, 160, 80, 48]} + + def __init__(self, backbone_type, keypoint_count=17, pretrained_backbone=True): + super(EfficientNetPoseEstimator, self).__init__() + + if pretrained_backbone: + backbone_weights = 'DEFAULT' + else: + backbone_weights = None + + if (backbone_type not in self.SUPPORTED_BACKBONE_TYPES or + backbone_type not in self.HEATMAP_LAYER_CHANNELS_BY_BACKBONE_TYPE): + raise ValueError('Invalid backbone type') + + backbone_layers = list(models.__dict__[backbone_type](weights=backbone_weights).features) + self._features_layers = nn.ModuleList(backbone_layers[:-1]) + self._heatmap_layers = self._create_heatmap_layers(self.HEATMAP_LAYER_CHANNELS_BY_BACKBONE_TYPE[backbone_type], keypoint_count) + + def _create_heatmap_layers(self, channels, keypoint_count): + heatmap_layers = nn.ModuleList() + for i in range(len(channels)): + if i < len(channels) - 1: + output_channels = channels[i + 1] + else: + output_channels = channels[i] + + heatmap_layers.append( + nn.Sequential( + nn.ConvTranspose2d( + in_channels=channels[i], + out_channels=channels[i], + kernel_size=2, + stride=2, + padding=0, + output_padding=0, + bias=False), + nn.BatchNorm2d(channels[i]), + nn.SiLU(inplace=True), + + nn.Conv2d(in_channels=channels[i], + out_channels=output_channels, + kernel_size=3, + padding=1, + bias=False), + nn.BatchNorm2d(output_channels), + nn.SiLU(inplace=True), + + nn.Conv2d(in_channels=output_channels, + out_channels=output_channels, + kernel_size=3, + padding=1, + bias=False), + nn.BatchNorm2d(output_channels), + nn.SiLU(inplace=True), + ) + ) + + heatmap_layers.append( + nn.Sequential( + nn.ConvTranspose2d( + in_channels=channels[-1], + out_channels=channels[-1], + kernel_size=2, + stride=2, + padding=0, + output_padding=0, + bias=False), + nn.BatchNorm2d(channels[-1]), + nn.SiLU(inplace=True), + + nn.Conv2d( + in_channels=channels[-1], + out_channels=channels[-1], + kernel_size=3, + padding=1, + bias=False), + nn.BatchNorm2d(channels[-1]), + nn.SiLU(inplace=True), + + nn.Conv2d(in_channels=channels[-1], + out_channels=keypoint_count, + kernel_size=1, + padding=0, + bias=True), + nn.Sigmoid() + ) + ) + + return heatmap_layers + + def forward(self, x0): + x3, x4, x5, x8 = self._forward_features(x0) + y4 = self._forward_heatmap(x3, x4, x5, x8) + return y4 + + def _forward_features(self, x0): + assert len(self._features_layers) == 8 + x1 = self._features_layers[0](x0) + x2 = self._features_layers[1](x1) + x3 = self._features_layers[2](x2) + x4 = self._features_layers[3](x3) + x5 = self._features_layers[4](x4) + x6 = self._features_layers[5](x5) + x7 = self._features_layers[6](x6) + x8 = self._features_layers[7](x7) + + return x3, x4, x5, x8 + + def _forward_heatmap(self, x3, x4, x5, x8): + y0 = self._heatmap_layers[0](x8) + y1 = self._heatmap_layers[1](y0 + x5) + y2 = self._heatmap_layers[2](y1 + x4) + y3 = self._heatmap_layers[3](y2 + x3) + y4 = self._heatmap_layers[4](y3) + return y4 -class PoseEstimator(nn.Module): - def __init__(self, backbone, keypoint_count=17, upsampling_count=3): - super(PoseEstimator, self).__init__() - self._backbone = backbone - self._heatmap_layers = self._create_heatmap_layers(self._backbone.last_channel_count(), - keypoint_count, - upsampling_count) - - def forward(self, x): - features = self._backbone(x) - heatmaps = self._heatmap_layers(features) - - return heatmaps - - def _create_heatmap_layers(self, last_channel_count, keypoint_count, upsampling_count): - layers = [] - in_channels = last_channel_count - for _ in range(upsampling_count): - layers.append(nn.ConvTranspose2d( - in_channels=in_channels, - out_channels=256, - kernel_size=4, - stride=2, - padding=1, - output_padding=0, - bias=False)) - layers.append(nn.BatchNorm2d(256)) - layers.append(nn.ReLU(inplace=True)) - in_channels = 256 - - layers.append(nn.Conv2d( - in_channels=256, - out_channels=keypoint_count, - kernel_size=1, - stride=1, - padding=0)) - - return nn.Sequential(*layers) def get_coordinates(heatmaps): diff --git a/tools/dnn_training/pose_estimation/trainers/__init__.py b/tools/dnn_training/pose_estimation/trainers/__init__.py index aab4e54b..8aba913c 100644 --- a/tools/dnn_training/pose_estimation/trainers/__init__.py +++ b/tools/dnn_training/pose_estimation/trainers/__init__.py @@ -1 +1,2 @@ from pose_estimation.trainers.pose_estimator_trainer import PoseEstimatorTrainer +from pose_estimation.trainers.pose_estimator_distillation_trainer import PoseEstimatorDistillationTrainer diff --git a/tools/dnn_training/pose_estimation/trainers/pose_estimator_distillation_trainer.py b/tools/dnn_training/pose_estimation/trainers/pose_estimator_distillation_trainer.py new file mode 100644 index 00000000..e6cc5d47 --- /dev/null +++ b/tools/dnn_training/pose_estimation/trainers/pose_estimator_distillation_trainer.py @@ -0,0 +1,110 @@ +import os + +import torch + +from common.trainers import DistillationTrainer +from common.metrics import LossMetric + +from pose_estimation.criterions import PoseEstimationDistillationLoss +from pose_estimation.metrics import PoseAccuracyMetric, PoseMapMetric, PoseLearningCurves, CocoPoseEvaluation + +from pose_estimation.trainers.pose_estimator_trainer import _create_training_dataset_loader,\ + _create_validation_dataset_loader + + +class PoseEstimatorDistillationTrainer(DistillationTrainer): + def __init__(self, device, student_model, teacher_model, dataset_root='', output_path='', + epoch_count=10, learning_rate=0.01, weight_decay=0.0, batch_size=128, batch_size_division=4, + heatmap_sigma=10, + student_model_checkpoint=None, teacher_model_checkpoint=None, loss_alpha=0.25): + self._heatmap_sigma = heatmap_sigma + self._loss_alpha = loss_alpha + + super(PoseEstimatorDistillationTrainer, self).__init__(device, student_model, teacher_model, + dataset_root=dataset_root, + output_path=output_path, + epoch_count=epoch_count, + learning_rate=learning_rate, + weight_decay=weight_decay, + batch_size=batch_size, + batch_size_division=batch_size_division, + student_model_checkpoint=student_model_checkpoint, + teacher_model_checkpoint=teacher_model_checkpoint) + + self._training_loss_metric = LossMetric() + self._training_accuracy_metric = PoseAccuracyMetric() + self._training_map_metric = PoseMapMetric() + self._validation_loss_metric = LossMetric() + self._validation_accuracy_metric = PoseAccuracyMetric() + self._validation_map_metric = PoseMapMetric() + self._learning_curves = PoseLearningCurves() + + def _create_criterion(self, student_model, teacher_model): + return PoseEstimationDistillationLoss(alpha=self._loss_alpha) + + def _create_training_dataset_loader(self, dataset_root, batch_size, batch_size_division): + return _create_training_dataset_loader(dataset_root, batch_size, batch_size_division, self._heatmap_sigma) + + def _create_validation_dataset_loader(self, dataset_root, batch_size, batch_size_division): + return _create_validation_dataset_loader(dataset_root, batch_size, batch_size_division, self._heatmap_sigma) + + def _clear_between_training(self): + self._learning_curves.clear() + + def _clear_between_training_epoch(self): + self._training_loss_metric.clear() + self._training_accuracy_metric.clear() + self._training_map_metric.clear() + + def _move_target_to_device(self, target, device): + return target[0].to(device), target[1].to(device), target[2].to(device) + + def _measure_training_metrics(self, loss, heatmap_prediction, target): + heatmap_target, presence_target, oks_scale = target + + self._training_loss_metric.add(loss.item()) + self._training_accuracy_metric.add(heatmap_prediction, heatmap_target, presence_target, oks_scale) + self._training_map_metric.add(heatmap_prediction, heatmap_target, presence_target, oks_scale) + + def _clear_between_validation_epoch(self): + self._validation_loss_metric.clear() + self._validation_accuracy_metric.clear() + self._validation_map_metric.clear() + + def _measure_validation_metrics(self, loss, heatmap_prediction, target): + heatmap_target, presence_target, oks_scale = target + + self._validation_loss_metric.add(loss.item()) + self._validation_accuracy_metric.add(heatmap_prediction, heatmap_target, presence_target, oks_scale) + self._validation_map_metric.add(heatmap_prediction, heatmap_target, presence_target, oks_scale) + + def _print_performances(self): + print('\nTraining : Loss={}, Accuracy={}, Precision={}, Recall={}, mAP={}'.format( + self._training_loss_metric.get_loss(), + self._training_accuracy_metric.get_accuracy(), + self._training_accuracy_metric.get_precision(), + self._training_accuracy_metric.get_recall(), + self._training_map_metric.get_map())) + print('Validation : Loss={}, Accuracy={}, Precision={}, Recall={}, mAP={}\n'.format( + self._validation_loss_metric.get_loss(), + self._validation_accuracy_metric.get_accuracy(), + self._validation_accuracy_metric.get_precision(), + self._validation_accuracy_metric.get_recall(), + self._validation_map_metric.get_map())) + + def _save_learning_curves(self): + self._learning_curves.add_training_loss_value(self._training_loss_metric.get_loss()) + self._learning_curves.add_training_accuracy_value(self._training_accuracy_metric.get_accuracy()) + self._learning_curves.add_training_map_value(self._training_map_metric.get_map()) + + self._learning_curves.add_validation_loss_value(self._validation_loss_metric.get_loss()) + self._learning_curves.add_validation_accuracy_value(self._validation_accuracy_metric.get_accuracy()) + self._learning_curves.add_validation_map_value(self._validation_map_metric.get_map()) + + self._learning_curves.save(os.path.join(self._output_path, 'learning_curves.png'), + os.path.join(self._output_path, 'learning_curves.json')) + + def _evaluate(self, model, device, dataset_loader, output_path): + print('Evaluation', flush=True) + coco_pose_evaluation = CocoPoseEvaluation(model, device, dataset_loader, output_path) + coco_pose_evaluation.evaluate() diff --git a/tools/dnn_training/pose_estimation/trainers/pose_estimator_trainer.py b/tools/dnn_training/pose_estimation/trainers/pose_estimator_trainer.py index 3d541b8c..d3fda87b 100644 --- a/tools/dnn_training/pose_estimation/trainers/pose_estimator_trainer.py +++ b/tools/dnn_training/pose_estimation/trainers/pose_estimator_trainer.py @@ -4,8 +4,10 @@ import torchvision.transforms as transforms from common.trainers import Trainer +from common.datasets import RandomSharpnessChange, RandomAutocontrast, RandomEqualize, RandomPosterize from common.metrics import LossMetric + from pose_estimation.criterions import PoseEstimationLoss from pose_estimation.datasets import PoseEstimationCoco from pose_estimation.metrics import PoseAccuracyMetric, PoseMapMetric, PoseLearningCurves, CocoPoseEvaluation @@ -14,14 +16,18 @@ class PoseEstimatorTrainer(Trainer): - def __init__(self, device, model, dataset_root='', output_path='', epoch_count=10, learning_rate=0.01, - batch_size=128, batch_size_division=4, + def __init__(self, device, model, dataset_root='', output_path='', + epoch_count=10, learning_rate=0.01, weight_decay=0.0, batch_size=128, batch_size_division=4, + heatmap_sigma=10, model_checkpoint=None): + self._heatmap_sigma = heatmap_sigma + super(PoseEstimatorTrainer, self).__init__(device, model, dataset_root=dataset_root, output_path=output_path, epoch_count=epoch_count, learning_rate=learning_rate, + weight_decay=weight_decay, batch_size=batch_size, batch_size_division=batch_size_division, model_checkpoint=model_checkpoint) @@ -38,24 +44,10 @@ def _create_criterion(self, model): return PoseEstimationLoss() def _create_training_dataset_loader(self, dataset_root, batch_size, batch_size_division): - training_dataset = PoseEstimationCoco(dataset_root, - train=True, - data_augmentation=True, - image_transforms=create_training_image_transform()) - return torch.utils.data.DataLoader(training_dataset, - batch_size=batch_size // batch_size_division, - shuffle=True, - num_workers=2) + return _create_training_dataset_loader(dataset_root, batch_size, batch_size_division, self._heatmap_sigma) def _create_validation_dataset_loader(self, dataset_root, batch_size, batch_size_division): - validation_dataset = PoseEstimationCoco(dataset_root, - train=False, - data_augmentation=False, - image_transforms=create_validation_image_transform()) - return torch.utils.data.DataLoader(validation_dataset, - batch_size=batch_size // batch_size_division, - shuffle=True, - num_workers=2) + return _create_validation_dataset_loader(dataset_root, batch_size, batch_size_division, self._heatmap_sigma) def _clear_between_training(self): self._learning_curves.clear() @@ -124,6 +116,10 @@ def create_training_image_transform(): transforms.Resize(IMAGE_SIZE), transforms.ColorJitter(brightness=0.2, saturation=0.2, contrast=0.2, hue=0.2), transforms.RandomGrayscale(p=0.1), + RandomSharpnessChange(), + RandomAutocontrast(), + RandomEqualize(), + RandomPosterize(), transforms.ToTensor(), transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random'), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) @@ -135,3 +131,27 @@ def create_validation_image_transform(): transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) + + +def _create_training_dataset_loader(dataset_root, batch_size, batch_size_division, heatmap_sigma): + training_dataset = PoseEstimationCoco(dataset_root, + train=True, + data_augmentation=True, + image_transforms=create_training_image_transform(), + heatmap_sigma=heatmap_sigma) + return torch.utils.data.DataLoader(training_dataset, + batch_size=batch_size // batch_size_division, + shuffle=True, + num_workers=8) + + +def _create_validation_dataset_loader(dataset_root, batch_size, batch_size_division, heatmap_sigma): + validation_dataset = PoseEstimationCoco(dataset_root, + train=False, + data_augmentation=False, + image_transforms=create_validation_image_transform(), + heatmap_sigma=heatmap_sigma) + return torch.utils.data.DataLoader(validation_dataset, + batch_size=batch_size // batch_size_division, + shuffle=False, + num_workers=8) diff --git a/tools/dnn_training/requirements.txt b/tools/dnn_training/requirements.txt index 4b22f232..2de50eea 100644 --- a/tools/dnn_training/requirements.txt +++ b/tools/dnn_training/requirements.txt @@ -10,3 +10,4 @@ opencv-python git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI librosa sklearn +pyyaml diff --git a/tools/dnn_training/test_pose_estimator_with_yolo_v4.py b/tools/dnn_training/test_pose_estimator_with_yolo.py similarity index 76% rename from tools/dnn_training/test_pose_estimator_with_yolo_v4.py rename to tools/dnn_training/test_pose_estimator_with_yolo.py index d0edeb1b..a2775d8b 100644 --- a/tools/dnn_training/test_pose_estimator_with_yolo_v4.py +++ b/tools/dnn_training/test_pose_estimator_with_yolo.py @@ -16,15 +16,15 @@ from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval -from export_yolo_v4 import create_model as create_yolo_v4_model -from train_pose_estimator import create_model as create_pose_estimator_model +from train_pose_estimator import create_model as create_pose_estimator_model, BACKBONE_TYPES as POSE_BACKBONE_TYPES from common.modules import load_checkpoint from object_detection.datasets import ObjectDetectionCoco, CocoDetectionValidationTransforms -from object_detection.filter_yolo_predictions import group_predictions, filter_yolo_predictions +from object_detection.filter_yolo_predictions import group_predictions, filter_yolo_predictions_by_classes from object_detection.modules.yolo_layer import X_INDEX, Y_INDEX, W_INDEX, H_INDEX, CONFIDENCE_INDEX from object_detection.modules.yolo_layer import CLASSES_INDEX +from object_detection.modules.test_converted_yolo import create_model as create_yolo_model from pose_estimation.trainers.pose_estimator_trainer import IMAGE_SIZE as POSE_ESTIMATOR_IMAGE_SIZE from pose_estimation.datasets.pose_estimation_coco import COCO_PERSON_CATEGORY_ID @@ -57,7 +57,9 @@ def __getitem__(self, index): 'image_id': image_id, 'initial_width': initial_width, 'initial_height': initial_height, - 'scale': transforms_metadata['scale'] + 'scale': transforms_metadata['scale'], + 'offset_x': transforms_metadata['offset_x'], + 'offset_y': transforms_metadata['offset_y'] } return image, target, metadata @@ -66,11 +68,12 @@ def __len__(self): # TODO Refactor to reduce code duplication -class CocoPoseEvaluationWithYoloV4(): - def __init__(self, yolo_model, pose_estimator_model, dataset_root, dataset_split, output_path, +class CocoPoseEvaluationWithYolo(): + def __init__(self, yolo_model, pose_estimator_model, device, dataset_root, dataset_split, output_path, confidence_threshold=0.01, nms_threshold=0.5, presence_threshold=0.0): - self._yolo_model = yolo_model - self._pose_estimator_model = pose_estimator_model + self._device = device + self._yolo_model = yolo_model.to(device) + self._pose_estimator_model = pose_estimator_model.to(device) transforms = CocoDetectionValidationTransforms(yolo_model.get_image_size(), one_hot_class=False) @@ -110,33 +113,37 @@ def evaluate(self): return self._evaluate_coco() def _get_results(self): - results = [] - for image, _, metadata in tqdm(self._dataset): - yolo_predictions = self._yolo_model.forward(image.unsqueeze(0)) - yolo_predictions = group_predictions(yolo_predictions)[0] - yolo_predictions = filter_yolo_predictions(yolo_predictions, - confidence_threshold=self._confidence_threshold, - nms_threshold=self._nms_threshold) + with torch.no_grad(): + results = [] + for image, _, metadata in tqdm(self._dataset): + yolo_predictions = self._yolo_model.forward(image.unsqueeze(0).to(self._device)) + yolo_predictions = group_predictions(yolo_predictions)[0] + yolo_predictions = filter_yolo_predictions_by_classes(yolo_predictions, + confidence_threshold=self._confidence_threshold, + nms_threshold=self._nms_threshold) - results.extend(self._get_image_results(yolo_predictions, metadata)) + results.extend(self._get_image_results(yolo_predictions, metadata)) - return results + return results def _get_image_results(self, yolo_predictions, metadata): image_id = metadata['image_id'] scale = metadata['scale'] + offset_x = metadata['offset_x'] + offset_y = metadata['offset_y'] initial_width = metadata['initial_width'] initial_height = metadata['initial_height'] results = [] for yolo_prediction in yolo_predictions: - class_index = torch.argmax(yolo_prediction[CLASSES_INDEX:], dim=0).item() + class_probs = yolo_prediction[CLASSES_INDEX:] + class_index = torch.argmax(class_probs, dim=0).item() confidence = yolo_prediction[CONFIDENCE_INDEX].item() if class_index != PERSON_CLASS_INDEX or confidence < self._confidence_threshold: continue - center_x = (yolo_prediction[X_INDEX] / scale).item() - center_y = (yolo_prediction[Y_INDEX] / scale).item() + center_x = ((yolo_prediction[X_INDEX] - offset_x) / scale).item() + center_y = ((yolo_prediction[Y_INDEX] - offset_y) / scale).item() width = (yolo_prediction[W_INDEX] / scale).item() * BBOX_SCALE height = (yolo_prediction[H_INDEX] / scale).item() * BBOX_SCALE @@ -155,7 +162,7 @@ def _get_image_results(self, yolo_predictions, metadata): def _get_heatmap_prediction(self, image_id, x0, y0, x1, y1): file = '{:012d}.jpg'.format(image_id) image_tensor = TF.to_tensor(Image.open(os.path.join(self._image_root_path, file)).convert('RGB')) - image_tensor = image_tensor[:, y0:y1, x0:x1] + image_tensor = image_tensor[:, y0:y1, x0:x1].to(self._device) image_tensor = F.interpolate(image_tensor.unsqueeze(0), size=POSE_ESTIMATOR_IMAGE_SIZE, mode='bilinear') image_tensor = self._pose_estimator_normalization(image_tensor.squeeze(0)) @@ -208,30 +215,32 @@ def _evaluate_coco(self): def main(): parser = argparse.ArgumentParser(description='Test pose estimator with detected person') + parser.add_argument('--use_gpu', action='store_true', help='Use the GPU') parser.add_argument('--dataset_root', type=str, help='Choose the dataset root path', required=True) parser.add_argument('--dataset_split', choices=['validation', 'test'], required=True) parser.add_argument('--output_path', type=str, help='Choose the output path', required=True) - parser.add_argument('--yolo_model_type', choices=['yolo_v4', 'yolo_v4_tiny'], + parser.add_argument('--yolo_model_type', choices=['yolo_v4', 'yolo_v4_tiny', 'yolo_v7', 'yolo_v7_tiny'], help='Choose the model type', required=True) - parser.add_argument('--yolo_model_checkpoint', type=str, help='Choose the model checkpoint file', required=True) + parser.add_argument('--yolo_model_checkpoint', type=str, help='Choose the model checkpoint file for YOLO', + required=True) - parser.add_argument('--pose_backbone_type', - choices=['mnasnet0.5', 'mnasnet1.0', 'resnet18', 'resnet34', 'resnet50'], + parser.add_argument('--pose_backbone_type', choices=POSE_BACKBONE_TYPES, help='Choose the backbone type', required=True) - parser.add_argument('--pose_upsampling_count', type=int, help='Set the upsamping layer count', required=True) - parser.add_argument('--pose_model_checkpoint', type=str, help='Choose the model checkpoint file', required=True) + parser.add_argument('--pose_model_checkpoint', type=str, help='Choose the model checkpoint file for the pose', + required=True) args = parser.parse_args() - yolo_model = create_yolo_v4_model(args.yolo_model_type) + device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') + yolo_model = create_yolo_model(args.yolo_model_type, class_probs=True) load_checkpoint(yolo_model, args.yolo_model_checkpoint) - pose_estimator_model = create_pose_estimator_model(args.pose_backbone_type, args.pose_upsampling_count) + pose_estimator_model = create_pose_estimator_model(args.pose_backbone_type) load_checkpoint(pose_estimator_model, args.pose_model_checkpoint) - evaluation = CocoPoseEvaluationWithYoloV4(yolo_model, pose_estimator_model, - args.dataset_root, args.dataset_split, args.output_path) + evaluation = CocoPoseEvaluationWithYolo(yolo_model, pose_estimator_model, device, + args.dataset_root, args.dataset_split, args.output_path) evaluation.evaluate() diff --git a/tools/dnn_training/train_audio_descriptor_extractor.py b/tools/dnn_training/train_audio_descriptor_extractor.py index 5d62213c..dc48045e 100644 --- a/tools/dnn_training/train_audio_descriptor_extractor.py +++ b/tools/dnn_training/train_audio_descriptor_extractor.py @@ -3,12 +3,10 @@ import torch +from export_audio_descriptor_extractor import create_model + from common.program_arguments import save_arguments, print_arguments -from audio_descriptor.backbones import Mnasnet0_5, Mnasnet1_0, Resnet18, Resnet34, Resnet50, OpenFaceInception -from audio_descriptor.backbones import ThinResnet34, EcapaTdnn, SmallEcapaTdnn -from audio_descriptor.audio_descriptor_extractor import AudioDescriptorExtractor, AudioDescriptorExtractorVLAD -from audio_descriptor.audio_descriptor_extractor import AudioDescriptorExtractorSAP from audio_descriptor.trainers import AudioDescriptorExtractorTrainer @@ -22,10 +20,11 @@ def main(): 'open_face_inception', 'thin_resnet_34', 'ecapa_tdnn_512', 'ecapa_tdnn_1024', 'small_ecapa_tdnn_128', 'small_ecapa_tdnn_256', - 'small_ecapa_tdnn_512'], + 'small_ecapa_tdnn_512', 'small_ecapa_tdnn_1024' + 'passt_s_n', 'passt_s_n_l'], help='Choose the backbone type', required=True) parser.add_argument('--embedding_size', type=int, help='Set the embedding size', required=True) - parser.add_argument('--pooling_layer', choices=['avg', 'vlad', 'sap'], help='Set the pooling layer') + parser.add_argument('--pooling_layer', choices=['avg', 'vlad', 'sap', 'psla'], help='Set the pooling layer') parser.add_argument('--waveform_size', type=int, help='Set the waveform size', required=True) parser.add_argument('--n_features', type=int, help='Set n_features', required=True) parser.add_argument('--n_fft', type=int, help='Set n_fft', required=True) @@ -59,7 +58,7 @@ def main(): pooling_layer=args.pooling_layer) elif args.criterion_type == 'am_softmax_loss' and args.dataset_class_count is not None: model = create_model(args.backbone_type, args.n_features, args.embedding_size, args.dataset_class_count, - am_softmax_linear=True, pooling_layer=args.pooling_layer) + normalized_linear=True, pooling_layer=args.pooling_layer) else: raise ValueError('--dataset_class_count must be used with "cross_entropy_loss" and "am_softmax_loss" criterion ' 'types') @@ -93,52 +92,5 @@ def main(): trainer.train() -def create_model(backbone_type, n_features, embedding_size, - class_count=None, am_softmax_linear=False, pooling_layer='avg', conv_bias=False): - pretrained = True - - backbone = create_backbone(backbone_type, n_features, pretrained, conv_bias) - if pooling_layer == 'avg': - return AudioDescriptorExtractor(backbone, embedding_size=embedding_size, - class_count=class_count, am_softmax_linear=am_softmax_linear) - elif pooling_layer == 'vlad': - return AudioDescriptorExtractorVLAD(backbone, embedding_size=embedding_size, - class_count=class_count, am_softmax_linear=am_softmax_linear) - elif pooling_layer == 'sap': - return AudioDescriptorExtractorSAP(backbone, embedding_size=embedding_size, - class_count=class_count, am_softmax_linear=am_softmax_linear) - else: - raise ValueError('Invalid pooling layer') - - -def create_backbone(backbone_type, n_features, pretrained, conv_bias=False): - if backbone_type == 'mnasnet0.5': - return Mnasnet0_5(pretrained=pretrained) - elif backbone_type == 'mnasnet1.0': - return Mnasnet1_0(pretrained=pretrained) - elif backbone_type == 'resnet18': - return Resnet18(pretrained=pretrained) - elif backbone_type == 'resnet34': - return Resnet34(pretrained=pretrained) - elif backbone_type == 'resnet50': - return Resnet50(pretrained=pretrained) - elif backbone_type == 'open_face_inception': - return OpenFaceInception(conv_bias) - elif backbone_type == 'thin_resnet_34': - return ThinResnet34() - elif backbone_type == 'ecapa_tdnn_512': - return EcapaTdnn(n_features, channels=512) - elif backbone_type == 'ecapa_tdnn_1024': - return EcapaTdnn(n_features, channels=1024) - elif backbone_type == 'small_ecapa_tdnn_128': - return SmallEcapaTdnn(n_features, channels=128) - elif backbone_type == 'small_ecapa_tdnn_256': - return SmallEcapaTdnn(n_features, channels=256) - elif backbone_type == 'small_ecapa_tdnn_512': - return SmallEcapaTdnn(n_features, channels=512) - else: - raise ValueError('Invalid backbone type') - - if __name__ == '__main__': main() diff --git a/tools/dnn_training/train_backbone.py b/tools/dnn_training/train_backbone.py index d9a9495f..79227305 100644 --- a/tools/dnn_training/train_backbone.py +++ b/tools/dnn_training/train_backbone.py @@ -6,7 +6,8 @@ from common.program_arguments import save_arguments, print_arguments from backbone.stdc import Stdc1, Stdc2 -from backbone.trainers import BackboneTrainer +from backbone.vit import Vit +from backbone.trainers import BackboneTrainer, IMAGE_SIZE from backbone.datasets.classification_image_net import CLASS_COUNT as IMAGE_NET_CLASS_COUNT from backbone.datasets.classification_open_images import CLASS_COUNT as OPEN_IMAGES_CLASS_COUNT @@ -18,7 +19,10 @@ def main(): parser.add_argument('--output_path', type=str, help='Choose the output path', required=True) parser.add_argument('--dataset_type', choices=['image_net', 'open_images'], help='Choose the dataset type', required=True) - parser.add_argument('--model_type', choices=['stdc1', 'stdc2'], help='Choose the model type', required=True) + parser.add_argument('--model_type', choices=['stdc1', 'stdc2', 'passt_s_n', 'passt_s_n_l'], + help='Choose the model type', required=True) + parser.add_argument('--dropout_rate', type=float, help='Choose the dropout rate for passt_s_n and passt_s_n_l', + default=0.0) parser.add_argument('--learning_rate', type=float, help='Choose the learning rate', required=True) parser.add_argument('--weight_decay', type=float, help='Choose the weight decay', required=True) @@ -32,7 +36,7 @@ def main(): args = parser.parse_args() - model = create_model(args.model_type, args.dataset_type) + model = create_model(args.model_type, args.dataset_type, args.dropout_rate) device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') output_path = os.path.join(args.output_path, args.model_type + '_' + args.criterion_type + '_' + @@ -53,7 +57,7 @@ def main(): trainer.train() -def create_model(model_type, dataset_type): +def create_model(model_type, dataset_type, dropout_rate): if dataset_type == 'image_net': class_count = IMAGE_NET_CLASS_COUNT elif dataset_type == 'open_images': @@ -65,6 +69,12 @@ def create_model(model_type, dataset_type): return Stdc1(class_count=class_count, dropout=0.0) elif model_type == 'stdc2': return Stdc2(class_count=class_count, dropout=0.0) + elif model_type == 'passt_s_n': + return Vit(IMAGE_SIZE, class_count=class_count, depth=12, + dropout_rate=dropout_rate, attention_dropout_rate=dropout_rate) + elif model_type == 'passt_s_n_l': + return Vit(IMAGE_SIZE, class_count=class_count, depth=7, + dropout_rate=dropout_rate, attention_dropout_rate=dropout_rate) else: raise ValueError('Invalid backbone type') diff --git a/tools/dnn_training/train_descriptor_yolo_v4.py b/tools/dnn_training/train_descriptor_yolo.py similarity index 93% rename from tools/dnn_training/train_descriptor_yolo_v4.py rename to tools/dnn_training/train_descriptor_yolo.py index 1197ee0c..c91b309a 100644 --- a/tools/dnn_training/train_descriptor_yolo_v4.py +++ b/tools/dnn_training/train_descriptor_yolo.py @@ -11,6 +11,7 @@ from object_detection.trainers.descriptor_yolo_v4_trainer import DescriptorYoloV4Trainer from object_detection.descriptor_yolo_v4 import DescriptorYoloV4 from object_detection.descriptor_yolo_v4_tiny import DescriptorYoloV4Tiny +from object_detection.descriptor_yolo_v7 import DescriptorYoloV7 def main(): @@ -20,7 +21,7 @@ def main(): parser.add_argument('--dataset_type', choices=['coco', 'open_images'], help='Choose the database type', required=True) parser.add_argument('--output_path', type=str, help='Choose the output path', required=True) - parser.add_argument('--model_type', choices=['yolo_v4', 'yolo_v4_tiny'], + parser.add_argument('--model_type', choices=['yolo_v4', 'yolo_v4_tiny', 'yolo_v7'], help='Choose the model type', required=True) parser.add_argument('--descriptor_size', type=int, help='Choose the descriptor size', required=True) parser.add_argument('--class_criterion_type', choices=['cross_entropy_loss', 'bce_loss', 'sigmoid_focal_loss'], @@ -61,6 +62,8 @@ def create_model(model_type, descriptor_size, dataset_type): model = DescriptorYoloV4(class_count, descriptor_size) elif model_type == 'yolo_v4_tiny': model = DescriptorYoloV4Tiny(class_count, descriptor_size) + elif model_type == 'yolo_v7': + model = DescriptorYoloV7(class_count, descriptor_size) else: raise ValueError('Invalid model type') @@ -72,6 +75,8 @@ def _get_class_count(dataset_type): return COCO_CLASS_COUNT elif dataset_type == 'open_images': return OPEN_IMAGES_CLASS_COUNT + elif dataset_type == 'objects365': + return 365 else: raise ValueError('Invalid dataset type') diff --git a/tools/dnn_training/train_face_descriptor_extractor.py b/tools/dnn_training/train_face_descriptor_extractor.py index 5a84d4f2..d0892fa0 100644 --- a/tools/dnn_training/train_face_descriptor_extractor.py +++ b/tools/dnn_training/train_face_descriptor_extractor.py @@ -5,16 +5,20 @@ from common.program_arguments import save_arguments, print_arguments -from face_recognition.face_descriptor_extractor import FaceDescriptorExtractor -from face_recognition.trainers import FaceDescriptorExtractorTrainer +from face_recognition.face_descriptor_extractor import FaceDescriptorExtractor, OpenFaceBackbone, EfficientNetBackbone +from face_recognition.trainers import FaceDescriptorExtractorTrainer, FaceDescriptorExtractorDistillationTrainer + +BACKBONE_TYPES = ['open_face', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', + 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7'] def main(): parser = argparse.ArgumentParser(description='Train Backbone') parser.add_argument('--use_gpu', action='store_true', help='Use the GPU') - parser.add_argument('--vvgface2_dataset_root', type=str, help='Choose the Vggface2 root path', required=True) + parser.add_argument('--dataset_roots', nargs='+', type=str, help='Choose the Vggface2 root path', required=True) parser.add_argument('--lfw_dataset_root', type=str, help='Choose the LFW root path', required=True) parser.add_argument('--output_path', type=str, help='Choose the output path', required=True) + parser.add_argument('--backbone_type', choices=BACKBONE_TYPES, help='Choose the backbone type', required=True) parser.add_argument('--embedding_size', type=int, help='Set the embedding size', required=True) parser.add_argument('--margin', type=float, help='Set the margin', required=True) @@ -22,51 +26,90 @@ def main(): parser.add_argument('--weight_decay', type=float, help='Choose the weight decay', required=True) parser.add_argument('--batch_size', type=int, help='Set the batch size for the training', required=True) parser.add_argument('--epoch_count', type=int, help='Choose the epoch count', required=True) - parser.add_argument('--criterion_type', choices=['triplet_loss', 'cross_entropy_loss', 'am_softmax_loss'], + parser.add_argument('--criterion_type', + choices=['triplet_loss', 'cross_entropy_loss', 'am_softmax_loss', 'arc_face_loss'], help='Choose the criterion type', required=True) parser.add_argument('--dataset_class_count', type=int, - help='Choose the dataset class count when criterion_type is "cross_entropy_loss" or ' - '"am_softmax_loss"', + help='Choose the dataset class count when criterion_type is "cross_entropy_loss", ' + '"am_softmax_loss" or "arc_face_loss"', default=None) parser.add_argument('--model_checkpoint', type=str, help='Choose the model checkpoint file', default=None) + parser.add_argument('--teacher_backbone_type', choices=BACKBONE_TYPES, help='Choose the teacher backbone type', + default=None) + parser.add_argument('--teacher_model_checkpoint', type=str, help='Choose the teacher model checkpoint file', + default=None) + args = parser.parse_args() - if args.criterion_type == 'triplet_loss' and args.dataset_class_count is None: - model = create_model(args.embedding_size) - elif args.criterion_type == 'cross_entropy_loss' and args.dataset_class_count is not None: - model = create_model(args.embedding_size, args.dataset_class_count) - elif args.criterion_type == 'am_softmax_loss' and args.dataset_class_count is not None: - model = create_model(args.embedding_size, args.dataset_class_count, am_softmax_linear=True) - else: - raise ValueError('--dataset_class_count must be used with "cross_entropy_loss" or "am_softmax_loss" types') + model = create_model_from_criterion_type(args.criterion_type, args.dataset_class_count, args.backbone_type, + args.embedding_size) device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') - output_path = os.path.join(args.output_path, 'e' + str(args.embedding_size) + + output_path = os.path.join(args.output_path, args.backbone_type + '_e' + str(args.embedding_size) + '_' + args.criterion_type + '_lr' + str(args.learning_rate) + - '_wd' + str(args.weight_decay)) + '_wd' + str(args.weight_decay) + '_m' + str(args.margin) + + '_t' + str(args.teacher_backbone_type)) save_arguments(output_path, args) print_arguments(args) - trainer = FaceDescriptorExtractorTrainer(device, model, - epoch_count=args.epoch_count, - learning_rate=args.learning_rate, - weight_decay=args.weight_decay, - criterion_type=args.criterion_type, - vvgface2_dataset_root=args.vvgface2_dataset_root, - lfw_dataset_root=args.lfw_dataset_root, - output_path=output_path, - batch_size=args.batch_size, - margin=args.margin, - model_checkpoint=args.model_checkpoint) + if args.teacher_backbone_type is not None and args.teacher_model_checkpoint is not None: + teacher_model = create_model_from_criterion_type(args.criterion_type, args.dataset_class_count, + args.teacher_backbone_type, args.embedding_size) + trainer = FaceDescriptorExtractorDistillationTrainer(device, model, teacher_model, + epoch_count=args.epoch_count, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + criterion_type=args.criterion_type, + dataset_roots=args.dataset_roots, + lfw_dataset_root=args.lfw_dataset_root, + output_path=output_path, + batch_size=args.batch_size, + margin=args.margin, + student_model_checkpoint=args.model_checkpoint, + teacher_model_checkpoint=args.teacher_model_checkpoint) + elif args.teacher_backbone_type is not None or args.teacher_model_checkpoint is not None: + raise ValueError('teacher_backbone_type and teacher_model_checkpoint must be set.') + else: + trainer = FaceDescriptorExtractorTrainer(device, model, + epoch_count=args.epoch_count, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + criterion_type=args.criterion_type, + dataset_roots=args.dataset_roots, + lfw_dataset_root=args.lfw_dataset_root, + output_path=output_path, + batch_size=args.batch_size, + margin=args.margin, + model_checkpoint=args.model_checkpoint) trainer.train() -def create_model(embedding_size, class_count=None, am_softmax_linear=False): - return FaceDescriptorExtractor(embedding_size=embedding_size, +def create_model_from_criterion_type(criterion_type, dataset_class_count, backbone_type, embedding_size): + if criterion_type == 'triplet_loss' and dataset_class_count is None: + return create_model(backbone_type, embedding_size) + elif criterion_type == 'cross_entropy_loss' and dataset_class_count is not None: + return create_model(backbone_type, embedding_size, dataset_class_count) + elif criterion_type == 'am_softmax_loss' or criterion_type == 'arc_face_loss' \ + and dataset_class_count is not None: + return create_model(backbone_type, embedding_size, dataset_class_count, normalized_linear=True) + else: + raise ValueError('--dataset_class_count must be used with "cross_entropy_loss" or "am_softmax_loss" types') + + +def create_model(backbone_type, embedding_size, class_count=None, normalized_linear=False): + if backbone_type == 'open_face': + backbone = OpenFaceBackbone() + elif backbone_type.startswith('efficientnet_b'): + backbone = EfficientNetBackbone(backbone_type, pretrained_backbone=True) + else: + raise ValueError('Invalid backbone') + + return FaceDescriptorExtractor(backbone, + embedding_size=embedding_size, class_count=class_count, - am_softmax_linear=am_softmax_linear) + normalized_linear=normalized_linear) if __name__ == '__main__': diff --git a/tools/dnn_training/train_keyword_spotter.py b/tools/dnn_training/train_keyword_spotter.py index c0b9b008..8aba6075 100644 --- a/tools/dnn_training/train_keyword_spotter.py +++ b/tools/dnn_training/train_keyword_spotter.py @@ -5,7 +5,7 @@ from common.program_arguments import save_arguments, print_arguments -from keyword_spotting.keyword_spotter import KeywordSpotter +from export_keyword_spotter import create_model from keyword_spotting.trainers import KeywordSpotterTrainer @@ -51,14 +51,5 @@ def main(): trainer.train() -def create_model(dataset_type): - if dataset_type == 'google_speech_commands': - return KeywordSpotter(class_count=36, use_softmax=False) - elif dataset_type == 'ttop_keyword': - return KeywordSpotter(class_count=2, use_softmax=False) - else: - raise ValueError('Invalid database type') - - if __name__ == '__main__': main() diff --git a/tools/dnn_training/train_multiclass_audio_descriptor_extractor.py b/tools/dnn_training/train_multiclass_audio_descriptor_extractor.py index 6fbe2b31..4f9f2712 100644 --- a/tools/dnn_training/train_multiclass_audio_descriptor_extractor.py +++ b/tools/dnn_training/train_multiclass_audio_descriptor_extractor.py @@ -7,34 +7,37 @@ from audio_descriptor.trainers import MulticlassAudioDescriptorExtractorTrainer -from train_audio_descriptor_extractor import create_model +from export_audio_descriptor_extractor import create_model def main(): parser = argparse.ArgumentParser(description='Train Backbone') parser.add_argument('--use_gpu', action='store_true', help='Use the GPU') - parser.add_argument('--dataset_root', type=str, help='Choose the dataset root path (FSD50k)', required=True) + parser.add_argument('--dataset_root', type=str, help='Choose the dataset root path ()', required=True) + parser.add_argument('--dataset_type', choices=['audio_set', 'fsd50k'], help='Choose the database type', + required=True) parser.add_argument('--output_path', type=str, help='Choose the output path', required=True) parser.add_argument('--backbone_type', choices=['mnasnet0.5', 'mnasnet1.0', - 'resnet18', 'resnet34', 'resnet50', + 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'open_face_inception', 'thin_resnet_34', 'ecapa_tdnn_512', 'ecapa_tdnn_1024', 'small_ecapa_tdnn_128', 'small_ecapa_tdnn_256', - 'small_ecapa_tdnn_512'], + 'small_ecapa_tdnn_512', + 'passt_s_n', 'passt_s_n_l'], help='Choose the backbone type', required=True) parser.add_argument('--embedding_size', type=int, help='Set the embedding size', required=True) - parser.add_argument('--pooling_layer', choices=['avg', 'vlad', 'sap'], help='Set the pooling layer') + parser.add_argument('--pooling_layer', choices=['avg', 'vlad', 'sap', 'psla'], help='Set the pooling layer') parser.add_argument('--waveform_size', type=int, help='Set the waveform size', required=True) parser.add_argument('--n_features', type=int, help='Set n_features', required=True) parser.add_argument('--n_fft', type=int, help='Set n_fft', required=True) - parser.add_argument('--audio_transform_type', choices=['mfcc', 'mel_spectrogram', 'spectrogram'], + parser.add_argument('--audio_transform_type', + choices=['mfcc', 'log_mel_spectrogram', 'mel_spectrogram', 'spectrogram', 'psla'], help='Choose the audio transform type', required=True) parser.add_argument('--enable_pitch_shifting', action='store_true', help='Use pitch shifting data augmentation') parser.add_argument('--enable_time_stretching', action='store_true', help='Use pitch shifting data augmentation') parser.add_argument('--enable_time_masking', action='store_true', help='Use time masking data augmentation') parser.add_argument('--enable_frequency_masking', action='store_true', help='Use time masking data augmentation') - parser.add_argument('--enable_pos_weight', action='store_true', help='Use pos weight in the loss') - parser.add_argument('--enable_mixup', action='store_true', help='Use pos weight in the loss') + parser.add_argument('--enhanced_targets', action='store_true', help='Use enhanced targets') parser.add_argument('--learning_rate', type=float, help='Choose the learning rate', required=True) parser.add_argument('--weight_decay', type=float, help='Choose the weight decay', required=True) @@ -47,14 +50,21 @@ def main(): args = parser.parse_args() - model = create_model(args.backbone_type, args.n_features, args.embedding_size, class_count=200, + if args.dataset_type == 'audio_set': + class_count = 527 + elif args.dataset_type == 'fsd50k': + class_count = 200 + else: + raise ValueError('Invalid dataset type') + + model = create_model(args.backbone_type, args.n_features, args.embedding_size, class_count=class_count, pooling_layer=args.pooling_layer) device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') - output_path = os.path.join(args.output_path, args.backbone_type + '_e' + str(args.embedding_size) + - '_' + args.audio_transform_type + '_mixup' + str(int(args.enable_mixup)) + - '_' + args.criterion_type + '_lr' + str(args.learning_rate) + - '_wd' + str(args.weight_decay)) + output_path = os.path.join(args.output_path, args.backbone_type + '_' + args.dataset_type + + '_e' + str(args.embedding_size) + '_' + args.audio_transform_type + + '_enhanced' + str(int(args.enhanced_targets)) + '_' + args.criterion_type + + '_lr' + str(args.learning_rate) + '_wd' + str(args.weight_decay)) save_arguments(output_path, args) print_arguments(args) @@ -63,6 +73,7 @@ def main(): learning_rate=args.learning_rate, weight_decay=args.weight_decay, dataset_root=args.dataset_root, + dataset_type=args.dataset_type, output_path=output_path, batch_size=args.batch_size, waveform_size=args.waveform_size, @@ -73,8 +84,7 @@ def main(): enable_time_stretching=args.enable_time_stretching, enable_time_masking=args.enable_time_masking, enable_frequency_masking=args.enable_frequency_masking, - enable_pos_weight=args.enable_pos_weight, - enable_mixup=args.enable_mixup, + enhanced_targets=args.enhanced_targets, model_checkpoint=args.model_checkpoint) trainer.train() diff --git a/tools/dnn_training/train_pose_estimator.py b/tools/dnn_training/train_pose_estimator.py index bf853315..11d46b52 100644 --- a/tools/dnn_training/train_pose_estimator.py +++ b/tools/dnn_training/train_pose_estimator.py @@ -5,73 +5,83 @@ from common.program_arguments import save_arguments, print_arguments -from pose_estimation.backbones import Mnasnet0_5, Mnasnet1_0, Resnet18, Resnet34, Resnet50 -from pose_estimation.pose_estimator import PoseEstimator -from pose_estimation.trainers import PoseEstimatorTrainer +from pose_estimation.pose_estimator import EfficientNetPoseEstimator +from pose_estimation.trainers import PoseEstimatorTrainer, PoseEstimatorDistillationTrainer +BACKBONE_TYPES = ['efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', + 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7'] -# Train a model like : https://github.com/microsoft/human-pose-estimation.pytorch + +# Train a model similar to https://github.com/microsoft/human-pose-estimation.pytorch, but with residual connections. def main(): + parser = argparse.ArgumentParser(description='Train Backbone') parser.add_argument('--use_gpu', action='store_true', help='Use the GPU') parser.add_argument('--dataset_root', type=str, help='Choose the dataset root path', required=True) parser.add_argument('--output_path', type=str, help='Choose the output path', required=True) - parser.add_argument('--backbone_type', choices=['mnasnet0.5', 'mnasnet1.0', 'resnet18', 'resnet34', 'resnet50'], - help='Choose the backbone type', required=True) - parser.add_argument('--upsampling_count', type=int, help='Set the upsamping layer count', required=True) + parser.add_argument('--backbone_type', choices=BACKBONE_TYPES, help='Choose the backbone type', required=True) parser.add_argument('--learning_rate', type=float, help='Choose the learning rate', required=True) + parser.add_argument('--weight_decay', type=float, help='Choose the weight decay', required=True) parser.add_argument('--batch_size', type=int, help='Set the batch size for the training', required=True) parser.add_argument('--batch_size_division', type=int, help='Set the batch size for the training', required=True) parser.add_argument('--epoch_count', type=int, help='Choose the epoch count', required=True) + parser.add_argument('--heatmap_sigma', type=float, help='Choose sigma to create the heatmap', required=True) + parser.add_argument('--model_checkpoint', type=str, help='Choose the model checkpoint file', default=None) + parser.add_argument('--teacher_backbone_type', choices=BACKBONE_TYPES, help='Choose the teacher backbone type', + default=None) + parser.add_argument('--teacher_model_checkpoint', type=str, help='Choose the teacher model checkpoint file', + default=None) + parser.add_argument('--distillation_loss_alpha', type=float, help='Choose the alpha for the distillation loss', + default=0.25) + args = parser.parse_args() - model = create_model(args.backbone_type, args.upsampling_count) + model = create_model(args.backbone_type) device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') - output_path = os.path.join(args.output_path, args.backbone_type) + output_path = os.path.join(args.output_path, args.backbone_type + '_sig' + str(args.heatmap_sigma) + + '_lr' + str(args.learning_rate) + '_wd' + str(args.weight_decay) + + '_t' + str(args.teacher_backbone_type) + '_a' + str(args.distillation_loss_alpha)) save_arguments(output_path, args) print_arguments(args) - trainer = PoseEstimatorTrainer(device, model, - epoch_count=args.epoch_count, - learning_rate=args.learning_rate, - dataset_root=args.dataset_root, - output_path=output_path, - batch_size=args.batch_size, - batch_size_division=args.batch_size_division, - model_checkpoint=args.model_checkpoint) + if args.teacher_backbone_type is not None and args.teacher_model_checkpoint is not None: + teacher_model = create_model(args.teacher_backbone_type) + trainer = PoseEstimatorDistillationTrainer(device, model, teacher_model, + epoch_count=args.epoch_count, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + dataset_root=args.dataset_root, + output_path=output_path, + batch_size=args.batch_size, + batch_size_division=args.batch_size_division, + heatmap_sigma=args.heatmap_sigma, + student_model_checkpoint=args.model_checkpoint, + teacher_model_checkpoint=args.teacher_model_checkpoint, + loss_alpha=args.distillation_loss_alpha) + elif args.teacher_backbone_type is not None or args.teacher_model_checkpoint is not None: + raise ValueError('teacher_backbone_type and teacher_model_checkpoint must be set.') + else: + trainer = PoseEstimatorTrainer(device, model, + epoch_count=args.epoch_count, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + dataset_root=args.dataset_root, + output_path=output_path, + batch_size=args.batch_size, + batch_size_division=args.batch_size_division, + heatmap_sigma=args.heatmap_sigma, + model_checkpoint=args.model_checkpoint) trainer.train() -def create_model(backbone_type, upsampling_count): - pretrained = True - keypoint_count = 17 - if backbone_type == 'mnasnet0.5': - return PoseEstimator(Mnasnet0_5(pretrained=pretrained), - keypoint_count=keypoint_count, - upsampling_count=upsampling_count) - elif backbone_type == 'mnasnet1.0': - return PoseEstimator(Mnasnet1_0(pretrained=pretrained), - keypoint_count=keypoint_count, - upsampling_count=upsampling_count) - elif backbone_type == 'resnet18': - return PoseEstimator(Resnet18(pretrained=pretrained), - keypoint_count=keypoint_count, - upsampling_count=upsampling_count) - elif backbone_type == 'resnet34': - return PoseEstimator(Resnet34(pretrained=pretrained), - keypoint_count=keypoint_count, - upsampling_count=upsampling_count) - elif backbone_type == 'resnet50': - return PoseEstimator(Resnet50(pretrained=pretrained), - keypoint_count=keypoint_count, - upsampling_count=upsampling_count) - else: - raise ValueError('Invalid backbone type') + +def create_model(backbone_type): + return EfficientNetPoseEstimator(backbone_type, keypoint_count=17, pretrained_backbone=True) if __name__ == '__main__': diff --git a/tools/odas_configuration_generator/odas_microphone_configuration.py b/tools/odas_configuration_generator/odas_microphone_configuration.py index 9609501d..12ad2cf3 100644 --- a/tools/odas_configuration_generator/odas_microphone_configuration.py +++ b/tools/odas_configuration_generator/odas_microphone_configuration.py @@ -18,8 +18,7 @@ TOP_MICROPHONE_WALL_HEIGHT = 0.0599928907 MIDDLE_HEIGHT = (BOTTOM_MICROPHONE_WALL_HEIGHT + TOP_MICROPHONE_WALL_HEIGHT) / 2 -# Allow to specify the height of the desired odas frame relative to the robot base, or use the default value (classic T-Top) -Z_OFFSET = (MIDDLE_HEIGHT + float(argv[1])) if len(argv) >= 2 else -0.30753835046 +Z_OFFSET = -MIDDLE_HEIGHT * math.sin(WALL_ANGLE) #Position and direction calculation @@ -89,9 +88,9 @@ ax.mouse_init() axis_length = BOTTOM_RADIUS / 2 -ax.plot([0, axis_length], [0, 0], [Z_OFFSET, Z_OFFSET], color='red') -ax.plot([0, 0], [0, axis_length], [Z_OFFSET, Z_OFFSET], color='green') -ax.plot([0, 0], [0, 0], [Z_OFFSET, Z_OFFSET + axis_length], color='blue') +ax.plot([0, axis_length], [0, 0], [0, 0], color='red') +ax.plot([0, 0], [0, axis_length], [0, 0], color='green') +ax.plot([0, 0], [0, 0], [0, axis_length], color='blue') direction_length = BOTTOM_RADIUS / 4 for i in range(MICROPHONE_COUNT):