forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support PIDNet (open-mmlab#2609)
## Motivation Support SOTA real-time semantic segmentation method in [Paper with code](https://paperswithcode.com/task/real-time-semantic-segmentation) Paper: https://arxiv.org/pdf/2206.02066.pdf Official repo: https://github.com/XuJiacong/PIDNet ## Current results **Cityscapes** |Model|Ref mIoU|mIoU (ours)| |---|---|---| |PIDNet-S|78.8|78.74| |PIDNet-M|79.9|80.22| |PIDNet-L|80.9|80.89| ## TODO - [x] Support inference with official weights - [x] Support training on Cityscapes - [x] Update docstring - [x] Add unit test
- Loading branch information
Showing
20 changed files
with
1,646 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# PIDNet | ||
|
||
> [PIDNet: A Real-time Semantic Segmentation Network Inspired from PID Controller](https://arxiv.org/pdf/2206.02066.pdf) | ||
## Introduction | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
<a href="https://github.com/XuJiacong/PIDNet">Official Repo</a> | ||
|
||
<a href="https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/pidnet.py">Code Snippet</a> | ||
|
||
## Abstract | ||
|
||
<!-- [ABSTRACT] --> | ||
|
||
Two-branch network architecture has shown its efficiency and effectiveness for real-time semantic segmentation tasks. However, direct fusion of low-level details and high-level semantics will lead to a phenomenon that the detailed features are easily overwhelmed by surrounding contextual information, namely overshoot in this paper, which limits the improvement of the accuracy of existed two-branch models. In this paper, we bridge a connection between Convolutional Neural Network (CNN) and Proportional-IntegralDerivative (PID) controller and reveal that the two-branch network is nothing but a Proportional-Integral (PI) controller, which inherently suffers from the similar overshoot issue. To alleviate this issue, we propose a novel threebranch network architecture: PIDNet, which possesses three branches to parse the detailed, context and boundary information (derivative of semantics), respectively, and employs boundary attention to guide the fusion of detailed and context branches in final stage. The family of PIDNets achieve the best trade-off between inference speed and accuracy and their test accuracy surpasses all the existed models with similar inference speed on Cityscapes, CamVid and COCO-Stuff datasets. Especially, PIDNet-S achieves 78.6% mIOU with inference speed of 93.2 FPS on Cityscapes test set and 80.1% mIOU with speed of 153.7 FPS on CamVid test set. | ||
|
||
<!-- [IMAGE] --> | ||
|
||
<div align=center> | ||
<img src="https://raw.githubusercontent.com/XuJiacong/PIDNet/main/figs/pidnet.jpg" width="800"/> | ||
</div> | ||
|
||
## Results and models | ||
|
||
### Cityscapes | ||
|
||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | | ||
| ------ | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | ----------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||
| PIDNet | PIDNet-S | 1024x1024 | 120000 | 3.38 | 80.82 | 78.74 | 80.87 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes/pidnet-s_2xb6-120k_1024x1024-cityscapes_20230302_191700-bb8e3bcc.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes/pidnet-s_2xb6-120k_1024x1024-cityscapes_20230302_191700.json) | | ||
| PIDNet | PIDNet-M | 1024x1024 | 120000 | 5.14 | 71.98 | 80.22 | 82.05 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes/pidnet-m_2xb6-120k_1024x1024-cityscapes_20230301_143452-f9bcdbf3.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes/pidnet-m_2xb6-120k_1024x1024-cityscapes_20230301_143452.json) | | ||
| PIDNet | PIDNet-L | 1024x1024 | 120000 | 5.83 | 60.06 | 80.89 | 82.37 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes/pidnet-l_2xb6-120k_1024x1024-cityscapes_20230303_114514-0783ca6b.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes/pidnet-l_2xb6-120k_1024x1024-cityscapes_20230303_114514.json) | | ||
|
||
## Notes | ||
|
||
The pretrained weights in config files are converted from [the official repo](https://github.com/XuJiacong/PIDNet#models). | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@misc{xu2022pidnet, | ||
title={PIDNet: A Real-time Semantic Segmentation Network Inspired from PID Controller}, | ||
author={Jiacong Xu and Zixiang Xiong and Shankar P. Bhattacharyya}, | ||
year={2022}, | ||
eprint={2206.02066}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CV} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
_base_ = './pidnet-s_2xb6-120k_1024x1024-cityscapes.py' | ||
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-l_imagenet1k_20230306-67889109.pth' # noqa | ||
model = dict( | ||
backbone=dict( | ||
channels=64, | ||
ppm_channels=112, | ||
num_stem_blocks=3, | ||
num_branch_blocks=4, | ||
init_cfg=dict(checkpoint=checkpoint_file)), | ||
decode_head=dict(in_channels=256, channels=256)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = './pidnet-s_2xb6-120k_1024x1024-cityscapes.py' | ||
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-m_imagenet1k_20230306-39893c52.pth' # noqa | ||
model = dict( | ||
backbone=dict(channels=64, init_cfg=dict(checkpoint=checkpoint_file)), | ||
decode_head=dict(in_channels=256)) |
113 changes: 113 additions & 0 deletions
113
configs/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
_base_ = [ | ||
'../_base_/datasets/cityscapes_1024x1024.py', | ||
'../_base_/default_runtime.py' | ||
] | ||
|
||
# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa | ||
# Licensed under the MIT License | ||
class_weight = [ | ||
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786, | ||
1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529, | ||
1.0507 | ||
] | ||
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-s_imagenet1k_20230306-715e6273.pth' # noqa | ||
crop_size = (1024, 1024) | ||
data_preprocessor = dict( | ||
type='SegDataPreProcessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
pad_val=0, | ||
seg_pad_val=255, | ||
size=crop_size) | ||
norm_cfg = dict(type='SyncBN', requires_grad=True) | ||
model = dict( | ||
type='EncoderDecoder', | ||
data_preprocessor=data_preprocessor, | ||
backbone=dict( | ||
type='PIDNet', | ||
in_channels=3, | ||
channels=32, | ||
ppm_channels=96, | ||
num_stem_blocks=2, | ||
num_branch_blocks=3, | ||
align_corners=False, | ||
norm_cfg=norm_cfg, | ||
act_cfg=dict(type='ReLU', inplace=True), | ||
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)), | ||
decode_head=dict( | ||
type='PIDHead', | ||
in_channels=128, | ||
channels=128, | ||
num_classes=19, | ||
norm_cfg=norm_cfg, | ||
act_cfg=dict(type='ReLU', inplace=True), | ||
align_corners=True, | ||
loss_decode=[ | ||
dict( | ||
type='CrossEntropyLoss', | ||
use_sigmoid=False, | ||
class_weight=class_weight, | ||
loss_weight=0.4), | ||
dict( | ||
type='OhemCrossEntropy', | ||
thres=0.9, | ||
min_kept=131072, | ||
class_weight=class_weight, | ||
loss_weight=1.0), | ||
dict(type='BoundaryLoss', loss_weight=20.0), | ||
dict( | ||
type='OhemCrossEntropy', | ||
thres=0.9, | ||
min_kept=131072, | ||
class_weight=class_weight, | ||
loss_weight=1.0) | ||
]), | ||
train_cfg=dict(), | ||
test_cfg=dict(mode='whole')) | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations'), | ||
dict( | ||
type='RandomResize', | ||
scale=(2048, 1024), | ||
ratio_range=(0.5, 2.0), | ||
keep_ratio=True), | ||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PhotoMetricDistortion'), | ||
dict(type='GenerateEdge', edge_width=4), | ||
dict(type='PackSegInputs') | ||
] | ||
train_dataloader = dict(batch_size=6, dataset=dict(pipeline=train_pipeline)) | ||
|
||
iters = 120000 | ||
# optimizer | ||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) | ||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None) | ||
# learning policy | ||
param_scheduler = [ | ||
dict( | ||
type='PolyLR', | ||
eta_min=0, | ||
power=0.9, | ||
begin=0, | ||
end=iters, | ||
by_epoch=False) | ||
] | ||
# training schedule for 120k | ||
train_cfg = dict( | ||
type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10) | ||
val_cfg = dict(type='ValLoop') | ||
test_cfg = dict(type='TestLoop') | ||
default_hooks = dict( | ||
timer=dict(type='IterTimerHook'), | ||
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), | ||
param_scheduler=dict(type='ParamSchedulerHook'), | ||
checkpoint=dict( | ||
type='CheckpointHook', by_epoch=False, interval=iters // 10), | ||
sampler_seed=dict(type='DistSamplerSeedHook'), | ||
visualization=dict(type='SegVisualizationHook')) | ||
|
||
randomness = dict(seed=304) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
Collections: | ||
- Name: PIDNet | ||
Metadata: | ||
Training Data: | ||
- Cityscapes | ||
Paper: | ||
URL: https://arxiv.org/pdf/2206.02066.pdf | ||
Title: 'PIDNet: A Real-time Semantic Segmentation Network Inspired from PID Controller' | ||
README: configs/pidnet/README.md | ||
Code: | ||
URL: https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/pidnet.py | ||
Version: dev-1.x | ||
Converted From: | ||
Code: https://github.com/XuJiacong/PIDNet | ||
Models: | ||
- Name: pidnet-s_2xb6-120k_1024x1024-cityscapes | ||
In Collection: PIDNet | ||
Metadata: | ||
backbone: PIDNet-S | ||
crop size: (1024,1024) | ||
lr schd: 120000 | ||
inference time (ms/im): | ||
- value: 12.37 | ||
hardware: V100 | ||
backend: PyTorch | ||
batch size: 1 | ||
mode: FP32 | ||
resolution: (1024,1024) | ||
Training Memory (GB): 3.38 | ||
Results: | ||
- Task: Semantic Segmentation | ||
Dataset: Cityscapes | ||
Metrics: | ||
mIoU: 78.74 | ||
mIoU(ms+flip): 80.87 | ||
Config: configs/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes.py | ||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes/pidnet-s_2xb6-120k_1024x1024-cityscapes_20230302_191700-bb8e3bcc.pth | ||
- Name: pidnet-m_2xb6-120k_1024x1024-cityscapes | ||
In Collection: PIDNet | ||
Metadata: | ||
backbone: PIDNet-M | ||
crop size: (1024,1024) | ||
lr schd: 120000 | ||
inference time (ms/im): | ||
- value: 13.89 | ||
hardware: V100 | ||
backend: PyTorch | ||
batch size: 1 | ||
mode: FP32 | ||
resolution: (1024,1024) | ||
Training Memory (GB): 5.14 | ||
Results: | ||
- Task: Semantic Segmentation | ||
Dataset: Cityscapes | ||
Metrics: | ||
mIoU: 80.22 | ||
mIoU(ms+flip): 82.05 | ||
Config: configs/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes.py | ||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes/pidnet-m_2xb6-120k_1024x1024-cityscapes_20230301_143452-f9bcdbf3.pth | ||
- Name: pidnet-l_2xb6-120k_1024x1024-cityscapes | ||
In Collection: PIDNet | ||
Metadata: | ||
backbone: PIDNet-L | ||
crop size: (1024,1024) | ||
lr schd: 120000 | ||
inference time (ms/im): | ||
- value: 16.65 | ||
hardware: V100 | ||
backend: PyTorch | ||
batch size: 1 | ||
mode: FP32 | ||
resolution: (1024,1024) | ||
Training Memory (GB): 5.83 | ||
Results: | ||
- Task: Semantic Segmentation | ||
Dataset: Cityscapes | ||
Metrics: | ||
mIoU: 80.89 | ||
mIoU(ms+flip): 82.37 | ||
Config: configs/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes.py | ||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes/pidnet-l_2xb6-120k_1024x1024-cityscapes_20230303_114514-0783ca6b.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.