Skip to content

Commit

Permalink
Code for "High-Resolution Representations for Labeling Pixels and Reg…
Browse files Browse the repository at this point in the history
…ions" (#610)

* support HRNet

* add zip

* remove zip files

* remove zip datasets in config

* modify format and shorten lines

* fix line to long

* support conv_cfg and update conv layer

* revise the backbone network and neck

* update format and pretrained mode

* fix flake8 error

* update modules following review suggestions

* revert some changes for adapting to pretrained models

* update hrnet and hrfpn

* remove unused import

* remove unused import

* finish testing

* change pretrained model link to open-mmlab

* fix docstring and convert models

* update README and model links

* modify configs and README

* support loss evaluator

* update model urls

* format hrnet.py

* format hrfpn.py

* add 20e for cascade config
  • Loading branch information
Tianheng Cheng authored and hellock committed May 22, 2019
1 parent 57d3459 commit 3cb84ac
Show file tree
Hide file tree
Showing 11 changed files with 1,865 additions and 2 deletions.
54 changes: 54 additions & 0 deletions configs/hrnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# High-resolution networks (HRNets) for object detection

## Introduction

```
@inproceedings{SunXLW19,
title={Deep High-Resolution Representation Learning for Human Pose Estimation},
author={Ke Sun and Bin Xiao and Dong Liu and Jingdong Wang},
booktitle={CVPR},
year={2019}
}
@article{SunZJCXLMWLW19,
title={High-Resolution Representations for Labeling Pixels and Regions},
author={Ke Sun and Yang Zhao and Borui Jiang and Tianheng Cheng and Bin Xiao
and Dong Liu and Yadong Mu and Xinggang Wang and Wenyu Liu and Jingdong Wang},
journal = {CoRR},
volume = {abs/1904.04514},
year={2019}
}
```

## Results and Models

Faster R-CNN

| Backbone|#Params|GFLOPs|Lr sched|mAP|Download|
| :--:|:--:|:--:|:--:|:--:|:--:|
| HRNetV2-W18 |26.2M|159.1| 1x | 36.1 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/hrnet/faster_rcnn_hrnetv2_w18_fpn_1x_20190522-e368c387.pth)|
| HRNetV2-W18 |26.2M|159.1| 20-23-24e | 38.1 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/hrnet/faster_rcnn_hrnetv2_w18_fpn_20_23_24e_20190522-ed3c0293.pth)|
| HRNetV2-W32 |45.0M|245.3| 1x | 39.5 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/hrnet/faster_rcnn_hrnetv2_w32_fpn_1x_20190522-d22f1fef.pth)|
| HRNetV2-W32 |45.0M|245.3| 20-23-24e | 40.8 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/hrnet/faster_rcnn_hrnetv2_w32_fpn_20_23_24e_20190522-2d67a5eb.pth)|
| HRNetV2-W40 |60.5M|314.9| 1x | 40.4 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/hrnet/faster_rcnn_hrnetv2_w40_fpn_1x_20190522-30502318.pth)|
| HRNetV2-W40 |60.5M|314.9| 20-23-24e | 41.4 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/hrnet/faster_rcnn_hrnetv2_w40_fpn_20_23_24e_20190522-050a7c7f.pth)|


Mask R-CNN

|Backbone|Lr sched|mask mAP|box mAP|Download|
|:--:|:--:|:--:|:--:|:--:|
| HRNetV2-W18 | 1x | 34.2 | 37.3 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/hrnet/mask_rcnn_hrnetv2_w18_fpn_1x_20190522-c8ad459f.pth)|
| HRNetV2-W18 | 20-23-24e | 35.7 | 39.2 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/hrnet/mask_rcnn_hrnetv2_w18_fpn_20_23_24e_20190522-5c11b7f2.pth)|
| HRNetV2-W32 | 1x | 36.8 | 40.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/hrnet/mask_rcnn_hrnetv2_w32_fpn_1x_20190522-374aaa00.pth)|
| HRNetV2-W32 | 20-23-24e | 37.6 | 42.1 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/hrnet/mask_rcnn_hrnetv2_w32_fpn_20_23_24e_20190522-4dd02a79.pth)|

Cascade R-CNN

|Backbone|Lr sched|mAP|Download|
|:--:|:--:|:--:|:--:|
| HRNetV2-W32 | 20e | 43.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/hrnet/cascade_rcnn_hrnetv2_w32_fpn_20e_20190522-55bec4ee.pth)|

**Note:**

- HRNetV2 ImageNet pretrained models are in [HRNets for Image Classification](https://github.com/HRNet/HRNet-Image-Classification).
268 changes: 268 additions & 0 deletions configs/hrnet/cascade_rcnn_hrnetv2p_w32_20e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
# model settings
model = dict(
type='CascadeRCNN',
num_stages=3,
pretrained='open-mmlab://msra/hrnetv2_w32',
backbone=dict(
type='HRNet',
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256)))),
neck=dict(
type='HRFPN',
in_channels=[32, 64, 128, 256],
out_channels=256),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(
type='RoIAlign',
out_size=7,
sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=[
dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss',
beta=1.0,
loss_weight=1.0)),
dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.05, 0.05, 0.1, 0.1],
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss',
beta=1.0,
loss_weight=1.0)),
dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.033, 0.033, 0.067, 0.067],
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss',
beta=1.0,
loss_weight=1.0)),
])
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=[
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.6,
neg_iou_thr=0.6,
min_pos_iou=0.6,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.7,
min_pos_iou=0.7,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False)
],
stage_loss_weights=[1, 0.5, 0.25])
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100),
keep_all_stages=False)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[16, 19])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 20
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/cascade_rcnn_hrnetv2p_w32'
load_from = None
resume_from = None
workflow = [('train', 1)]
Loading

0 comments on commit 3cb84ac

Please sign in to comment.