Skip to content

Commit

Permalink
support hmdb51 (open-mmlab#134)
Browse files Browse the repository at this point in the history
* Update INSTALL.md

libboost-all-dev is required to build dense_flow

* fix dense_flow path error

* support database hmdb51
  • Loading branch information
loveunk authored Mar 16, 2020
1 parent c480d83 commit 909e90d
Show file tree
Hide file tree
Showing 13 changed files with 475 additions and 4 deletions.
3 changes: 2 additions & 1 deletion DATASET.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The supported datasets are listed below.
We provide shell scripts for data preparation under the path `$MMACTION/data_tools/`.
To ease usage, we provide tutorials of data deployment for each dataset.

- [HMDB51](http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/): See [PREPARING_HMDB51.md](https://github.com/open-mmlab/mmaction/tree/master/data_tools/hmdb51/PREPARING_HMDB51.md)
- [UCF101](https://www.crcv.ucf.edu/data/UCF101.php): See [PREPARING_UCF101.md](https://github.com/open-mmlab/mmaction/tree/master/data_tools/ucf101/PREPARING_UCF101.md)
- [Kinetics400](https://deepmind.com/research/open-source/open-source-datasets/kinetics/): See [PREPARING_KINETICS400.md](https://github.com/open-mmlab/mmaction/tree/master/data_tools/kinetics400/PREPARING_KINETICS400.md)
- [THUMOS14](https://www.crcv.ucf.edu/THUMOS14/download.html): See [PREPARING_TH14.md](https://github.com/open-mmlab/mmaction/tree/master/data_tools/thumos14/PREPARING_TH14.md)
Expand Down Expand Up @@ -64,5 +65,5 @@ cd $MMACTION
python data_tools/build_file_list.py ${DATASET} ${SRC_FOLDER} --level {1, 2} --format {rawframes, videos}
```
- `${SRC_FOLDER}` should point to the folder of the corresponding to the data format:
- "$MMACTION/data/$DATASET/rawframes" `--format rawframes`
- "$MMACTION/data/$DATASET/rawframes" if `--format rawframes`
- "$MMACTION/data/$DATASET/videos" if `--format videos`
2 changes: 1 addition & 1 deletion INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ You can skip this argument to speed up the compilation if you do not intend to u
```shell
cd third_party/dense_flow
# dense_flow dependencies
sudo apt-get -qq install libzip-dev
sudo apt-get -qq install libzip-dev libboost-all-dev
mkdir build && cd build
# deprecated:
# OpenCV_DIR=../../opencv-2.4.13/build cmake ..
Expand Down
125 changes: 125 additions & 0 deletions configs/hmdb51/tsn_flow_bninception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# model settings
model = dict(
type='TSN2D',
modality='Flow',
in_channels=10,
backbone=dict(
type='BNInception',
pretrained='open-mmlab://bninception_caffe',
bn_eval=False,
partial_bn=True),
spatial_temporal_module=dict(
type='SimpleSpatialModule',
spatial_type='avg',
spatial_size=7),
segmental_consensus=dict(
type='SimpleConsensus',
consensus_type='avg'),
cls_head=dict(
type='ClsHead',
with_avg_pool=False,
temporal_feature_size=1,
spatial_feature_size=1,
dropout_ratio=0.7,
in_channels=1024,
num_classes=51))
train_cfg = None
test_cfg = None
# dataset settings
dataset_type = 'RawFramesDataset'
data_root = 'data/hmdb51/rawframes'
img_norm_cfg = dict(
mean=[128], std=[1], to_rgb=False)
data = dict(
videos_per_gpu=32,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file='data/hmdb51/hmdb51_train_split_1_rawframes.txt',
img_prefix=data_root,
img_norm_cfg=img_norm_cfg,
num_segments=3,
new_length=5,
new_step=1,
random_shift=True,
modality='Flow',
image_tmpl='flow_{}_{:05d}.jpg',
img_scale=256,
input_size=224,
div_255=False,
flip_ratio=0.5,
resize_keep_ratio=True,
oversample=None,
random_crop=False,
more_fix_crop=False,
multiscale_crop=True,
scales=[1, 0.875, 0.75, 0.66],
max_distort=1,
test_mode=False),
val=dict(
type=dataset_type,
ann_file='data/hmdb51/hmdb51_val_split_1_rawframes.txt',
img_prefix=data_root,
img_norm_cfg=img_norm_cfg,
num_segments=3,
new_length=5,
new_step=1,
random_shift=False,
modality='Flow',
image_tmpl='flow_{}_{:05d}.jpg',
img_scale=256,
input_size=224,
div_255=False,
flip_ratio=0,
resize_keep_ratio=True,
oversample=None,
random_crop=False,
more_fix_crop=False,
multiscale_crop=False,
test_mode=False),
test=dict(
type=dataset_type,
ann_file='data/hmdb51/hmdb51_val_split_1_rawframes.txt',
img_prefix=data_root,
img_norm_cfg=img_norm_cfg,
num_segments=25,
new_length=5,
new_step=1,
random_shift=False,
modality='Flow',
image_tmpl='flow_{}_{:05d}.jpg',
img_scale=256,
input_size=224,
div_255=False,
flip_ratio=0,
resize_keep_ratio=True,
oversample='ten_crop',
random_crop=False,
more_fix_crop=False,
multiscale_crop=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
step=[190, 300])
checkpoint_config = dict(interval=1)
# workflow = [('train', 5), ('val', 1)]
workflow = [('train', 1)]
# yapf:disable
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 340
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/tsn_2d_flow_bninception_seg_3_f1s1_b32_g8_lr_0.005'
load_from = None
resume_from = None
125 changes: 125 additions & 0 deletions configs/hmdb51/tsn_rgb_bninception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# model settings
model = dict(
type='TSN2D',
backbone=dict(
type='BNInception',
pretrained='open-mmlab://bninception_caffe',
bn_eval=False,
partial_bn=True),
spatial_temporal_module=dict(
type='SimpleSpatialModule',
spatial_type='avg',
spatial_size=7),
segmental_consensus=dict(
type='SimpleConsensus',
consensus_type='avg'),
cls_head=dict(
type='ClsHead',
with_avg_pool=False,
temporal_feature_size=1,
spatial_feature_size=1,
dropout_ratio=0.8,
in_channels=1024,
init_std=0.001,
num_classes=51))
train_cfg = None
test_cfg = None
# dataset settings
dataset_type = 'RawFramesDataset'
data_root = 'data/hmdb51/rawframes'
img_norm_cfg = dict(
mean=[104, 117, 128], std=[1, 1, 1], to_rgb=False)

data = dict(
videos_per_gpu=32,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file='data/hmdb51/hmdb51_train_split_1_rawframes.txt',
img_prefix=data_root,
img_norm_cfg=img_norm_cfg,
num_segments=3,
new_length=1,
new_step=1,
random_shift=True,
modality='RGB',
image_tmpl='img_{:05d}.jpg',
img_scale=256,
input_size=224,
div_255=False,
flip_ratio=0.5,
resize_keep_ratio=True,
oversample=None,
random_crop=False,
more_fix_crop=False,
multiscale_crop=True,
scales=[1, 0.875, 0.75, 0.66],
max_distort=1,
test_mode=False),
val=dict(
type=dataset_type,
ann_file='data/hmdb51/hmdb51_val_split_1_rawframes.txt',
img_prefix=data_root,
img_norm_cfg=img_norm_cfg,
num_segments=3,
new_length=1,
new_step=1,
random_shift=False,
modality='RGB',
image_tmpl='img_{:05d}.jpg',
img_scale=256,
input_size=224,
div_255=False,
flip_ratio=0,
resize_keep_ratio=True,
oversample=None,
random_crop=False,
more_fix_crop=False,
multiscale_crop=False,
test_mode=False),
test=dict(
type=dataset_type,
ann_file='data/hmdb51/hmdb51_val_split_1_rawframes.txt',
img_prefix=data_root,
img_norm_cfg=img_norm_cfg,
num_segments=25,
new_length=1,
new_step=1,
random_shift=False,
modality='RGB',
image_tmpl='img_{:05d}.jpg',
img_scale=256,
input_size=224,
div_255=False,
flip_ratio=0,
resize_keep_ratio=True,
oversample='ten_crop',
random_crop=False,
more_fix_crop=False,
multiscale_crop=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
step=[30, 60])
checkpoint_config = dict(interval=1)
# workflow = [('train', 5), ('val', 1)]
workflow = [('train', 1)]
# yapf:disable
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 80
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/tsn_2d_rgb_bninception_seg_3_f1s1_b32_g8'
load_from = None
resume_from = None
5 changes: 4 additions & 1 deletion data_tools/build_file_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os.path as osp
import glob
from mmaction.datasets.utils import (parse_directory,
parse_hmdb51_splits,
parse_ucf101_splits,
parse_kinetics_splits,
build_split_list)
Expand All @@ -10,7 +11,7 @@
def parse_args():
parser = argparse.ArgumentParser(description='Build file list')
parser.add_argument('dataset', type=str, choices=[
'ucf101', 'kinetics400'])
'hmdb51', 'ucf101', 'kinetics400'])
parser.add_argument('frame_path', type=str,
help='root directory for the frames')
parser.add_argument('--rgb_prefix', type=str, default='img_')
Expand Down Expand Up @@ -52,6 +53,8 @@ def key_func(x): return x.split('/')[-1]
frame_info = {osp.relpath(
x.split('.')[0], args.frame_path): (x, -1, -1) for x in video_list}

if args.dataset == 'hmdb51':
split_tp = parse_hmdb51_splits(args.level)
if args.dataset == 'ucf101':
split_tp = parse_ucf101_splits(args.level)
elif args.dataset == 'kinetics400':
Expand Down
2 changes: 1 addition & 1 deletion data_tools/build_rawframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def parse_args():
parser.add_argument('--flow_type', type=str,
default=None, choices=[None, 'tvl1', 'warp_tvl1'])
parser.add_argument('--df_path', type=str,
default='../mmaction/third_party/dense_flow')
default='../../mmaction/third_party/dense_flow')
parser.add_argument("--out_format", type=str, default='dir',
choices=['dir', 'zip'], help='output format')
parser.add_argument("--ext", type=str, default='avi',
Expand Down
80 changes: 80 additions & 0 deletions data_tools/hmdb51/PREPARING_HMDB51.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
## Preparing HMDB51

For more details, please refer to the official [website](http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/). We provide scripts with documentations. Before we start, please make sure that the directory is located at `$MMACTION/data_tools/hmdb51/`.

### Prepare annotations
First of all, run the following script to prepare annotations.
```shell
bash download_annotations.sh
```

### Prepare videos
Then, use the following script to prepare videos.
```shell
bash download_videos.sh
```

### Extract frames
Now it is time to extract frames from videos.
Before extraction, please refer to `DATASET.md` for installing [dense_flow](https://github.com/yjxiong/dense_flow).
If you have some SSD, then we recommend extracting frames there for better I/O performance. The extracted frames (RGB + Flow) will take up ~24GB.
```shell
# execute these two line (Assume the SSD is mounted at "/mnt/SSD/")
mkdir /mnt/SSD/hmdb51_extracted/
ln -s /mnt/SSD/hmdb51_extracted/ ../data/hmdb51/rawframes
```

If you didn't install dense_flow in the installation or only want to play with RGB frames (since extracting optical flow can be both time-comsuming and space-hogging), consider running the following script to extract **RGB-only** frames.
```shell
bash extract_rgb_frames.sh
```

If both rgb and optical flow are required, run the following script to extract frames alternatively.
```shell
bash extract_frames.sh
```

### Generate filelist
Run the follow script to generate filelist in the format of rawframes and videos.
```shell
bash generate_filelist.sh
```

### Folder structure
In the context of the whole project (for ucf101 only), the folder structure will look like:
```
mmaction
├── mmaction
├── tools
├── configs
├── data
│ ├── hmdb51
│ │ ├── hmdb51_{train,val}_split_{1,2,3}_rawframes.txt
│ │ ├── hmdb51_{train,val}_split_{1,2,3}_videos.txt
│ │ ├── annotations
│ │ ├── videos
│ │ │ ├── brush_hair
│ │ │ │ ├── April_09_brush_hair_u_nm_np1_ba_goo_0.avi
│ │ │ ├── wave
│ │ │ │ ├── 20060723sfjffbartsinger_wave_f_cm_np1_ba_med_0.avi
│ │ ├── rawframes
│ │ │ ├── brush_hair
│ │ │ │ ├── April_09_brush_hair_u_nm_np1_ba_goo_0
│ │ │ │ │ ├── img_00001.jpg
│ │ │ │ │ ├── img_00002.jpg
│ │ │ │ │ ├── ...
│ │ │ │ │ ├── flow_x_00001.jpg
│ │ │ │ │ ├── flow_x_00002.jpg
│ │ │ │ │ ├── ...
│ │ │ │ │ ├── flow_y_00001.jpg
│ │ │ │ │ ├── flow_y_00002.jpg
│ │ │ ├── ...
│ │ │ ├── wave
│ │ │ │ ├── 20060723sfjffbartsinger_wave_f_cm_np1_ba_med_0
│ │ │ │ ├── ...
│ │ │ │ ├── winKen_wave_u_cm_np1_ri_bad_1
```

For training and evaluating on HMDB51, please refer to [GETTING_STARTED.md](https://github.com/open-mmlab/mmaction/blob/master/GETTING_STARTED.md).
Loading

0 comments on commit 909e90d

Please sign in to comment.