Skip to content

Commit

Permalink
Added medium and large options for EfficientRep and RepPANNeck with C…
Browse files Browse the repository at this point in the history
…SP blocks
  • Loading branch information
klemen1999 committed Sep 22, 2024
1 parent 2449850 commit 9869dee
Show file tree
Hide file tree
Showing 8 changed files with 567 additions and 136 deletions.
44 changes: 32 additions & 12 deletions luxonis_train/nodes/backbones/efficientrep/efficientrep.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from typing import Any
from typing import Any, Literal

from torch import Tensor, nn

from luxonis_train.nodes.base_node import BaseNode
from luxonis_train.nodes.blocks import (
BlockRepeater,
CSPStackRepBlock,
RepVGGBlock,
SpatialPyramidPoolingBlock,
)
Expand All @@ -26,23 +27,26 @@ def __init__(
n_repeats: list[int] | None = None,
depth_mul: float | None = None,
width_mul: float | None = None,
block: Literal["RepBlock", "CSPStackRepBlock"] | None = None,
csp_e: float | None = None,
**kwargs: Any,
):
"""Implementation of the EfficientRep backbone.
"""Implementation of the EfficientRep backbone. Supports the
version with RepBlock and CSPStackRepBlock (for larger networks)
Adapted from U{YOLOv6: A Single-Stage Object Detection Framework
for Industrial Applications
<https://arxiv.org/pdf/2209.02976.pdf>}.
@type variant: Literal["n", "nano", "s", "small", "m", "medium", "l", "large"]
@param variant: EfficientRep variant. Defaults to "nano".
The variant determines the depth and width multipliers.
The variant determines the depth and width multipliers, block used and intermediate channel scaling factor.
The depth multiplier determines the number of blocks in each stage and the width multiplier determines the number of channels.
The following variants are available:
- "n" or "nano" (default): depth_multiplier=0.33, width_multiplier=0.25
- "s" or "small": depth_multiplier=0.33, width_multiplier=0.50
- "m" or "medium": depth_multiplier=0.60, width_multiplier=0.75
- "l" or "large": depth_multiplier=1.0, width_multiplier=1.0
- "n" or "nano" (default): depth_multiplier=0.33, width_multiplier=0.25, block=RepBlock, e=None
- "s" or "small": depth_multiplier=0.33, width_multiplier=0.50, block=RepBlock, e=None
- "m" or "medium": depth_multiplier=0.60, width_multiplier=0.75, block=CSPStackRepBlock, e=2/3
- "l" or "large": depth_multiplier=1.0, width_multiplier=1.0, block=CSPStackRepBlock, e=1/2
@type channels_list: list[int] | None
@param channels_list: List of number of channels for each block. If unspecified,
defaults to [64, 128, 256, 512, 1024].
Expand All @@ -53,12 +57,19 @@ def __init__(
@param depth_mul: Depth multiplier. If provided, overrides the variant value.
@type width_mul: float
@param width_mul: Width multiplier. If provided, overrides the variant value.
@type block: Literal["RepBlock", "CSPStackRepBlock"] | None
@param block: Base block used when building the backbone. If provided, overrides the variant value.
@tpe csp_e: float | None
@param csp_e: Factor that controls number of intermediate channels if block="CSPStackRepBlock". If provided,
overrides the variant value.
"""
super().__init__(**kwargs)

var = get_variant(variant)
depth_mul = depth_mul or var.depth_multiplier
width_mul = width_mul or var.width_multiplier
block = block or var.block
csp_e = csp_e or var.csp_e or 0.5

channels_list = channels_list or [64, 128, 256, 512, 1024]
n_repeats = n_repeats or [1, 6, 12, 18, 6]
Expand All @@ -85,11 +96,20 @@ def __init__(
kernel_size=3,
stride=2,
),
BlockRepeater(
block=RepVGGBlock,
in_channels=channels_list[i + 1],
out_channels=channels_list[i + 1],
n_blocks=n_repeats[i + 1],
(
BlockRepeater(
block=RepVGGBlock,
in_channels=channels_list[i + 1],
out_channels=channels_list[i + 1],
n_blocks=n_repeats[i + 1],
)
if block == "RepBlock"
else CSPStackRepBlock(
in_channels=channels_list[i + 1],
out_channels=channels_list[i + 1],
n_blocks=n_repeats[i + 1],
e=csp_e,
)
),
)
self.blocks.append(curr_block)
Expand Down
10 changes: 10 additions & 0 deletions luxonis_train/nodes/backbones/efficientrep/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,35 @@
class EfficientRepVariant(BaseModel):
depth_multiplier: float
width_multiplier: float
block: Literal["RepBlock", "CSPStackRepBlock"]
csp_e: float | None


def get_variant(variant: VariantLiteral) -> EfficientRepVariant:
variants = {
"n": EfficientRepVariant(
depth_multiplier=0.33,
width_multiplier=0.25,
block="RepBlock",
csp_e=None,
),
"s": EfficientRepVariant(
depth_multiplier=0.33,
width_multiplier=0.50,
block="RepBlock",
csp_e=None,
),
"m": EfficientRepVariant(
depth_multiplier=0.60,
width_multiplier=0.75,
block="CSPStackRepBlock",
csp_e=2 / 3,
),
"l": EfficientRepVariant(
depth_multiplier=1.0,
width_multiplier=1.0,
block="CSPStackRepBlock",
csp_e=1 / 2,
),
}
variants["nano"] = variants["n"]
Expand Down
6 changes: 2 additions & 4 deletions luxonis_train/nodes/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
BlockRepeater,
Bottleneck,
ConvModule,
CSPStackRepBlock,
DropPath,
EfficientDecoupledBlock,
FeatureFusionBlock,
KeypointBlock,
LearnableAdd,
LearnableMulAddConv,
LearnableMultiply,
RepDownBlock,
RepUpBlock,
RepVGGBlock,
SpatialPyramidPoolingBlock,
SqueezeExciteBlock,
Expand All @@ -26,18 +25,17 @@
"EfficientDecoupledBlock",
"ConvModule",
"UpBlock",
"RepDownBlock",
"SqueezeExciteBlock",
"RepVGGBlock",
"BlockRepeater",
"CSPStackRepBlock",
"AttentionRefinmentBlock",
"SpatialPyramidPoolingBlock",
"FeatureFusionBlock",
"LearnableAdd",
"LearnableMultiply",
"LearnableMulAddConv",
"KeypointBlock",
"RepUpBlock",
"BasicResNetBlock",
"Bottleneck",
"UpscaleOnline",
Expand Down
200 changes: 97 additions & 103 deletions luxonis_train/nodes/blocks/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
@type bias: bool
@param bias: Whether to use bias. Defaults to False.
@type activation: L{nn.Module} | None
@param activation: Activation function. Defaults to None.
@param activation: Activation function. If None then nn.ReLU.
"""
super().__init__(
nn.Conv2d(
Expand Down Expand Up @@ -407,20 +407,112 @@ def __init__(
"""
super().__init__()

in_channels = in_channels
self.blocks = nn.ModuleList()
for _ in range(n_blocks):
self.blocks.append(
block(in_channels=in_channels, out_channels=out_channels)
)
for _ in range(n_blocks - 1):
self.blocks.append(
block(in_channels=in_channels, out_channels=out_channels)
block(in_channels=out_channels, out_channels=out_channels)
)
in_channels = out_channels

def forward(self, x: Tensor) -> Tensor:
for block in self.blocks:
x = block(x)
return x


class CSPStackRepBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
n_blocks: int = 1,
e: float = 0.5,
):
super().__init__()
"""Module composed of three 1x1 conv layers and a stack of sub-
blocks consisting of two RepVGG blocks with a residual
connection.
@type in_channels: int
@param in_channels: Number of input channels.
@type out_channels: int
@param out_channels: Number of output channels.
@type n_blocks: int
@param n_blocks: Number of blocks to repeat. Defaults to C{1}.
@type e: float
@param e: Factor for number of intermediate channels. Defaults
to C{0.5}.
"""
intermediate_channels = int(out_channels * e)
self.conv_1 = ConvModule(
in_channels=in_channels,
out_channels=intermediate_channels,
kernel_size=1,
padding=autopad(1, None),
)
self.conv_2 = ConvModule(
in_channels=in_channels,
out_channels=intermediate_channels,
kernel_size=1,
padding=autopad(1, None),
)
self.conv_3 = ConvModule(
in_channels=intermediate_channels * 2,
out_channels=out_channels,
kernel_size=1,
padding=autopad(1, None),
)
self.rep_stack = BlockRepeater(
block=BottleRep,
in_channels=intermediate_channels,
out_channels=intermediate_channels,
n_blocks=n_blocks // 2,
)

def forward(self, x: Tensor) -> Tensor:
out_1 = self.conv_1(x)
out_1 = self.rep_stack(out_1)
out_2 = self.conv_2(x)
out = torch.cat((out_1, out_2), dim=1)
return self.conv_3(out)


class BottleRep(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
block: type[nn.Module] = RepVGGBlock,
weight: bool = True,
):
super().__init__()
"""RepVGG bottleneck module.
@type in_channels: int
@param in_channels: Number of input channels.
@type out_channels: int
@param out_channels: Number of output channels.
@type block: L{nn.Module}
@param block: Block to use. Defaults to C{RepVGGBlock}.
@type weight: bool
@param weight: If using learnable or static shortcut weight.
Defaults to C{True}.
"""
self.conv_1 = block(in_channels=in_channels, out_channels=out_channels)
self.conv_2 = block(
in_channels=out_channels, out_channels=out_channels
)
self.shortcut = in_channels == out_channels
self.alpha = nn.Parameter(torch.ones(1)) if weight else 1.0

def forward(self, x: Tensor) -> Tensor:
out = self.conv_1(x)
out = self.conv_2(out)
return out + self.alpha * x if self.shortcut else out


class SpatialPyramidPoolingBlock(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, kernel_size: int = 5
Expand Down Expand Up @@ -608,104 +700,6 @@ def forward(self, x: Tensor) -> Tensor:
return out


class RepUpBlock(nn.Module):
def __init__(
self,
in_channels: int,
in_channels_next: int,
out_channels: int,
n_repeats: int,
):
"""UpBlock used in RepPAN neck.
@type in_channels: int
@param in_channels: Number of input channels.
@type in_channels_next: int
@param in_channels_next: Number of input channels of next input
which is used in concat.
@type out_channels: int
@param out_channels: Number of output channels.
@type n_repeats: int
@param n_repeats: Number of RepVGGBlock repeats.
"""

super().__init__()

self.conv = ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
)
self.upsample = torch.nn.ConvTranspose2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=2,
stride=2,
bias=True,
)
self.rep_block = BlockRepeater(
block=RepVGGBlock,
in_channels=in_channels_next + out_channels,
out_channels=out_channels,
n_blocks=n_repeats,
)

def forward(self, x0: Tensor, x1: Tensor) -> tuple[Tensor, Tensor]:
conv_out = self.conv(x0)
upsample_out = self.upsample(conv_out)
concat_out = torch.cat([upsample_out, x1], dim=1)
out = self.rep_block(concat_out)
return conv_out, out


class RepDownBlock(nn.Module):
def __init__(
self,
in_channels: int,
downsample_out_channels: int,
in_channels_next: int,
out_channels: int,
n_repeats: int,
):
"""DownBlock used in RepPAN neck.
@type in_channels: int
@param in_channels: Number of input channels.
@type downsample_out_channels: int
@param downsample_out_channels: Number of output channels after
downsample.
@type in_channels_next: int
@param in_channels_next: Number of input channels of next input
which is used in concat.
@type out_channels: int
@param out_channels: Number of output channels.
@type n_repeats: int
@param n_repeats: Number of RepVGGBlock repeats.
"""
super().__init__()

self.downsample = ConvModule(
in_channels=in_channels,
out_channels=downsample_out_channels,
kernel_size=3,
stride=2,
padding=3 // 2,
)
self.rep_block = BlockRepeater(
block=RepVGGBlock,
in_channels=downsample_out_channels + in_channels_next,
out_channels=out_channels,
n_blocks=n_repeats,
)

def forward(self, x0: Tensor, x1: Tensor) -> Tensor:
x = self.downsample(x0)
x = torch.cat([x, x1], dim=1)
x = self.rep_block(x)
return x


T = TypeVar("T", int, tuple[int, ...])


Expand Down
Loading

0 comments on commit 9869dee

Please sign in to comment.