diff --git a/configs/_base_/dataset/rgbcod.py b/configs/_base_/dataset/rgbcod.py deleted file mode 100644 index 2400ea8..0000000 --- a/configs/_base_/dataset/rgbcod.py +++ /dev/null @@ -1,14 +0,0 @@ -datasets = dict( - train=dict( - dataset_type="rgb_cod_tr", - shape=dict(h=256, w=256), - path=["cod10k_camo_tr"], - interp_cfg=dict(), - ), - test=dict( - dataset_type="rgb_cod_te", - shape=dict(h=256, w=256), - path=["camo_te", "chameleon", "cpd1k_te", "cod10k_te", "nc4k"], - interp_cfg=dict(), - ), -) diff --git a/configs/_base_/dataset/rgbsod.py b/configs/_base_/dataset/rgbsod.py deleted file mode 100644 index ac62538..0000000 --- a/configs/_base_/dataset/rgbsod.py +++ /dev/null @@ -1,14 +0,0 @@ -datasets = dict( - train=dict( - dataset_type="rgb_sod_tr", - shape=dict(h=256, w=256), - path=["dutstr"], - interp_cfg=dict(), - ), - test=dict( - dataset_type="rgb_sod_te", - shape=dict(h=256, w=256), - path=["pascal-s", "ecssd", "hku-is", "dutste", "dut-omron", "socte"], - interp_cfg=dict(), - ), -) diff --git a/configs/zoomnet/zoomnet.py b/configs/zoomnet/cod_zoomnet.py similarity index 85% rename from configs/zoomnet/zoomnet.py rename to configs/zoomnet/cod_zoomnet.py index 949a990..49c54d7 100644 --- a/configs/zoomnet/zoomnet.py +++ b/configs/zoomnet/cod_zoomnet.py @@ -2,13 +2,12 @@ "../_base_/common.py", "../_base_/train.py", "../_base_/test.py", - "../_base_/dataset/rgbcod.py", ] has_test = True deterministic = True use_custom_worker_init = False -model_name = 'ZoomNet' +model_name = "ZoomNet" train = dict( batch_size=8, @@ -52,9 +51,13 @@ train=dict( dataset_type="msi_cod_tr", shape=dict(h=384, w=384), + path=["cod10k_camo_tr"], + interp_cfg=dict(), ), test=dict( dataset_type="msi_cod_te", shape=dict(h=384, w=384), + path=["camo_te", "chameleon", "cpd1k_te", "cod10k_te", "nc4k"], + interp_cfg=dict(), ), ) diff --git a/configs/zoomnet/sod_zoomnet.py b/configs/zoomnet/sod_zoomnet.py new file mode 100644 index 0000000..047f9e1 --- /dev/null +++ b/configs/zoomnet/sod_zoomnet.py @@ -0,0 +1,67 @@ +_base_ = [ + "../_base_/common.py", + "../_base_/train.py", + "../_base_/test.py", +] + +has_test = True +deterministic = True +use_custom_worker_init = False +model_name = "ZoomNet" + +train = dict( + batch_size=22, + num_workers=4, + use_amp=True, + num_epochs=50, + epoch_based=True, + lr=0.05, + optimizer=dict( + mode="sgd", + set_to_none=True, + group_mode="finetune", + cfg=dict( + momentum=0.9, + weight_decay=5e-4, + nesterov=False, + ), + ), + sche_usebatch=True, + scheduler=dict( + warmup=dict( + num_iters=0, + initial_coef=0.01, + mode="linear", + ), + mode="f3", + cfg=dict( + lr_decay=0.9, + min_coef=None, + ), + ), + ms=dict( + enable=True, + extra_scales=[i / 352 for i in [224, 256, 288, 320, 352]], + ), +) + +test = dict( + batch_size=22, + num_workers=4, + show_bar=False, +) + +datasets = dict( + train=dict( + dataset_type="msi_sod_tr", + shape=dict(h=352, w=352), + path=["dutstr"], + interp_cfg=dict(), + ), + test=dict( + dataset_type="msi_sod_te", + shape=dict(h=352, w=352), + path=["pascal-s", "ecssd", "hku-is", "dutste", "dut-omron", "socte"], + interp_cfg=dict(), + ), +) diff --git a/dataset/__init__.py b/dataset/__init__.py index 960d1d7..2513857 100644 --- a/dataset/__init__.py +++ b/dataset/__init__.py @@ -1,3 +1,4 @@ # -*- coding: utf-8 -*- from .msi_cod import MSICOD_TestDataset, MSICOD_TrainDataset +from .msi_sod import MSISOD_TrainDataset, MSISOD_TestDataset diff --git a/dataset/msi_sod.py b/dataset/msi_sod.py new file mode 100644 index 0000000..c0546ed --- /dev/null +++ b/dataset/msi_sod.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +import random +from typing import Dict, List, Tuple + +from PIL import Image +from torchvision.transforms import transforms + +from dataset.base_dataset import _BaseSODDataset +from utils.builder import DATASETS +from utils.io.genaral import get_datasets_info_with_keys + + +class RandomHorizontallyFlip(object): + def __call__(self, img, mask): + if random.random() < 0.5: + return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) + return img, mask + + +class RandomRotate(object): + def __init__(self, degree): + self.degree = degree + + def __call__(self, img, mask): + rotate_degree = random.random() * 2 * self.degree - self.degree + return img.rotate(rotate_degree, Image.BILINEAR), mask.rotate(rotate_degree, Image.NEAREST) + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img, mask): + assert img.size == mask.size + for t in self.transforms: + img, mask = t(img, mask) + return img, mask + + +@DATASETS.register(name="msi_sod_te") +class MSISOD_TestDataset(_BaseSODDataset): + def __init__(self, root: Tuple[str, dict], shape: Dict[str, int], interp_cfg: Dict = None): + super().__init__(base_shape=shape, interp_cfg=interp_cfg) + self.datasets = get_datasets_info_with_keys(dataset_infos=[root], extra_keys=["mask"]) + self.total_image_paths = self.datasets["image"] + self.total_mask_paths = self.datasets["mask"] + + self.to_tensor = transforms.ToTensor() + self.to_normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + + def __getitem__(self, index): + image_path = self.total_image_paths[index] + mask_path = self.total_mask_paths[index] + image = Image.open(image_path).convert("RGB") + + base_h = self.base_shape["h"] + base_w = self.base_shape["w"] + image_1_5 = image.resize((int(base_h * 1.5), int(base_w * 1.5)), resample=Image.BILINEAR) + image_1_0 = image.resize((base_h, base_w), resample=Image.BILINEAR) + image_0_5 = image.resize((int(base_h * 0.5), int(base_w * 0.5)), resample=Image.BILINEAR) + image_1_5 = self.to_normalize(self.to_tensor(image_1_5)) + image_1_0 = self.to_normalize(self.to_tensor(image_1_0)) + image_0_5 = self.to_normalize(self.to_tensor(image_0_5)) + + return dict( + data={ + "image1.5": image_1_5, + "image1.0": image_1_0, + "image0.5": image_0_5, + }, + info=dict( + mask_path=mask_path, + ), + ) + + def __len__(self): + return len(self.total_image_paths) + + +@DATASETS.register(name="msi_sod_tr") +class MSISOD_TrainDataset(_BaseSODDataset): + def __init__( + self, root: List[Tuple[str, dict]], shape: Dict[str, int], extra_scales: List = None, interp_cfg: Dict = None + ): + super().__init__(base_shape=shape, extra_scales=extra_scales, interp_cfg=interp_cfg) + self.datasets = get_datasets_info_with_keys(dataset_infos=root, extra_keys=["mask"]) + self.total_image_paths = self.datasets["image"] + self.total_mask_paths = self.datasets["mask"] + + self.joint_transform = Compose([RandomHorizontallyFlip(), RandomRotate(10)]) + self.to_tensor = transforms.ToTensor() + self.image_transform = transforms.ColorJitter(0.1, 0.1, 0.1) + self.to_normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + + def __getitem__(self, index): + image_path = self.total_image_paths[index] + mask_path = self.total_mask_paths[index] + image = Image.open(image_path).convert("RGB") + mask = Image.open(mask_path).convert("L") + + image, mask = self.joint_transform(image, mask) + image = self.image_transform(image) + + base_h = self.base_shape["h"] + base_w = self.base_shape["w"] + image_1_5 = image.resize((int(base_h * 1.5), int(base_w * 1.5)), resample=Image.BILINEAR) + image_1_0 = image.resize((base_h, base_w), resample=Image.BILINEAR) + image_0_5 = image.resize((int(base_h * 0.5), int(base_w * 0.5)), resample=Image.BILINEAR) + image_1_5 = self.to_normalize(self.to_tensor(image_1_5)) + image_1_0 = self.to_normalize(self.to_tensor(image_1_0)) + image_0_5 = self.to_normalize(self.to_tensor(image_0_5)) + + mask_1_0 = mask.resize((base_h, base_w), resample=Image.BILINEAR) + mask_1_0 = self.to_tensor(mask_1_0) + mask_1_0 = mask_1_0.ge(0.5).float() # 二值化 + + return dict( + data={ + "image1.5": image_1_5, + "image1.0": image_1_0, + "image0.5": image_0_5, + "mask": mask_1_0, + } + ) + + def __len__(self): + return len(self.total_image_paths) diff --git a/methods/__init__.py b/methods/__init__.py index 9b9b116..c57b312 100755 --- a/methods/__init__.py +++ b/methods/__init__.py @@ -2,4 +2,4 @@ from .classic_methods.CPD import CPD_R50 from .classic_methods.HDFNet import HDFNet_Res50 from .classic_methods.MINet import MINet_Res50, MINet_VGG16 -from .zoomnet.zoomnet import ZoomNet +from .zoomnet.zoomnet import ZoomNet, ZoomNet_CK diff --git a/methods/zoomnet/zoomnet.py b/methods/zoomnet/zoomnet.py index ff9b0e7..054912e 100644 --- a/methods/zoomnet/zoomnet.py +++ b/methods/zoomnet/zoomnet.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F from torch import nn +from torch.utils.checkpoint import checkpoint from methods.module.base_model import BasicModelClass from methods.module.conv_block import ConvBNReLU @@ -137,6 +138,36 @@ def forward(self, x): return self.final_relu(out + x) +def get_coef(iter_percentage, method): + if method == "linear": + milestones = (0.3, 0.7) + coef_range = (0, 1) + min_point, max_point = min(milestones), max(milestones) + min_coef, max_coef = min(coef_range), max(coef_range) + if iter_percentage < min_point: + ual_coef = min_coef + elif iter_percentage > max_point: + ual_coef = max_coef + else: + ratio = (max_coef - min_coef) / (max_point - min_point) + ual_coef = ratio * (iter_percentage - min_point) + elif method == "cos": + coef_range = (0, 1) + min_coef, max_coef = min(coef_range), max(coef_range) + normalized_coef = (1 - np.cos(iter_percentage * np.pi)) / 2 + ual_coef = normalized_coef * (max_coef - min_coef) + min_coef + else: + ual_coef = 1.0 + return ual_coef + + +def cal_ual(seg_logits, seg_gts): + assert seg_logits.shape == seg_gts.shape, (seg_logits.shape, seg_gts.shape) + sigmoid_x = seg_logits.sigmoid() + loss_map = 1 - (2 * sigmoid_x - 1).abs().pow(2) + return loss_map.mean() + + @MODELS.register() class ZoomNet(BasicModelClass): def __init__(self): @@ -204,34 +235,8 @@ def test_forward(self, data, **kwargs): ) return output["seg"] - @staticmethod - def cal_smooth_sparse_loss(seg_logits, seg_gts): - assert seg_logits.shape == seg_gts.shape, (seg_logits.shape, seg_gts.shape) - sigmoid_x = seg_logits.sigmoid() - loss_map = 1 - (2 * sigmoid_x - 1).pow(2) - return loss_map.mean() - - def cal_loss(self, all_preds: dict, gts: torch.Tensor, iter_percentage: float = 0): - method = "cos" - if method == "linear": - milestones = (0.3, 0.7) - coef_range = (0, 1) - min_point, max_point = min(milestones), max(milestones) - min_coef, max_coef = min(coef_range), max(coef_range) - if iter_percentage < min_point: - dsl_coef = min_coef - elif iter_percentage > max_point: - dsl_coef = max_coef - else: - ratio = (max_coef - min_coef) / (max_point - min_point) - dsl_coef = ratio * (iter_percentage - min_point) - elif method == "cos": - coef_range = (0, 1) - min_coef, max_coef = min(coef_range), max(coef_range) - normalized_coef = (1 - np.cos(iter_percentage * np.pi)) / 2 - dsl_coef = normalized_coef * (max_coef - min_coef) + min_coef - else: - dsl_coef = 1.0 + def cal_loss(self, all_preds: dict, gts: torch.Tensor, method="cos", iter_percentage: float = 0): + ual_coef = get_coef(iter_percentage, method) losses = [] loss_str = [] @@ -243,10 +248,10 @@ def cal_loss(self, all_preds: dict, gts: torch.Tensor, iter_percentage: float = losses.append(sod_loss) loss_str.append(f"{name}_BCE: {sod_loss.item():.5f}") - cel_loss = self.cal_smooth_sparse_loss(seg_logits=preds, seg_gts=resized_gts) - cel_loss *= dsl_coef - losses.append(cel_loss) - loss_str.append(f"{name}_DSL_{dsl_coef:.5f}: {cel_loss.item():.5f}") + ual_loss = cal_ual(seg_logits=preds, seg_gts=resized_gts) + ual_loss *= ual_coef + losses.append(ual_loss) + loss_str.append(f"{name}_UAL_{ual_coef:.5f}: {ual_loss.item():.5f}") return sum(losses), " ".join(loss_str) def get_grouped_params(self): @@ -260,22 +265,48 @@ def get_grouped_params(self): param_groups.setdefault("retrained", []).append(param) return param_groups - @torch.no_grad() - def get_feature_maps(self, data): - l_scale = data["image1.5"] - m_scale = data["image1.0"] - s_scale = data["image0.5"] - l_trans_feats = self.encoder_translayer(l_scale) - m_trans_feats = self.encoder_translayer(m_scale) - s_trans_feats = self.encoder_translayer(s_scale) +@MODELS.register() +class ZoomNet_CK(ZoomNet): + def __init__(self): + super().__init__() + self.dummy = torch.ones(1, dtype=torch.float32, requires_grad=True) - end_m_trans_feats = [] - for layer_id, (l, m, s, layer) in enumerate( - zip(l_trans_feats, m_trans_feats, s_trans_feats, self.merge_layers) - ): - siu_outs = layer(l=l, m=m, s=s) - end_m_trans_feats.append(siu_outs) + def encoder(self, x, dummy_arg=None): + assert dummy_arg is not None + x0, x1, x2, x3, x4 = self.shared_encoder(x) + return x0, x1, x2, x3, x4 - seg_logits = self.seg_head(end_m_trans_feats) - return dict(seg=seg_logits) + def trans(self, x0, x1, x2, x3, x4): + x5, x4, x3, x2, x1 = self.translayer([x0, x1, x2, x3, x4]) + return x5, x4, x3, x2, x1 + + def decoder(self, x5, x4, x3, x2, x1): + x = self.d5(x5) + x = cus_sample(x, mode="scale", factors=2) + x = self.d4(x + x4) + x = cus_sample(x, mode="scale", factors=2) + x = self.d3(x + x3) + x = cus_sample(x, mode="scale", factors=2) + x = self.d2(x + x2) + x = cus_sample(x, mode="scale", factors=2) + x = self.d1(x + x1) + x = cus_sample(x, mode="scale", factors=2) + logits = self.out_layer_01(self.out_layer_00(x)) + return logits + + def body(self, l_scale, m_scale, s_scale): + l_trans_feats = checkpoint(self.encoder, l_scale, self.dummy) + m_trans_feats = checkpoint(self.encoder, m_scale, self.dummy) + s_trans_feats = checkpoint(self.encoder, s_scale, self.dummy) + l_trans_feats = checkpoint(self.trans, *l_trans_feats) + m_trans_feats = checkpoint(self.trans, *m_trans_feats) + s_trans_feats = checkpoint(self.trans, *s_trans_feats) + + feats = [] + for layer_idx, (l, m, s) in enumerate(zip(l_trans_feats, m_trans_feats, s_trans_feats)): + siu_outs = checkpoint(self.merge_layers[layer_idx], l, m, s) + feats.append(siu_outs) + + logits = checkpoint(self.decoder, *feats) + return dict(seg=logits) diff --git a/readme.md b/readme.md index 35c06f1..9335a58 100644 --- a/readme.md +++ b/readme.md @@ -7,7 +7,9 @@ ## Changelog -* 2020/3/28: Initialize the repository. +* 2022-03-04: + - Initialize the repository. + - Add the model and configuration file for SOD. ## Usage diff --git a/requirements.txt b/requirements.txt index ac5d765..8e4a8de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,8 @@ # Automatically generated by https://github.com/damnever/pigar. +# ZoomNet/dataset/msi_sod.py: 5 +Pillow == 8.1.2 + # ZoomNet/utils/configurator.py: 15 addict == 2.4.0 @@ -7,10 +10,9 @@ addict == 2.4.0 # ZoomNet/dataset/transforms/composition.py: 3,5 # ZoomNet/dataset/transforms/resize.py: 1 # ZoomNet/dataset/transforms/rotate.py: 3 -# ZoomNet/test.py: 8 albumentations == 1.0.0 -# ZoomNet/utils/pipeline/scheduler.py: 279 +# ZoomNet/utils/pipeline/scheduler.py: 280 # ZoomNet/utils/recorder/visualize_results.py: 1,2 matplotlib == 3.4.2 @@ -18,8 +20,8 @@ matplotlib == 3.4.2 # ZoomNet/main.py: 10 # ZoomNet/methods/classic_methods/CPD.py: 8 # ZoomNet/methods/zoomnet/zoomnet.py: 1 -# ZoomNet/test.py: 10 -# ZoomNet/utils/io/image.py: 11 +# ZoomNet/test.py: 8 +# ZoomNet/utils/io/image.py: 6 # ZoomNet/utils/misc.py: 13 # ZoomNet/utils/ops/array_ops.py: 5 # ZoomNet/utils/pipeline/scheduler.py: 12 @@ -33,8 +35,7 @@ nvidia_ml_py3 == 7.352.0 # ZoomNet/dataset/msi_cod.py: 6 # ZoomNet/dataset/transforms/resize.py: 2 # ZoomNet/dataset/transforms/rotate.py: 4 -# ZoomNet/test.py: 9 -# ZoomNet/utils/io/image.py: 10 +# ZoomNet/utils/io/image.py: 5 # ZoomNet/utils/ops/array_ops.py: 4 opencv_python_headless == 4.5.1.48 @@ -50,11 +51,10 @@ scipy == 1.6.2 # ZoomNet/methods/classic_methods/HDFNet.py: 5 # ZoomNet/methods/classic_methods/MINet.py: 2 -# ZoomNet/methods/module/conv_block.py: 8 +# ZoomNet/methods/module/conv_block.py: 7 # ZoomNet/methods/zoomnet/zoomnet.py: 2 timm == 0.4.12 -# ZoomNet/check_model.py: 5 # ZoomNet/dataset/base_dataset.py: 9,10 # ZoomNet/dataset/msi_cod.py: 7 # ZoomNet/main.py: 11 @@ -63,30 +63,30 @@ timm == 0.4.12 # ZoomNet/methods/classic_methods/HDFNet.py: 6,7,8 # ZoomNet/methods/classic_methods/MINet.py: 3,4,5 # ZoomNet/methods/module/base_model.py: 7 -# ZoomNet/methods/module/conv_block.py: 7 -# ZoomNet/methods/zoomnet/zoomnet.py: 3,4,5 -# ZoomNet/test.py: 11,12 -# ZoomNet/utils/io/params.py: 8 +# ZoomNet/methods/module/conv_block.py: 6 +# ZoomNet/methods/zoomnet/zoomnet.py: 3,4,5,6 +# ZoomNet/test.py: 9 +# ZoomNet/utils/io/params.py: 8,9 # ZoomNet/utils/misc.py: 14,15,16 # ZoomNet/utils/ops/module_ops.py: 7 -# ZoomNet/utils/ops/tensor_ops.py: 8,9 +# ZoomNet/utils/ops/tensor_ops.py: 7,8 # ZoomNet/utils/pipeline/dataloader.py: 8 -# ZoomNet/utils/pipeline/ema.py: 8,9 -# ZoomNet/utils/pipeline/optimizer.py: 9 +# ZoomNet/utils/pipeline/ema.py: 7,8 +# ZoomNet/utils/pipeline/optimizer.py: 8 # ZoomNet/utils/pipeline/scheduler.py: 13 # ZoomNet/utils/recorder/tensorboard.py: 6 # ZoomNet/utils/recorder/visualize_results.py: 4 torch == 1.8.1 +# ZoomNet/dataset/msi_sod.py: 6 # ZoomNet/methods/classic_methods/CMWNet.py: 8 # ZoomNet/methods/classic_methods/CPD.py: 13 -# ZoomNet/utils/pipeline/optimizer.py: 7 # ZoomNet/utils/recorder/tensorboard.py: 7 # ZoomNet/utils/recorder/visualize_results.py: 5,6 torchvision == 0.9.1 # ZoomNet/main.py: 12 -# ZoomNet/test.py: 13 +# ZoomNet/test.py: 10 tqdm == 4.59.0 # ZoomNet/utils/pipeline/tta.py: 6 diff --git a/test.py b/test.py index 5f9d890..a7ad1ba 100755 --- a/test.py +++ b/test.py @@ -20,6 +20,7 @@ def parse_config(): parser.add_argument("--batch-size", type=int) parser.add_argument("--load-from", type=str) parser.add_argument("--save-path", type=str) + parser.add_argument("--minmax-results", action="store_true") parser.add_argument("--info", type=str) args = parser.parse_args() @@ -28,7 +29,7 @@ def parse_config(): if args.model_name is not None: config.model_name = args.model_name if args.batch_size is not None: - config.train.batch_size = args.batch_size + config.test.batch_size = args.batch_size if args.load_from is not None: config.load_from = args.load_from if args.info is not None: @@ -41,6 +42,7 @@ def parse_config(): print(f"{args.save_path} does not exist, create it.") os.makedirs(args.save_path) config.save_path = args.save_path + config.test.to_minmax = args.minmax_results with open(args.datasets_info, encoding="utf-8", mode="r") as f: datasets_info = json.load(f) @@ -110,7 +112,7 @@ def testing(model, cfg): pred_save_path = None for data_name, data_path, loader in pipeline.get_te_loader(cfg): if cfg.save_path: - pred_save_path = os.path.join(cfg.path.save, data_name) + pred_save_path = os.path.join(cfg.save_path, data_name) print(f"Results will be saved into {pred_save_path}") seg_results = test_once( model=model, diff --git a/test.sh b/test.sh index a1bd922..5c3ccf0 100755 --- a/test.sh +++ b/test.sh @@ -7,5 +7,16 @@ set -o pipefail # 确保只要一个子命令失败,整个管道命令就失 export CUDA_VISIBLE_DEVICES="$1" echo 'Excute the script on GPU: ' "$1" -python test.py --model-name ZoomNetV1 --batch-size 4 \ - --load-from ./output/ZoomNet_BS8_LR0.05_E40_H384_W384_OPMsgd_OPGMfinetune_SCf3_AMP_INFOdemo/pth/state_final.pth +echo 'For COD' +python test.py --model-name ZoomNet --batch-size 20 \ + --config ./configs/zoomnet/cod_zoomnet.py \ + --load-from output/ForSharing/cod_zoomnet_r50_bs8_e40_2022-03-04.pth \ + --save-path output/ForSharing/COD_Results \ + --minmax-results + +echo 'For SOD' +python test.py --model-name ZoomNet --batch-size 20 \ + --config ./configs/zoomnet/sod_zoomnet.py \ + --load-from output/ForSharing/sod_zoomnet_r50_bs22_e50_2022-03-04_fixed.pth \ + --save-path output/ForSharing/SOD_Results \ + --minmax-results diff --git a/tools/commands.txt b/tools/commands.txt index 8dc09e3..0b5c368 100644 --- a/tools/commands.txt +++ b/tools/commands.txt @@ -1 +1,2 @@ -main.py --model-name=ZoomNet --config=configs/zoomnet/zoomnet.py --datasets-info ./configs/_base_/dataset/dataset_configs.json --info demo +main.py --model-name=ZoomNet --config=configs/zoomnet/cod_zoomnet.py --datasets-info ./configs/_base_/dataset/dataset_configs.json +main.py --model-name=ZoomNet_CK --config=configs/zoomnet/sod_zoomnet.py --datasets-info ./configs/_base_/dataset/dataset_configs.json diff --git a/utils/io/genaral.py b/utils/io/genaral.py index 852657f..09038d7 100644 --- a/utils/io/genaral.py +++ b/utils/io/genaral.py @@ -32,7 +32,7 @@ def get_name_list_from_dir(path: str) -> list: return [os.path.splitext(x)[0] for x in os.listdir(path)] -def get_datasets_info_with_keys(dataset_infos: dict, extra_keys: list) -> dict: +def get_datasets_info_with_keys(dataset_infos: list, extra_keys: list) -> dict: """ 从给定的包含数据信息字典的列表中,依据给定的extra_kers和固定获取的key='image'来获取相应的路径 Args: @@ -85,7 +85,7 @@ def _get_info(dataset_info: dict, extra_keys: list, path_collection: defaultdict path_collection[k].append(os.path.join(infos[k]["dir"], name + infos[k]["ext"])) path_collection = defaultdict(list) - for dataset_name, dataset_info in dataset_infos.items(): + for dataset_name, dataset_info in dataset_infos: prev_num = len(path_collection["image"]) _get_info(dataset_info=dataset_info, extra_keys=extra_keys, path_collection=path_collection) curr_num = len(path_collection["image"]) diff --git a/utils/io/params.py b/utils/io/params.py index 23ab26d..4cea9e8 100644 --- a/utils/io/params.py +++ b/utils/io/params.py @@ -6,6 +6,7 @@ import os import torch +from torch import nn def save_params( @@ -111,7 +112,7 @@ def load_specific_params(load_path, names): return parmas_dict -def load_weight(load_path, model): +def load_weight(load_path, model: nn.Module): """ 从保存节点恢复模型 @@ -122,5 +123,11 @@ def load_weight(load_path, model): assert os.path.exists(load_path), load_path print(f"Loading weight '{load_path}'") - model.load_state_dict(torch.load(load_path, map_location="cpu")) + ckpt_dict = torch.load(load_path, map_location="cpu") + state_dict = model.state_dict() + ckpt_keys = ckpt_dict.keys() + state_keys = state_dict.keys() + print(f"Unique Keys in model: {sorted(set(state_keys).difference(ckpt_keys))}") + print(f"Unique Keys in ckpt: {sorted(set(ckpt_keys).difference(state_keys))}") + model.load_state_dict(ckpt_dict, strict=False) print(f"Loaded weight '{load_path}' " f"(only contains the net's weight)") diff --git a/utils/pipeline/dataloader.py b/utils/pipeline/dataloader.py index 56e8fa1..f77edb9 100644 --- a/utils/pipeline/dataloader.py +++ b/utils/pipeline/dataloader.py @@ -15,7 +15,7 @@ def get_tr_loader(cfg, shuffle=True, drop_last=True, pin_memory=True): registry_name="DATASETS", obj_name=cfg.datasets.train.dataset_type, obj_cfg=dict( - root=cfg.datasets.train.path, + root=[(name, path) for name, path in cfg.datasets.train.path.items()], shape=cfg.datasets.train.shape, extra_scales=cfg.train.ms.extra_scales if cfg.train.ms.enable else None, interp_cfg=cfg.datasets.train.get("interp_cfg", None), diff --git a/utils/recorder/metric_caller.py b/utils/recorder/metric_caller.py index 1811982..3fa3617 100644 --- a/utils/recorder/metric_caller.py +++ b/utils/recorder/metric_caller.py @@ -4,7 +4,7 @@ # @GitHub : https://github.com/lartpang import numpy as np -from py_sod_metrics import Emeasure, Fmeasure, MAE, Smeasure, WeightedFmeasure +from py_sod_metrics.sod_metrics import Emeasure, Fmeasure, MAE, Smeasure, WeightedFmeasure class CalTotalMetric(object):