Skip to content

Commit

Permalink
Add the model and configuration file for SOD.
Browse files Browse the repository at this point in the history
  • Loading branch information
lartpang committed Mar 4, 2022
1 parent 31bc2e2 commit d21a3fc
Show file tree
Hide file tree
Showing 17 changed files with 332 additions and 108 deletions.
14 changes: 0 additions & 14 deletions configs/_base_/dataset/rgbcod.py

This file was deleted.

14 changes: 0 additions & 14 deletions configs/_base_/dataset/rgbsod.py

This file was deleted.

7 changes: 5 additions & 2 deletions configs/zoomnet/zoomnet.py → configs/zoomnet/cod_zoomnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
"../_base_/common.py",
"../_base_/train.py",
"../_base_/test.py",
"../_base_/dataset/rgbcod.py",
]

has_test = True
deterministic = True
use_custom_worker_init = False
model_name = 'ZoomNet'
model_name = "ZoomNet"

train = dict(
batch_size=8,
Expand Down Expand Up @@ -52,9 +51,13 @@
train=dict(
dataset_type="msi_cod_tr",
shape=dict(h=384, w=384),
path=["cod10k_camo_tr"],
interp_cfg=dict(),
),
test=dict(
dataset_type="msi_cod_te",
shape=dict(h=384, w=384),
path=["camo_te", "chameleon", "cpd1k_te", "cod10k_te", "nc4k"],
interp_cfg=dict(),
),
)
67 changes: 67 additions & 0 deletions configs/zoomnet/sod_zoomnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
_base_ = [
"../_base_/common.py",
"../_base_/train.py",
"../_base_/test.py",
]

has_test = True
deterministic = True
use_custom_worker_init = False
model_name = "ZoomNet"

train = dict(
batch_size=22,
num_workers=4,
use_amp=True,
num_epochs=50,
epoch_based=True,
lr=0.05,
optimizer=dict(
mode="sgd",
set_to_none=True,
group_mode="finetune",
cfg=dict(
momentum=0.9,
weight_decay=5e-4,
nesterov=False,
),
),
sche_usebatch=True,
scheduler=dict(
warmup=dict(
num_iters=0,
initial_coef=0.01,
mode="linear",
),
mode="f3",
cfg=dict(
lr_decay=0.9,
min_coef=None,
),
),
ms=dict(
enable=True,
extra_scales=[i / 352 for i in [224, 256, 288, 320, 352]],
),
)

test = dict(
batch_size=22,
num_workers=4,
show_bar=False,
)

datasets = dict(
train=dict(
dataset_type="msi_sod_tr",
shape=dict(h=352, w=352),
path=["dutstr"],
interp_cfg=dict(),
),
test=dict(
dataset_type="msi_sod_te",
shape=dict(h=352, w=352),
path=["pascal-s", "ecssd", "hku-is", "dutste", "dut-omron", "socte"],
interp_cfg=dict(),
),
)
1 change: 1 addition & 0 deletions dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-

from .msi_cod import MSICOD_TestDataset, MSICOD_TrainDataset
from .msi_sod import MSISOD_TrainDataset, MSISOD_TestDataset
127 changes: 127 additions & 0 deletions dataset/msi_sod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-
import random
from typing import Dict, List, Tuple

from PIL import Image
from torchvision.transforms import transforms

from dataset.base_dataset import _BaseSODDataset
from utils.builder import DATASETS
from utils.io.genaral import get_datasets_info_with_keys


class RandomHorizontallyFlip(object):
def __call__(self, img, mask):
if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT)
return img, mask


class RandomRotate(object):
def __init__(self, degree):
self.degree = degree

def __call__(self, img, mask):
rotate_degree = random.random() * 2 * self.degree - self.degree
return img.rotate(rotate_degree, Image.BILINEAR), mask.rotate(rotate_degree, Image.NEAREST)


class Compose(object):
def __init__(self, transforms):
self.transforms = transforms

def __call__(self, img, mask):
assert img.size == mask.size
for t in self.transforms:
img, mask = t(img, mask)
return img, mask


@DATASETS.register(name="msi_sod_te")
class MSISOD_TestDataset(_BaseSODDataset):
def __init__(self, root: Tuple[str, dict], shape: Dict[str, int], interp_cfg: Dict = None):
super().__init__(base_shape=shape, interp_cfg=interp_cfg)
self.datasets = get_datasets_info_with_keys(dataset_infos=[root], extra_keys=["mask"])
self.total_image_paths = self.datasets["image"]
self.total_mask_paths = self.datasets["mask"]

self.to_tensor = transforms.ToTensor()
self.to_normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

def __getitem__(self, index):
image_path = self.total_image_paths[index]
mask_path = self.total_mask_paths[index]
image = Image.open(image_path).convert("RGB")

base_h = self.base_shape["h"]
base_w = self.base_shape["w"]
image_1_5 = image.resize((int(base_h * 1.5), int(base_w * 1.5)), resample=Image.BILINEAR)
image_1_0 = image.resize((base_h, base_w), resample=Image.BILINEAR)
image_0_5 = image.resize((int(base_h * 0.5), int(base_w * 0.5)), resample=Image.BILINEAR)
image_1_5 = self.to_normalize(self.to_tensor(image_1_5))
image_1_0 = self.to_normalize(self.to_tensor(image_1_0))
image_0_5 = self.to_normalize(self.to_tensor(image_0_5))

return dict(
data={
"image1.5": image_1_5,
"image1.0": image_1_0,
"image0.5": image_0_5,
},
info=dict(
mask_path=mask_path,
),
)

def __len__(self):
return len(self.total_image_paths)


@DATASETS.register(name="msi_sod_tr")
class MSISOD_TrainDataset(_BaseSODDataset):
def __init__(
self, root: List[Tuple[str, dict]], shape: Dict[str, int], extra_scales: List = None, interp_cfg: Dict = None
):
super().__init__(base_shape=shape, extra_scales=extra_scales, interp_cfg=interp_cfg)
self.datasets = get_datasets_info_with_keys(dataset_infos=root, extra_keys=["mask"])
self.total_image_paths = self.datasets["image"]
self.total_mask_paths = self.datasets["mask"]

self.joint_transform = Compose([RandomHorizontallyFlip(), RandomRotate(10)])
self.to_tensor = transforms.ToTensor()
self.image_transform = transforms.ColorJitter(0.1, 0.1, 0.1)
self.to_normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

def __getitem__(self, index):
image_path = self.total_image_paths[index]
mask_path = self.total_mask_paths[index]
image = Image.open(image_path).convert("RGB")
mask = Image.open(mask_path).convert("L")

image, mask = self.joint_transform(image, mask)
image = self.image_transform(image)

base_h = self.base_shape["h"]
base_w = self.base_shape["w"]
image_1_5 = image.resize((int(base_h * 1.5), int(base_w * 1.5)), resample=Image.BILINEAR)
image_1_0 = image.resize((base_h, base_w), resample=Image.BILINEAR)
image_0_5 = image.resize((int(base_h * 0.5), int(base_w * 0.5)), resample=Image.BILINEAR)
image_1_5 = self.to_normalize(self.to_tensor(image_1_5))
image_1_0 = self.to_normalize(self.to_tensor(image_1_0))
image_0_5 = self.to_normalize(self.to_tensor(image_0_5))

mask_1_0 = mask.resize((base_h, base_w), resample=Image.BILINEAR)
mask_1_0 = self.to_tensor(mask_1_0)
mask_1_0 = mask_1_0.ge(0.5).float() # 二值化

return dict(
data={
"image1.5": image_1_5,
"image1.0": image_1_0,
"image0.5": image_0_5,
"mask": mask_1_0,
}
)

def __len__(self):
return len(self.total_image_paths)
2 changes: 1 addition & 1 deletion methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from .classic_methods.CPD import CPD_R50
from .classic_methods.HDFNet import HDFNet_Res50
from .classic_methods.MINet import MINet_Res50, MINet_VGG16
from .zoomnet.zoomnet import ZoomNet
from .zoomnet.zoomnet import ZoomNet, ZoomNet_CK
Loading

0 comments on commit d21a3fc

Please sign in to comment.