From 57d7527d83e2721aa6e1fa06728700cb9d999ca5 Mon Sep 17 00:00:00 2001 From: Ren Tianhe <48727989+rentainhe@users.noreply.github.com> Date: Fri, 2 Dec 2022 11:31:54 +0800 Subject: [PATCH] [Feature] Support new focal-net backbone (#145) * refine focalnet backbone * refine focalnet args format * add dino-focalnet-large configs * refine config Co-authored-by: ntianhe ren --- detrex/modeling/backbone/focalnet.py | 358 +++++++++--------- ...dino_focalnet_large_lrf_384_4scale_12ep.py | 11 + ..._focalnet_large_lrf_384_fl4_4scale_12ep.py | 17 + projects/dino/configs/models/dino_focalnet.py | 28 ++ 4 files changed, 240 insertions(+), 174 deletions(-) create mode 100644 projects/dino/configs/dino_focalnet_large_lrf_384_4scale_12ep.py create mode 100644 projects/dino/configs/dino_focalnet_large_lrf_384_fl4_4scale_12ep.py create mode 100644 projects/dino/configs/models/dino_focalnet.py diff --git a/detrex/modeling/backbone/focalnet.py b/detrex/modeling/backbone/focalnet.py index 254ac3f7..e2548eed 100644 --- a/detrex/modeling/backbone/focalnet.py +++ b/detrex/modeling/backbone/focalnet.py @@ -16,7 +16,7 @@ # Copyright (c) 2022 Microsoft # ------------------------------------------------------------------------------------------------ # Modified from: -# https://github.com/microsoft/FocalNet/blob/main/detection/mmdet/models/backbones/focalnet.py +# https://github.com/FocalNet/FocalNet-DINO/blob/main/models/dino/focal.py # ------------------------------------------------------------------------------------------------ import torch @@ -32,7 +32,12 @@ class Mlp(nn.Module): """Multilayer perceptron.""" def __init__( - self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0 ): super().__init__() out_features = out_features or in_features @@ -52,20 +57,28 @@ def forward(self, x): class FocalModulation(nn.Module): - """Focal Modulation - + """ Focal Modulation + Args: dim (int): Number of input channels. - proj_drop (float, optional): Dropout ratio of output. Default: 0.0. - focal_level (int): Number of focal levels. Default: 2. - focal_window (int): Focal window size at focal level 1. Default: 7. - focal_factor (int, default=2): Step to increase the focal window. Default: 2. - use_postln (bool, default=False): Whether use post-modulation layernorm. Default: False. + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + focal_level (int): Number of focal levels + focal_window (int): Focal window size at focal level 1 + focal_factor (int, default=2): Step to increase the focal window + use_postln (bool, default=False): Whether use post-modulation layernorm """ def __init__( - self, dim, proj_drop=0.0, focal_level=2, focal_window=7, focal_factor=2, use_postln=False - ): + self, + dim, + proj_drop=0., + focal_level=2, + focal_window=7, + focal_factor=2, + use_postln=False, + use_postln_in_modulation=False, + normalize_modulator=False + ): super().__init__() self.dim = dim @@ -74,9 +87,10 @@ def __init__( self.focal_level = focal_level self.focal_window = focal_window self.focal_factor = focal_factor - self.use_postln = use_postln + self.use_postln_in_modulation = use_postln_in_modulation + self.normalize_modulator = normalize_modulator - self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=True) + self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True) self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True) self.act = nn.GELU() @@ -84,98 +98,98 @@ def __init__( self.proj_drop = nn.Dropout(proj_drop) self.focal_layers = nn.ModuleList() - if self.use_postln: + if self.use_postln_in_modulation: self.ln = nn.LayerNorm(dim) for k in range(self.focal_level): - kernel_size = self.focal_factor * k + self.focal_window + kernel_size = self.focal_factor*k + self.focal_window self.focal_layers.append( nn.Sequential( - nn.Conv2d( - dim, - dim, - kernel_size=kernel_size, - stride=1, - groups=dim, - padding=kernel_size // 2, - bias=False, - ), + nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, + padding=kernel_size//2, bias=False), nn.GELU(), + ) ) - ) def forward(self, x): - """Forward function of `FocalModulation` - + """ Forward function. Args: x: input features with shape of (B, H, W, C) """ B, nH, nW, C = x.shape x = self.f(x) x = x.permute(0, 3, 1, 2).contiguous() - q, ctx, gates = torch.split(x, (C, C, self.focal_level + 1), 1) - + q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1) + ctx_all = 0 - for l in range(self.focal_level): + for l in range(self.focal_level): ctx = self.focal_layers[l](ctx) - ctx_all = ctx_all + ctx * gates[:, l : l + 1] + ctx_all = ctx_all + ctx*gates[:, l:l+1] ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) - ctx_all = ctx_all + ctx_global * gates[:, self.focal_level :] + ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:] + if self.normalize_modulator: + ctx_all = ctx_all / (self.focal_level+1) x_out = q * self.h(ctx_all) x_out = x_out.permute(0, 2, 3, 1).contiguous() - if self.use_postln: - x_out = self.ln(x_out) + if self.use_postln_in_modulation: + x_out = self.ln(x_out) x_out = self.proj(x_out) x_out = self.proj_drop(x_out) return x_out class FocalModulationBlock(nn.Module): - """Focal Modulation Block. - + """ Focal Modulation Block. Args: dim (int): Number of input channels. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - drop (float, optional): Dropout rate. Default: 0.0. - drop_path (float, optional): Stochastic depth rate. Default: 0.0. - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm. - focal_level (int): number of focal levels. Default: 2. - focal_window (int): focal kernel size at level 1. Default: 9. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + focal_level (int): number of focal levels + focal_window (int): focal kernel size at level 1 """ def __init__( - self, - dim, - mlp_ratio=4.0, - drop=0.0, - drop_path=0.0, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - focal_level=2, - focal_window=9, - use_layerscale=False, - layerscale_value=1e-4, - ): + self, + dim, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + focal_level=2, + focal_window=9, + use_postln=False, + use_postln_in_modulation=False, + normalize_modulator=False, + use_layerscale=False, + layerscale_value=1e-4 + ): super().__init__() self.dim = dim self.mlp_ratio = mlp_ratio self.focal_window = focal_window self.focal_level = focal_level + self.use_postln = use_postln self.use_layerscale = use_layerscale self.norm1 = norm_layer(dim) self.modulation = FocalModulation( - dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop - ) - - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + dim, + focal_window=self.focal_window, + focal_level=self.focal_level, + proj_drop=drop, + use_postln_in_modulation=use_postln_in_modulation, + normalize_modulator=normalize_modulator, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop - ) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.H = None self.W = None @@ -187,8 +201,7 @@ def __init__( self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) def forward(self, x): - """Forward function of `FocalModulationBlock`. - + """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. @@ -198,89 +211,94 @@ def forward(self, x): assert L == H * W, "input feature has wrong size" shortcut = x - x = self.norm1(x) + if not self.use_postln: + x = self.norm1(x) x = x.view(B, H, W, C) - + # FM x = self.modulation(x).view(B, H * W, C) + if self.use_postln: + x = self.norm1(x) # FFN x = shortcut + self.drop_path(self.gamma_1 * x) - x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + + if self.use_postln: + x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class BasicLayer(nn.Module): - """A basic focal modulation layer for one stage. - + """ A basic focal modulation layer for one stage. Args: - dim (int): Number of feature channels. + dim (int): Number of feature channels depth (int): Depths of this stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. - drop (float, optional): Dropout rate. Default: 0.0. - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm. - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None. - focal_level (int): Number of focal levels. Default: 2. - focal_window (int): Focal window size at focal level 1. Default: 9. - use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False. - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + focal_level (int): Number of focal levels + focal_window (int): Focal window size at focal level 1 + use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ - def __init__( - self, - dim, - depth, - mlp_ratio=4.0, - drop=0.0, - drop_path=0.0, - norm_layer=nn.LayerNorm, - downsample=None, - focal_window=9, - focal_level=2, - use_conv_embed=False, - use_layerscale=False, - use_checkpoint=False, - ): + def __init__(self, + dim, + depth, + mlp_ratio=4., + drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + focal_window=9, + focal_level=2, + use_conv_embed=False, + use_postln=False, + use_postln_in_modulation=False, + normalize_modulator=False, + use_layerscale=False, + use_checkpoint=False + ): super().__init__() self.depth = depth self.use_checkpoint = use_checkpoint # build blocks - self.blocks = nn.ModuleList( - [ - FocalModulationBlock( - dim=dim, - mlp_ratio=mlp_ratio, - drop=drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - focal_window=focal_window, - focal_level=focal_level, - use_layerscale=use_layerscale, - norm_layer=norm_layer, - ) - for i in range(depth) - ] - ) + self.blocks = nn.ModuleList([ + FocalModulationBlock( + dim=dim, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + focal_window=focal_window, + focal_level=focal_level, + use_postln=use_postln, + use_postln_in_modulation=use_postln_in_modulation, + normalize_modulator=normalize_modulator, + use_layerscale=use_layerscale, + norm_layer=norm_layer) + for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample( - patch_size=2, - in_chans=dim, - embed_dim=2 * dim, - use_conv_embed=use_conv_embed, - norm_layer=norm_layer, - is_stem=False, + patch_size=2, + in_chans=dim, embed_dim=2*dim, + use_conv_embed=use_conv_embed, + norm_layer=norm_layer, + is_stem=False ) else: self.downsample = None def forward(self, x, H, W): - """Forward function of `BasicLayer`. - + """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. @@ -294,8 +312,8 @@ def forward(self, x, H, W): x = blk(x) if self.downsample is not None: x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W) - x_down = self.downsample(x_reshaped) - x_down = x_down.flatten(2).transpose(1, 2) + x_down = self.downsample(x_reshaped) + x_down = x_down.flatten(2).transpose(1, 2) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: @@ -303,26 +321,25 @@ def forward(self, x, H, W): class PatchEmbed(nn.Module): - """Image to Patch Embedding - + """ Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None. - use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False. - is_stem (bool): Is the stem block or not. Default: False. + norm_layer (nn.Module, optional): Normalization layer. Default: None + use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False + is_stem (bool): Is the stem block or not. """ def __init__( - self, - patch_size=4, - in_chans=3, - embed_dim=96, - norm_layer=None, - use_conv_embed=False, - is_stem=False, - ): + self, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None, + use_conv_embed=False, + is_stem=False + ): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size @@ -333,16 +350,10 @@ def __init__( if use_conv_embed: # if we choose to use conv embedding, then we treat the stem and non-stem differently if is_stem: - kernel_size = 7 - padding = 3 - stride = 4 + kernel_size = 7; padding = 2; stride = 4 else: - kernel_size = 3 - padding = 1 - stride = 2 - self.proj = nn.Conv2d( - in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding - ) + kernel_size = 3; padding = 1; stride = 2 + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) else: self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) @@ -371,27 +382,25 @@ def forward(self, x): class FocalNet(Backbone): """Implement paper `Focal Modulation Networks `_ - + Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. - depths (tuple[int]): Depths of each FocalNet stage. Default: `[2, 2, 6, 2]`. + depths (tuple[int]): Depths of each Swin Transformer stage. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. - drop_rate (float): Dropout rate. Default: 0.0. + drop_rate (float): Dropout rate. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. patch_norm (bool): If True, add normalization after patch embedding. Default: True. - out_indices (Sequence[int]): Output from which stages. Default: `(0, 1, 2, 3)`. + out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). - -1 means not freezing any parameters. Default: -1. - focal_levels (Sequence[int]): Number of focal levels at four stages. - Default: `[2, 2, 2, 2]` - focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages. - Default: `[9, 9, 9, 9]`. - use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False. + -1 means not freezing any parameters. + focal_levels (Sequence[int]): Number of focal levels at four stages + focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages + use_conv_embed (bool): Whether use overlapped convolution for patch embedding use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ @@ -402,18 +411,21 @@ def __init__( in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], - mlp_ratio=4.0, - drop_rate=0.0, - drop_path_rate=0.2, + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.3, # 0.3 or 0.4 works better for large+ models norm_layer=nn.LayerNorm, patch_norm=True, out_indices=(0, 1, 2, 3), frozen_stages=-1, - focal_levels=[2, 2, 2, 2], - focal_windows=[9, 9, 9, 9], - use_conv_embed=False, - use_checkpoint=False, - use_layerscale=False, + focal_levels=[3, 3, 3, 3], + focal_windows=[3, 3, 3, 3], + use_conv_embed=False, + use_postln=False, + use_postln_in_modulation=False, + use_layerscale=False, + normalize_modulator=False, + use_checkpoint=False, ): super().__init__() @@ -426,38 +438,34 @@ def __init__( # split image into non-overlapping patches self.patch_embed = PatchEmbed( - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None, - use_conv_embed=use_conv_embed, - is_stem=True, - ) + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + use_conv_embed=use_conv_embed, is_stem=True) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) - ] # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( - dim=int(embed_dim * 2**i_layer), + dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], mlp_ratio=mlp_ratio, drop=drop_rate, - drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None, - focal_window=focal_windows[i_layer], - focal_level=focal_levels[i_layer], + focal_window=focal_windows[i_layer], + focal_level=focal_levels[i_layer], use_conv_embed=use_conv_embed, - use_layerscale=use_layerscale, - use_checkpoint=use_checkpoint, - ) + use_postln=use_postln, + use_postln_in_modulation=use_postln_in_modulation, + normalize_modulator=normalize_modulator, + use_layerscale=use_layerscale, + use_checkpoint=use_checkpoint) self.layers.append(layer) num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] @@ -470,6 +478,8 @@ def __init__( self.add_module(layer_name, layer) self._freeze_stages() + + # add basic info self._out_features = ["p{}".format(i) for i in self.out_indices] self._out_feature_channels = { "p{}".format(i): self.embed_dim * 2**i for i in self.out_indices diff --git a/projects/dino/configs/dino_focalnet_large_lrf_384_4scale_12ep.py b/projects/dino/configs/dino_focalnet_large_lrf_384_4scale_12ep.py new file mode 100644 index 00000000..fd3addf0 --- /dev/null +++ b/projects/dino/configs/dino_focalnet_large_lrf_384_4scale_12ep.py @@ -0,0 +1,11 @@ +from .dino_r50_4scale_12ep import ( + train, + dataloader, + optimizer, + lr_multiplier, +) +from .models.dino_focalnet import model + +# modify training config +train.init_checkpoint = "/path/to/focalnet_large_lrf_384.pth" +train.output_dir = "./output/dino_focalnet_large_4scale_12ep" diff --git a/projects/dino/configs/dino_focalnet_large_lrf_384_fl4_4scale_12ep.py b/projects/dino/configs/dino_focalnet_large_lrf_384_fl4_4scale_12ep.py new file mode 100644 index 00000000..241acfc0 --- /dev/null +++ b/projects/dino/configs/dino_focalnet_large_lrf_384_fl4_4scale_12ep.py @@ -0,0 +1,17 @@ +from .dino_focalnet_large_lrf_384_4scale_12ep import ( + train, + dataloader, + optimizer, + lr_multiplier, + model, +) + + +# modify training config +train.init_checkpoint = "/path/to/focalnet_large_lrf_384_fl4.pth" +train.output_dir = "./output/dino_focalnet_large_fl4_4scale_12ep" + + +# convert to 4 focal-level +model.backbone.focal_levels = (4, 4, 4, 4) +model.backbone.focal_windows = (3, 3, 3, 3) diff --git a/projects/dino/configs/models/dino_focalnet.py b/projects/dino/configs/models/dino_focalnet.py new file mode 100644 index 00000000..54497825 --- /dev/null +++ b/projects/dino/configs/models/dino_focalnet.py @@ -0,0 +1,28 @@ +from detectron2.config import LazyCall as L +from detectron2.layers import ShapeSpec +from detrex.modeling.backbone import FocalNet + +from .dino_r50 import model + + +# focalnet-large-4scale baseline +model.backbone = L(FocalNet)( + embed_dim=192, + depths=(2, 2, 18, 2), + focal_levels=(3, 3, 3, 3), + focal_windows=(5, 5, 5, 5), + use_conv_embed=True, + use_postln=True, + use_postln_in_modulation=False, + use_layerscale=True, + normalize_modulator=False, + out_indices=(1, 2, 3), +) + +# modify neck config +model.neck.input_shapes = { + "p1": ShapeSpec(channels=384), + "p2": ShapeSpec(channels=768), + "p3": ShapeSpec(channels=1536), +} +model.neck.in_features = ["p1", "p2", "p3"]