From 9869dee5bba650db0fed55141b79b75aae66694e Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Sun, 22 Sep 2024 19:50:26 +0200 Subject: [PATCH 1/6] Added medium and large options for EfficientRep and RepPANNeck with CSP blocks --- .../backbones/efficientrep/efficientrep.py | 44 ++- .../nodes/backbones/efficientrep/variants.py | 10 + luxonis_train/nodes/blocks/__init__.py | 6 +- luxonis_train/nodes/blocks/blocks.py | 200 ++++++------ .../nodes/necks/reppan_neck/__init__.py | 3 + .../nodes/necks/reppan_neck/blocks.py | 287 ++++++++++++++++++ .../necks/{ => reppan_neck}/reppan_neck.py | 99 ++++-- .../nodes/necks/reppan_neck/variants.py | 54 ++++ 8 files changed, 567 insertions(+), 136 deletions(-) create mode 100644 luxonis_train/nodes/necks/reppan_neck/__init__.py create mode 100644 luxonis_train/nodes/necks/reppan_neck/blocks.py rename luxonis_train/nodes/necks/{ => reppan_neck}/reppan_neck.py (57%) create mode 100644 luxonis_train/nodes/necks/reppan_neck/variants.py diff --git a/luxonis_train/nodes/backbones/efficientrep/efficientrep.py b/luxonis_train/nodes/backbones/efficientrep/efficientrep.py index 0143855c..f41580c6 100644 --- a/luxonis_train/nodes/backbones/efficientrep/efficientrep.py +++ b/luxonis_train/nodes/backbones/efficientrep/efficientrep.py @@ -1,11 +1,12 @@ import logging -from typing import Any +from typing import Any, Literal from torch import Tensor, nn from luxonis_train.nodes.base_node import BaseNode from luxonis_train.nodes.blocks import ( BlockRepeater, + CSPStackRepBlock, RepVGGBlock, SpatialPyramidPoolingBlock, ) @@ -26,9 +27,12 @@ def __init__( n_repeats: list[int] | None = None, depth_mul: float | None = None, width_mul: float | None = None, + block: Literal["RepBlock", "CSPStackRepBlock"] | None = None, + csp_e: float | None = None, **kwargs: Any, ): - """Implementation of the EfficientRep backbone. + """Implementation of the EfficientRep backbone. Supports the + version with RepBlock and CSPStackRepBlock (for larger networks) Adapted from U{YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications @@ -36,13 +40,13 @@ def __init__( @type variant: Literal["n", "nano", "s", "small", "m", "medium", "l", "large"] @param variant: EfficientRep variant. Defaults to "nano". - The variant determines the depth and width multipliers. + The variant determines the depth and width multipliers, block used and intermediate channel scaling factor. The depth multiplier determines the number of blocks in each stage and the width multiplier determines the number of channels. The following variants are available: - - "n" or "nano" (default): depth_multiplier=0.33, width_multiplier=0.25 - - "s" or "small": depth_multiplier=0.33, width_multiplier=0.50 - - "m" or "medium": depth_multiplier=0.60, width_multiplier=0.75 - - "l" or "large": depth_multiplier=1.0, width_multiplier=1.0 + - "n" or "nano" (default): depth_multiplier=0.33, width_multiplier=0.25, block=RepBlock, e=None + - "s" or "small": depth_multiplier=0.33, width_multiplier=0.50, block=RepBlock, e=None + - "m" or "medium": depth_multiplier=0.60, width_multiplier=0.75, block=CSPStackRepBlock, e=2/3 + - "l" or "large": depth_multiplier=1.0, width_multiplier=1.0, block=CSPStackRepBlock, e=1/2 @type channels_list: list[int] | None @param channels_list: List of number of channels for each block. If unspecified, defaults to [64, 128, 256, 512, 1024]. @@ -53,12 +57,19 @@ def __init__( @param depth_mul: Depth multiplier. If provided, overrides the variant value. @type width_mul: float @param width_mul: Width multiplier. If provided, overrides the variant value. + @type block: Literal["RepBlock", "CSPStackRepBlock"] | None + @param block: Base block used when building the backbone. If provided, overrides the variant value. + @tpe csp_e: float | None + @param csp_e: Factor that controls number of intermediate channels if block="CSPStackRepBlock". If provided, + overrides the variant value. """ super().__init__(**kwargs) var = get_variant(variant) depth_mul = depth_mul or var.depth_multiplier width_mul = width_mul or var.width_multiplier + block = block or var.block + csp_e = csp_e or var.csp_e or 0.5 channels_list = channels_list or [64, 128, 256, 512, 1024] n_repeats = n_repeats or [1, 6, 12, 18, 6] @@ -85,11 +96,20 @@ def __init__( kernel_size=3, stride=2, ), - BlockRepeater( - block=RepVGGBlock, - in_channels=channels_list[i + 1], - out_channels=channels_list[i + 1], - n_blocks=n_repeats[i + 1], + ( + BlockRepeater( + block=RepVGGBlock, + in_channels=channels_list[i + 1], + out_channels=channels_list[i + 1], + n_blocks=n_repeats[i + 1], + ) + if block == "RepBlock" + else CSPStackRepBlock( + in_channels=channels_list[i + 1], + out_channels=channels_list[i + 1], + n_blocks=n_repeats[i + 1], + e=csp_e, + ) ), ) self.blocks.append(curr_block) diff --git a/luxonis_train/nodes/backbones/efficientrep/variants.py b/luxonis_train/nodes/backbones/efficientrep/variants.py index 7ced749e..c5640237 100644 --- a/luxonis_train/nodes/backbones/efficientrep/variants.py +++ b/luxonis_train/nodes/backbones/efficientrep/variants.py @@ -10,6 +10,8 @@ class EfficientRepVariant(BaseModel): depth_multiplier: float width_multiplier: float + block: Literal["RepBlock", "CSPStackRepBlock"] + csp_e: float | None def get_variant(variant: VariantLiteral) -> EfficientRepVariant: @@ -17,18 +19,26 @@ def get_variant(variant: VariantLiteral) -> EfficientRepVariant: "n": EfficientRepVariant( depth_multiplier=0.33, width_multiplier=0.25, + block="RepBlock", + csp_e=None, ), "s": EfficientRepVariant( depth_multiplier=0.33, width_multiplier=0.50, + block="RepBlock", + csp_e=None, ), "m": EfficientRepVariant( depth_multiplier=0.60, width_multiplier=0.75, + block="CSPStackRepBlock", + csp_e=2 / 3, ), "l": EfficientRepVariant( depth_multiplier=1.0, width_multiplier=1.0, + block="CSPStackRepBlock", + csp_e=1 / 2, ), } variants["nano"] = variants["n"] diff --git a/luxonis_train/nodes/blocks/__init__.py b/luxonis_train/nodes/blocks/__init__.py index 52c3408e..c35186e1 100644 --- a/luxonis_train/nodes/blocks/__init__.py +++ b/luxonis_train/nodes/blocks/__init__.py @@ -4,6 +4,7 @@ BlockRepeater, Bottleneck, ConvModule, + CSPStackRepBlock, DropPath, EfficientDecoupledBlock, FeatureFusionBlock, @@ -11,8 +12,6 @@ LearnableAdd, LearnableMulAddConv, LearnableMultiply, - RepDownBlock, - RepUpBlock, RepVGGBlock, SpatialPyramidPoolingBlock, SqueezeExciteBlock, @@ -26,10 +25,10 @@ "EfficientDecoupledBlock", "ConvModule", "UpBlock", - "RepDownBlock", "SqueezeExciteBlock", "RepVGGBlock", "BlockRepeater", + "CSPStackRepBlock", "AttentionRefinmentBlock", "SpatialPyramidPoolingBlock", "FeatureFusionBlock", @@ -37,7 +36,6 @@ "LearnableMultiply", "LearnableMulAddConv", "KeypointBlock", - "RepUpBlock", "BasicResNetBlock", "Bottleneck", "UpscaleOnline", diff --git a/luxonis_train/nodes/blocks/blocks.py b/luxonis_train/nodes/blocks/blocks.py index 99fe2a9a..81c3bbf7 100644 --- a/luxonis_train/nodes/blocks/blocks.py +++ b/luxonis_train/nodes/blocks/blocks.py @@ -113,7 +113,7 @@ def __init__( @type bias: bool @param bias: Whether to use bias. Defaults to False. @type activation: L{nn.Module} | None - @param activation: Activation function. Defaults to None. + @param activation: Activation function. If None then nn.ReLU. """ super().__init__( nn.Conv2d( @@ -407,13 +407,14 @@ def __init__( """ super().__init__() - in_channels = in_channels self.blocks = nn.ModuleList() - for _ in range(n_blocks): + self.blocks.append( + block(in_channels=in_channels, out_channels=out_channels) + ) + for _ in range(n_blocks - 1): self.blocks.append( - block(in_channels=in_channels, out_channels=out_channels) + block(in_channels=out_channels, out_channels=out_channels) ) - in_channels = out_channels def forward(self, x: Tensor) -> Tensor: for block in self.blocks: @@ -421,6 +422,97 @@ def forward(self, x: Tensor) -> Tensor: return x +class CSPStackRepBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + n_blocks: int = 1, + e: float = 0.5, + ): + super().__init__() + """Module composed of three 1x1 conv layers and a stack of sub- + blocks consisting of two RepVGG blocks with a residual + connection. + + @type in_channels: int + @param in_channels: Number of input channels. + @type out_channels: int + @param out_channels: Number of output channels. + @type n_blocks: int + @param n_blocks: Number of blocks to repeat. Defaults to C{1}. + @type e: float + @param e: Factor for number of intermediate channels. Defaults + to C{0.5}. + """ + intermediate_channels = int(out_channels * e) + self.conv_1 = ConvModule( + in_channels=in_channels, + out_channels=intermediate_channels, + kernel_size=1, + padding=autopad(1, None), + ) + self.conv_2 = ConvModule( + in_channels=in_channels, + out_channels=intermediate_channels, + kernel_size=1, + padding=autopad(1, None), + ) + self.conv_3 = ConvModule( + in_channels=intermediate_channels * 2, + out_channels=out_channels, + kernel_size=1, + padding=autopad(1, None), + ) + self.rep_stack = BlockRepeater( + block=BottleRep, + in_channels=intermediate_channels, + out_channels=intermediate_channels, + n_blocks=n_blocks // 2, + ) + + def forward(self, x: Tensor) -> Tensor: + out_1 = self.conv_1(x) + out_1 = self.rep_stack(out_1) + out_2 = self.conv_2(x) + out = torch.cat((out_1, out_2), dim=1) + return self.conv_3(out) + + +class BottleRep(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + block: type[nn.Module] = RepVGGBlock, + weight: bool = True, + ): + super().__init__() + """RepVGG bottleneck module. + + @type in_channels: int + @param in_channels: Number of input channels. + @type out_channels: int + @param out_channels: Number of output channels. + @type block: L{nn.Module} + @param block: Block to use. Defaults to C{RepVGGBlock}. + @type weight: bool + @param weight: If using learnable or static shortcut weight. + Defaults to C{True}. + """ + self.conv_1 = block(in_channels=in_channels, out_channels=out_channels) + self.conv_2 = block( + in_channels=out_channels, out_channels=out_channels + ) + self.shortcut = in_channels == out_channels + self.alpha = nn.Parameter(torch.ones(1)) if weight else 1.0 + + def forward(self, x: Tensor) -> Tensor: + out = self.conv_1(x) + out = self.conv_2(out) + return out + self.alpha * x if self.shortcut else out + + class SpatialPyramidPoolingBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 5 @@ -608,104 +700,6 @@ def forward(self, x: Tensor) -> Tensor: return out -class RepUpBlock(nn.Module): - def __init__( - self, - in_channels: int, - in_channels_next: int, - out_channels: int, - n_repeats: int, - ): - """UpBlock used in RepPAN neck. - - @type in_channels: int - @param in_channels: Number of input channels. - @type in_channels_next: int - @param in_channels_next: Number of input channels of next input - which is used in concat. - @type out_channels: int - @param out_channels: Number of output channels. - @type n_repeats: int - @param n_repeats: Number of RepVGGBlock repeats. - """ - - super().__init__() - - self.conv = ConvModule( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=1, - ) - self.upsample = torch.nn.ConvTranspose2d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=2, - stride=2, - bias=True, - ) - self.rep_block = BlockRepeater( - block=RepVGGBlock, - in_channels=in_channels_next + out_channels, - out_channels=out_channels, - n_blocks=n_repeats, - ) - - def forward(self, x0: Tensor, x1: Tensor) -> tuple[Tensor, Tensor]: - conv_out = self.conv(x0) - upsample_out = self.upsample(conv_out) - concat_out = torch.cat([upsample_out, x1], dim=1) - out = self.rep_block(concat_out) - return conv_out, out - - -class RepDownBlock(nn.Module): - def __init__( - self, - in_channels: int, - downsample_out_channels: int, - in_channels_next: int, - out_channels: int, - n_repeats: int, - ): - """DownBlock used in RepPAN neck. - - @type in_channels: int - @param in_channels: Number of input channels. - @type downsample_out_channels: int - @param downsample_out_channels: Number of output channels after - downsample. - @type in_channels_next: int - @param in_channels_next: Number of input channels of next input - which is used in concat. - @type out_channels: int - @param out_channels: Number of output channels. - @type n_repeats: int - @param n_repeats: Number of RepVGGBlock repeats. - """ - super().__init__() - - self.downsample = ConvModule( - in_channels=in_channels, - out_channels=downsample_out_channels, - kernel_size=3, - stride=2, - padding=3 // 2, - ) - self.rep_block = BlockRepeater( - block=RepVGGBlock, - in_channels=downsample_out_channels + in_channels_next, - out_channels=out_channels, - n_blocks=n_repeats, - ) - - def forward(self, x0: Tensor, x1: Tensor) -> Tensor: - x = self.downsample(x0) - x = torch.cat([x, x1], dim=1) - x = self.rep_block(x) - return x - - T = TypeVar("T", int, tuple[int, ...]) diff --git a/luxonis_train/nodes/necks/reppan_neck/__init__.py b/luxonis_train/nodes/necks/reppan_neck/__init__.py new file mode 100644 index 00000000..eef2e9a0 --- /dev/null +++ b/luxonis_train/nodes/necks/reppan_neck/__init__.py @@ -0,0 +1,3 @@ +from .reppan_neck import RepPANNeck + +__all__ = ["RepPANNeck"] diff --git a/luxonis_train/nodes/necks/reppan_neck/blocks.py b/luxonis_train/nodes/necks/reppan_neck/blocks.py new file mode 100644 index 00000000..9f7eda2b --- /dev/null +++ b/luxonis_train/nodes/necks/reppan_neck/blocks.py @@ -0,0 +1,287 @@ +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn +from torch import Tensor + +from luxonis_train.nodes.blocks import ( + BlockRepeater, + ConvModule, + CSPStackRepBlock, + RepVGGBlock, +) + + +class PANUpBlockBase(ABC, nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ): + """Base RepPANNeck up block. + + @type in_channels: int + @param in_channels: Number of input channels. + @type out_channels: int + @param out_channels: Number of output channels. + """ + + super().__init__() + + self.conv = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + ) + self.upsample = torch.nn.ConvTranspose2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=2, + stride=2, + bias=True, + ) + + @property + @abstractmethod + def encode_block(self) -> nn.Module: + """Encode block that is used. + + Make sure actual module is initialized in the __init__ and not + inside this function otherwise it will be reinitialized every + time + """ + ... + + def forward(self, x0: Tensor, x1: Tensor) -> tuple[Tensor, Tensor]: + conv_out = self.conv(x0) + upsample_out = self.upsample(conv_out) + concat_out = torch.cat([upsample_out, x1], dim=1) + out = self.encode_block(concat_out) + return conv_out, out + + +class RepUpBlock(PANUpBlockBase): + def __init__( + self, + in_channels: int, + in_channels_next: int, + out_channels: int, + n_repeats: int, + ): + """RepPANNeck up block for smaller networks that uses RepBlock. + + @type in_channels: int + @param in_channels: Number of input channels. + @type in_channels_next: int + @param in_channels_next: Number of input channels of next input + which is used in concat. + @type out_channels: int + @param out_channels: Number of output channels. + @type n_repeats: int + @param n_repeats: Number of RepVGGBlock repeats. + """ + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + ) + + self._encode_block = BlockRepeater( + block=RepVGGBlock, + in_channels=in_channels_next + out_channels, + out_channels=out_channels, + n_blocks=n_repeats, + ) + + @property + def encode_block(self) -> nn.Module: + return self._encode_block + + +class CSPUpBlock(PANUpBlockBase): + def __init__( + self, + in_channels: int, + in_channels_next: int, + out_channels: int, + n_repeats: int, + e: float, + ): + """RepPANNeck up block for larger networks that uses + CSPStackRepBlock. + + @type in_channels: int + @param in_channels: Number of input channels. + @type in_channels_next: int + @param in_channels_next: Number of input channels of next input + which is used in concat. + @type out_channels: int + @param out_channels: Number of output channels. + @type n_repeats: int + @param n_repeats: Number of RepVGGBlock repeats. + @type e: float + @param e: Factor that controls number of intermediate channels. + """ + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + ) + self._encode_block = CSPStackRepBlock( + in_channels=in_channels_next + out_channels, + out_channels=out_channels, + n_blocks=n_repeats, + e=e, + ) + + @property + def encode_block(self) -> nn.Module: + return self._encode_block + + +class PANDownBlockBase(ABC, nn.Module): + def __init__( + self, + in_channels: int, + downsample_out_channels: int, + ): + """Base RepPANNeck up block. + + @type in_channels: int + @param in_channels: Number of input channels. + @type downsample_out_channels: int + @param downsample_out_channels: Number of output channels after + downsample. + @type in_channels_next: int + @param in_channels_next: Number of input channels of next input + which is used in concat. + @type out_channels: int + @param out_channels: Number of output channels. + @type n_repeats: int + @param n_repeats: Number of RepVGGBlock repeats. + """ + super().__init__() + + self.downsample = ConvModule( + in_channels=in_channels, + out_channels=downsample_out_channels, + kernel_size=3, + stride=2, + padding=3 // 2, + ) + + @property + @abstractmethod + def encode_block(self) -> nn.Module: + """Encode block that is used. + + Make sure actual module is initialized in the __init__ and not + inside this function otherwise it will be reinitialized every + time + """ + ... + + def forward(self, x0: Tensor, x1: Tensor) -> Tensor: + x = self.downsample(x0) + x = torch.cat([x, x1], dim=1) + x = self.encode_block(x) + return x + + +class RepDownBlock(PANDownBlockBase): + def __init__( + self, + in_channels: int, + downsample_out_channels: int, + in_channels_next: int, + out_channels: int, + n_repeats: int, + ): + """RepPANNeck down block for smaller networks that uses + RepBlock. + + @type in_channels: int + @param in_channels: Number of input channels. + @type downsample_out_channels: int + @param downsample_out_channels: Number of output channels after + downsample. + @type in_channels_next: int + @param in_channels_next: Number of input channels of next input + which is used in concat. + @type out_channels: int + @param out_channels: Number of output channels. + @type n_repeats: int + @param n_repeats: Number of RepVGGBlock repeats. + """ + super().__init__( + in_channels=in_channels, + downsample_out_channels=downsample_out_channels, + ) + + self._encode_block = BlockRepeater( + block=RepVGGBlock, + in_channels=downsample_out_channels + in_channels_next, + out_channels=out_channels, + n_blocks=n_repeats, + ) + + @property + def encode_block(self) -> nn.Module: + return self._encode_block + + def forward(self, x0: Tensor, x1: Tensor) -> Tensor: + x = self.downsample(x0) + x = torch.cat([x, x1], dim=1) + x = self.encode_block(x) + return x + + +class CSPDownBlock(PANDownBlockBase): + def __init__( + self, + in_channels: int, + downsample_out_channels: int, + in_channels_next: int, + out_channels: int, + n_repeats: int, + e: float, + ): + """RepPANNeck up block for larger networks that uses + CSPStackRepBlock. + + @type in_channels: int + @param in_channels: Number of input channels. + @type downsample_out_channels: int + @param downsample_out_channels: Number of output channels after + downsample. + @type in_channels_next: int + @param in_channels_next: Number of input channels of next input + which is used in concat. + @type out_channels: int + @param out_channels: Number of output channels. + @type n_repeats: int + @param n_repeats: Number of RepVGGBlock repeats. + @type e: float + @param e: Factor that controls number of intermediate channels. + """ + super().__init__( + in_channels=in_channels, + downsample_out_channels=downsample_out_channels, + ) + + self._encode_block = CSPStackRepBlock( + in_channels=downsample_out_channels + in_channels_next, + out_channels=out_channels, + n_blocks=n_repeats, + e=e, + ) + + @property + def encode_block(self) -> nn.Module: + return self._encode_block + + def forward(self, x0: Tensor, x1: Tensor) -> Tensor: + x = self.downsample(x0) + x = torch.cat([x, x1], dim=1) + x = self.encode_block(x) + return x diff --git a/luxonis_train/nodes/necks/reppan_neck.py b/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py similarity index 57% rename from luxonis_train/nodes/necks/reppan_neck.py rename to luxonis_train/nodes/necks/reppan_neck/reppan_neck.py index 107151a6..8187b3ed 100644 --- a/luxonis_train/nodes/necks/reppan_neck.py +++ b/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py @@ -1,30 +1,49 @@ +import logging from typing import Any, Literal from torch import Tensor, nn from luxonis_train.nodes.base_node import BaseNode -from luxonis_train.nodes.blocks import RepDownBlock, RepUpBlock +from luxonis_train.nodes.blocks import RepVGGBlock from luxonis_train.utils import make_divisible +from .blocks import CSPDownBlock, CSPUpBlock, RepDownBlock, RepUpBlock +from .variants import VariantLiteral, get_variant + +logger = logging.getLogger(__name__) + class RepPANNeck(BaseNode[list[Tensor], list[Tensor]]): in_channels: list[int] def __init__( self, + variant: VariantLiteral = "nano", n_heads: Literal[2, 3, 4] = 3, channels_list: list[int] | None = None, n_repeats: list[int] | None = None, - depth_mul: float = 0.33, - width_mul: float = 0.25, + depth_mul: float | None = None, + width_mul: float | None = None, + block: Literal["RepBlock", "CSPStackRepBlock"] | None = None, + csp_e: float | None = None, **kwargs: Any, ): - """Implementation of the RepPANNeck module. + """Implementation of the RepPANNeck module. Supports the version + with RepBlock and CSPStackRepBlock (for larger networks) Adapted from U{YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications}. It has the balance of feature fusion ability and hardware efficiency. + @type variant: Literal["n", "nano", "s", "small", "m", "medium", "l", "large"] + @param variant: RepPANNeck variant. Defaults to "nano". + The variant determines the depth and width multipliers, block used and intermediate channel scaling factor. + The depth multiplier determines the number of blocks in each stage and the width multiplier determines the number of channels. + The following variants are available: + - "n" or "nano" (default): depth_multiplier=0.33, width_multiplier=0.25, block=RepBlock, e=None + - "s" or "small": depth_multiplier=0.33, width_multiplier=0.50, block=RepBlock, e=None + - "m" or "medium": depth_multiplier=0.60, width_multiplier=0.75, block=CSPStackRepBlock, e=2/3 + - "l" or "large": depth_multiplier=1.0, width_multiplier=1.0, block=CSPStackRepBlock, e=1/2 @type n_heads: Literal[2,3,4] @param n_heads: Number of output heads. Defaults to 3. B{Note: Should be same also on head in most cases.} @@ -38,21 +57,32 @@ def __init__( @param depth_mul: Depth multiplier. Defaults to C{0.33}. @type width_mul: float @param width_mul: Width multiplier. Defaults to C{0.25}. + @type block: Literal["RepBlock", "CSPStackRepBlock"] | None + @param block: Base block used when building the backbone. If provided, overrides the variant value. + @tpe csp_e: float | None + @param csp_e: Factor that controls number of intermediate channels if block="CSPStackRepBlock". If provided, + overrides the variant value. """ super().__init__(**kwargs) self.n_heads = n_heads - n_repeats = n_repeats or [12, 12, 12, 12] - channels_list = channels_list or [256, 128, 128, 256, 256, 512] + var = get_variant(variant) + depth_mul = depth_mul or var.depth_multiplier + width_mul = width_mul or var.width_multiplier + block = block or var.block + csp_e = csp_e or var.csp_e or 0.5 + channels_list = channels_list or [256, 128, 128, 256, 256, 512] + n_repeats = n_repeats or [12, 12, 12, 12] channels_list = [ make_divisible(ch * width_mul, 8) for ch in channels_list ] n_repeats = [ (max(round(i * depth_mul), 1) if i > 1 else i) for i in n_repeats ] + channels_list, n_repeats = self._fit_to_n_heads( channels_list, n_repeats ) @@ -66,11 +96,21 @@ def __init__( up_out_channel_list = [in_channels] # used in DownBlocks for i in range(1, n_heads): - curr_up_block = RepUpBlock( - in_channels=in_channels, - in_channels_next=in_channels_next, - out_channels=out_channels, - n_repeats=curr_n_repeats, + curr_up_block = ( + RepUpBlock( + in_channels=in_channels, + in_channels_next=in_channels_next, + out_channels=out_channels, + n_repeats=curr_n_repeats, + ) + if block == "RepBlock" + else CSPUpBlock( + in_channels=in_channels, + in_channels_next=in_channels_next, + out_channels=out_channels, + n_repeats=curr_n_repeats, + e=csp_e, + ) ) up_out_channel_list.append(out_channels) self.up_blocks.append(curr_up_block) @@ -94,12 +134,23 @@ def __init__( curr_n_repeats = n_repeats_down_blocks[0] for i in range(1, n_heads): - curr_down_block = RepDownBlock( - in_channels=in_channels, - downsample_out_channels=downsample_out_channels, - in_channels_next=in_channels_next, - out_channels=out_channels, - n_repeats=curr_n_repeats, + curr_down_block = ( + RepDownBlock( + in_channels=in_channels, + downsample_out_channels=downsample_out_channels, + in_channels_next=in_channels_next, + out_channels=out_channels, + n_repeats=curr_n_repeats, + ) + if block == "RepBlock" + else CSPDownBlock( + in_channels=in_channels, + downsample_out_channels=downsample_out_channels, + in_channels_next=in_channels_next, + out_channels=out_channels, + n_repeats=curr_n_repeats, + e=csp_e, + ) ) self.down_blocks.append(curr_down_block) if len(self.down_blocks) == (n_heads - 1): @@ -128,6 +179,20 @@ def forward(self, inputs: list[Tensor]) -> list[Tensor]: outs.append(x) return outs + def set_export_mode(self, mode: bool = True) -> None: + """Reparametrizes instances of L{RepVGGBlock} in the network. + + @type mode: bool + @param mode: Whether to set the export mode. Defaults to + C{True}. + """ + super().set_export_mode(mode) + if self.export: + logger.info("Reparametrizing 'RepPANNeck'.") + for module in self.modules(): + if isinstance(module, RepVGGBlock): + module.reparametrize() + def _fit_to_n_heads( self, channels_list: list[int], n_repeats: list[int] ) -> tuple[list[int], list[int]]: diff --git a/luxonis_train/nodes/necks/reppan_neck/variants.py b/luxonis_train/nodes/necks/reppan_neck/variants.py new file mode 100644 index 00000000..9ca9fc72 --- /dev/null +++ b/luxonis_train/nodes/necks/reppan_neck/variants.py @@ -0,0 +1,54 @@ +from typing import Literal, TypeAlias + +from pydantic import BaseModel + +VariantLiteral: TypeAlias = Literal[ + "n", "nano", "s", "small", "m", "medium", "l", "large" +] + + +class RepPANNeckVariant(BaseModel): + depth_multiplier: float + width_multiplier: float + block: Literal["RepBlock", "CSPStackRepBlock"] + csp_e: float | None + + +def get_variant(variant: VariantLiteral) -> RepPANNeckVariant: + variants = { + "n": RepPANNeckVariant( + depth_multiplier=0.33, + width_multiplier=0.25, + block="RepBlock", + csp_e=None, + ), + "s": RepPANNeckVariant( + depth_multiplier=0.33, + width_multiplier=0.50, + block="RepBlock", + csp_e=None, + ), + "m": RepPANNeckVariant( + depth_multiplier=0.60, + width_multiplier=0.75, + block="CSPStackRepBlock", + csp_e=2 / 3, + ), + "l": RepPANNeckVariant( + depth_multiplier=1.0, + width_multiplier=1.0, + block="CSPStackRepBlock", + csp_e=1 / 2, + ), + } + variants["nano"] = variants["n"] + variants["small"] = variants["s"] + variants["medium"] = variants["m"] + variants["large"] = variants["l"] + + if variant not in variants: # pragma: no cover + raise ValueError( + f"EfficientRep variant should be one of " + f"{list(variants.keys())}, got '{variant}'." + ) + return variants[variant] From 0786779a6cf88a79df4095f3ae63218973fe340d Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Sun, 22 Sep 2024 20:01:11 +0200 Subject: [PATCH 2/6] updated docs --- luxonis_train/nodes/README.md | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/luxonis_train/nodes/README.md b/luxonis_train/nodes/README.md index 60e5971c..f63b2885 100644 --- a/luxonis_train/nodes/README.md +++ b/luxonis_train/nodes/README.md @@ -74,13 +74,15 @@ Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf). **Params** -| Key | Type | Default value | Description | -| ------------- | ----------- | --------------------------- | --------------------------------------------------- | -| channels_list | List\[int\] | \[64, 128, 256, 512, 1024\] | List of number of channels for each block | -| n_repeats | List\[int\] | \[1, 6, 12, 18, 6\] | List of number of repeats of RepVGGBlock | -| in_channels | int | 3 | Number of input channels, should be 3 in most cases | -| depth_mul | int | 0.33 | Depth multiplier | -| width_mul | int | 0.25 | Width multiplier | +| Key | Type | Default value | Description | +| ------------- | ----------------------------------------------------------------- | --------------------------- | --------------------------------------------------------------- | +| variant | Literal\["n", "nano", "s", "small", "m", "medium", "l", "large"\] | "nano" | Variant of the network | +| channels_list | List\[int\] | \[64, 128, 256, 512, 1024\] | List of number of channels for each block | +| n_repeats | List\[int\] | \[1, 6, 12, 18, 6\] | List of number of repeats of RepVGGBlock | +| depth_mul | float | 0.33 | Depth multiplier | +| width_mul | float | 0.25 | Width multiplier | +| block | Literal\["RepBlock", "CSPStackRepBlock"\] | "RepBlock" | Base block used | +| csp_e | float | 0.5 | Factor for intermediate channels when block=="CSPStackRepBlock" | ## RexNetV1_lite @@ -143,13 +145,16 @@ Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf). **Params** -| Key | Type | Default value | Description | -| ------------- | ---------------- | ------------------------------------------------------- | ----------------------------------------- | -| n_heads | Literal\[2,3,4\] | 3 ***Note:** Should be same also on head in most cases* | Number of output heads | -| channels_list | List\[int\] | \[256, 128, 128, 256, 256, 512\] | List of number of channels for each block | -| n_repeats | List\[int\] | \[12, 12, 12, 12\] | List of number of repeats of RepVGGBlock | -| depth_mul | int | 0.33 | Depth multiplier | -| width_mul | int | 0.25 | Width multiplier | +| Key | Type | Default value | Description | +| ------------- | ----------------------------------------------------------------- | ------------------------------------------------------- | --------------------------------------------------------------- | +| variant | Literal\["n", "nano", "s", "small", "m", "medium", "l", "large"\] | "nano" | Variant of the network | +| n_heads | Literal\[2,3,4\] | 3 ***Note:** Should be same also on head in most cases* | Number of output heads | +| channels_list | List\[int\] | \[256, 128, 128, 256, 256, 512\] | List of number of channels for each block | +| n_repeats | List\[int\] | \[12, 12, 12, 12\] | List of number of repeats of RepVGGBlock | +| depth_mul | float | 0.33 | Depth multiplier | +| width_mul | float | 0.25 | Width multiplier | +| block | Literal\["RepBlock", "CSPStackRepBlock"\] | "RepBlock" | Base block used | +| csp_e | float | 0.5 | Factor for intermediate channels when block=="CSPStackRepBlock" | ## ClassificationHead From 3b7cccb4c292d5e5d3380da1545d6f6b0c003ec1 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 23 Sep 2024 17:42:51 +0200 Subject: [PATCH 3/6] fixed indented epytext field --- luxonis_train/nodes/necks/reppan_neck/reppan_neck.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py b/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py index 8187b3ed..eaf397bc 100644 --- a/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py +++ b/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py @@ -57,7 +57,7 @@ def __init__( @param depth_mul: Depth multiplier. Defaults to C{0.33}. @type width_mul: float @param width_mul: Width multiplier. Defaults to C{0.25}. - @type block: Literal["RepBlock", "CSPStackRepBlock"] | None + @type block: Literal["RepBlock", "CSPStackRepBlock"] | None @param block: Base block used when building the backbone. If provided, overrides the variant value. @tpe csp_e: float | None @param csp_e: Factor that controls number of intermediate channels if block="CSPStackRepBlock". If provided, From c39a85d3fca448408c1294a95d45a9cd1337ba85 Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Fri, 27 Sep 2024 16:38:23 +0200 Subject: [PATCH 4/6] added tests for variants --- tests/integration/test_detection.py | 51 +++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_detection.py b/tests/integration/test_detection.py index fb184b6f..4ee98dae 100644 --- a/tests/integration/test_detection.py +++ b/tests/integration/test_detection.py @@ -7,7 +7,7 @@ from luxonis_train.nodes.backbones import __all__ as BACKBONES -def get_opts(backbone: str) -> dict[str, Any]: +def get_opts_backbone(backbone: str) -> dict[str, Any]: return { "model": { "nodes": [ @@ -70,6 +70,42 @@ def get_opts(backbone: str) -> dict[str, Any]: } +def get_opts_variant(variant: str): + return { + "model": { + "nodes": [ + { + "name": "EfficientRep", + "alias": "backbone", + "params": {"variant": variant}, + }, + { + "name": "RepPANNeck", + "alias": "neck", + "inputs": ["backbone"], + "params": {"variant": variant}, + }, + { + "name": "EfficientBBoxHead", + "inputs": ["neck"], + }, + ], + "losses": [ + { + "name": "AdaptiveDetectionLoss", + "attached_to": "EfficientBBoxHead", + }, + ], + "metrics": [ + { + "name": "MeanAveragePrecision", + "attached_to": "EfficientBBoxHead", + }, + ], + } + } + + def train_and_test( config: dict[str, Any], opts: dict[str, Any], @@ -90,6 +126,17 @@ def test_backbones( config: dict[str, Any], parking_lot_dataset: LuxonisDataset, ): - opts = get_opts(backbone) + opts = get_opts_backbone(backbone) + opts["loader.params.dataset_name"] = parking_lot_dataset.identifier + train_and_test(config, opts) + + +@pytest.mark.parametrize("variant", ["n", "s", "m", "l"]) +def test_variants( + variant: str, + config: dict[str, Any], + parking_lot_dataset: LuxonisDataset, +): + opts = get_opts_variant(variant) opts["loader.params.dataset_name"] = parking_lot_dataset.identifier train_and_test(config, opts) From 30ae9da9477b65049b3d50056447df57bbc32993 Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Fri, 27 Sep 2024 16:46:04 +0200 Subject: [PATCH 5/6] added return type hint --- tests/integration/test_detection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_detection.py b/tests/integration/test_detection.py index 4ee98dae..9df9350f 100644 --- a/tests/integration/test_detection.py +++ b/tests/integration/test_detection.py @@ -70,7 +70,7 @@ def get_opts_backbone(backbone: str) -> dict[str, Any]: } -def get_opts_variant(variant: str): +def get_opts_variant(variant: str) -> dict[str, Any]: return { "model": { "nodes": [ From faf6e6377ae82a8dce9e35654878f0a73d3ff2ed Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Sat, 28 Sep 2024 21:30:26 +0200 Subject: [PATCH 6/6] removed extra code --- luxonis_train/nodes/necks/reppan_neck/blocks.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/luxonis_train/nodes/necks/reppan_neck/blocks.py b/luxonis_train/nodes/necks/reppan_neck/blocks.py index 9f7eda2b..b7c05a0f 100644 --- a/luxonis_train/nodes/necks/reppan_neck/blocks.py +++ b/luxonis_train/nodes/necks/reppan_neck/blocks.py @@ -229,12 +229,6 @@ def __init__( def encode_block(self) -> nn.Module: return self._encode_block - def forward(self, x0: Tensor, x1: Tensor) -> Tensor: - x = self.downsample(x0) - x = torch.cat([x, x1], dim=1) - x = self.encode_block(x) - return x - class CSPDownBlock(PANDownBlockBase): def __init__( @@ -279,9 +273,3 @@ def __init__( @property def encode_block(self) -> nn.Module: return self._encode_block - - def forward(self, x0: Tensor, x1: Tensor) -> Tensor: - x = self.downsample(x0) - x = torch.cat([x, x1], dim=1) - x = self.encode_block(x) - return x