diff --git a/configs/rscup/htc_next_3s_ft.py b/configs/rscup/htc_next_3s_ft.py index a5be38c..7c1c2af 100644 --- a/configs/rscup/htc_next_3s_ft.py +++ b/configs/rscup/htc_next_3s_ft.py @@ -309,6 +309,6 @@ dist_params = dict(backend='nccl') log_level = 'INFO' work_dir = './work_dirs/htc_next_3s_ft' -load_from = None +load_from = "./work_dirs/htc_next_3s/epoch_10.pth" resume_from = None workflow = [('train', 1)] diff --git a/mmdet/models/detectors/htc.py b/mmdet/models/detectors/htc.py index 9482c99..c22c4f6 100644 --- a/mmdet/models/detectors/htc.py +++ b/mmdet/models/detectors/htc.py @@ -116,12 +116,10 @@ def _mask_forward_train(self, semantic_feat=None): mask_roi_extractor = self.mask_roi_extractor[stage] mask_head = self.mask_head[stage] - ic(len(sampling_results)) pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results]) mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs], pos_rois) - ic(mask_feats.shape) # semantic feature fusion # element-wise sum for original features and pooled semantic features diff --git a/tools/merge_result.py b/tools/merge_result.py index 68fbb0e..5433527 100644 --- a/tools/merge_result.py +++ b/tools/merge_result.py @@ -116,11 +116,17 @@ def get_all_ann(filename, result, img_prefix, size, CLASS_NUM=18): def merge_result(config_file, result_file, anno_file, img_prefix, out_file=None, CLASS_NUM=18): - cfg = mmcv.Config.fromfile(config_file) results = mmcv.load(result_file) - dataset = build_dataset(cfg.data.test) - print(cfg.data.test) - img_infos = dataset.load_annotations(anno_file) + coco = COCO(anno_file) + img_ids = coco.getImgIds() + img_infos = [] + for i in img_ids: + info = coco.loadImgs([i])[0] + info['filename'] = info['file_name'] + info['height'] = int(info['height']) + info['width'] = int(info['width']) + img_infos.append(info) + ann = {} rets = [] pbar = mmcv.ProgressBar(len(results)) @@ -128,7 +134,7 @@ def merge_result(config_file, result_file, anno_file, img_prefix, out_file=None, def update(*a): pbar.update() - p = Pool(18) + p = Pool(8) for i in range(len(results)): filename = img_infos[i]['filename'] h = img_infos[i]['height'] @@ -429,14 +435,14 @@ def calcoverlaps(BBGT_keep, bb): def evaluate(iou_thresh): - p = Pool(18) - detpath = "./result/val_temp/{}.txt" + p = Pool(8) + detpath = "./result/detection/{}.txt" CLASS = ['tennis-court', 'container-crane', 'storage-tank', 'baseball-diamond', 'plane', 'ground-track-field', 'helicopter', 'airport', 'harbor', 'ship', 'large-vehicle', 'swimming-pool', 'soccer-ball-field', 'roundabout', 'basketball-court', 'bridge', 'small-vehicle', 'helipad'] iou_thresh = 0.5 aps = [] - coco = COCO("/home/xfr/rssid/data/annotation/annos_rscup_val.json") + coco = COCO("./gt_val.json") for classname in CLASS: aps.append(p.apply_async(eval_warpper, args=(classname, detpath, iou_thresh, coco))) ret = [] @@ -450,13 +456,12 @@ def evaluate(iou_thresh): CLASSES = ['tennis-court', 'container-crane', 'storage-tank', 'baseball-diamond', 'plane', 'ground-track-field', 'helicopter', 'airport', 'harbor', 'ship', 'large-vehicle', 'swimming-pool', 'soccer-ball-field', 'roundabout', 'basketball-court', 'bridge', 'small-vehicle', 'helipad'] - config_file = "./configs/rs_cascade_mask_rcnn_r50_fpn_ohem.py" result_file = "./batch_3s.pkl" - anno_file = "/home/xfr/mmdetection/data/rscup/annotation/annos_rscup_val.json" + anno_file = "./data/rscup/annotation/annos_rscup_val.json" out_file = "./result/eval_temp.pkl" img_prefix = "./data/rscup/val/" ann = merge_result(config_file, result_file, anno_file, img_prefix, out_file) ann = nms(ann, "poly", 0.5) mmcv.dump(ann, "./result/post_nms.pkl") - generate_submit(ann, "val_temp", CLASSES) - evaluate(0.5) \ No newline at end of file + generate_submit(ann, "detection", CLASSES) + #evaluate(0.5) \ No newline at end of file diff --git a/tools/requirements.txt b/tools/requirements.txt index 8cd31ea..14f7c29 100644 --- a/tools/requirements.txt +++ b/tools/requirements.txt @@ -67,7 +67,7 @@ matplotlib==3.0.1 mistune==0.7.4 mmcv==0.2.8 mmd==0.2.6 --e git+git@github.com:cizhenshi/mm.git@5d30d2907180eff5c4b3fc8f90771ddb16ea202d#egg=mmdet +-e git+git@github.com:cizhenshi/mm.git@0251178e52ceda1b4dc14ac2b1c56ae87ee93ff5#egg=mmdet msgpack==0.5.6 msgpack-numpy==0.4.3.2 munch==2.3.2