diff --git a/luxonis_train/attached_modules/losses/README.md b/luxonis_train/attached_modules/losses/README.md index 38f8b42f..aa1b9ca6 100644 --- a/luxonis_train/attached_modules/losses/README.md +++ b/luxonis_train/attached_modules/losses/README.md @@ -118,6 +118,6 @@ Adapted from [here](https://arxiv.org/abs/2108.07610). **Parameters:** -| Key | Type | Default value | Description | -| --------------- | ------- | ------------- | ----------------------------------------------- | -| `object_weight` | `float` | `1000` | Weight for the objects in the loss calculation. | +| Key | Type | Default value | Description | +| --------------- | ------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `object_weight` | `float` | `1000` | Weight for the objects in the loss calculation. Training with a larger `object_weight` in the loss parameters may result in more false positives (FP), but it will improve accuracy. | diff --git a/luxonis_train/config/predefined_models/README.md b/luxonis_train/config/predefined_models/README.md index 9c6727fb..629b17a5 100644 --- a/luxonis_train/config/predefined_models/README.md +++ b/luxonis_train/config/predefined_models/README.md @@ -160,6 +160,8 @@ The `FOMOModel` allows for both `"light"` and `"heavy"` variants, where the `"he See an example configuration file using this predefined model [here](../../../configs/detection_fomo_light_model.yaml) for the `"light"` variant, and [here](../../../configs/detection_fomo_heavy_model.yaml) for the `"heavy"` variant. +There is a trade-off in this simple model: training with a larger `object_weight` in the loss parameters may result in more false positives (FP), but it will improve accuracy. You can also use `use_nms: True` in the `head_params` to enable NMS which can reduce FP, but it will also reduce TP for close neighbors. + ### **Components** | Name | Alias | Function | diff --git a/luxonis_train/nodes/README.md b/luxonis_train/nodes/README.md index 31f1f6c2..18beb17a 100644 --- a/luxonis_train/nodes/README.md +++ b/luxonis_train/nodes/README.md @@ -268,7 +268,8 @@ Adapted from [here](https://arxiv.org/abs/2108.07610). **Parameters:** -| Key | Type | Default value | Description | -| ----------------- | ----- | ------------- | ------------------------------------------------------- | -| `num_conv_layers` | `int` | `3` | Number of convolutional layers to use in the model. | -| `conv_channels` | `int` | `16` | Number of output channels for each convolutional layer. | +| Key | Type | Default value | Description | +| ----------------- | ------ | ------------- | ---------------------------------------------------------------------------------------- | +| `num_conv_layers` | `int` | `3` | Number of convolutional layers to use in the model. | +| `conv_channels` | `int` | `16` | Number of output channels for each convolutional layer. | +| `use_nms` | `bool` | `False` | If True, enable NMS. This can reduce FP, but it will also reduce TP for close neighbors. | diff --git a/luxonis_train/nodes/heads/fomo_head.py b/luxonis_train/nodes/heads/fomo_head.py index cc9ea1a0..ce63f975 100644 --- a/luxonis_train/nodes/heads/fomo_head.py +++ b/luxonis_train/nodes/heads/fomo_head.py @@ -21,6 +21,7 @@ def __init__( self, num_conv_layers: int = 3, conv_channels: int = 16, + use_nms: bool = False, **kwargs: Any, ): """FOMO Head for object detection using heatmaps. @@ -37,6 +38,7 @@ def __init__( self.original_img_size = self.original_in_shape[1:] self.num_conv_layers = num_conv_layers self.conv_channels = conv_channels + self.use_nms = use_nms current_channels = self.in_channels @@ -76,7 +78,11 @@ def wrap(self, heatmap: Tensor) -> Packet[Tensor]: def _heatmap_to_kpts(self, heatmap: Tensor) -> List[Tensor]: """Convert heatmap to keypoint pairs using local-max NMS so that - only the strongest local peak in a neighborhood is retained.""" + only the strongest local peak in a neighborhood is retained. + + @type heatmap: Tensor + @param heatmap: Heatmap to convert to keypoints. + """ device = heatmap.device batch_size, num_classes, height, width = heatmap.shape @@ -87,19 +93,25 @@ def _heatmap_to_kpts(self, heatmap: Tensor) -> List[Tensor]: for c in range(num_classes): prob_map = torch.sigmoid(heatmap[batch_idx, c, :, :]) - pooled_map = ( - F.max_pool2d( - prob_map.unsqueeze(0).unsqueeze(0), # shape [1,1,H,W] - kernel_size=3, - stride=1, - padding=1, - ) - .squeeze(0) - .squeeze(0) - ) # back to [H,W] - - threshold = 0.5 - keep = (prob_map == pooled_map) & (prob_map > threshold) + if self.use_nms: + pooled_map = ( + F.max_pool2d( + prob_map.unsqueeze(0).unsqueeze( + 0 + ), # shape [1,1,H,W] + kernel_size=3, + stride=1, + padding=1, + ) + .squeeze(0) + .squeeze(0) + ) # back to [H,W] + + threshold = 0.5 + keep = (prob_map == pooled_map) & (prob_map > threshold) + else: + threshold = 0.5 + keep = prob_map > threshold y_indices, x_indices = torch.where(keep) kpts = []