From 42c30d9aaa41fa2fa2daf30c65469f64372ce509 Mon Sep 17 00:00:00 2001 From: umerhasan17 Date: Sun, 13 Aug 2023 22:23:30 +0100 Subject: [PATCH] Implement Horizontal Flipping for test time augmentation for object detection --- src/openpifpaf/decoder/decoder.py | 26 ++++++++++++- src/openpifpaf/decoder/utils/hflip.py | 54 +++++++++++++++++++++++++++ src/openpifpaf/predictor.py | 9 ++++- 3 files changed, 86 insertions(+), 3 deletions(-) create mode 100644 src/openpifpaf/decoder/utils/hflip.py diff --git a/src/openpifpaf/decoder/decoder.py b/src/openpifpaf/decoder/decoder.py index 9d26463e9..20c15f567 100644 --- a/src/openpifpaf/decoder/decoder.py +++ b/src/openpifpaf/decoder/decoder.py @@ -7,6 +7,7 @@ import torch +from .utils.hflip import hflip_average_fields_batch from .. import annotation, visualizer LOG = logging.getLogger(__name__) @@ -111,10 +112,31 @@ def apply(f, items): LOG.debug('nn processing time: %.1fms', (time.time() - start) * 1000.0) return heads - def batch(self, model, image_batch, *, device=None, gt_anns_batch=None): + def batch(self, model, image_batch, *, device=None, hflip=False, gt_anns_batch=None): """From image batch straight to annotations batch.""" start_nn = time.perf_counter() - fields_batch = self.fields_batch(model, image_batch, device=device) + + if hflip: + # The horizontal-flip evaluation technique improves accuracy when evaluating the test set. + # We average the predictions generated for the original and flipped image and use that as the + # final prediction. This method reduces prediction noise. + + # Take horizontal flipped image and generate fields. + hflip_image_batch = torch.flip(image_batch, [-1]) + combined_image_batch = torch.cat((image_batch, hflip_image_batch), dim=0) + combined_fields_batch = self.fields_batch(model, combined_image_batch, device=device) + cfb_len = len(combined_fields_batch) + assert cfb_len % 2 == 0 + fields_batch = combined_fields_batch[:cfb_len // 2] + hflip_fields_batch = combined_fields_batch[cfb_len // 2:] + + # Average the fields with the original fields before decoding to the final prediction. + fields_batch = hflip_average_fields_batch( + fields_batch=fields_batch, hflip_fields_batch=hflip_fields_batch, head_metas=model.head_metas + ) + else: + fields_batch = self.fields_batch(model, image_batch, device=device) + self.last_nn_time = time.perf_counter() - start_nn if gt_anns_batch is None: diff --git a/src/openpifpaf/decoder/utils/hflip.py b/src/openpifpaf/decoder/utils/hflip.py new file mode 100644 index 000000000..47f70911b --- /dev/null +++ b/src/openpifpaf/decoder/utils/hflip.py @@ -0,0 +1,54 @@ +""" +Helper methods for horizontally flipping field representations of the image during evaluation. +""" + +import torch + + +def hflip_average_fields_batch(fields_batch, hflip_fields_batch, head_metas): + """ Entrypoint function for horizontal flipping. """ + hflip_funcs = [] + for head_meta in head_metas: + if head_meta.name == 'cifdet': + hflip_func = hflip_average_cifdet_fields_batch + else: + raise ValueError(f'Unsupported head meta for hflip: {head_meta.name}.') + hflip_funcs.append(hflip_func) + + for i, current_batch in enumerate(fields_batch): + assert len(current_batch) == len(head_metas) + for j, field_set in enumerate(current_batch): + # Additional processing for hflip field set specific to heads used. + hflip_field_set = hflip_funcs[j](hflip_fields_batch[i][j]) + + # Take an average of both fields for final fields batch prediction. + field_set = field_set.add(hflip_field_set) + field_set = torch.div(field_set, 2) + fields_batch[i][j] = field_set + + return fields_batch + + +def hflip_handle_reg_x_offset(hflip_field_set, offset_field_index=2): + """ Handle the set of x regression fields that the cifdet head produces. """ + # Horizontally flip field to perform offset (reverse the operation of flipping all fields) + hflip_field_set[:, offset_field_index, :, :] = torch.flip(hflip_field_set[:, offset_field_index, :, :], [-1]) + # Deal with vector offsets for x regression field + fields_shape = hflip_field_set.shape + offset_tensor = torch.arange(fields_shape[3]).repeat(fields_shape[0], fields_shape[2], 1) + # Remove offset + hflip_field_set[:, offset_field_index, :, :] = hflip_field_set[:, offset_field_index, :, :].subtract(offset_tensor) + # Horizontally flip field again + hflip_field_set[:, offset_field_index, :, :] = torch.flip(hflip_field_set[:, offset_field_index, :, :], [-1]) + # Negate x regression field + hflip_field_set[:, offset_field_index, :, :] = torch.neg(hflip_field_set[:, offset_field_index, :, :]) + # Add back offset + hflip_field_set[:, offset_field_index, :, :] = hflip_field_set[:, offset_field_index, :, :].add(offset_tensor) + return hflip_field_set + + +def hflip_average_cifdet_fields_batch(hflip_field_set): + """ Function returns the horizontally flipped set of cifdet fields used in object detection tasks. """ + hflip_field_set = torch.flip(hflip_field_set, [-1]) + hflip_field_set = hflip_handle_reg_x_offset(hflip_field_set) + return hflip_field_set diff --git a/src/openpifpaf/predictor.py b/src/openpifpaf/predictor.py index fdd8278d9..d354f25c8 100644 --- a/src/openpifpaf/predictor.py +++ b/src/openpifpaf/predictor.py @@ -71,6 +71,9 @@ def cli(cls, parser: argparse.ArgumentParser, *, group.add_argument('--precise-rescaling', dest='fast_rescaling', default=True, action='store_false', help='use more exact image rescaling (requires scipy)') + group.add_argument('--tta-hflip', dest='tta_hflip', + default=False, action='store_true', + help='apply horizontal flipping as test time augmentation') @classmethod def configure(cls, args: argparse.Namespace): @@ -80,6 +83,7 @@ def configure(cls, args: argparse.Namespace): cls.fast_rescaling = args.fast_rescaling cls.loader_workers = args.loader_workers cls.long_edge = args.long_edge + cls.tta_hflip = args.tta_hflip def preprocess_factory(self): rescale_t = None @@ -125,7 +129,10 @@ def enumerated_dataloader(self, enumerated_dataloader): if self.visualize_processed_image: visualizer.Base.processed_image(processed_image_batch[0]) - pred_batch = self.processor.batch(self.model, processed_image_batch, device=self.device) + pred_batch = self.processor.batch( + self.model, processed_image_batch, hflip=self.tta_hflip, device=self.device + ) + self.last_decoder_time = self.processor.last_decoder_time self.last_nn_time = self.processor.last_nn_time self.total_decoder_time += self.processor.last_decoder_time