Skip to content

Commit

Permalink
Add the config about using checkpoint in PVTv2.
Browse files Browse the repository at this point in the history
  • Loading branch information
lartpang committed Oct 14, 2024
1 parent c3fc51c commit b2cb089
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 4 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,7 @@ python main_for_image.py --config configs/icod_train.py --pretrained --model-nam

1. Pretrain on COD10K-TR: `python main_for_image.py --config configs/icod_pretrain.py --info pretrain --model-name PvtV2B5_ZoomNeXt --pretrained`
2. Finetune on MoCA-Mask-TR: `python main_for_video.py --config configs/vcod_finetune.py --info finetune --model-name videoPvtV2B5_ZoomNeXt --load-from <PRETAINED_WEIGHT>`

> [!note]
> If you meets the OOM problem, you can try to reduce the batch size or switch on the `--use-checkpoint` flag:
> `python main_for_image.py/main_for_video.py <your config> --use-checkpoint`
3 changes: 2 additions & 1 deletion main_for_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def parse_cfg():
)
parser.add_argument("--evaluate", action="store_true")
parser.add_argument("--save-results", action="store_true")
parser.add_argument("--use-checkpoint", action="store_true")
parser.add_argument("--info", type=str)
args = parser.parse_args()

Expand Down Expand Up @@ -382,7 +383,7 @@ def main():
model_class = model_zoo.__dict__.get(cfg.model_name)
assert model_class is not None, "Please check your --model-name"
model_code = inspect.getsource(model_class)
model = model_class(num_frames=1, pretrained=cfg.pretrained)
model = model_class(num_frames=1, pretrained=cfg.pretrained, use_checkpoint=cfg.use_checkpoint)
LOGGER.info(model_code)
model.to(cfg.device)

Expand Down
3 changes: 2 additions & 1 deletion main_for_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def parse_cfg():
)
parser.add_argument("--evaluate", action="store_true")
parser.add_argument("--save-results", action="store_true")
parser.add_argument("--use-checkpoint", action="store_true")
parser.add_argument("--info", type=str)
args = parser.parse_args()

Expand Down Expand Up @@ -504,7 +505,7 @@ def main():
model_class = model_zoo.__dict__.get(cfg.model_name)
assert model_class is not None, "Please check your --model-name"
model_code = inspect.getsource(model_class)
model = model_class(num_frames=cfg.num_frames, pretrained=cfg.pretrained)
model = model_class(num_frames=cfg.num_frames, pretrained=cfg.pretrained, use_checkpoint=cfg.use_checkpoint)
LOGGER.info(model_code)
model.to(cfg.device)

Expand Down
6 changes: 4 additions & 2 deletions methods/zoomnext/zoomnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def get_grouped_params(self):


class RN50_ZoomNeXt(_ZoomNeXt_Base):
def __init__(self, pretrained=True, num_frames=1, input_norm=True, mid_dim=64, siu_groups=4, hmu_groups=6):
def __init__(
self, pretrained=True, num_frames=1, input_norm=True, mid_dim=64, siu_groups=4, hmu_groups=6, **kwargs
):
super().__init__()
self.encoder = timm.create_model(
model_name="resnet50", features_only=True, out_indices=range(5), pretrained=False
Expand Down Expand Up @@ -297,7 +299,7 @@ def get_grouped_params(self):


class EffB1_ZoomNeXt(_ZoomNeXt_Base):
def __init__(self, pretrained, num_frames=1, input_norm=True, mid_dim=64, siu_groups=4, hmu_groups=6):
def __init__(self, pretrained, num_frames=1, input_norm=True, mid_dim=64, siu_groups=4, hmu_groups=6, **kwargs):
super().__init__()
self.set_backbone(pretrained)

Expand Down

0 comments on commit b2cb089

Please sign in to comment.