Skip to content

Commit

Permalink
Support Monocular 3D Detector CaDDN (open-mmlab#538)
Browse files Browse the repository at this point in the history
* Added CaDDN detector and support for image, depth map, and 2D GT box
dataloading

* Moved image flip augmentation to augmentor_utils

* Updated default get item list to include points

* Moved utils functions into transform_utils

* Combined FFE + F2V into ImageVFE, renamed FFE to FFN, moved depth downsample into data_processor

* Updated README with updated CaDDN weights

* Updated comments for image vfe
  • Loading branch information
codyreading authored May 20, 2021
1 parent e3bec15 commit aaf9cbe
Show file tree
Hide file tree
Showing 37 changed files with 1,379 additions and 26 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ venv/
*.idea/
*.so
*.yaml
*.sh
*.sh
*.pth
*.pkl
*.zip
*.bin
output
version.py
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18


## Changelog
[2021-05-14] Added support for the monocular 3D object detection model [`CaDDN`](#KITTI-3D-Object-Detection-Baselines)

[2020-11-27] **Bugfixed:** Please re-prepare the validation infos of Waymo dataset (version 1.2) if you would like to
use our provided Waymo evaluation tool (see [PR](https://github.com/open-mmlab/OpenPCDet/pull/383)).
Note that you do not need to re-prepare the training data and ground-truth database.
Expand Down Expand Up @@ -104,6 +106,7 @@ Selected supported methods are shown in the below table. The results are the 3D
| [Part-A^2-Free](tools/cfgs/kitti_models/PartA2_free.yaml) | ~3.8 hours| 78.72 | 65.99 | 74.29 | [model-226M](https://drive.google.com/file/d/1lcUUxF8mJgZ_e-tZhP1XNQtTBuC-R0zr/view?usp=sharing) |
| [Part-A^2-Anchor](tools/cfgs/kitti_models/PartA2.yaml) | ~4.3 hours| 79.40 | 60.05 | 69.90 | [model-244M](https://drive.google.com/file/d/10GK1aCkLqxGNeX3lVu8cLZyE0G8002hY/view?usp=sharing) |
| [PV-RCNN](tools/cfgs/kitti_models/pv_rcnn.yaml) | ~5 hours| 83.61 | 57.90 | 70.47 | [model-50M](https://drive.google.com/file/d/1lIOq4Hxr0W3qsX83ilQv0nk1Cls6KAr-/view?usp=sharing) |
| [CaDDN](tools/cfgs/kitti_models/CaDDN.yaml) |~15 hours| 21.38 | 13.02 | 9.76 | [model-774M](https://drive.google.com/file/d/1OQTO2PtXT8GGr35W9m2GZGuqgb6fyU1V/view?usp=sharing) |

### NuScenes 3D Object Detection Baselines
All models are trained with 8 GTX 1080Ti GPUs and are available for download.
Expand Down
14 changes: 13 additions & 1 deletion docs/GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Currently we provide the dataloader of KITTI dataset and NuScenes dataset, and t

### KITTI Dataset
* Please download the official [KITTI 3D object detection](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) dataset and organize the downloaded files as follows (the road planes could be downloaded from [[road plane]](https://drive.google.com/file/d/1d5mq0RXRnvHPVeKx6Q612z0YRO1t2wAp/view?usp=sharing), which are optional for data augmentation in the training):
* If you would like to train [CaDDN](../tools/cfgs/kitti_models/CaDDN.yaml), download the precomputed [depth maps](https://drive.google.com/file/d/1qFZux7KC_gJ0UHEg-qGJKqteE9Ivojin/view?usp=sharing) for the KITTI training set
* NOTE: if you already have the data infos from `pcdet v0.1`, you can choose to use the old infos and set the DATABASE_WITH_FAKELIDAR option in tools/cfgs/dataset_configs/kitti_dataset.yaml as True. The second choice is that you can create the infos and gt database again and leave the config unchanged.

```
Expand All @@ -17,7 +18,7 @@ OpenPCDet
│ ├── kitti
│ │ │── ImageSets
│ │ │── training
│ │ │ ├──calib & velodyne & label_2 & image_2 & (optional: planes)
│ │ │ ├──calib & velodyne & label_2 & image_2 & (optional: planes) & (optional: depth_2)
│ │ │── testing
│ │ │ ├──calib & velodyne & image_2
├── pcdet
Expand Down Expand Up @@ -94,6 +95,17 @@ python -m pcdet.datasets.waymo.waymo_dataset --func create_waymo_infos \

Note that you do not need to install `waymo-open-dataset` if you have already processed the data before and do not need to evaluate with official Waymo Metrics.

## Pretrained Models
If you would like to train [CaDDN](../tools/cfgs/kitti_models/CaDDN.yaml), download the pretrained [DeepLabV3 model](https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth) and place within the `checkpoints` directory
```
OpenPCDet
├── checkpoints
│ ├── deeplabv3_resnet101_coco-586e9e4e.pth
├── data
├── pcdet
├── tools
```

## Training & Testing


Expand Down
40 changes: 40 additions & 0 deletions pcdet/datasets/augmentor/augmentor_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import numpy as np

from ...utils import common_utils
Expand Down Expand Up @@ -76,3 +77,42 @@ def global_scaling(gt_boxes, points, scale_range):
points[:, :3] *= noise_scale
gt_boxes[:, :6] *= noise_scale
return gt_boxes, points

def random_image_flip_horizontal(image, depth_map, gt_boxes, calib):
"""
Performs random horizontal flip augmentation
Args:
image: (H_image, W_image, 3), Image
depth_map: (H_depth, W_depth), Depth map
gt_boxes: (N, 7), 3D box labels in LiDAR coordinates [x, y, z, w, l, h, ry]
calib: calibration.Calibration, Calibration object
Returns:
aug_image: (H_image, W_image, 3), Augmented image
aug_depth_map: (H_depth, W_depth), Augmented depth map
aug_gt_boxes: (N, 7), Augmented 3D box labels in LiDAR coordinates [x, y, z, w, l, h, ry]
"""
# Randomly augment with 50% chance
enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5])

if enable:
# Flip images
aug_image = np.fliplr(image)
aug_depth_map = np.fliplr(depth_map)

# Flip 3D gt_boxes by flipping the centroids in image space
aug_gt_boxes = copy.copy(gt_boxes)
locations = aug_gt_boxes[:, :3]
img_pts, img_depth = calib.lidar_to_img(locations)
W = image.shape[1]
img_pts[:, 0] = W - img_pts[:, 0]
pts_rect = calib.img_to_rect(u=img_pts[:, 0], v=img_pts[:, 1], depth_rect=img_depth)
pts_lidar = calib.rect_to_lidar(pts_rect)
aug_gt_boxes[:, :3] = pts_lidar
aug_gt_boxes[:, 6] = -1 * aug_gt_boxes[:, 6]

else:
aug_image = image
aug_depth_map = depth_map
aug_gt_boxes = gt_boxes

return aug_image, aug_depth_map, aug_gt_boxes
24 changes: 23 additions & 1 deletion pcdet/datasets/augmentor/data_augmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __getstate__(self):

def __setstate__(self, d):
self.__dict__.update(d)

def random_world_flip(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.random_world_flip, config=config)
Expand Down Expand Up @@ -78,6 +78,25 @@ def random_world_scaling(self, data_dict=None, config=None):
data_dict['points'] = points
return data_dict

def random_image_flip(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.random_image_flip, config=config)
images = data_dict["images"]
depth_maps = data_dict["depth_maps"]
gt_boxes = data_dict['gt_boxes']
gt_boxes2d = data_dict["gt_boxes2d"]
calib = data_dict["calib"]
for cur_axis in config['ALONG_AXIS_LIST']:
assert cur_axis in ['horizontal']
images, depth_maps, gt_boxes = getattr(augmentor_utils, 'random_image_flip_%s' % cur_axis)(
images, depth_maps, gt_boxes, calib,
)

data_dict['images'] = images
data_dict['depth_maps'] = depth_maps
data_dict['gt_boxes'] = gt_boxes
return data_dict

def forward(self, data_dict):
"""
Args:
Expand All @@ -103,5 +122,8 @@ def forward(self, data_dict):
gt_boxes_mask = data_dict['gt_boxes_mask']
data_dict['gt_boxes'] = data_dict['gt_boxes'][gt_boxes_mask]
data_dict['gt_names'] = data_dict['gt_names'][gt_boxes_mask]
if 'gt_boxes2d' in data_dict:
data_dict['gt_boxes2d'] = data_dict['gt_boxes2d'][gt_boxes_mask]

data_dict.pop('gt_boxes_mask')
return data_dict
50 changes: 48 additions & 2 deletions pcdet/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def __init__(self, dataset_cfg=None, class_names=None, training=True, root_path=
self.total_epochs = 0
self._merge_all_iters_to_one_epoch = False

if hasattr(self.data_processor, "depth_downsample_factor"):
self.depth_downsample_factor = self.data_processor.depth_downsample_factor
else:
self.depth_downsample_factor = None

@property
def mode(self):
return 'train' if self.training else 'test'
Expand Down Expand Up @@ -97,7 +102,7 @@ def prepare_data(self, data_dict):
"""
Args:
data_dict:
points: (N, 3 + C_in)
points: optional, (N, 3 + C_in)
gt_boxes: optional, (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]
gt_names: optional, (N), string
...
Expand Down Expand Up @@ -133,7 +138,11 @@ def prepare_data(self, data_dict):
gt_boxes = np.concatenate((data_dict['gt_boxes'], gt_classes.reshape(-1, 1).astype(np.float32)), axis=1)
data_dict['gt_boxes'] = gt_boxes

data_dict = self.point_feature_encoder.forward(data_dict)
if data_dict.get('gt_boxes2d', None) is not None:
data_dict['gt_boxes2d'] = data_dict['gt_boxes2d'][selected]

if data_dict.get('points', None) is not None:
data_dict = self.point_feature_encoder.forward(data_dict)

data_dict = self.data_processor.forward(
data_dict=data_dict
Expand Down Expand Up @@ -172,6 +181,43 @@ def collate_batch(batch_list, _unused=False):
for k in range(batch_size):
batch_gt_boxes3d[k, :val[k].__len__(), :] = val[k]
ret[key] = batch_gt_boxes3d
elif key in ['gt_boxes2d']:
max_boxes = 0
max_boxes = max([len(x) for x in val])
batch_boxes2d = np.zeros((batch_size, max_boxes, val[0].shape[-1]), dtype=np.float32)
for k in range(batch_size):
if val[k].size > 0:
batch_boxes2d[k, :val[k].__len__(), :] = val[k]
ret[key] = batch_boxes2d
elif key in ["images", "depth_maps"]:
# Get largest image size (H, W)
max_h = 0
max_w = 0
for image in val:
max_h = max(max_h, image.shape[0])
max_w = max(max_w, image.shape[1])

# Change size of images
images = []
for image in val:
pad_h = common_utils.get_pad_params(desired_size=max_h, cur_size=image.shape[0])
pad_w = common_utils.get_pad_params(desired_size=max_w, cur_size=image.shape[1])
pad_width = (pad_h, pad_w)
# Pad with nan, to be replaced later in the pipeline.
pad_value = np.nan

if key == "images":
pad_width = (pad_h, pad_w, (0, 0))
elif key == "depth_maps":
pad_width = (pad_h, pad_w)

image_pad = np.pad(image,
pad_width=pad_width,
mode='constant',
constant_values=pad_value)

images.append(image_pad)
ret[key] = np.stack(images, axis=0)
else:
ret[key] = np.stack(val, axis=0)
except:
Expand Down
64 changes: 54 additions & 10 deletions pcdet/datasets/kitti/kitti_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from skimage import io

from . import kitti_utils
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from ...utils import box_utils, calibration_kitti, common_utils, object3d_kitti
from ..dataset import DatasetTemplate
Expand Down Expand Up @@ -64,6 +65,21 @@ def get_lidar(self, idx):
assert lidar_file.exists()
return np.fromfile(str(lidar_file), dtype=np.float32).reshape(-1, 4)

def get_image(self, idx):
"""
Loads image for a sample
Args:
idx: int, Sample index
Returns:
image: (H, W, 3), RGB Image
"""
img_file = self.root_split_path / 'image_2' / ('%s.png' % idx)
assert img_file.exists()
image = io.imread(img_file)
image = image.astype(np.float32)
image /= 255.0
return image

def get_image_shape(self, idx):
img_file = self.root_split_path / 'image_2' / ('%s.png' % idx)
assert img_file.exists()
Expand All @@ -74,6 +90,21 @@ def get_label(self, idx):
assert label_file.exists()
return object3d_kitti.get_objects_from_label(label_file)

def get_depth_map(self, idx):
"""
Loads depth map for a sample
Args:
idx: str, Sample index
Returns:
depth: (H, W), Depth map
"""
depth_file = self.root_split_path / 'depth_2' / ('%s.png' % idx)
assert depth_file.exists()
depth = io.imread(depth_file)
depth = depth.astype(np.float32)
depth /= 256.0
return depth

def get_calib(self, idx):
calib_file = self.root_split_path / 'calib' / ('%s.txt' % idx)
assert calib_file.exists()
Expand Down Expand Up @@ -277,7 +308,7 @@ def generate_single_sample_dict(batch_index, box_dict):
return pred_dict

calib = batch_dict['calib'][batch_index]
image_shape = batch_dict['image_shape'][batch_index]
image_shape = batch_dict['image_shape'][batch_index].cpu().numpy()
pred_boxes_camera = box_utils.boxes3d_lidar_to_kitti_camera(pred_boxes, calib)
pred_boxes_img = box_utils.boxes3d_kitti_camera_to_imageboxes(
pred_boxes_camera, calib, image_shape=image_shape
Expand Down Expand Up @@ -345,18 +376,11 @@ def __getitem__(self, index):
info = copy.deepcopy(self.kitti_infos[index])

sample_idx = info['point_cloud']['lidar_idx']

points = self.get_lidar(sample_idx)
calib = self.get_calib(sample_idx)

img_shape = info['image']['image_shape']
if self.dataset_cfg.FOV_POINTS_ONLY:
pts_rect = calib.lidar_to_rect(points[:, 0:3])
fov_flag = self.get_fov_flag(pts_rect, img_shape, calib)
points = points[fov_flag]
calib = self.get_calib(sample_idx)
get_item_list = self.dataset_cfg.get('GET_ITEM_LIST', ['points'])

input_dict = {
'points': points,
'frame_id': sample_idx,
'calib': calib,
}
Expand All @@ -373,10 +397,30 @@ def __getitem__(self, index):
'gt_names': gt_names,
'gt_boxes': gt_boxes_lidar
})
if "gt_boxes2d" in get_item_list:
input_dict['gt_boxes2d'] = annos["bbox"]

road_plane = self.get_road_plane(sample_idx)
if road_plane is not None:
input_dict['road_plane'] = road_plane

if "points" in get_item_list:
points = self.get_lidar(sample_idx)
if self.dataset_cfg.FOV_POINTS_ONLY:
pts_rect = calib.lidar_to_rect(points[:, 0:3])
fov_flag = self.get_fov_flag(pts_rect, img_shape, calib)
points = points[fov_flag]
input_dict['points'] = points

if "images" in get_item_list:
input_dict['images'] = self.get_image(sample_idx)

if "depth_maps" in get_item_list:
input_dict['depth_maps'] = self.get_depth_map(sample_idx)

if "calib_matricies" in get_item_list:
input_dict["trans_lidar_to_cam"], input_dict["trans_cam_to_img"] = kitti_utils.calib_to_matricies(calib)

data_dict = self.prepare_data(data_dict=input_dict)

data_dict['image_shape'] = img_shape
Expand Down
17 changes: 17 additions & 0 deletions pcdet/datasets/kitti/kitti_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,20 @@ def transform_annotations_to_kitti_format(annos, map_name_to_kitti=None, info_wi
anno['rotation_y'] = anno['alpha'] = np.zeros(0)

return annos


def calib_to_matricies(calib):
"""
Converts calibration object to transformation matricies
Args:
calib: calibration.Calibration, Calibration object
Returns
V2R: (4, 4), Lidar to rectified camera transformation matrix
P2: (3, 4), Camera projection matrix
"""
V2C = np.vstack((calib.V2C, np.array([0, 0, 0, 1], dtype=np.float32))) # (4, 4)
R0 = np.hstack((calib.R0, np.zeros((3, 1), dtype=np.float32))) # (3, 4)
R0 = np.vstack((R0, np.array([0, 0, 0, 1], dtype=np.float32))) # (4, 4)
V2R = R0 @ V2C
P2 = calib.P2
return V2R, P2
Loading

0 comments on commit aaf9cbe

Please sign in to comment.