diff --git a/README.md b/README.md index 150d7a2..54e9d90 100644 --- a/README.md +++ b/README.md @@ -177,3 +177,7 @@ python main_for_image.py --config configs/icod_train.py --pretrained --model-nam 1. Pretrain on COD10K-TR: `python main_for_image.py --config configs/icod_pretrain.py --info pretrain --model-name PvtV2B5_ZoomNeXt --pretrained` 2. Finetune on MoCA-Mask-TR: `python main_for_video.py --config configs/vcod_finetune.py --info finetune --model-name videoPvtV2B5_ZoomNeXt --load-from ` + +> [!note] +> If you meets the OOM problem, you can try to reduce the batch size or switch on the `--use-checkpoint` flag: +> `python main_for_image.py/main_for_video.py --use-checkpoint` diff --git a/main_for_image.py b/main_for_image.py index 329c610..3de6d3f 100644 --- a/main_for_image.py +++ b/main_for_image.py @@ -345,6 +345,7 @@ def parse_cfg(): ) parser.add_argument("--evaluate", action="store_true") parser.add_argument("--save-results", action="store_true") + parser.add_argument("--use-checkpoint", action="store_true") parser.add_argument("--info", type=str) args = parser.parse_args() @@ -382,7 +383,7 @@ def main(): model_class = model_zoo.__dict__.get(cfg.model_name) assert model_class is not None, "Please check your --model-name" model_code = inspect.getsource(model_class) - model = model_class(num_frames=1, pretrained=cfg.pretrained) + model = model_class(num_frames=1, pretrained=cfg.pretrained, use_checkpoint=cfg.use_checkpoint) LOGGER.info(model_code) model.to(cfg.device) diff --git a/main_for_video.py b/main_for_video.py index eab29a8..59efacd 100644 --- a/main_for_video.py +++ b/main_for_video.py @@ -467,6 +467,7 @@ def parse_cfg(): ) parser.add_argument("--evaluate", action="store_true") parser.add_argument("--save-results", action="store_true") + parser.add_argument("--use-checkpoint", action="store_true") parser.add_argument("--info", type=str) args = parser.parse_args() @@ -504,7 +505,7 @@ def main(): model_class = model_zoo.__dict__.get(cfg.model_name) assert model_class is not None, "Please check your --model-name" model_code = inspect.getsource(model_class) - model = model_class(num_frames=cfg.num_frames, pretrained=cfg.pretrained) + model = model_class(num_frames=cfg.num_frames, pretrained=cfg.pretrained, use_checkpoint=cfg.use_checkpoint) LOGGER.info(model_code) model.to(cfg.device) diff --git a/methods/zoomnext/zoomnext.py b/methods/zoomnext/zoomnext.py index f048665..6ed26db 100644 --- a/methods/zoomnext/zoomnext.py +++ b/methods/zoomnext/zoomnext.py @@ -86,7 +86,9 @@ def get_grouped_params(self): class RN50_ZoomNeXt(_ZoomNeXt_Base): - def __init__(self, pretrained=True, num_frames=1, input_norm=True, mid_dim=64, siu_groups=4, hmu_groups=6): + def __init__( + self, pretrained=True, num_frames=1, input_norm=True, mid_dim=64, siu_groups=4, hmu_groups=6, **kwargs + ): super().__init__() self.encoder = timm.create_model( model_name="resnet50", features_only=True, out_indices=range(5), pretrained=False @@ -297,7 +299,7 @@ def get_grouped_params(self): class EffB1_ZoomNeXt(_ZoomNeXt_Base): - def __init__(self, pretrained, num_frames=1, input_norm=True, mid_dim=64, siu_groups=4, hmu_groups=6): + def __init__(self, pretrained, num_frames=1, input_norm=True, mid_dim=64, siu_groups=4, hmu_groups=6, **kwargs): super().__init__() self.set_backbone(pretrained)