diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..eb796ac
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,179 @@
+### JupyterNotebooks template
+# gitignore template for Jupyter Notebooks
+# website: http://jupyter.org/
+
+.ipynb_checkpoints
+*/.ipynb_checkpoints/*
+
+# IPython
+profile_default/
+ipython_config.py
+
+# Remove previous ipynb_checkpoints
+# git rm -r .ipynb_checkpoints/
+
+### Python template
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+.idea/
+.vscode/
+weights/
+outputs/
+pretrained/
+dataset.yaml
+
+# Custom Scripts
+*.ps1
+*.sh
diff --git a/README.md b/README.md
index ab7d4cf..150d7a2 100644
--- a/README.md
+++ b/README.md
@@ -8,13 +8,12 @@
```bibtex
-@misc{ZoomNeXt,
- title={ZoomNeXt: A Unified Collaborative Pyramid Network for Camouflaged Object Detection},
- author={Youwei Pang and Xiaoqi Zhao and Tian-Zhu Xiang and Lihe Zhang and Huchuan Lu},
- year={2023},
- eprint={2310.20208},
- archivePrefix={arXiv},
- primaryClass={cs.CV}
+@ARTICLE {ZoomNeXt,
+ title = {ZoomNeXt: A Unified Collaborative Pyramid Network for Camouflaged Object Detection},
+ author ={Youwei Pang and Xiaoqi Zhao and Tian-Zhu Xiang and Lihe Zhang and Huchuan Lu},
+ journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence},
+ year = {2024},
+ doi = {10.1109/TPAMI.2024.3417329},
}
```
@@ -29,19 +28,152 @@
### Weights
-| Backbone | CAMO-TE | | | CHAMELEON | | | COD10K-TE | | | NC4K | | | Links |
-| ---------------- | ------- | -------------------- | ----- | --------- | -------------------- | ----- | --------- | -------------------- | ----- | ----- | -------------------- | ----- | ------------------------------------------------------------------------------------------------------- |
-| | $S_m$ | $F^{\omega}_{\beta}$ | MAE | $S_m$ | $F^{\omega}_{\beta}$ | MAE | $S_m$ | $F^{\omega}_{\beta}$ | MAE | $S_m$ | $F^{\omega}_{\beta}$ | MAE | |
-| ResNet-50 | 0.833 | 0.774 | 0.065 | 0.908 | 0.858 | 0.021 | 0.861 | 0.768 | 0.026 | 0.874 | 0.816 | 0.037 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/resnet50-zoomnext.pth) |
-| EfficientNet-B4 | 0.867 | 0.824 | 0.046 | 0.911 | 0.865 | 0.020 | 0.875 | 0.797 | 0.021 | 0.884 | 0.837 | 0.032 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/eff-b4-zoomnext.pth) |
-| PVTv2-B2 | 0.874 | 0.839 | 0.047 | 0.922 | 0.884 | 0.017 | 0.887 | 0.818 | 0.019 | 0.892 | 0.852 | 0.030 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b2-zoomnext.pth) |
-| PVTv2-B3 | 0.885 | 0.854 | 0.042 | 0.927 | 0.898 | 0.017 | 0.895 | 0.829 | 0.018 | 0.900 | 0.861 | 0.028 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b3-zoomnext.pth) |
-| PVTv2-B4 | 0.888 | 0.859 | 0.040 | 0.925 | 0.897 | 0.016 | 0.898 | 0.838 | 0.017 | 0.900 | 0.865 | 0.028 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b4-zoomnext.pth) |
-| PVTv2-B5 | 0.889 | 0.857 | 0.041 | 0.924 | 0.885 | 0.018 | 0.898 | 0.827 | 0.018 | 0.903 | 0.863 | 0.028 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b5-zoomnext.pth) |
-| EfficientNet-B1 | | | | | | | | | | | | | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/eff-b1-zoomnext.pth) |
-| ConvNeXtV2-Large | | | | | | | | | | | | | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/convnextv2-l-zoomnext.pth) |
-
-| Backbone | CAD | | | | | MoCA-Mask-TE | | | | | Links |
-| --------------- | ----- | -------------------- | ----- | ----- | ----- | ------------ | -------------------- | ----- | ----- | ----- | ---------------------------------------------------------------------------------------------------------- |
-| | $S_m$ | $F^{\omega}_{\beta}$ | MAE | mDice | mIoU | $S_m$ | $F^{\omega}_{\beta}$ | MAE | mDice | mIoU | |
-| PVTv2-B5 (T=5) | 0.757 | 0.593 | 0.020 | 0.599 | 0.510 | 0.734 | 0.476 | 0.010 | 0.497 | 0.422 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b5-5frame-zoomnext.pth) |
+| Backbone | CAMO-TE | | | CHAMELEON | | | COD10K-TE | | | NC4K | | | Links |
+| --------------- | ------- | -------------------- | ----- | --------- | -------------------- | ----- | --------- | -------------------- | ----- | ----- | -------------------- | ----- | --------------------------------------------------------------------------------------------------- |
+| | $S_m$ | $F^{\omega}_{\beta}$ | MAE | $S_m$ | $F^{\omega}_{\beta}$ | MAE | $S_m$ | $F^{\omega}_{\beta}$ | MAE | $S_m$ | $F^{\omega}_{\beta}$ | MAE | |
+| ResNet-50 | 0.833 | 0.774 | 0.065 | 0.908 | 0.858 | 0.021 | 0.861 | 0.768 | 0.026 | 0.874 | 0.816 | 0.037 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/resnet50-zoomnext.pth) |
+| EfficientNet-B1 | 0.848 | 0.803 | 0.056 | 0.916 | 0.870 | 0.020 | 0.863 | 0.773 | 0.024 | 0.876 | 0.823 | 0.036 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/eff-b1-zoomnext.pth) |
+| EfficientNet-B4 | 0.867 | 0.824 | 0.046 | 0.911 | 0.865 | 0.020 | 0.875 | 0.797 | 0.021 | 0.884 | 0.837 | 0.032 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/eff-b4-zoomnext.pth) |
+| PVTv2-B2 | 0.874 | 0.839 | 0.047 | 0.922 | 0.884 | 0.017 | 0.887 | 0.818 | 0.019 | 0.892 | 0.852 | 0.030 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b2-zoomnext.pth) |
+| PVTv2-B3 | 0.885 | 0.854 | 0.042 | 0.927 | 0.898 | 0.017 | 0.895 | 0.829 | 0.018 | 0.900 | 0.861 | 0.028 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b3-zoomnext.pth) |
+| PVTv2-B4 | 0.888 | 0.859 | 0.040 | 0.925 | 0.897 | 0.016 | 0.898 | 0.838 | 0.017 | 0.900 | 0.865 | 0.028 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b4-zoomnext.pth) |
+| PVTv2-B5 | 0.889 | 0.857 | 0.041 | 0.924 | 0.885 | 0.018 | 0.898 | 0.827 | 0.018 | 0.903 | 0.863 | 0.028 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b5-zoomnext.pth) |
+
+| Backbone | CAD | | | | | MoCA-Mask-TE | | | | | Links |
+| -------------- | ----- | -------------------- | ----- | ----- | ----- | ------------ | -------------------- | ----- | ----- | ----- | ---------------------------------------------------------------------------------------------------------- |
+| | $S_m$ | $F^{\omega}_{\beta}$ | MAE | mDice | mIoU | $S_m$ | $F^{\omega}_{\beta}$ | MAE | mDice | mIoU | |
+| PVTv2-B5 (T=5) | 0.757 | 0.593 | 0.020 | 0.599 | 0.510 | 0.734 | 0.476 | 0.010 | 0.497 | 0.422 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b5-5frame-zoomnext.pth) |
+
+
+## Prepare Data
+
+Set all dataset information to the `dataset.yaml` as follows.
+
+
+
+Example of the config file (dataset.yaml):
+
+
+```yaml
+# VCOD Datasets
+moca_mask_tr:
+ {
+ root: "YOUR-VCOD-DATASETS-ROOT/MoCA-Mask/MoCA_Video/TrainDataset_per_sq",
+ image: { path: "*/Imgs", suffix: ".jpg" },
+ mask: { path: "*/GT", suffix: ".png" },
+ }
+moca_mask_te:
+ {
+ root: "YOUR-VCOD-DATASETS-ROOT/MoCA-Mask/MoCA_Video/TestDataset_per_sq",
+ image: { path: "*/Imgs", suffix: ".jpg" },
+ mask: { path: "*/GT", suffix: ".png" },
+ }
+cad:
+ {
+ root: "YOUR-VCOD-DATASETS-ROOT/CamouflagedAnimalDataset",
+ image: { path: "original_data/*/frames", suffix: ".png" },
+ mask: { path: "converted_mask/*/groundtruth", suffix: ".png" },
+ }
+
+# ICOD Datasets
+cod10k_tr:
+ {
+ root: "YOUR-ICOD-DATASETS-ROOT/Train/COD10K-TR",
+ image: { path: "Image", suffix: ".jpg" },
+ mask: { path: "Mask", suffix: ".png" },
+ }
+camo_tr:
+ {
+ root: "YOUR-ICOD-DATASETS-ROOT/Train/CAMO-TR",
+ image: { path: "Image", suffix: ".jpg" },
+ mask: { path: "Mask", suffix: ".png" },
+ }
+cod10k_te:
+ {
+ root: "YOUR-ICOD-DATASETS-ROOT/Test/COD10K-TE",
+ image: { path: "Image", suffix: ".jpg" },
+ mask: { path: "Mask", suffix: ".png" },
+ }
+camo_te:
+ {
+ root: "YOUR-ICOD-DATASETS-ROOT/Test/CAMO-TE",
+ image: { path: "Image", suffix: ".jpg" },
+ mask: { path: "Mask", suffix: ".png" },
+ }
+chameleon:
+ {
+ root: "YOUR-ICOD-DATASETS-ROOT/Test/CHAMELEON",
+ image: { path: "Image", suffix: ".jpg" },
+ mask: { path: "Mask", suffix: ".png" },
+ }
+nc4k:
+ {
+ root: "YOUR-ICOD-DATASETS-ROOT/Test/NC4K",
+ image: { path: "Imgs", suffix: ".jpg" },
+ mask: { path: "GT", suffix: ".png" },
+ }
+```
+
+
+## Install Requirements
+
+- torch==2.1.2
+- torchvision==0.16.2
+- Others: `pip install -r requirements.txt`
+
+## Evaluation
+
+```shell
+# ICOD
+python main_for_image.py --config configs/icod_train.py --model-name --evaluate --load-from
+# VCOD
+python main_for_video.py --config configs/vcod_finetune.py --model-name --evaluate --load-from
+```
+
+> [!note]
+>
+> Evaluating performance on the VCOD dataset directly using training scripts is not consistent with the paper.
+> This is because the evaluation approach in the paper continues the strategy of previous work [SLT-Net](https://github.com/XuelianCheng/SLT-Net), which adjusts the range of valid frames in the sequence.
+
+To get the results in our paper, you can use [PySODEvalToolkit](https://github.com/lartpang/PySODEvalToolkit) and use the similar command as:
+
+```shell
+python ./eval.py `
+ --dataset-json vcod-datasets.json `
+ --method-json vcod-methods.json `
+ --include-datasets CAD `
+ --include-methods videoPvtV2B5_ZoomNeXt `
+ --data-type video `
+ --valid-frame-start "0" `
+ --valid-frame-end "0" `
+ --metric-names "sm" "wfm" "mae" "fmeasure" "em" "dice" "iou"
+
+python ./eval.py `
+ --dataset-json vcod-datasets.json `
+ --method-json vcod-methods.json `
+ --include-datasets MOCA-MASK-TE `
+ --include-methods videoPvtV2B5_ZoomNeXt `
+ --data-type video `
+ --valid-frame-start "0" `
+ --valid-frame-end "-2" `
+ --metric-names "sm" "wfm" "mae" "fmeasure" "em" "dice" "iou"
+```
+
+## Training
+
+### Image Camouflaged Object Detection
+
+```shell
+python main_for_image.py --config configs/icod_train.py --pretrained --model-name EffB1_ZoomNeXt
+python main_for_image.py --config configs/icod_train.py --pretrained --model-name EffB4_ZoomNeXt
+python main_for_image.py --config configs/icod_train.py --pretrained --model-name PvtV2B2_ZoomNeXt
+python main_for_image.py --config configs/icod_train.py --pretrained --model-name PvtV2B3_ZoomNeXt
+python main_for_image.py --config configs/icod_train.py --pretrained --model-name PvtV2B4_ZoomNeXt
+python main_for_image.py --config configs/icod_train.py --pretrained --model-name PvtV2B5_ZoomNeXt
+python main_for_image.py --config configs/icod_train.py --pretrained --model-name RN50_ZoomNeXt
+```
+
+### Video Camouflaged Object Detection
+
+1. Pretrain on COD10K-TR: `python main_for_image.py --config configs/icod_pretrain.py --info pretrain --model-name PvtV2B5_ZoomNeXt --pretrained`
+2. Finetune on MoCA-Mask-TR: `python main_for_video.py --config configs/vcod_finetune.py --info finetune --model-name videoPvtV2B5_ZoomNeXt --load-from `
diff --git a/configs/icod_pretrain.py b/configs/icod_pretrain.py
new file mode 100644
index 0000000..1bb7337
--- /dev/null
+++ b/configs/icod_pretrain.py
@@ -0,0 +1,54 @@
+_base_ = ["icod_train.py"]
+
+has_test = False
+
+__BATCHSIZE = 4
+__NUM_EPOCHS = 150
+__NUM_TR_SAMPLES = 3040
+__ITER_PER_EPOCH = __NUM_TR_SAMPLES // __BATCHSIZE # drop_last is True
+__NUM_ITERS = __NUM_EPOCHS * __ITER_PER_EPOCH
+
+train = dict(
+ batch_size=__BATCHSIZE,
+ use_amp=True,
+ num_epochs=__NUM_EPOCHS,
+ lr=0.0001,
+ optimizer=dict(
+ mode="adam",
+ set_to_none=False,
+ group_mode="finetune",
+ cfg=dict(
+ weight_decay=0,
+ diff_factor=0.1,
+ ),
+ ),
+ sche_usebatch=True,
+ scheduler=dict(
+ warmup=dict(
+ num_iters=0,
+ initial_coef=0.01,
+ mode="linear",
+ ),
+ mode="step",
+ cfg=dict(
+ milestones=int(__NUM_ITERS * 2 / 3),
+ gamma=0.1,
+ ),
+ ),
+ bn=dict(
+ freeze_status=True,
+ freeze_affine=True,
+ freeze_encoder=False,
+ ),
+ data=dict(
+ shape=dict(h=384, w=384),
+ names=["cod10k_tr"],
+ ),
+)
+
+test = dict(
+ data=dict(
+ shape=dict(h=384, w=384),
+ names=[],
+ ),
+)
diff --git a/configs/icod_train.py b/configs/icod_train.py
new file mode 100644
index 0000000..a8f3358
--- /dev/null
+++ b/configs/icod_train.py
@@ -0,0 +1,63 @@
+has_test = True
+deterministic = True
+use_custom_worker_init = True
+log_interval = 20
+base_seed = 112358
+
+__BATCHSIZE = 4
+__NUM_EPOCHS = 150
+__NUM_TR_SAMPLES = 3040 + 1000
+__ITER_PER_EPOCH = __NUM_TR_SAMPLES // __BATCHSIZE # drop_last is True
+__NUM_ITERS = __NUM_EPOCHS * __ITER_PER_EPOCH
+
+train = dict(
+ batch_size=__BATCHSIZE,
+ num_workers=2,
+ use_amp=True,
+ num_epochs=__NUM_EPOCHS,
+ epoch_based=True,
+ num_iters=None,
+ lr=0.0001,
+ grad_acc_step=1,
+ optimizer=dict(
+ mode="adam",
+ set_to_none=False,
+ group_mode="finetune",
+ cfg=dict(
+ weight_decay=0,
+ diff_factor=0.1,
+ ),
+ ),
+ sche_usebatch=True,
+ scheduler=dict(
+ warmup=dict(
+ num_iters=0,
+ initial_coef=0.01,
+ mode="linear",
+ ),
+ mode="step",
+ cfg=dict(
+ milestones=int(__NUM_ITERS * 2 / 3),
+ gamma=0.1,
+ ),
+ ),
+ bn=dict(
+ freeze_status=True,
+ freeze_affine=True,
+ freeze_encoder=False,
+ ),
+ data=dict(
+ shape=dict(h=384, w=384),
+ names=["cod10k_tr", "camo_tr"],
+ ),
+)
+
+test = dict(
+ batch_size=__BATCHSIZE,
+ num_workers=2,
+ clip_range=None,
+ data=dict(
+ shape=dict(h=384, w=384),
+ names=["chameleon", "camo_te", "cod10k_te", "nc4k"],
+ ),
+)
diff --git a/configs/vcod_finetune.py b/configs/vcod_finetune.py
new file mode 100644
index 0000000..2c9fbd8
--- /dev/null
+++ b/configs/vcod_finetune.py
@@ -0,0 +1,44 @@
+_base_ = ["icod_train.py"]
+
+num_frames = 5
+
+__BATCHSIZE = 2
+
+train = dict(
+ batch_size=__BATCHSIZE,
+ use_amp=True,
+ num_epochs=10,
+ lr=0.0001,
+ optimizer=dict(
+ mode="adam",
+ set_to_none=False,
+ group_mode="finetune",
+ cfg=dict(
+ weight_decay=0,
+ diff_factor=0.1,
+ ),
+ ),
+ sche_usebatch=True,
+ scheduler=dict(
+ warmup=dict(num_iters=0),
+ mode="constant",
+ cfg=dict(coef=1),
+ ),
+ bn=dict(
+ freeze_status=True,
+ freeze_affine=True,
+ freeze_encoder=False,
+ ),
+ data=dict(
+ shape=dict(h=384, w=384),
+ names=["moca_mask_tr"],
+ ),
+)
+
+test = dict(
+ batch_size=__BATCHSIZE,
+ data=dict(
+ shape=dict(h=384, w=384),
+ names=["cad", "moca_mask_te"],
+ ),
+)
diff --git a/main_for_image.py b/main_for_image.py
new file mode 100644
index 0000000..329c610
--- /dev/null
+++ b/main_for_image.py
@@ -0,0 +1,404 @@
+# -*- coding: utf-8 -*-
+import argparse
+import datetime
+import inspect
+import logging
+import os
+import shutil
+import time
+
+import albumentations as A
+import colorlog
+import cv2
+import numpy as np
+import torch
+import yaml
+from mmengine import Config
+from torch.utils import data
+from tqdm import tqdm
+
+import methods as model_zoo
+from utils import io, ops, pipeline, pt_utils, py_utils, recorder
+
+LOGGER = logging.getLogger("main")
+LOGGER.propagate = False
+LOGGER.setLevel(level=logging.DEBUG)
+stream_handler = logging.StreamHandler()
+stream_handler.setLevel(logging.DEBUG)
+stream_handler.setFormatter(colorlog.ColoredFormatter("%(log_color)s[%(filename)s] %(reset)s%(message)s"))
+LOGGER.addHandler(stream_handler)
+
+
+class ImageTestDataset(data.Dataset):
+ def __init__(self, dataset_info: dict, shape: dict):
+ super().__init__()
+ self.shape = shape
+
+ image_path = os.path.join(dataset_info["root"], dataset_info["image"]["path"])
+ image_suffix = dataset_info["image"]["suffix"]
+ mask_path = os.path.join(dataset_info["root"], dataset_info["mask"]["path"])
+ mask_suffix = dataset_info["mask"]["suffix"]
+
+ image_names = [p[: -len(image_suffix)] for p in sorted(os.listdir(image_path)) if p.endswith(image_suffix)]
+ mask_names = [p[: -len(mask_suffix)] for p in sorted(os.listdir(mask_path)) if p.endswith(mask_suffix)]
+ valid_names = sorted(set(image_names).intersection(mask_names))
+ self.total_data_paths = [
+ (os.path.join(image_path, n) + image_suffix, os.path.join(mask_path, n) + mask_suffix) for n in valid_names
+ ]
+
+ def __getitem__(self, index):
+ image_path, mask_path = self.total_data_paths[index]
+ image = io.read_color_array(image_path)
+
+ base_h = self.shape["h"]
+ base_w = self.shape["w"]
+
+ images = ops.ms_resize(image, scales=(0.5, 1.0, 1.5), base_h=base_h, base_w=base_w)
+ image_s = torch.from_numpy(images[0]).div(255).permute(2, 0, 1)
+ image_m = torch.from_numpy(images[1]).div(255).permute(2, 0, 1)
+ image_l = torch.from_numpy(images[2]).div(255).permute(2, 0, 1)
+
+ return dict(
+ data={"image_s": image_s, "image_m": image_m, "image_l": image_l},
+ info=dict(mask_path=mask_path, group_name="image"),
+ )
+
+ def __len__(self):
+ return len(self.total_data_paths)
+
+
+class ImageTrainDataset(data.Dataset):
+ def __init__(self, dataset_infos: dict, shape: dict):
+ super().__init__()
+ self.shape = shape
+
+ self.total_data_paths = []
+ for dataset_name, dataset_info in dataset_infos.items():
+ image_path = os.path.join(dataset_info["root"], dataset_info["image"]["path"])
+ image_suffix = dataset_info["image"]["suffix"]
+ mask_path = os.path.join(dataset_info["root"], dataset_info["mask"]["path"])
+ mask_suffix = dataset_info["mask"]["suffix"]
+
+ image_names = [p[: -len(image_suffix)] for p in sorted(os.listdir(image_path)) if p.endswith(image_suffix)]
+ mask_names = [p[: -len(mask_suffix)] for p in sorted(os.listdir(mask_path)) if p.endswith(mask_suffix)]
+ valid_names = sorted(set(image_names).intersection(mask_names))
+ data_paths = [
+ (os.path.join(image_path, n) + image_suffix, os.path.join(mask_path, n) + mask_suffix)
+ for n in valid_names
+ ]
+ LOGGER.info(f"Length of {dataset_name}: {len(data_paths)}")
+ self.total_data_paths.extend(data_paths)
+
+ self.trains = A.Compose(
+ [
+ A.HorizontalFlip(p=0.5),
+ A.Rotate(limit=90, p=0.5, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REPLICATE),
+ A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
+ A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=10, p=0.5),
+ ]
+ )
+
+ def __getitem__(self, index):
+ image_path, mask_path = self.total_data_paths[index]
+ image = io.read_color_array(image_path)
+ mask = io.read_gray_array(mask_path, thr=0)
+ if image.shape[:2] != mask.shape:
+ h, w = mask.shape
+ image = ops.resize(image, height=h, width=w)
+
+ transformed = self.trains(image=image, mask=mask)
+ image = transformed["image"]
+ mask = transformed["mask"]
+
+ base_h = self.shape["h"]
+ base_w = self.shape["w"]
+
+ images = ops.ms_resize(image, scales=(0.5, 1.0, 1.5), base_h=base_h, base_w=base_w)
+ image_s = torch.from_numpy(images[0]).div(255).permute(2, 0, 1)
+ image_m = torch.from_numpy(images[1]).div(255).permute(2, 0, 1)
+ image_l = torch.from_numpy(images[2]).div(255).permute(2, 0, 1)
+
+ mask = ops.resize(mask, height=base_h, width=base_w)
+ mask = torch.from_numpy(mask).unsqueeze(0)
+
+ return dict(
+ data={
+ "image_s": image_s,
+ "image_m": image_m,
+ "image_l": image_l,
+ "mask": mask,
+ }
+ )
+
+ def __len__(self):
+ return len(self.total_data_paths)
+
+
+class Evaluator:
+ def __init__(self, device, metric_names, clip_range=None):
+ self.device = device
+ self.clip_range = clip_range
+ self.metric_names = metric_names
+
+ @torch.no_grad()
+ def eval(self, model, data_loader, save_path=""):
+ model.eval()
+ all_metrics = recorder.GroupedMetricRecorder(metric_names=self.metric_names)
+
+ for batch in tqdm(data_loader, total=len(data_loader), ncols=79, desc="[EVAL]"):
+ batch_images = pt_utils.to_device(batch["data"], device=self.device)
+ logits = model(data=batch_images) # B,1,H,W
+ probs = logits.sigmoid().squeeze(1).cpu().detach().numpy()
+ probs = probs - probs.min()
+ probs = probs / (probs.max() + 1e-8)
+
+ mask_paths = batch["info"]["mask_path"]
+ group_names = batch["info"]["group_name"]
+ for pred_idx, pred in enumerate(probs):
+ mask_path = mask_paths[pred_idx]
+ mask_array = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
+ mask_array[mask_array > 0] = 255
+ mask_h, mask_w = mask_array.shape
+ pred = ops.resize(pred, height=mask_h, width=mask_w)
+
+ if self.clip_range is not None:
+ pred = ops.clip_to_normalize(pred, clip_range=self.clip_range)
+
+ group_name = group_names[pred_idx]
+ if save_path: # 这里的save_path包含了数据集名字
+ ops.save_array_as_image(
+ data_array=pred,
+ save_name=os.path.basename(mask_path),
+ save_dir=os.path.join(save_path, group_name),
+ )
+
+ pred = (pred * 255).astype(np.uint8)
+ all_metrics.step(group_name=group_name, pre=pred, gt=mask_array, gt_path=mask_path)
+ return all_metrics.show()
+
+
+def test(model, cfg):
+ test_wrapper = Evaluator(device=cfg.device, metric_names=cfg.metric_names, clip_range=cfg.test.clip_range)
+
+ for te_name in cfg.test.data.names:
+ te_info = cfg.dataset_infos[te_name]
+ te_dataset = ImageTestDataset(dataset_info=te_info, shape=cfg.test.data.shape)
+ te_loader = data.DataLoader(
+ dataset=te_dataset, batch_size=cfg.test.batch_size, num_workers=cfg.test.num_workers, pin_memory=True
+ )
+ LOGGER.info(f"Testing with testset: {te_name}: {len(te_dataset)}")
+
+ if cfg.save_results:
+ save_path = os.path.join(cfg.path.save, te_name)
+ LOGGER.info(f"Results will be saved into {save_path}")
+ else:
+ save_path = ""
+
+ seg_results = test_wrapper.eval(model=model, data_loader=te_loader, save_path=save_path)
+ seg_results_str = ", ".join([f"{k}: {v:.03f}" for k, v in seg_results.items()])
+ LOGGER.info(f"({te_name}): {py_utils.mapping_to_str(te_info)}\n{seg_results_str}")
+
+
+def train(model, cfg):
+ tr_dataset = ImageTrainDataset(
+ dataset_infos={data_name: cfg.dataset_infos[data_name] for data_name in cfg.train.data.names},
+ shape=cfg.train.data.shape,
+ )
+ LOGGER.info(f"Total Length of Image Trainset: {len(tr_dataset)}")
+
+ tr_loader = data.DataLoader(
+ dataset=tr_dataset,
+ batch_size=cfg.train.batch_size,
+ num_workers=cfg.train.num_workers,
+ shuffle=True,
+ drop_last=True,
+ pin_memory=True,
+ worker_init_fn=pt_utils.customized_worker_init_fn if cfg.use_custom_worker_init else None,
+ )
+
+ counter = recorder.TrainingCounter(
+ epoch_length=len(tr_loader),
+ epoch_based=cfg.train.epoch_based,
+ num_epochs=cfg.train.num_epochs,
+ num_total_iters=cfg.train.num_iters,
+ )
+ optimizer = pipeline.construct_optimizer(
+ model=model,
+ initial_lr=cfg.train.lr,
+ mode=cfg.train.optimizer.mode,
+ group_mode=cfg.train.optimizer.group_mode,
+ cfg=cfg.train.optimizer.cfg,
+ )
+ scheduler = pipeline.Scheduler(
+ optimizer=optimizer,
+ num_iters=counter.num_total_iters,
+ epoch_length=counter.num_inner_iters,
+ scheduler_cfg=cfg.train.scheduler,
+ step_by_batch=cfg.train.sche_usebatch,
+ )
+ scheduler.record_lrs(param_groups=optimizer.param_groups)
+ scheduler.plot_lr_coef_curve(save_path=cfg.path.pth_log)
+ scaler = pipeline.Scaler(optimizer, cfg.train.use_amp, set_to_none=cfg.train.optimizer.set_to_none)
+
+ LOGGER.info(f"Scheduler:\n{scheduler}\nOptimizer:\n{optimizer}")
+
+ loss_recorder = recorder.HistoryBuffer()
+ iter_time_recorder = recorder.HistoryBuffer()
+
+ LOGGER.info(f"Image Mean: {model.normalizer.mean.flatten()}, Image Std: {model.normalizer.std.flatten()}")
+ if cfg.train.bn.freeze_encoder:
+ LOGGER.info(" >>> Freeze Backbone !!! <<< ")
+ model.encoder.requires_grad_(False)
+
+ train_start_time = time.perf_counter()
+ for _ in range(counter.num_epochs):
+ LOGGER.info(f"Exp_Name: {cfg.exp_name}")
+
+ model.train()
+ if cfg.train.bn.freeze_status:
+ pt_utils.frozen_bn_stats(model.encoder, freeze_affine=cfg.train.bn.freeze_affine)
+
+ # an epoch starts
+ for batch_idx, batch in enumerate(tr_loader):
+ iter_start_time = time.perf_counter()
+ scheduler.step(curr_idx=counter.curr_iter) # update learning rate
+
+ data_batch = pt_utils.to_device(data=batch["data"], device=cfg.device)
+ with torch.cuda.amp.autocast(enabled=cfg.train.use_amp):
+ outputs = model(data=data_batch, iter_percentage=counter.curr_percent)
+
+ loss = outputs["loss"]
+ loss_str = outputs["loss_str"]
+ loss = loss / cfg.train.grad_acc_step
+ scaler.calculate_grad(loss=loss)
+ if counter.every_n_iters(cfg.train.grad_acc_step): # Accumulates scaled gradients.
+ scaler.update_grad()
+
+ item_loss = loss.item()
+ data_shape = tuple(data_batch["mask"].shape)
+ loss_recorder.update(value=item_loss, num=data_shape[0])
+
+ if cfg.log_interval > 0 and (
+ counter.every_n_iters(cfg.log_interval)
+ or counter.is_first_inner_iter()
+ or counter.is_last_inner_iter()
+ or counter.is_last_total_iter()
+ ):
+ gpu_mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
+ eta_seconds = iter_time_recorder.avg * (counter.num_total_iters - counter.curr_iter - 1)
+ eta_string = f"ETA: {datetime.timedelta(seconds=int(eta_seconds))}"
+ progress = (
+ f"{counter.curr_iter}:{counter.num_total_iters} "
+ f"{batch_idx}/{counter.num_inner_iters} "
+ f"{counter.curr_epoch}/{counter.num_epochs}"
+ )
+ loss_info = f"{loss_str} (M:{loss_recorder.global_avg:.5f}/C:{item_loss:.5f})"
+ lr_info = f"LR: {optimizer.lr_string()}"
+ LOGGER.info(f"{eta_string}({gpu_mem}) | {progress} | {lr_info} | {loss_info} | {data_shape}")
+ cfg.tb_logger.write_to_tb("lr", optimizer.lr_groups(), counter.curr_iter)
+ cfg.tb_logger.write_to_tb("iter_loss", item_loss, counter.curr_iter)
+ cfg.tb_logger.write_to_tb("avg_loss", loss_recorder.global_avg, counter.curr_iter)
+
+ if counter.curr_iter < 3: # plot some batches of the training phase
+ recorder.plot_results(
+ dict(img=data_batch["image_m"], msk=data_batch["mask"], **outputs["vis"]),
+ save_path=os.path.join(cfg.path.pth_log, "img", f"iter_{counter.curr_iter}.png"),
+ )
+
+ iter_time_recorder.update(value=time.perf_counter() - iter_start_time)
+ if counter.is_last_total_iter():
+ break
+ counter.update_iter_counter()
+
+ # an epoch ends
+ recorder.plot_results(
+ dict(img=data_batch["image_m"], msk=data_batch["mask"], **outputs["vis"]),
+ save_path=os.path.join(cfg.path.pth_log, "img", f"epoch_{counter.curr_epoch}.png"),
+ )
+ io.save_weight(model=model, save_path=cfg.path.final_state_net)
+ counter.update_epoch_counter()
+
+ cfg.tb_logger.close_tb()
+ io.save_weight(model=model, save_path=cfg.path.final_state_net)
+
+ total_train_time = time.perf_counter() - train_start_time
+ total_other_time = datetime.timedelta(seconds=int(total_train_time - iter_time_recorder.global_sum))
+ LOGGER.info(
+ f"Total Training Time: {datetime.timedelta(seconds=int(total_train_time))} ({total_other_time} on others)"
+ )
+
+
+def parse_cfg():
+ parser = argparse.ArgumentParser("Training and evaluation script")
+ parser.add_argument("--config", required=True, type=str)
+ parser.add_argument("--data-cfg", type=str, default="./dataset.yaml")
+ parser.add_argument("--model-name", type=str, choices=model_zoo.__dict__.keys())
+ parser.add_argument("--output-dir", type=str, default="outputs")
+ parser.add_argument("--load-from", type=str)
+ parser.add_argument("--pretrained", action="store_true")
+ parser.add_argument(
+ "--metric-names",
+ nargs="+",
+ type=str,
+ default=["sm", "wfm", "mae", "em", "fmeasure"],
+ choices=recorder.GroupedMetricRecorder.supported_metrics,
+ )
+ parser.add_argument("--evaluate", action="store_true")
+ parser.add_argument("--save-results", action="store_true")
+ parser.add_argument("--info", type=str)
+ args = parser.parse_args()
+
+ cfg = Config.fromfile(args.config)
+ cfg.merge_from_dict(vars(args))
+
+ with open(cfg.data_cfg, mode="r") as f:
+ cfg.dataset_infos = yaml.safe_load(f)
+
+ cfg.proj_root = os.path.dirname(os.path.abspath(__file__))
+ cfg.exp_name = py_utils.construct_exp_name(model_name=cfg.model_name, cfg=cfg)
+ cfg.output_dir = os.path.join(cfg.proj_root, cfg.output_dir)
+ cfg.path = py_utils.construct_path(output_dir=cfg.output_dir, exp_name=cfg.exp_name)
+ cfg.device = "cuda:0"
+
+ py_utils.pre_mkdir(cfg.path)
+ with open(cfg.path.cfg_copy, encoding="utf-8", mode="w") as f:
+ f.write(cfg.pretty_text)
+ shutil.copy(__file__, cfg.path.trainer_copy)
+
+ file_handler = logging.FileHandler(cfg.path.log)
+ file_handler.setLevel(logging.INFO)
+ file_handler.setFormatter(logging.Formatter("[%(filename)s] %(message)s"))
+ LOGGER.addHandler(file_handler)
+ LOGGER.info(cfg.pretty_text)
+
+ cfg.tb_logger = recorder.TBLogger(tb_root=cfg.path.tb)
+ return cfg
+
+
+def main():
+ cfg = parse_cfg()
+ pt_utils.initialize_seed_cudnn(seed=cfg.base_seed, deterministic=cfg.deterministic)
+
+ model_class = model_zoo.__dict__.get(cfg.model_name)
+ assert model_class is not None, "Please check your --model-name"
+ model_code = inspect.getsource(model_class)
+ model = model_class(num_frames=1, pretrained=cfg.pretrained)
+ LOGGER.info(model_code)
+ model.to(cfg.device)
+
+ if cfg.load_from:
+ io.load_weight(model=model, load_path=cfg.load_from, strict=True)
+
+ LOGGER.info(f"Number of Parameters: {sum((v.numel() for v in model.parameters(recurse=True)))}")
+ if not cfg.evaluate:
+ train(model=model, cfg=cfg)
+
+ if cfg.evaluate or cfg.has_test:
+ io.save_weight(model=model, save_path=cfg.path.final_state_net)
+ test(model=model, cfg=cfg)
+
+ LOGGER.info("End training...")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/main_for_video.py b/main_for_video.py
new file mode 100644
index 0000000..eab29a8
--- /dev/null
+++ b/main_for_video.py
@@ -0,0 +1,526 @@
+# -*- coding: utf-8 -*-
+import argparse
+import datetime
+import glob
+import inspect
+import logging
+import os
+import re
+import shutil
+import time
+
+import albumentations as A
+import colorlog
+import cv2
+import numpy as np
+import torch
+import yaml
+from mmengine import Config
+from torch.utils import data
+from tqdm import tqdm
+
+import methods as model_zoo
+from utils import io, ops, pipeline, pt_utils, py_utils, recorder
+
+LOGGER = logging.getLogger("main")
+LOGGER.propagate = False
+LOGGER.setLevel(level=logging.DEBUG)
+stream_handler = logging.StreamHandler()
+stream_handler.setLevel(logging.DEBUG)
+stream_handler.setFormatter(colorlog.ColoredFormatter("%(log_color)s[%(filename)s] %(reset)s%(message)s"))
+LOGGER.addHandler(stream_handler)
+
+
+def construct_frame_transform():
+ return A.Compose(
+ [
+ A.Rotate(limit=90, p=0.5, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT101),
+ A.RandomBrightnessContrast(brightness_limit=0.02, contrast_limit=0.02, p=0.5),
+ ]
+ )
+
+
+def construct_video_transform():
+ return A.ReplayCompose(
+ [
+ A.HorizontalFlip(p=0.5),
+ A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
+ A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=10, p=0.5),
+ ]
+ )
+
+
+def get_number_from_tail(string):
+ return int(re.findall(pattern="\d+$", string=string)[0])
+
+
+class VideoTestDataset(data.Dataset):
+ def __init__(self, dataset_info: dict, shape: dict, num_frames: int = 1, overlap: int = 0):
+ super().__init__()
+ self.shape = shape
+ self.num_frames = num_frames
+ assert 0 <= overlap < num_frames
+
+ image_path = os.path.join(dataset_info["root"], dataset_info["image"]["path"])
+ image_suffix = dataset_info["image"]["suffix"]
+ mask_path = os.path.join(dataset_info["root"], dataset_info["mask"]["path"])
+ mask_suffix = dataset_info["mask"]["suffix"]
+
+ image_group_paths = sorted(glob.glob(image_path))
+ mask_group_paths = sorted(glob.glob(mask_path))
+ group_name_place = image_path.find("*")
+
+ self.total_data_paths = []
+ for image_group_path, mask_group_path in zip(image_group_paths, mask_group_paths):
+ group_name = image_group_path[group_name_place:].split("/")[0]
+ image_names = [
+ p[: -len(image_suffix)] for p in sorted(os.listdir(image_group_path)) if p.endswith(image_suffix)
+ ]
+ mask_names = [
+ p[: -len(mask_suffix)] for p in sorted(os.listdir(mask_group_path)) if p.endswith(mask_suffix)
+ ]
+ valid_names = sorted(
+ set(image_names).intersection(mask_names), key=lambda item: get_number_from_tail(item)
+ )
+ assert len(valid_names) >= num_frames, image_group_path
+
+ i = 0
+ while i < len(valid_names):
+ if i + num_frames > len(valid_names):
+ i = len(valid_names) - num_frames
+ clip_info = [
+ (
+ os.path.join(image_group_path, n) + image_suffix,
+ os.path.join(mask_group_path, n) + mask_suffix,
+ group_name,
+ )
+ for n in valid_names[i : i + num_frames]
+ ]
+ if len(clip_info) < num_frames:
+ times, last = divmod(num_frames, len(clip_info))
+ clip_info.extend(clip_info * (times - 1))
+ clip_info.extend(clip_info[:last])
+ self.total_data_paths.append(clip_info)
+
+ i += num_frames - overlap
+
+ def __getitem__(self, index):
+ clip_info = self.total_data_paths[index]
+
+ base_h = self.shape["h"]
+ base_w = self.shape["w"]
+ image_ss = []
+ image_ms = []
+ image_ls = []
+ mask_paths = []
+ for image_path, mask_path, group_name in clip_info:
+ image = io.read_color_array(image_path)
+ mask_paths.append(mask_path)
+ images = ops.ms_resize(image, scales=(0.5, 1.0, 1.5), base_h=base_h, base_w=base_w)
+ image_ss.append(torch.from_numpy(images[0]).div(255).permute(2, 0, 1))
+ image_ms.append(torch.from_numpy(images[1]).div(255).permute(2, 0, 1))
+ image_ls.append(torch.from_numpy(images[2]).div(255).permute(2, 0, 1))
+
+ return dict(
+ data={
+ "image_s": torch.stack(image_ss, dim=0),
+ "image_m": torch.stack(image_ms, dim=0),
+ "image_l": torch.stack(image_ls, dim=0),
+ },
+ info=dict(mask_path={f"frame_{i}": p for i, p in enumerate(mask_paths)}, group_name=group_name),
+ )
+
+ def __len__(self):
+ return len(self.total_data_paths)
+
+
+class VideoTrainDataset(data.Dataset):
+ def __init__(self, dataset_infos: dict, shape: dict, num_frames: int = 1):
+ super().__init__()
+ self.shape = shape
+ self.num_frames = num_frames
+ self.stride = num_frames - 1 if num_frames > 1 else 1
+
+ self.total_data_paths = []
+ for dataset_name, dataset_info in dataset_infos.items():
+ image_path = os.path.join(dataset_info["root"], dataset_info["image"]["path"])
+ image_suffix = dataset_info["image"]["suffix"]
+ mask_path = os.path.join(dataset_info["root"], dataset_info["mask"]["path"])
+ mask_suffix = dataset_info["mask"]["suffix"]
+
+ image_group_paths = sorted(glob.glob(image_path))
+ mask_group_paths = sorted(glob.glob(mask_path))
+ group_name_place = image_path.find("*")
+
+ data_paths = []
+ for image_group_path, mask_group_path in zip(image_group_paths, mask_group_paths):
+ group_name = image_group_path[group_name_place:].split("/")[0]
+ image_names = [
+ p[: -len(image_suffix)] for p in sorted(os.listdir(image_group_path)) if p.endswith(image_suffix)
+ ]
+ mask_names = [
+ p[: -len(mask_suffix)] for p in sorted(os.listdir(mask_group_path)) if p.endswith(mask_suffix)
+ ]
+ valid_names = sorted(
+ set(image_names).intersection(mask_names), key=lambda item: get_number_from_tail(item)
+ )
+
+ length_of_sequence = len(valid_names)
+ for clip_idx in range(0, length_of_sequence, self.stride):
+ start_idx = clip_idx * self.stride
+ if start_idx + num_frames > length_of_sequence:
+ start_idx = length_of_sequence - num_frames
+
+ clip_info = []
+ for i, n in enumerate(valid_names[start_idx : start_idx + num_frames]):
+ clip_info.append(
+ (
+ os.path.join(image_group_path, n) + image_suffix,
+ os.path.join(mask_group_path, n) + mask_suffix,
+ group_name,
+ i,
+ clip_idx,
+ )
+ )
+ if len(clip_info) < num_frames:
+ times, last = divmod(num_frames, len(clip_info))
+ clip_info.extend(clip_info * (times - 1))
+ clip_info.extend(clip_info[:last])
+ data_paths.append(clip_info)
+
+ LOGGER.info(f"Length of {dataset_name}: {len(data_paths)}")
+ self.total_data_paths.extend(data_paths)
+
+ self.frame_specific_transformation = construct_frame_transform()
+ self.frame_share_transformation = construct_video_transform()
+
+ def __getitem__(self, index):
+ base_h = self.shape["h"]
+ base_w = self.shape["w"]
+ image_ss = []
+ image_ms = []
+ image_ls = []
+ masks = []
+ for image_path, mask_path, _, idx_in_group, _ in self.total_data_paths[index]:
+ image = io.read_color_array(image_path)
+ mask = io.read_gray_array(mask_path, thr=0)
+ if image.shape[:2] != mask.shape:
+ h, w = mask.shape
+ image = ops.resize(image, height=h, width=w)
+
+ if idx_in_group == 0:
+ shared_transformed = self.frame_share_transformation(image=image, mask=mask)
+ else:
+ shared_transformed = A.ReplayCompose.replay(
+ saved_augmentations=shared_transformed["replay"], image=image, mask=mask
+ )
+ specific_transformed = self.frame_specific_transformation(
+ image=shared_transformed["image"], mask=shared_transformed["mask"]
+ )
+ image = specific_transformed["image"]
+ mask = specific_transformed["mask"]
+
+ images = ops.ms_resize(image, scales=(0.5, 1.0, 1.5), base_h=base_h, base_w=base_w)
+ mask = ops.resize(mask, height=base_h, width=base_w)
+
+ image_ss.append(torch.from_numpy(images[0]).div(255).permute(2, 0, 1))
+ image_ms.append(torch.from_numpy(images[1]).div(255).permute(2, 0, 1))
+ image_ls.append(torch.from_numpy(images[2]).div(255).permute(2, 0, 1))
+ masks.append(torch.from_numpy(mask).unsqueeze(0))
+
+ return dict(
+ data={
+ "image_s": torch.stack(image_ss, dim=0),
+ "image_m": torch.stack(image_ms, dim=0),
+ "image_l": torch.stack(image_ls, dim=0),
+ "mask": torch.stack(masks, dim=0),
+ }
+ )
+
+ def __len__(self):
+ return len(self.total_data_paths)
+
+
+class Evaluator:
+ def __init__(self, device, metric_names, clip_range=None):
+ self.device = device
+ self.clip_range = clip_range
+ self.metric_names = metric_names
+
+ @torch.no_grad()
+ def eval(self, model, data_loader, save_path=""):
+ model.eval()
+ all_metrics = recorder.GroupedMetricRecorder(metric_names=self.metric_names)
+
+ num_frames = data_loader.dataset.num_frames
+ for batch in tqdm(data_loader, total=len(data_loader), ncols=79, desc="[EVAL]"):
+ batch_images = pt_utils.to_device(batch["data"], device=self.device)
+ batch_images = {k: v.flatten(0, 1) for k, v in batch_images.items()} # B_T,C,H,W
+ logits = model(data=batch_images) # BT,1,H,W
+ probs = logits.sigmoid().squeeze(1).cpu().detach()
+ probs = probs - probs.min()
+ probs = probs / (probs.max() + 1e-8)
+ probs = torch.unflatten(probs, dim=0, sizes=(-1, num_frames))
+ probs = probs.numpy()
+
+ mask_paths = batch["info"]["mask_path"]
+ group_names = batch["info"]["group_name"]
+ for clip_idx, clip_pred in enumerate(probs):
+ for frame_idx, pred in enumerate(clip_pred):
+ mask_path = mask_paths[f"frame_{frame_idx}"][clip_idx]
+ mask_array = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
+ mask_array[mask_array > 0] = 255
+ mask_h, mask_w = mask_array.shape
+ pred = ops.resize(pred, height=mask_h, width=mask_w)
+
+ if self.clip_range is not None:
+ pred = ops.clip_to_normalize(pred, clip_range=self.clip_range)
+
+ group_name = group_names[clip_idx]
+ if save_path: # 这里的save_path包含了数据集名字
+ ops.save_array_as_image(
+ data_array=pred,
+ save_name=os.path.basename(mask_path),
+ save_dir=os.path.join(save_path, group_name),
+ )
+
+ pred = (pred * 255).astype(np.uint8)
+ all_metrics.step(group_name=group_name, pre=pred, gt=mask_array, gt_path=mask_path)
+ seg_results, group_seg_results = all_metrics.show(return_group=True)
+ return seg_results, group_seg_results
+
+
+def test(model, cfg):
+ test_wrapper = Evaluator(device=cfg.device, metric_names=cfg.metric_names, clip_range=cfg.test.clip_range)
+
+ for te_name in cfg.test.data.names:
+ te_info = cfg.dataset_infos[te_name]
+ te_dataset = VideoTestDataset(dataset_info=te_info, shape=cfg.test.data.shape, num_frames=cfg.num_frames)
+ te_loader = data.DataLoader(
+ dataset=te_dataset, batch_size=cfg.test.batch_size, num_workers=cfg.test.num_workers, pin_memory=True
+ )
+ LOGGER.info(f"Testing with testset: {te_name}: {len(te_dataset)}")
+
+ if cfg.save_results:
+ save_path = os.path.join(cfg.path.save, te_name)
+ LOGGER.info(f"Results will be saved into {save_path}")
+ else:
+ save_path = ""
+
+ seg_results, group_seg_results = test_wrapper.eval(model=model, data_loader=te_loader, save_path=save_path)
+ seg_results_str = ", ".join([f"{k}: {v:.03f}" for k, v in seg_results.items()])
+ LOGGER.info(f"({te_name}): {py_utils.mapping_to_str(te_info)}\n{seg_results_str}")
+
+ max_group_name_length = max([len(n) for n in group_seg_results.keys()])
+ for group_name, group_results in group_seg_results.items():
+ metric_str = ""
+ for metric_name, metric_value in group_results.items():
+ metric_str += "|" + metric_name + " " + str(metric_value).ljust(6)
+ LOGGER.info(f"{group_name.rjust(max_group_name_length)}: {metric_str}")
+
+
+def train(model, cfg):
+ tr_dataset = VideoTrainDataset(
+ dataset_infos={data_name: cfg.dataset_infos[data_name] for data_name in cfg.train.data.names},
+ shape=cfg.train.data.shape,
+ num_frames=cfg.num_frames,
+ )
+ LOGGER.info(f"Total Length of Video Trainset: {len(tr_dataset)}")
+
+ tr_loader = data.DataLoader(
+ dataset=tr_dataset,
+ batch_size=cfg.train.batch_size,
+ num_workers=cfg.train.num_workers,
+ shuffle=True,
+ drop_last=True,
+ pin_memory=True,
+ worker_init_fn=pt_utils.customized_worker_init_fn if cfg.use_custom_worker_init else None,
+ )
+
+ counter = recorder.TrainingCounter(
+ epoch_length=len(tr_loader),
+ epoch_based=cfg.train.epoch_based,
+ num_epochs=cfg.train.num_epochs,
+ num_total_iters=cfg.train.num_iters,
+ )
+ optimizer = pipeline.construct_optimizer(
+ model=model,
+ initial_lr=cfg.train.lr,
+ mode=cfg.train.optimizer.mode,
+ group_mode=cfg.train.optimizer.group_mode,
+ cfg=cfg.train.optimizer.cfg,
+ )
+ scheduler = pipeline.Scheduler(
+ optimizer=optimizer,
+ num_iters=counter.num_total_iters,
+ epoch_length=counter.num_inner_iters,
+ scheduler_cfg=cfg.train.scheduler,
+ step_by_batch=cfg.train.sche_usebatch,
+ )
+ scheduler.record_lrs(param_groups=optimizer.param_groups)
+ scheduler.plot_lr_coef_curve(save_path=cfg.path.pth_log)
+ scaler = pipeline.Scaler(optimizer, cfg.train.use_amp, set_to_none=cfg.train.optimizer.set_to_none)
+
+ LOGGER.info(f"Scheduler:\n{scheduler}\nOptimizer:\n{optimizer}")
+
+ loss_recorder = recorder.HistoryBuffer()
+ iter_time_recorder = recorder.HistoryBuffer()
+
+ LOGGER.info(f"Image Mean: {model.normalizer.mean.flatten()}, Image Std: {model.normalizer.std.flatten()}")
+ if cfg.train.bn.freeze_encoder:
+ LOGGER.info(" >>> Freeze Backbone !!! <<< ")
+ model.encoder.requires_grad_(False)
+
+ train_start_time = time.perf_counter()
+ for _ in range(counter.num_epochs):
+ LOGGER.info(f"Exp_Name: {cfg.exp_name}")
+
+ model.train()
+ if cfg.train.bn.freeze_status:
+ pt_utils.frozen_bn_stats(model.encoder, freeze_affine=cfg.train.bn.freeze_affine)
+
+ # an epoch starts
+ for batch_idx, batch in enumerate(tr_loader):
+ iter_start_time = time.perf_counter()
+ scheduler.step(curr_idx=counter.curr_iter) # update learning rate
+
+ data_batch = pt_utils.to_device(data=batch["data"], device=cfg.device)
+ data_batch = {k: v.flatten(0, 1) for k, v in data_batch.items()}
+ with torch.cuda.amp.autocast(enabled=cfg.train.use_amp):
+ outputs = model(data=data_batch, iter_percentage=counter.curr_percent)
+
+ loss = outputs["loss"]
+ loss_str = outputs["loss_str"]
+ loss = loss / cfg.train.grad_acc_step
+ scaler.calculate_grad(loss=loss)
+ if counter.every_n_iters(cfg.train.grad_acc_step): # Accumulates scaled gradients.
+ scaler.update_grad()
+
+ item_loss = loss.item()
+ data_shape = tuple(data_batch["mask"].shape)
+ loss_recorder.update(value=item_loss, num=data_shape[0])
+
+ if cfg.log_interval > 0 and (
+ counter.every_n_iters(cfg.log_interval)
+ or counter.is_first_inner_iter()
+ or counter.is_last_inner_iter()
+ or counter.is_last_total_iter()
+ ):
+ gpu_mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
+ eta_seconds = iter_time_recorder.avg * (counter.num_total_iters - counter.curr_iter - 1)
+ eta_string = f"ETA: {datetime.timedelta(seconds=int(eta_seconds))}"
+ progress = (
+ f"{counter.curr_iter}:{counter.num_total_iters} "
+ f"{batch_idx}/{counter.num_inner_iters} "
+ f"{counter.curr_epoch}/{counter.num_epochs}"
+ )
+ loss_info = f"{loss_str} (M:{loss_recorder.global_avg:.5f}/C:{item_loss:.5f})"
+ lr_info = f"LR: {optimizer.lr_string()}"
+ LOGGER.info(f"{eta_string}({gpu_mem}) | {progress} | {lr_info} | {loss_info} | {data_shape}")
+ cfg.tb_logger.write_to_tb("lr", optimizer.lr_groups(), counter.curr_iter)
+ cfg.tb_logger.write_to_tb("iter_loss", item_loss, counter.curr_iter)
+ cfg.tb_logger.write_to_tb("avg_loss", loss_recorder.global_avg, counter.curr_iter)
+
+ if counter.curr_iter < 3: # plot some batches of the training phase
+ recorder.plot_results(
+ dict(img=data_batch["image_m"], msk=data_batch["mask"], **outputs["vis"]),
+ save_path=os.path.join(cfg.path.pth_log, "img", f"iter_{counter.curr_iter}.png"),
+ )
+
+ iter_time_recorder.update(value=time.perf_counter() - iter_start_time)
+ if counter.is_last_total_iter():
+ break
+ counter.update_iter_counter()
+
+ # an epoch ends
+ recorder.plot_results(
+ dict(img=data_batch["image_m"], msk=data_batch["mask"], **outputs["vis"]),
+ save_path=os.path.join(cfg.path.pth_log, "img", f"epoch_{counter.curr_epoch}.png"),
+ )
+ io.save_weight(model=model, save_path=cfg.path.final_state_net)
+ counter.update_epoch_counter()
+
+ cfg.tb_logger.close_tb()
+ io.save_weight(model=model, save_path=cfg.path.final_state_net)
+
+ total_train_time = time.perf_counter() - train_start_time
+ total_other_time = datetime.timedelta(seconds=int(total_train_time - iter_time_recorder.global_sum))
+ LOGGER.info(
+ f"Total Training Time: {datetime.timedelta(seconds=int(total_train_time))} ({total_other_time} on others)"
+ )
+
+
+def parse_cfg():
+ parser = argparse.ArgumentParser("Training and evaluation script")
+ parser.add_argument("--config", required=True, type=str)
+ parser.add_argument("--data-cfg", type=str, default="./dataset.yaml")
+ parser.add_argument("--model-name", type=str, choices=model_zoo.__dict__.keys())
+ parser.add_argument("--output-dir", type=str, default="outputs")
+ parser.add_argument("--load-from", type=str)
+ parser.add_argument("--pretrained", action="store_true")
+ parser.add_argument(
+ "--metric-names",
+ nargs="+",
+ type=str,
+ default=["sm", "wfm", "mae", "em", "fmeasure", "iou", "dice"],
+ choices=recorder.GroupedMetricRecorder.supported_metrics,
+ )
+ parser.add_argument("--evaluate", action="store_true")
+ parser.add_argument("--save-results", action="store_true")
+ parser.add_argument("--info", type=str)
+ args = parser.parse_args()
+
+ cfg = Config.fromfile(args.config)
+ cfg.merge_from_dict(vars(args))
+
+ with open(cfg.data_cfg, mode="r") as f:
+ cfg.dataset_infos = yaml.safe_load(f)
+
+ cfg.proj_root = os.path.dirname(os.path.abspath(__file__))
+ cfg.exp_name = py_utils.construct_exp_name(model_name=cfg.model_name, cfg=cfg)
+ cfg.output_dir = os.path.join(cfg.proj_root, cfg.output_dir)
+ cfg.path = py_utils.construct_path(output_dir=cfg.output_dir, exp_name=cfg.exp_name)
+ cfg.device = "cuda:0"
+
+ py_utils.pre_mkdir(cfg.path)
+ with open(cfg.path.cfg_copy, encoding="utf-8", mode="w") as f:
+ f.write(cfg.pretty_text)
+ shutil.copy(__file__, cfg.path.trainer_copy)
+
+ file_handler = logging.FileHandler(cfg.path.log)
+ file_handler.setLevel(logging.INFO)
+ file_handler.setFormatter(logging.Formatter("[%(filename)s] %(message)s"))
+ LOGGER.addHandler(file_handler)
+ LOGGER.info(cfg.pretty_text)
+
+ cfg.tb_logger = recorder.TBLogger(tb_root=cfg.path.tb)
+ return cfg
+
+
+def main():
+ cfg = parse_cfg()
+ pt_utils.initialize_seed_cudnn(seed=cfg.base_seed, deterministic=cfg.deterministic)
+
+ model_class = model_zoo.__dict__.get(cfg.model_name)
+ assert model_class is not None, "Please check your --model-name"
+ model_code = inspect.getsource(model_class)
+ model = model_class(num_frames=cfg.num_frames, pretrained=cfg.pretrained)
+ LOGGER.info(model_code)
+ model.to(cfg.device)
+
+ if cfg.load_from:
+ io.load_weight(model=model, load_path=cfg.load_from, strict=True)
+
+ LOGGER.info(f"Number of Parameters: {sum((v.numel() for v in model.parameters(recurse=True)))}")
+ if not cfg.evaluate:
+ train(model=model, cfg=cfg)
+
+ if cfg.evaluate or cfg.has_test:
+ io.save_weight(model=model, save_path=cfg.path.final_state_net)
+ test(model=model, cfg=cfg)
+
+ LOGGER.info("End training...")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/methods/__init__.py b/methods/__init__.py
new file mode 100644
index 0000000..fa7e6ba
--- /dev/null
+++ b/methods/__init__.py
@@ -0,0 +1,10 @@
+from .zoomnext.zoomnext import (
+ EffB1_ZoomNeXt,
+ EffB4_ZoomNeXt,
+ PvtV2B2_ZoomNeXt,
+ PvtV2B3_ZoomNeXt,
+ PvtV2B4_ZoomNeXt,
+ PvtV2B5_ZoomNeXt,
+ RN50_ZoomNeXt,
+ videoPvtV2B5_ZoomNeXt,
+)
diff --git a/methods/backbone/__init__.py b/methods/backbone/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/methods/backbone/efficientnet.py b/methods/backbone/efficientnet.py
new file mode 100644
index 0000000..8e60b6b
--- /dev/null
+++ b/methods/backbone/efficientnet.py
@@ -0,0 +1,436 @@
+"""model.py - Model and module class for EfficientNet.
+They are built to mirror those in the official TensorFlow implementation.
+"""
+
+# Author: lukemelas (github username)
+# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
+# With adjustments and added comments by workingcoder (github username).
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .efficientnet_utils import (
+ MemoryEfficientSwish,
+ Swish,
+ calculate_output_image_size,
+ drop_connect,
+ efficientnet_params,
+ get_model_params,
+ get_same_padding_conv2d,
+ load_pretrained_weights,
+ round_filters,
+ round_repeats,
+)
+
+VALID_MODELS = (
+ "efficientnet-b0",
+ "efficientnet-b1",
+ "efficientnet-b2",
+ "efficientnet-b3",
+ "efficientnet-b4",
+ "efficientnet-b5",
+ "efficientnet-b6",
+ "efficientnet-b7",
+ "efficientnet-b8",
+ # Support the construction of 'efficientnet-l2' without pretrained weights
+ "efficientnet-l2",
+)
+
+
+class MBConvBlock(nn.Module):
+ """Mobile Inverted Residual Bottleneck Block.
+
+ Args:
+ block_args (namedtuple): BlockArgs, defined in utils.py.
+ global_params (namedtuple): GlobalParam, defined in utils.py.
+ image_size (tuple or list): [image_height, image_width].
+
+ References:
+ [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
+ [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
+ [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
+ """
+
+ def __init__(self, block_args, global_params, image_size=None):
+ super().__init__()
+ self._block_args = block_args
+ self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
+ self._bn_eps = global_params.batch_norm_epsilon
+ self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
+ self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
+
+ # Expansion phase (Inverted Bottleneck)
+ inp = self._block_args.input_filters # number of input channels
+ oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
+ if self._block_args.expand_ratio != 1:
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
+ self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
+ self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
+ # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
+
+ # Depthwise convolution phase
+ k = self._block_args.kernel_size
+ s = self._block_args.stride
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
+ self._depthwise_conv = Conv2d(
+ in_channels=oup,
+ out_channels=oup,
+ groups=oup, # groups makes it depthwise
+ kernel_size=k,
+ stride=s,
+ bias=False,
+ )
+ self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
+ image_size = calculate_output_image_size(image_size, s)
+
+ # Squeeze and Excitation layer, if desired
+ if self.has_se:
+ Conv2d = get_same_padding_conv2d(image_size=(1, 1))
+ num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
+ self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
+ self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
+
+ # Pointwise convolution phase
+ final_oup = self._block_args.output_filters
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
+ self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
+ self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
+ self._swish = MemoryEfficientSwish()
+
+ def forward(self, inputs, drop_connect_rate=None):
+ """MBConvBlock's forward function.
+
+ Args:
+ inputs (tensor): Input tensor.
+ drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
+
+ Returns:
+ Output of this block after processing.
+ """
+
+ # Expansion and Depthwise Convolution
+ x = inputs
+ if self._block_args.expand_ratio != 1:
+ x = self._expand_conv(inputs)
+ x = self._bn0(x)
+ x = self._swish(x)
+
+ x = self._depthwise_conv(x)
+ x = self._bn1(x)
+ x = self._swish(x)
+
+ # Squeeze and Excitation
+ if self.has_se:
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
+ x_squeezed = self._se_reduce(x_squeezed)
+ x_squeezed = self._swish(x_squeezed)
+ x_squeezed = self._se_expand(x_squeezed)
+ x = torch.sigmoid(x_squeezed) * x
+
+ # Pointwise Convolution
+ x = self._project_conv(x)
+ x = self._bn2(x)
+
+ # Skip connection and drop connect
+ input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
+ if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
+ # The combination of skip connection and drop connect brings about stochastic depth.
+ if drop_connect_rate:
+ x = drop_connect(x, p=drop_connect_rate, training=self.training)
+ x = x + inputs # skip connection
+ return x
+
+ def set_swish(self, memory_efficient=True):
+ """Sets swish function as memory efficient (for training) or standard (for export).
+
+ Args:
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
+ """
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+
+
+class EfficientNet(nn.Module):
+ """EfficientNet model.
+ Most easily loaded with the .from_name or .from_pretrained methods.
+
+ Args:
+ blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
+ global_params (namedtuple): A set of GlobalParams shared between blocks.
+
+ References:
+ [1] https://arxiv.org/abs/1905.11946 (EfficientNet)
+
+ Example:
+ >>> import torch
+ >>> from efficientnet.model1 import EfficientNet
+ >>> inputs = torch.rand(1, 3, 224, 224)
+ >>> model = EfficientNet.from_pretrained('efficientnet-b0')
+ >>> model.eval()
+ >>> outputs = model(inputs)
+ """
+
+ def __init__(self, blocks_args=None, global_params=None):
+ super().__init__()
+ assert isinstance(blocks_args, list), "blocks_args should be a list"
+ assert len(blocks_args) > 0, "block args must be greater than 0"
+ self._global_params = global_params
+ self._blocks_args = blocks_args
+
+ # Batch norm parameters
+ bn_mom = 1 - self._global_params.batch_norm_momentum
+ bn_eps = self._global_params.batch_norm_epsilon
+
+ # Get stem static or dynamic convolution depending on image size
+ image_size = global_params.image_size
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
+
+ # Stem
+ in_channels = 3 # rgb
+ out_channels = round_filters(32, self._global_params) # number of output channels
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
+ self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+ image_size = calculate_output_image_size(image_size, 2)
+
+ # Build blocks
+ self._blocks = nn.ModuleList([])
+ for block_args in self._blocks_args:
+ # Update block input and output filters based on depth multiplier.
+ block_args = block_args._replace(
+ input_filters=round_filters(block_args.input_filters, self._global_params),
+ output_filters=round_filters(block_args.output_filters, self._global_params),
+ num_repeat=round_repeats(block_args.num_repeat, self._global_params),
+ )
+
+ # The first block needs to take care of stride and filter size increase.
+ self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
+ image_size = calculate_output_image_size(image_size, block_args.stride)
+ if block_args.num_repeat > 1: # modify block_args to keep same output size
+ block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
+ for _ in range(block_args.num_repeat - 1):
+ self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
+ # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
+
+ # Head
+ in_channels = block_args.output_filters # output of final block
+ out_channels = round_filters(1280, self._global_params)
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
+ self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+ self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
+
+ # Final linear layer
+ self._avg_pooling = nn.AdaptiveAvgPool2d(1)
+ # if self._global_params.include_top:
+ # self._dropout = nn.Dropout(self._global_params.dropout_rate)
+ # self._fc = nn.Linear(out_channels, self._global_params.num_classes)
+
+ # set activation to memory efficient swish by default
+ self._swish = MemoryEfficientSwish()
+
+ def set_swish(self, memory_efficient=True):
+ """Sets swish function as memory efficient (for training) or standard (for export).
+
+ Args:
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
+ """
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
+ for block in self._blocks:
+ block.set_swish(memory_efficient)
+
+ def extract_endpoints(self, inputs):
+ """Use convolution layer to extract features
+ from reduction levels i in [1, 2, 3, 4, 5].
+
+ Args:
+ inputs (tensor): Input tensor.
+
+ Returns:
+ Dictionary of last intermediate features
+ with reduction levels i in [1, 2, 3, 4, 5].
+ Example:
+ >>> import torch
+ >>> from efficientnet.model1 import EfficientNet
+ >>> inputs = torch.rand(1, 3, 224, 224)
+ >>> model = EfficientNet.from_pretrained('efficientnet-b0')
+ >>> endpoints = model.extract_endpoints(inputs)
+ >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
+ >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
+ >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
+ >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
+ >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
+ >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
+ """
+ endpoints = dict()
+
+ # Stem
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
+ prev_x = x
+
+ # Blocks
+ for idx, block in enumerate(self._blocks):
+ drop_connect_rate = self._global_params.drop_connect_rate
+ if drop_connect_rate:
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
+ x = block(x, drop_connect_rate=drop_connect_rate)
+ if prev_x.size(2) > x.size(2):
+ endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x
+ elif idx == len(self._blocks) - 1:
+ endpoints["reduction_{}".format(len(endpoints) + 1)] = x
+ prev_x = x
+
+ # Head
+ # x = self._swish(self._bn1(self._conv_head(x)))
+ # endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
+
+ return endpoints
+
+ def extract_features(self, inputs):
+ """use convolution layer to extract feature .
+
+ Args:
+ inputs (tensor): Input tensor.
+
+ Returns:
+ Output of the final convolution
+ layer in the efficientnet model.
+ """
+ # Stem
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
+
+ # Blocks
+ for idx, block in enumerate(self._blocks):
+ drop_connect_rate = self._global_params.drop_connect_rate
+ if drop_connect_rate:
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
+ x = block(x, drop_connect_rate=drop_connect_rate)
+
+ # Head
+ x = self._swish(self._bn1(self._conv_head(x)))
+
+ return x
+
+ def forward(self, inputs):
+ """EfficientNet's forward function.
+ Calls extract_features to extract features, applies final linear layer, and returns logits.
+
+ Args:
+ inputs (tensor): Input tensor.
+
+ Returns:
+ Output of this model after processing.
+ """
+ # Convolution layers
+ x = self.extract_features(inputs)
+ # Pooling and final linear layer
+ x = self._avg_pooling(x)
+ # if self._global_params.include_top:
+ # x = x.flatten(start_dim=1)
+ # x = self._dropout(x)
+ # x = self._fc(x)
+ return x
+
+ @classmethod
+ def from_name(cls, model_name, in_channels=3, **override_params):
+ """Create an efficientnet model according to name.
+
+ Args:
+ model_name (str): Name for efficientnet.
+ in_channels (int): Input data's channel number.
+ override_params (other key word params):
+ Params to override model's global_params.
+ Optional key:
+ 'width_coefficient', 'depth_coefficient',
+ 'image_size', 'dropout_rate',
+ 'num_classes', 'batch_norm_momentum',
+ 'batch_norm_epsilon', 'drop_connect_rate',
+ 'depth_divisor', 'min_depth'
+
+ Returns:
+ An efficientnet model.
+ """
+ cls._check_model_name_is_valid(model_name)
+ blocks_args, global_params = get_model_params(model_name, override_params)
+ model = cls(blocks_args, global_params)
+ model._change_in_channels(in_channels)
+ return model
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_name,
+ pretrained=True,
+ weights_path=None,
+ advprop=False,
+ in_channels=3,
+ num_classes=1000,
+ **override_params,
+ ):
+ """Create an efficientnet model according to name.
+
+ Args:
+ model_name (str): Name for efficientnet.
+ weights_path (None or str):
+ str: path to pretrained weights file on the local disk.
+ None: use pretrained weights downloaded from the Internet.
+ advprop (bool):
+ Whether to load pretrained weights
+ trained with advprop (valid when weights_path is None).
+ in_channels (int): Input data's channel number.
+ num_classes (int):
+ Number of categories for classification.
+ It controls the output size for final linear layer.
+ override_params (other key word params):
+ Params to override model's global_params.
+ Optional key:
+ 'width_coefficient', 'depth_coefficient',
+ 'image_size', 'dropout_rate',
+ 'batch_norm_momentum',
+ 'batch_norm_epsilon', 'drop_connect_rate',
+ 'depth_divisor', 'min_depth'
+
+ Returns:
+ A pretrained efficientnet model.
+ """
+ model = cls.from_name(model_name, num_classes=num_classes, **override_params)
+ if pretrained:
+ load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=False, advprop=advprop)
+ model._change_in_channels(in_channels)
+ return model
+
+ @classmethod
+ def get_image_size(cls, model_name):
+ """Get the input image size for a given efficientnet model.
+
+ Args:
+ model_name (str): Name for efficientnet.
+
+ Returns:
+ Input image size (resolution).
+ """
+ cls._check_model_name_is_valid(model_name)
+ _, _, res, _ = efficientnet_params(model_name)
+ return res
+
+ @classmethod
+ def _check_model_name_is_valid(cls, model_name):
+ """Validates model name.
+
+ Args:
+ model_name (str): Name for efficientnet.
+
+ Returns:
+ bool: Is a valid name or not.
+ """
+ if model_name not in VALID_MODELS:
+ raise ValueError("model_name should be one of: " + ", ".join(VALID_MODELS))
+
+ def _change_in_channels(self, in_channels):
+ """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
+
+ Args:
+ in_channels (int): Input data's channel number.
+ """
+ if in_channels != 3:
+ Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
+ out_channels = round_filters(32, self._global_params)
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
diff --git a/methods/backbone/efficientnet_utils.py b/methods/backbone/efficientnet_utils.py
new file mode 100644
index 0000000..f59e54a
--- /dev/null
+++ b/methods/backbone/efficientnet_utils.py
@@ -0,0 +1,616 @@
+"""utils.py - Helper functions for building the model and for loading model parameters.
+ These helper functions are built to mirror those in the official TensorFlow implementation.
+"""
+
+# Author: lukemelas (github username)
+# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
+# With adjustments and added comments by workingcoder (github username).
+
+import re
+import math
+import collections
+from functools import partial
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.utils import model_zoo
+
+
+################################################################################
+# Help functions for model architecture
+################################################################################
+
+# GlobalParams and BlockArgs: Two namedtuples
+# Swish and MemoryEfficientSwish: Two implementations of the method
+# round_filters and round_repeats:
+# Functions to calculate params for scaling model width and depth ! ! !
+# get_width_and_height_from_size and calculate_output_image_size
+# drop_connect: A structural design
+# get_same_padding_conv2d:
+# Conv2dDynamicSamePadding
+# Conv2dStaticSamePadding
+# get_same_padding_maxPool2d:
+# MaxPool2dDynamicSamePadding
+# MaxPool2dStaticSamePadding
+# It's an additional function, not used in EfficientNet,
+# but can be used in other model (such as EfficientDet).
+
+# Parameters for the entire model (stem, all blocks, and head)
+GlobalParams = collections.namedtuple('GlobalParams', [
+ 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
+ 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
+ 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])
+
+# Parameters for an individual model block
+BlockArgs = collections.namedtuple('BlockArgs', [
+ 'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
+ 'input_filters', 'output_filters', 'se_ratio', 'id_skip'])
+
+# Set GlobalParams and BlockArgs's defaults
+GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
+BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
+
+# Swish activation function
+if hasattr(nn, 'SiLU'):
+ Swish = nn.SiLU
+else:
+ # For compatibility with old PyTorch versions
+ class Swish(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+# A memory-efficient implementation of Swish function
+class SwishImplementation(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, i):
+ result = i * torch.sigmoid(i)
+ ctx.save_for_backward(i)
+ return result
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ i = ctx.saved_tensors[0]
+ sigmoid_i = torch.sigmoid(i)
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
+
+
+class MemoryEfficientSwish(nn.Module):
+ def forward(self, x):
+ return SwishImplementation.apply(x)
+
+
+def round_filters(filters, global_params):
+ """Calculate and round number of filters based on width multiplier.
+ Use width_coefficient, depth_divisor and min_depth of global_params.
+
+ Args:
+ filters (int): Filters number to be calculated.
+ global_params (namedtuple): Global params of the model.
+
+ Returns:
+ new_filters: New filters number after calculating.
+ """
+ multiplier = global_params.width_coefficient
+ if not multiplier:
+ return filters
+ # TODO: modify the params names.
+ # maybe the names (width_divisor,min_width)
+ # are more suitable than (depth_divisor,min_depth).
+ divisor = global_params.depth_divisor
+ min_depth = global_params.min_depth
+ filters *= multiplier
+ min_depth = min_depth or divisor # pay attention to this line when using min_depth
+ # follow the formula transferred from official TensorFlow implementation
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
+ if new_filters < 0.9 * filters: # prevent rounding by more than 10%
+ new_filters += divisor
+ return int(new_filters)
+
+
+def round_repeats(repeats, global_params):
+ """Calculate module's repeat number of a block based on depth multiplier.
+ Use depth_coefficient of global_params.
+
+ Args:
+ repeats (int): num_repeat to be calculated.
+ global_params (namedtuple): Global params of the model.
+
+ Returns:
+ new repeat: New repeat number after calculating.
+ """
+ multiplier = global_params.depth_coefficient
+ if not multiplier:
+ return repeats
+ # follow the formula transferred from official TensorFlow implementation
+ return int(math.ceil(multiplier * repeats))
+
+
+def drop_connect(inputs, p, training):
+ """Drop connect.
+
+ Args:
+ input (tensor: BCWH): Input of this structure.
+ p (float: 0.0~1.0): Probability of drop connection.
+ training (bool): The running mode.
+
+ Returns:
+ output: Output after drop connection.
+ """
+ assert 0 <= p <= 1, 'p must be in range of [0,1]'
+
+ if not training:
+ return inputs
+
+ batch_size = inputs.shape[0]
+ keep_prob = 1 - p
+
+ # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
+ random_tensor = keep_prob
+ random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
+ binary_tensor = torch.floor(random_tensor)
+
+ output = inputs / keep_prob * binary_tensor
+ return output
+
+
+def get_width_and_height_from_size(x):
+ """Obtain height and width from x.
+
+ Args:
+ x (int, tuple or list): Data size.
+
+ Returns:
+ size: A tuple or list (H,W).
+ """
+ if isinstance(x, int):
+ return x, x
+ if isinstance(x, list) or isinstance(x, tuple):
+ return x
+ else:
+ raise TypeError()
+
+
+def calculate_output_image_size(input_image_size, stride):
+ """Calculates the output image size when using Conv2dSamePadding with a stride.
+ Necessary for static padding. Thanks to mannatsingh for pointing this out.
+
+ Args:
+ input_image_size (int, tuple or list): Size of input image.
+ stride (int, tuple or list): Conv2d operation's stride.
+
+ Returns:
+ output_image_size: A list [H,W].
+ """
+ if input_image_size is None:
+ return None
+ image_height, image_width = get_width_and_height_from_size(input_image_size)
+ stride = stride if isinstance(stride, int) else stride[0]
+ image_height = int(math.ceil(image_height / stride))
+ image_width = int(math.ceil(image_width / stride))
+ return [image_height, image_width]
+
+
+# Note:
+# The following 'SamePadding' functions make output size equal ceil(input size/stride).
+# Only when stride equals 1, can the output size be the same as input size.
+# Don't be confused by their function names ! ! !
+
+def get_same_padding_conv2d(image_size=None):
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
+ Static padding is necessary for ONNX exporting of models.
+
+ Args:
+ image_size (int or tuple): Size of the image.
+
+ Returns:
+ Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
+ """
+ if image_size is None:
+ return Conv2dDynamicSamePadding
+ else:
+ return partial(Conv2dStaticSamePadding, image_size=image_size)
+
+
+class Conv2dDynamicSamePadding(nn.Conv2d):
+ """2D Convolutions like TensorFlow, for a dynamic image size.
+ The padding is operated in forward function by calculating dynamically.
+ """
+
+ # Tips for 'SAME' mode padding.
+ # Given the following:
+ # i: width or height
+ # s: stride
+ # k: kernel size
+ # d: dilation
+ # p: padding
+ # Output after Conv2d:
+ # o = floor((i+p-((k-1)*d+1))/s+1)
+ # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
+ # => p = (i-1)*s+((k-1)*d+1)-i
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
+
+ def forward(self, x):
+ ih, iw = x.size()[-2:]
+ kh, kw = self.weight.size()[-2:]
+ sh, sw = self.stride
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+
+class Conv2dStaticSamePadding(nn.Conv2d):
+ """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
+ The padding mudule is calculated in construction function, then used in forward.
+ """
+
+ # With the same calculation as Conv2dDynamicSamePadding
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
+ super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
+
+ # Calculate padding based on image size and save it
+ assert image_size is not None
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
+ kh, kw = self.weight.size()[-2:]
+ sh, sw = self.stride
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+ if pad_h > 0 or pad_w > 0:
+ self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2,
+ pad_h // 2, pad_h - pad_h // 2))
+ else:
+ self.static_padding = nn.Identity()
+
+ def forward(self, x):
+ x = self.static_padding(x)
+ x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+ return x
+
+
+def get_same_padding_maxPool2d(image_size=None):
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
+ Static padding is necessary for ONNX exporting of models.
+
+ Args:
+ image_size (int or tuple): Size of the image.
+
+ Returns:
+ MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
+ """
+ if image_size is None:
+ return MaxPool2dDynamicSamePadding
+ else:
+ return partial(MaxPool2dStaticSamePadding, image_size=image_size)
+
+
+class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
+ The padding is operated in forward function by calculating dynamically.
+ """
+
+ def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False):
+ super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
+ self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
+ self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
+
+ def forward(self, x):
+ ih, iw = x.size()[-2:]
+ kh, kw = self.kernel_size
+ sh, sw = self.stride
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
+ return F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
+ self.dilation, self.ceil_mode, self.return_indices)
+
+
+class MaxPool2dStaticSamePadding(nn.MaxPool2d):
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
+ The padding mudule is calculated in construction function, then used in forward.
+ """
+
+ def __init__(self, kernel_size, stride, image_size=None, **kwargs):
+ super().__init__(kernel_size, stride, **kwargs)
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
+ self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
+ self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
+
+ # Calculate padding based on image size and save it
+ assert image_size is not None
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
+ kh, kw = self.kernel_size
+ sh, sw = self.stride
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+ if pad_h > 0 or pad_w > 0:
+ self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
+ else:
+ self.static_padding = nn.Identity()
+
+ def forward(self, x):
+ x = self.static_padding(x)
+ x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
+ self.dilation, self.ceil_mode, self.return_indices)
+ return x
+
+
+################################################################################
+# Helper functions for loading model params
+################################################################################
+
+# BlockDecoder: A Class for encoding and decoding BlockArgs
+# efficientnet_params: A function to query compound coefficient
+# get_model_params and efficientnet:
+# Functions to get BlockArgs and GlobalParams for efficientnet
+# url_map and url_map_advprop: Dicts of url_map for pretrained weights
+# load_pretrained_weights: A function to load pretrained weights
+
+class BlockDecoder(object):
+ """Block Decoder for readability,
+ straight from the official TensorFlow repository.
+ """
+
+ @staticmethod
+ def _decode_block_string(block_string):
+ """Get a block through a string notation of arguments.
+
+ Args:
+ block_string (str): A string notation of arguments.
+ Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
+
+ Returns:
+ BlockArgs: The namedtuple defined at the top of this file.
+ """
+ assert isinstance(block_string, str)
+
+ ops = block_string.split('_')
+ options = {}
+ for op in ops:
+ splits = re.split(r'(\d.*)', op)
+ if len(splits) >= 2:
+ key, value = splits[:2]
+ options[key] = value
+
+ # Check stride
+ assert (('s' in options and len(options['s']) == 1) or
+ (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
+
+ return BlockArgs(
+ num_repeat=int(options['r']),
+ kernel_size=int(options['k']),
+ stride=[int(options['s'][0])],
+ expand_ratio=int(options['e']),
+ input_filters=int(options['i']),
+ output_filters=int(options['o']),
+ se_ratio=float(options['se']) if 'se' in options else None,
+ id_skip=('noskip' not in block_string))
+
+ @staticmethod
+ def _encode_block_string(block):
+ """Encode a block to a string.
+
+ Args:
+ block (namedtuple): A BlockArgs type argument.
+
+ Returns:
+ block_string: A String form of BlockArgs.
+ """
+ args = [
+ 'r%d' % block.num_repeat,
+ 'k%d' % block.kernel_size,
+ 's%d%d' % (block.strides[0], block.strides[1]),
+ 'e%s' % block.expand_ratio,
+ 'i%d' % block.input_filters,
+ 'o%d' % block.output_filters
+ ]
+ if 0 < block.se_ratio <= 1:
+ args.append('se%s' % block.se_ratio)
+ if block.id_skip is False:
+ args.append('noskip')
+ return '_'.join(args)
+
+ @staticmethod
+ def decode(string_list):
+ """Decode a list of string notations to specify blocks inside the network.
+
+ Args:
+ string_list (list[str]): A list of strings, each string is a notation of block.
+
+ Returns:
+ blocks_args: A list of BlockArgs namedtuples of block args.
+ """
+ assert isinstance(string_list, list)
+ blocks_args = []
+ for block_string in string_list:
+ blocks_args.append(BlockDecoder._decode_block_string(block_string))
+ return blocks_args
+
+ @staticmethod
+ def encode(blocks_args):
+ """Encode a list of BlockArgs to a list of strings.
+
+ Args:
+ blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
+
+ Returns:
+ block_strings: A list of strings, each string is a notation of block.
+ """
+ block_strings = []
+ for block in blocks_args:
+ block_strings.append(BlockDecoder._encode_block_string(block))
+ return block_strings
+
+
+def efficientnet_params(model_name):
+ """Map EfficientNet model name to parameter coefficients.
+
+ Args:
+ model_name (str): Model name to be queried.
+
+ Returns:
+ params_dict[model_name]: A (width,depth,res,dropout) tuple.
+ """
+ params_dict = {
+ # Coefficients: width,depth,res,dropout
+ 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
+ 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
+ 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
+ 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
+ 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
+ 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
+ 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
+ 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
+ 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
+ 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
+ }
+ return params_dict[model_name]
+
+
+def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
+ dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True):
+ """Create BlockArgs and GlobalParams for efficientnet model.
+
+ Args:
+ width_coefficient (float)
+ depth_coefficient (float)
+ image_size (int)
+ dropout_rate (float)
+ drop_connect_rate (float)
+ num_classes (int)
+
+ Meaning as the name suggests.
+
+ Returns:
+ blocks_args, global_params.
+ """
+
+ # Blocks args for the whole model(efficientnet-b0 by default)
+ # It will be modified in the construction of EfficientNet Class according to model
+ blocks_args = [
+ 'r1_k3_s11_e1_i32_o16_se0.25',
+ 'r2_k3_s22_e6_i16_o24_se0.25',
+ 'r2_k5_s22_e6_i24_o40_se0.25',
+ 'r3_k3_s22_e6_i40_o80_se0.25',
+ 'r3_k5_s11_e6_i80_o112_se0.25',
+ 'r4_k5_s22_e6_i112_o192_se0.25',
+ 'r1_k3_s11_e6_i192_o320_se0.25',
+ ]
+ blocks_args = BlockDecoder.decode(blocks_args)
+
+ global_params = GlobalParams(
+ width_coefficient=width_coefficient,
+ depth_coefficient=depth_coefficient,
+ image_size=image_size,
+ dropout_rate=dropout_rate,
+
+ num_classes=num_classes,
+ batch_norm_momentum=0.99,
+ batch_norm_epsilon=1e-3,
+ drop_connect_rate=drop_connect_rate,
+ depth_divisor=8,
+ min_depth=None,
+ include_top=include_top,
+ )
+
+ return blocks_args, global_params
+
+
+def get_model_params(model_name, override_params):
+ """Get the block args and global params for a given model name.
+
+ Args:
+ model_name (str): Model's name.
+ override_params (dict): A dict to modify global_params.
+
+ Returns:
+ blocks_args, global_params
+ """
+ if model_name.startswith('efficientnet'):
+ w, d, s, p = efficientnet_params(model_name)
+ # note: all models have drop connect rate = 0.2
+ blocks_args, global_params = efficientnet(
+ width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
+ else:
+ raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
+ if override_params:
+ # ValueError will be raised here if override_params has fields not included in global_params.
+ global_params = global_params._replace(**override_params)
+ return blocks_args, global_params
+
+
+# train with Standard methods
+# check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks)
+url_map = {
+ 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
+ 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
+ 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
+ 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
+ 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
+ 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
+ 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
+ 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
+}
+
+# train with Adversarial Examples(AdvProp)
+# check more details in paper(Adversarial Examples Improve Image Recognition)
+url_map_advprop = {
+ 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
+ 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
+ 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
+ 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
+ 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
+ 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
+ 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
+ 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
+ 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
+}
+
+# TODO: add the petrained weights url map of 'efficientnet-l2'
+
+
+def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True):
+ """Loads pretrained weights from weights path or download using url.
+
+ Args:
+ model (Module): The whole model of efficientnet.
+ model_name (str): Model name of efficientnet.
+ weights_path (None or str):
+ str: path to pretrained weights file on the local disk.
+ None: use pretrained weights downloaded from the Internet.
+ load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
+ advprop (bool): Whether to load pretrained weights
+ trained with advprop (valid when weights_path is None).
+ """
+ if isinstance(weights_path, str):
+ state_dict = torch.load(weights_path)
+ else:
+ # AutoAugment or Advprop (different preprocessing)
+ url_map_ = url_map_advprop if advprop else url_map
+ state_dict = model_zoo.load_url(url_map_[model_name])
+
+ if load_fc:
+ ret = model.load_state_dict(state_dict, strict=False)
+ assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
+ else:
+ state_dict.pop('_fc.weight')
+ state_dict.pop('_fc.bias')
+ ret = model.load_state_dict(state_dict, strict=False)
+ # assert set(ret.missing_keys) == set(
+ # ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
+ assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
+
+ if verbose:
+ print('Loaded pretrained weights for {}'.format(model_name))
diff --git a/methods/backbone/pvt_v2_eff.py b/methods/backbone/pvt_v2_eff.py
new file mode 100644
index 0000000..6c4eb47
--- /dev/null
+++ b/methods/backbone/pvt_v2_eff.py
@@ -0,0 +1,552 @@
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+from timm.models.vision_transformer import _cfg
+from torch.backends import cuda
+from torch.hub import load_state_dict_from_url
+from torch.utils.checkpoint import checkpoint
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, linear=False
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.dwconv = DWConv(hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+ self.linear = linear
+ if self.linear:
+ self.relu = nn.ReLU(inplace=True)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x, H, W):
+ x = self.fc1(x)
+ if self.linear:
+ x = self.relu(x)
+ x = self.dwconv(x, H, W)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, sr_ratio=1, linear=False
+ ):
+ super().__init__()
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
+
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.linear = linear
+ self.sr_ratio = sr_ratio
+ if not linear:
+ if sr_ratio > 1:
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
+ self.norm = nn.LayerNorm(dim)
+ else:
+ self.pool = nn.AdaptiveAvgPool2d(7)
+ self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
+ self.norm = nn.LayerNorm(dim)
+ self.act = nn.GELU()
+ self.apply(self._init_weights)
+
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
+ if device_properties.major == 8 and device_properties.minor == 0:
+ # print("A100 GPU detected, using flash attention if input tensor is on cuda")
+ self.cuda_config = {"enable_flash": True, "enable_math": False, "enable_mem_efficient": False}
+ else:
+ # print("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda")
+ self.cuda_config = {"enable_flash": False, "enable_math": True, "enable_mem_efficient": True}
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+
+ if not self.linear:
+ if self.sr_ratio > 1:
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
+ x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
+ x_ = self.norm(x_)
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ else:
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ else:
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
+ x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
+ x_ = self.norm(x_)
+ x_ = self.act(x_)
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ k, v = kv[0], kv[1]
+
+ # attn = (q @ k.transpose(-2, -1)) * self.scale
+ # attn = attn.softmax(dim=-1)
+ # attn = self.attn_drop(attn)
+ # x = attn @ v
+ with cuda.sdp_kernel(**self.cuda_config):
+ # q: bs,nh,l,hd, k: bs,nh,s,hd, v: bs,nh,s,hd
+ # same as: (((q @ k.transpose(-1, -2)) * q.shape[-1] ** -0.5).softmax(dim=-1) @ v)
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False) # built-in scale
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ sr_ratio=1,
+ linear=False,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ sr_ratio=sr_ratio,
+ linear=linear,
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x, H, W):
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
+
+ return x
+
+
+class OverlapPatchEmbed(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
+ super().__init__()
+
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+
+ assert max(patch_size) > stride, "Set larger patch_size than stride"
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.H, self.W = img_size[0] // stride, img_size[1] // stride
+ self.num_patches = self.H * self.W
+ self.proj = nn.Conv2d(
+ in_chans,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=stride,
+ padding=(patch_size[0] // 2, patch_size[1] // 2),
+ )
+ self.norm = nn.LayerNorm(embed_dim)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ x = self.proj(x)
+ _, _, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+
+ return x, H, W
+
+
+class PyramidVisionTransformerV2(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ num_classes=1000,
+ embed_dims=[64, 128, 256, 512],
+ num_heads=[1, 2, 4, 8],
+ mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ norm_layer=nn.LayerNorm,
+ depths=[3, 4, 6, 3],
+ sr_ratios=[8, 4, 2, 1],
+ num_stages=4,
+ linear=False,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+ self.num_classes = num_classes
+ self.depths = depths
+ self.num_stages = num_stages
+ self.embed_dims = embed_dims
+ self.use_checkpoint = use_checkpoint
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+ cur = 0
+
+ for i in range(num_stages):
+ patch_embed = OverlapPatchEmbed(
+ img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
+ patch_size=7 if i == 0 else 3,
+ stride=4 if i == 0 else 2,
+ in_chans=in_chans if i == 0 else embed_dims[i - 1],
+ embed_dim=embed_dims[i],
+ )
+
+ block = nn.ModuleList(
+ [
+ Block(
+ dim=embed_dims[i],
+ num_heads=num_heads[i],
+ mlp_ratio=mlp_ratios[i],
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[cur + j],
+ norm_layer=norm_layer,
+ sr_ratio=sr_ratios[i],
+ linear=linear,
+ )
+ for j in range(depths[i])
+ ]
+ )
+ norm = norm_layer(embed_dims[i])
+ cur += depths[i]
+
+ setattr(self, f"patch_embed{i + 1}", patch_embed)
+ setattr(self, f"block{i + 1}", block)
+ setattr(self, f"norm{i + 1}", norm)
+
+ # classification head
+ # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def freeze_patch_emb(self):
+ self.patch_embed1.requires_grad = False
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"pos_embed1", "pos_embed2", "pos_embed3", "pos_embed4", "cls_token"} # has pos_embed may be better
+
+ def extract_endpoints(self, x):
+ B = x.shape[0]
+ endpoints = dict()
+ for i in range(self.num_stages):
+ patch_embed = getattr(self, f"patch_embed{i + 1}")
+ block = getattr(self, f"block{i + 1}")
+ norm = getattr(self, f"norm{i + 1}")
+ x, H, W = patch_embed(x)
+ for blk in block:
+ if self.use_checkpoint:
+ x = checkpoint(blk, x, H, W)
+ else:
+ x = blk(x, H, W)
+ x = norm(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ # print(i + 2, x.shape)
+ endpoints["reduction_{}".format(i + 2)] = x
+ return endpoints
+
+ def forward(self, x):
+ endpoints = self.extract_endpoints(x)
+ return endpoints
+
+
+class DWConv(nn.Module):
+ def __init__(self, dim=768):
+ super(DWConv, self).__init__()
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ x = x.transpose(1, 2).view(B, C, H, W)
+ x = self.dwconv(x)
+ x = x.flatten(2).transpose(1, 2)
+
+ return x
+
+
+def _conv_filter(state_dict, patch_size=16):
+ """convert patch embedding weight from manual patchify + linear proj to conv"""
+ out_dict = {}
+ for k, v in state_dict.items():
+ if "patch_embed.proj.weight" in k:
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
+ out_dict[k] = v
+
+ return out_dict
+
+
+def pvt_v2_eff_b0(pretrained=False, **kwargs):
+ model = PyramidVisionTransformerV2(
+ patch_size=4,
+ embed_dims=[32, 64, 160, 256],
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ depths=[2, 2, 2, 2],
+ sr_ratios=[8, 4, 2, 1],
+ **kwargs,
+ )
+ model.default_cfg = _cfg()
+ if pretrained:
+ state_dict = load_state_dict_from_url(
+ "https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b0.pth", progress=True
+ )
+ state_dict.pop("head.weight")
+ state_dict.pop("head.bias")
+ model.load_state_dict(state_dict)
+ return model
+
+
+def pvt_v2_eff_b1(pretrained=False, **kwargs):
+ model = PyramidVisionTransformerV2(
+ patch_size=4,
+ embed_dims=[64, 128, 320, 512],
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ depths=[2, 2, 2, 2],
+ sr_ratios=[8, 4, 2, 1],
+ **kwargs,
+ )
+ model.default_cfg = _cfg()
+ if pretrained:
+ state_dict = load_state_dict_from_url(
+ "https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b1.pth", progress=True
+ )
+ state_dict.pop("head.weight")
+ state_dict.pop("head.bias")
+ model.load_state_dict(state_dict)
+ return model
+
+
+def pvt_v2_eff_b2(pretrained=False, **kwargs):
+ model = PyramidVisionTransformerV2(
+ patch_size=4,
+ embed_dims=[64, 128, 320, 512],
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ depths=[3, 4, 6, 3],
+ sr_ratios=[8, 4, 2, 1],
+ **kwargs,
+ )
+ model.default_cfg = _cfg()
+ if pretrained:
+ state_dict = load_state_dict_from_url(
+ "https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2.pth", progress=True
+ )
+ state_dict.pop("head.weight")
+ state_dict.pop("head.bias")
+ model.load_state_dict(state_dict)
+ return model
+
+
+def pvt_v2_eff_b3(pretrained=False, **kwargs):
+ model = PyramidVisionTransformerV2(
+ patch_size=4,
+ embed_dims=[64, 128, 320, 512],
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ depths=[3, 4, 18, 3],
+ sr_ratios=[8, 4, 2, 1],
+ **kwargs,
+ )
+ model.default_cfg = _cfg()
+ if pretrained:
+ state_dict = load_state_dict_from_url(
+ "https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b3.pth", progress=True
+ )
+ state_dict.pop("head.weight")
+ state_dict.pop("head.bias")
+ model.load_state_dict(state_dict)
+
+ return model
+
+
+def pvt_v2_eff_b4(pretrained=False, **kwargs):
+ model = PyramidVisionTransformerV2(
+ patch_size=4,
+ embed_dims=[64, 128, 320, 512],
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ depths=[3, 8, 27, 3],
+ sr_ratios=[8, 4, 2, 1],
+ **kwargs,
+ )
+ model.default_cfg = _cfg()
+ if pretrained:
+ state_dict = load_state_dict_from_url(
+ "https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b4.pth", progress=True
+ )
+ state_dict.pop("head.weight")
+ state_dict.pop("head.bias")
+ model.load_state_dict(state_dict)
+ return model
+
+
+def pvt_v2_eff_b5(pretrained=False, **kwargs):
+ model = PyramidVisionTransformerV2(
+ patch_size=4,
+ embed_dims=[64, 128, 320, 512],
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ depths=[3, 6, 40, 3],
+ sr_ratios=[8, 4, 2, 1],
+ **kwargs,
+ )
+ model.default_cfg = _cfg()
+ if pretrained:
+ state_dict = load_state_dict_from_url(
+ "https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b5.pth", progress=True
+ )
+ state_dict.pop("head.weight")
+ state_dict.pop("head.bias")
+ model.load_state_dict(state_dict)
+ return model
+
+
+def pvt_v2_eff_b2_li(pretrained=False, **kwargs):
+ model = PyramidVisionTransformerV2(
+ patch_size=4,
+ embed_dims=[64, 128, 320, 512],
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ depths=[3, 4, 6, 3],
+ sr_ratios=[8, 4, 2, 1],
+ linear=True,
+ **kwargs,
+ )
+ model.default_cfg = _cfg()
+
+ return model
diff --git a/methods/zoomnext/__init__.py b/methods/zoomnext/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/methods/zoomnext/layers.py b/methods/zoomnext/layers.py
new file mode 100644
index 0000000..c232ebe
--- /dev/null
+++ b/methods/zoomnext/layers.py
@@ -0,0 +1,182 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+from .ops import ConvBNReLU, resize_to
+
+
+class SimpleASPP(nn.Module):
+ def __init__(self, in_dim, out_dim, dilation=3):
+ """A simple ASPP variant.
+
+ Args:
+ in_dim (int): Input channels.
+ out_dim (int): Output channels.
+ dilation (int, optional): Dilation of the convolution operation. Defaults to 3.
+ """
+ super().__init__()
+ self.conv1x1_1 = ConvBNReLU(in_dim, 2 * out_dim, 1)
+ self.conv1x1_2 = ConvBNReLU(out_dim, out_dim, 1)
+ self.conv3x3_1 = ConvBNReLU(out_dim, out_dim, 3, dilation=dilation, padding=dilation)
+ self.conv3x3_2 = ConvBNReLU(out_dim, out_dim, 3, dilation=dilation, padding=dilation)
+ self.conv3x3_3 = ConvBNReLU(out_dim, out_dim, 3, dilation=dilation, padding=dilation)
+ self.fuse = nn.Sequential(ConvBNReLU(5 * out_dim, out_dim, 1), ConvBNReLU(out_dim, out_dim, 3, 1, 1))
+
+ def forward(self, x):
+ y = self.conv1x1_1(x)
+ y1, y5 = y.chunk(2, dim=1)
+
+ # dilation branch
+ y2 = self.conv3x3_1(y1)
+ y3 = self.conv3x3_2(y2)
+ y4 = self.conv3x3_3(y3)
+
+ # global branch
+ y0 = torch.mean(y5, dim=(2, 3), keepdim=True)
+ y0 = self.conv1x1_2(y0)
+ y0 = resize_to(y0, tgt_hw=x.shape[-2:])
+ return self.fuse(torch.cat([y0, y1, y2, y3, y4], dim=1))
+
+
+class DifferenceAwareOps(nn.Module):
+ def __init__(self, num_frames):
+ super().__init__()
+ self.num_frames = num_frames
+
+ self.temperal_proj_norm = nn.LayerNorm(num_frames, elementwise_affine=False)
+ self.temperal_proj_kv = nn.Linear(num_frames, 2 * num_frames, bias=False)
+ self.temperal_proj = nn.Sequential(
+ nn.Conv2d(num_frames, num_frames, 3, 1, 1, bias=False),
+ nn.ReLU(True),
+ nn.Conv2d(num_frames, num_frames, 3, 1, 1, bias=False),
+ )
+ for t in self.parameters():
+ nn.init.zeros_(t)
+
+ def forward(self, x):
+ if self.num_frames == 1:
+ return x
+
+ unshifted_x_tmp = rearrange(x, "(b t) c h w -> b c h w t", t=self.num_frames)
+ B, C, H, W, T = unshifted_x_tmp.shape
+ shifted_x_tmp = torch.roll(unshifted_x_tmp, shifts=1, dims=-1)
+ diff_q = shifted_x_tmp - unshifted_x_tmp # B,C,H,W,T
+ diff_q = self.temperal_proj_norm(diff_q) # normalization along the time
+
+ # merge all channels
+ diff_k, diff_v = self.temperal_proj_kv(diff_q).chunk(2, dim=-1)
+ diff_qk = torch.einsum("bxhwt, byhwt -> bxyt", diff_q, diff_k) * (H * W) ** -0.5
+ temperal_diff = torch.einsum("bxyt, byhwt -> bxhwt", diff_qk.softmax(dim=2), diff_v)
+
+ temperal_diff = rearrange(temperal_diff, "b c h w t -> (b c) t h w")
+ shifted_x_tmp = self.temperal_proj(temperal_diff) # combine different time step
+ shifted_x_tmp = rearrange(shifted_x_tmp, "(b c) t h w -> (b t) c h w", c=x.shape[1])
+ return x + shifted_x_tmp
+
+
+class RGPU(nn.Module):
+ def __init__(self, in_c, num_groups=6, hidden_dim=None, num_frames=1):
+ super().__init__()
+ self.num_groups = num_groups
+
+ hidden_dim = hidden_dim or in_c // 2
+ expand_dim = hidden_dim * num_groups
+ self.expand_conv = ConvBNReLU(in_c, expand_dim, 1)
+ self.gate_genator = nn.Sequential(
+ nn.AdaptiveAvgPool2d((1, 1)),
+ nn.Conv2d(num_groups * hidden_dim, hidden_dim, 1),
+ nn.ReLU(True),
+ nn.Conv2d(hidden_dim, num_groups * hidden_dim, 1),
+ nn.Softmax(dim=1),
+ )
+
+ self.interact = nn.ModuleDict()
+ self.interact["0"] = ConvBNReLU(hidden_dim, 3 * hidden_dim, 3, 1, 1)
+ for group_id in range(1, num_groups - 1):
+ self.interact[str(group_id)] = ConvBNReLU(2 * hidden_dim, 3 * hidden_dim, 3, 1, 1)
+ self.interact[str(num_groups - 1)] = ConvBNReLU(2 * hidden_dim, 2 * hidden_dim, 3, 1, 1)
+
+ self.fuse = nn.Sequential(
+ DifferenceAwareOps(num_frames=num_frames),
+ ConvBNReLU(num_groups * hidden_dim, in_c, 3, 1, 1, act_name=None),
+ )
+ self.final_relu = nn.ReLU(True)
+
+ def forward(self, x):
+ xs = self.expand_conv(x).chunk(self.num_groups, dim=1)
+
+ outs = []
+ gates = []
+
+ group_id = 0
+ curr_x = xs[group_id]
+ branch_out = self.interact[str(group_id)](curr_x)
+ curr_out, curr_fork, curr_gate = branch_out.chunk(3, dim=1)
+ outs.append(curr_out)
+ gates.append(curr_gate)
+
+ for group_id in range(1, self.num_groups - 1):
+ curr_x = torch.cat([xs[group_id], curr_fork], dim=1)
+ branch_out = self.interact[str(group_id)](curr_x)
+ curr_out, curr_fork, curr_gate = branch_out.chunk(3, dim=1)
+ outs.append(curr_out)
+ gates.append(curr_gate)
+
+ group_id = self.num_groups - 1
+ curr_x = torch.cat([xs[group_id], curr_fork], dim=1)
+ branch_out = self.interact[str(group_id)](curr_x)
+ curr_out, curr_gate = branch_out.chunk(2, dim=1)
+ outs.append(curr_out)
+ gates.append(curr_gate)
+
+ out = torch.cat(outs, dim=1)
+ gate = self.gate_genator(torch.cat(gates, dim=1))
+ out = self.fuse(out * gate)
+ return self.final_relu(out + x)
+
+
+class MHSIU(nn.Module):
+ def __init__(self, in_dim, num_groups=4):
+ super().__init__()
+ self.conv_l_pre = ConvBNReLU(in_dim, in_dim, 3, 1, 1)
+ self.conv_s_pre = ConvBNReLU(in_dim, in_dim, 3, 1, 1)
+
+ self.conv_l = ConvBNReLU(in_dim, in_dim, 3, 1, 1) # intra-branch
+ self.conv_m = ConvBNReLU(in_dim, in_dim, 3, 1, 1) # intra-branch
+ self.conv_s = ConvBNReLU(in_dim, in_dim, 3, 1, 1) # intra-branch
+
+ self.conv_lms = ConvBNReLU(3 * in_dim, 3 * in_dim, 1) # inter-branch
+ self.initial_merge = ConvBNReLU(3 * in_dim, 3 * in_dim, 1) # inter-branch
+
+ self.num_groups = num_groups
+ self.trans = nn.Sequential(
+ ConvBNReLU(3 * in_dim // num_groups, in_dim // num_groups, 1),
+ ConvBNReLU(in_dim // num_groups, in_dim // num_groups, 3, 1, 1),
+ nn.Conv2d(in_dim // num_groups, 3, 1),
+ nn.Softmax(dim=1),
+ )
+
+ def forward(self, l, m, s):
+ tgt_size = m.shape[2:]
+
+ l = self.conv_l_pre(l)
+ l = F.adaptive_max_pool2d(l, tgt_size) + F.adaptive_avg_pool2d(l, tgt_size)
+ s = self.conv_s_pre(s)
+ s = resize_to(s, tgt_hw=m.shape[2:])
+
+ l = self.conv_l(l)
+ m = self.conv_m(m)
+ s = self.conv_s(s)
+ lms = torch.cat([l, m, s], dim=1) # BT,3C,H,W
+
+ attn = self.conv_lms(lms) # BT,3C,H,W
+ attn = rearrange(attn, "bt (nb ng d) h w -> (bt ng) (nb d) h w", nb=3, ng=self.num_groups)
+ attn = self.trans(attn) # BTG,3,H,W
+ attn = attn.unsqueeze(dim=2) # BTG,3,1,H,W
+
+ x = self.initial_merge(lms)
+ x = rearrange(x, "bt (nb ng d) h w -> (bt ng) nb d h w", nb=3, ng=self.num_groups)
+ x = (attn * x).sum(dim=1)
+ x = rearrange(x, "(bt ng) d h w -> bt (ng d) h w", ng=self.num_groups)
+ return x
diff --git a/methods/zoomnext/ops.py b/methods/zoomnext/ops.py
new file mode 100644
index 0000000..d97cbf1
--- /dev/null
+++ b/methods/zoomnext/ops.py
@@ -0,0 +1,222 @@
+# -*- coding: utf-8 -*-
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import to_2tuple
+
+
+def rescale_2x(x: torch.Tensor, scale_factor=2):
+ return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
+
+
+def resize_to(x: torch.Tensor, tgt_hw: tuple):
+ return F.interpolate(x, size=tgt_hw, mode="bilinear", align_corners=False)
+
+
+def global_avgpool(x: torch.Tensor):
+ return x.mean((-1, -2), keepdim=True)
+
+
+def _get_act_fn(act_name, inplace=True):
+ if act_name == "relu":
+ return nn.ReLU(inplace=inplace)
+ elif act_name == "leaklyrelu":
+ return nn.LeakyReLU(negative_slope=0.1, inplace=inplace)
+ elif act_name == "gelu":
+ return nn.GELU()
+ elif act_name == "sigmoid":
+ return nn.Sigmoid()
+ else:
+ raise NotImplementedError
+
+
+class ConvBN(nn.Module):
+ def __init__(self, in_dim, out_dim, k, s=1, p=0, d=1, g=1, bias=True):
+ super(ConvBN, self).__init__()
+ self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+ self.bn = nn.BatchNorm2d(out_dim)
+
+ def forward(self, x):
+ return self.bn(self.conv(x))
+
+
+class CBR(nn.Module):
+ def __init__(self, in_dim, out_dim, k, s=1, p=0, d=1, bias=True):
+ super().__init__()
+ self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p, dilation=d, bias=bias)
+ self.bn = nn.BatchNorm2d(out_dim)
+ self.relu = nn.ReLU(True)
+
+ def forward(self, x):
+ return self.relu(self.bn(self.conv(x)))
+
+
+class ConvBNReLU(nn.Sequential):
+ def __init__(
+ self,
+ in_planes,
+ out_planes,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=False,
+ act_name="relu",
+ is_transposed=False,
+ ):
+ """
+ Convolution-BatchNormalization-ActivationLayer
+
+ :param in_planes:
+ :param out_planes:
+ :param kernel_size:
+ :param stride:
+ :param padding:
+ :param dilation:
+ :param groups:
+ :param bias:
+ :param act_name: None denote it doesn't use the activation layer.
+ :param is_transposed: True -> nn.ConvTranspose2d, False -> nn.Conv2d
+ """
+ super().__init__()
+ self.in_planes = in_planes
+ self.out_planes = out_planes
+
+ if is_transposed:
+ conv_module = nn.ConvTranspose2d
+ else:
+ conv_module = nn.Conv2d
+ self.add_module(
+ name="conv",
+ module=conv_module(
+ in_planes,
+ out_planes,
+ kernel_size=kernel_size,
+ stride=to_2tuple(stride),
+ padding=to_2tuple(padding),
+ dilation=to_2tuple(dilation),
+ groups=groups,
+ bias=bias,
+ ),
+ )
+ self.add_module(name="bn", module=nn.BatchNorm2d(out_planes))
+ if act_name is not None:
+ self.add_module(name=act_name, module=_get_act_fn(act_name=act_name))
+
+
+class ConvGNReLU(nn.Sequential):
+ def __init__(
+ self,
+ in_planes,
+ out_planes,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ gn_groups=8,
+ bias=False,
+ act_name="relu",
+ inplace=True,
+ ):
+ """
+ 执行流程Conv2d => GroupNormalization [=> Activation]
+
+ Args:
+ in_planes: 模块输入通道数
+ out_planes: 模块输出通道数
+ kernel_size: 内部卷积操作的卷积核大小
+ stride: 卷积步长
+ padding: 卷积padding
+ dilation: 卷积的扩张率
+ groups: 卷积分组数,需满足pytorch自身要求
+ gn_groups: GroupNormalization的分组数,默认为4
+ bias: 是否启用卷积的偏置,默认为False
+ act_name: 使用的激活函数,默认为relu,设置为None的时候则不使用激活函数
+ inplace: 设置激活函数的inplace参数
+ """
+ super().__init__()
+ self.in_planes = in_planes
+ self.out_planes = out_planes
+
+ self.add_module(
+ name="conv",
+ module=nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=kernel_size,
+ stride=to_2tuple(stride),
+ padding=to_2tuple(padding),
+ dilation=to_2tuple(dilation),
+ groups=groups,
+ bias=bias,
+ ),
+ )
+ self.add_module(name="gn", module=nn.GroupNorm(num_groups=gn_groups, num_channels=out_planes))
+ if act_name is not None:
+ self.add_module(name=act_name, module=_get_act_fn(act_name=act_name, inplace=inplace))
+
+
+class PixelNormalizer(nn.Module):
+ def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
+ """Divide pixel values by 255 = 2**8 - 1, subtract mean per channel and divide by std per channel.
+
+ Args:
+ mean (tuple, optional): the mean value. Defaults to (0.485, 0.456, 0.406).
+ std (tuple, optional): the std value. Defaults to (0.229, 0.224, 0.225).
+ """
+ super().__init__()
+ # self.norm = A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
+ self.register_buffer(name="mean", tensor=torch.Tensor(mean).reshape(3, 1, 1))
+ self.register_buffer(name="std", tensor=torch.Tensor(std).reshape(3, 1, 1))
+
+ def __repr__(self):
+ return self.__class__.__name__ + f"(mean={self.mean.flatten()}, std={self.std.flatten()})"
+
+ def forward(self, x):
+ """normalize x by the mean and std values
+
+ Args:
+ x (torch.Tensor): input tensor
+
+ Returns:
+ torch.Tensor: output tensor
+
+ Albumentations:
+
+ ```
+ mean = np.array(mean, dtype=np.float32)
+ mean *= max_pixel_value
+ std = np.array(std, dtype=np.float32)
+ std *= max_pixel_value
+ denominator = np.reciprocal(std, dtype=np.float32)
+
+ img = img.astype(np.float32)
+ img -= mean
+ img *= denominator
+ ```
+ """
+ x = x.sub(self.mean)
+ x = x.div(self.std)
+ return x
+
+
+class LayerNorm2d(nn.Module):
+ """
+ From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py
+ Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119
+ """
+
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(num_channels))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
diff --git a/methods/zoomnext/zoomnext.py b/methods/zoomnext/zoomnext.py
new file mode 100644
index 0000000..f048665
--- /dev/null
+++ b/methods/zoomnext/zoomnext.py
@@ -0,0 +1,376 @@
+import abc
+import logging
+
+import numpy as np
+import timm
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..backbone.efficientnet import EfficientNet
+from ..backbone.pvt_v2_eff import pvt_v2_eff_b2, pvt_v2_eff_b3, pvt_v2_eff_b4, pvt_v2_eff_b5
+from .layers import MHSIU, RGPU, SimpleASPP
+from .ops import ConvBNReLU, PixelNormalizer, resize_to
+
+LOGGER = logging.getLogger("main")
+
+
+class _ZoomNeXt_Base(nn.Module):
+ @staticmethod
+ def get_coef(iter_percentage=1, method="cos", milestones=(0, 1)):
+ min_point, max_point = min(milestones), max(milestones)
+ min_coef, max_coef = 0, 1
+
+ ual_coef = 1.0
+ if iter_percentage < min_point:
+ ual_coef = min_coef
+ elif iter_percentage > max_point:
+ ual_coef = max_coef
+ else:
+ if method == "linear":
+ ratio = (max_coef - min_coef) / (max_point - min_point)
+ ual_coef = ratio * (iter_percentage - min_point)
+ elif method == "cos":
+ perc = (iter_percentage - min_point) / (max_point - min_point)
+ normalized_coef = (1 - np.cos(perc * np.pi)) / 2
+ ual_coef = normalized_coef * (max_coef - min_coef) + min_coef
+ return ual_coef
+
+ @abc.abstractmethod
+ def body(self):
+ pass
+
+ def forward(self, data, iter_percentage=1, **kwargs):
+ logits = self.body(data=data)
+
+ if self.training:
+ mask = data["mask"]
+ prob = logits.sigmoid()
+
+ losses = []
+ loss_str = []
+
+ sod_loss = F.binary_cross_entropy_with_logits(input=logits, target=mask, reduction="mean")
+ losses.append(sod_loss)
+ loss_str.append(f"bce: {sod_loss.item():.5f}")
+
+ ual_coef = self.get_coef(iter_percentage=iter_percentage, method="cos", milestones=(0, 1))
+ ual_loss = ual_coef * (1 - (2 * prob - 1).abs().pow(2)).mean()
+ losses.append(ual_loss)
+ loss_str.append(f"powual_{ual_coef:.5f}: {ual_loss.item():.5f}")
+ return dict(vis=dict(sal=prob), loss=sum(losses), loss_str=" ".join(loss_str))
+ else:
+ return logits
+
+ def get_grouped_params(self):
+ param_groups = {"pretrained": [], "fixed": [], "retrained": []}
+ for name, param in self.named_parameters():
+ if name.startswith("encoder.patch_embed1."):
+ param.requires_grad = False
+ param_groups["fixed"].append(param)
+ elif name.startswith("encoder."):
+ param_groups["pretrained"].append(param)
+ else:
+ if "clip." in name:
+ param.requires_grad = False
+ param_groups["fixed"].append(param)
+ else:
+ param_groups["retrained"].append(param)
+ LOGGER.info(
+ f"Parameter Groups:{{"
+ f"Pretrained: {len(param_groups['pretrained'])}, "
+ f"Fixed: {len(param_groups['fixed'])}, "
+ f"ReTrained: {len(param_groups['retrained'])}}}"
+ )
+ return param_groups
+
+
+class RN50_ZoomNeXt(_ZoomNeXt_Base):
+ def __init__(self, pretrained=True, num_frames=1, input_norm=True, mid_dim=64, siu_groups=4, hmu_groups=6):
+ super().__init__()
+ self.encoder = timm.create_model(
+ model_name="resnet50", features_only=True, out_indices=range(5), pretrained=False
+ )
+ if pretrained:
+ params = torch.hub.load_state_dict_from_url(
+ url="https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.1/resnet50-timm.pth",
+ map_location="cpu",
+ )
+ self.encoder.load_state_dict(params, strict=False)
+
+ self.tra_5 = SimpleASPP(in_dim=2048, out_dim=mid_dim)
+ self.siu_5 = MHSIU(mid_dim, siu_groups)
+ self.hmu_5 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.tra_4 = ConvBNReLU(1024, mid_dim, 3, 1, 1)
+ self.siu_4 = MHSIU(mid_dim, siu_groups)
+ self.hmu_4 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.tra_3 = ConvBNReLU(512, mid_dim, 3, 1, 1)
+ self.siu_3 = MHSIU(mid_dim, siu_groups)
+ self.hmu_3 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.tra_2 = ConvBNReLU(256, mid_dim, 3, 1, 1)
+ self.siu_2 = MHSIU(mid_dim, siu_groups)
+ self.hmu_2 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.tra_1 = ConvBNReLU(64, mid_dim, 3, 1, 1)
+ self.siu_1 = MHSIU(mid_dim, siu_groups)
+ self.hmu_1 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.normalizer = PixelNormalizer() if input_norm else nn.Identity()
+ self.predictor = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
+ ConvBNReLU(64, 32, 3, 1, 1),
+ nn.Conv2d(32, 1, 1),
+ )
+
+ def normalize_encoder(self, x):
+ x = self.normalizer(x)
+ c1, c2, c3, c4, c5 = self.encoder(x)
+ return c1, c2, c3, c4, c5
+
+ def body(self, data):
+ l_trans_feats = self.normalize_encoder(data["image_l"])
+ m_trans_feats = self.normalize_encoder(data["image_m"])
+ s_trans_feats = self.normalize_encoder(data["image_s"])
+
+ l, m, s = (
+ self.tra_5(l_trans_feats[4]),
+ self.tra_5(m_trans_feats[4]),
+ self.tra_5(s_trans_feats[4]),
+ )
+ lms = self.siu_5(l=l, m=m, s=s)
+ x = self.hmu_5(lms)
+
+ l, m, s = (
+ self.tra_4(l_trans_feats[3]),
+ self.tra_4(m_trans_feats[3]),
+ self.tra_4(s_trans_feats[3]),
+ )
+ lms = self.siu_4(l=l, m=m, s=s)
+ x = self.hmu_4(lms + resize_to(x, tgt_hw=lms.shape[-2:]))
+
+ l, m, s = (
+ self.tra_3(l_trans_feats[2]),
+ self.tra_3(m_trans_feats[2]),
+ self.tra_3(s_trans_feats[2]),
+ )
+ lms = self.siu_3(l=l, m=m, s=s)
+ x = self.hmu_3(lms + resize_to(x, tgt_hw=lms.shape[-2:]))
+
+ l, m, s = (
+ self.tra_2(l_trans_feats[1]),
+ self.tra_2(m_trans_feats[1]),
+ self.tra_2(s_trans_feats[1]),
+ )
+ lms = self.siu_2(l=l, m=m, s=s)
+ x = self.hmu_2(lms + resize_to(x, tgt_hw=lms.shape[-2:]))
+
+ l, m, s = (
+ self.tra_1(l_trans_feats[0]),
+ self.tra_1(m_trans_feats[0]),
+ self.tra_1(s_trans_feats[0]),
+ )
+ lms = self.siu_1(l=l, m=m, s=s)
+ x = self.hmu_1(lms + resize_to(x, tgt_hw=lms.shape[-2:]))
+
+ return self.predictor(x)
+
+
+class PvtV2B2_ZoomNeXt(_ZoomNeXt_Base):
+ def __init__(
+ self,
+ pretrained=True,
+ num_frames=1,
+ input_norm=True,
+ mid_dim=64,
+ siu_groups=4,
+ hmu_groups=6,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+ self.set_backbone(pretrained=pretrained, use_checkpoint=use_checkpoint)
+
+ self.embed_dims = self.encoder.embed_dims
+ self.tra_5 = SimpleASPP(self.embed_dims[3], out_dim=mid_dim)
+ self.siu_5 = MHSIU(mid_dim, siu_groups)
+ self.hmu_5 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.tra_4 = ConvBNReLU(self.embed_dims[2], mid_dim, 3, 1, 1)
+ self.siu_4 = MHSIU(mid_dim, siu_groups)
+ self.hmu_4 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.tra_3 = ConvBNReLU(self.embed_dims[1], mid_dim, 3, 1, 1)
+ self.siu_3 = MHSIU(mid_dim, siu_groups)
+ self.hmu_3 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.tra_2 = ConvBNReLU(self.embed_dims[0], mid_dim, 3, 1, 1)
+ self.siu_2 = MHSIU(mid_dim, siu_groups)
+ self.hmu_2 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.tra_1 = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), ConvBNReLU(64, mid_dim, 3, 1, 1)
+ )
+
+ self.normalizer = PixelNormalizer() if input_norm else nn.Identity()
+ self.predictor = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
+ ConvBNReLU(64, 32, 3, 1, 1),
+ nn.Conv2d(32, 1, 1),
+ )
+
+ def set_backbone(self, pretrained: bool, use_checkpoint: bool):
+ self.encoder = pvt_v2_eff_b2(pretrained=pretrained, use_checkpoint=use_checkpoint)
+
+ def normalize_encoder(self, x):
+ x = self.normalizer(x)
+ features = self.encoder(x)
+ c2 = features["reduction_2"]
+ c3 = features["reduction_3"]
+ c4 = features["reduction_4"]
+ c5 = features["reduction_5"]
+ return c2, c3, c4, c5
+
+ def body(self, data):
+ l_trans_feats = self.normalize_encoder(data["image_l"])
+ m_trans_feats = self.normalize_encoder(data["image_m"])
+ s_trans_feats = self.normalize_encoder(data["image_s"])
+
+ l, m, s = self.tra_5(l_trans_feats[3]), self.tra_5(m_trans_feats[3]), self.tra_5(s_trans_feats[3])
+ lms = self.siu_5(l=l, m=m, s=s)
+ x = self.hmu_5(lms)
+
+ l, m, s = self.tra_4(l_trans_feats[2]), self.tra_4(m_trans_feats[2]), self.tra_4(s_trans_feats[2])
+ lms = self.siu_4(l=l, m=m, s=s)
+ x = self.hmu_4(lms + resize_to(x, tgt_hw=lms.shape[-2:]))
+
+ l, m, s = self.tra_3(l_trans_feats[1]), self.tra_3(m_trans_feats[1]), self.tra_3(s_trans_feats[1])
+ lms = self.siu_3(l=l, m=m, s=s)
+ x = self.hmu_3(lms + resize_to(x, tgt_hw=lms.shape[-2:]))
+
+ l, m, s = self.tra_2(l_trans_feats[0]), self.tra_2(m_trans_feats[0]), self.tra_2(s_trans_feats[0])
+ lms = self.siu_2(l=l, m=m, s=s)
+ x = self.hmu_2(lms + resize_to(x, tgt_hw=lms.shape[-2:]))
+
+ x = self.tra_1(x)
+ return self.predictor(x)
+
+
+class PvtV2B3_ZoomNeXt(PvtV2B2_ZoomNeXt):
+ def set_backbone(self, pretrained: bool, use_checkpoint: bool):
+ self.encoder = pvt_v2_eff_b3(pretrained=pretrained, use_checkpoint=use_checkpoint)
+
+
+class PvtV2B4_ZoomNeXt(PvtV2B2_ZoomNeXt):
+ def set_backbone(self, pretrained: bool, use_checkpoint: bool):
+ self.encoder = pvt_v2_eff_b4(pretrained=pretrained, use_checkpoint=use_checkpoint)
+
+
+class PvtV2B5_ZoomNeXt(PvtV2B2_ZoomNeXt):
+ def set_backbone(self, pretrained: bool, use_checkpoint: bool):
+ self.encoder = pvt_v2_eff_b5(pretrained=pretrained, use_checkpoint=use_checkpoint)
+
+
+class videoPvtV2B5_ZoomNeXt(PvtV2B5_ZoomNeXt):
+ def get_grouped_params(self):
+ param_groups = {"pretrained": [], "fixed": [], "retrained": []}
+ for name, param in self.named_parameters():
+ if name.startswith("encoder.patch_embed1."):
+ param.requires_grad = False
+ param_groups["fixed"].append(param)
+ elif name.startswith("encoder."):
+ param_groups["pretrained"].append(param)
+ else:
+ if "temperal_proj" in name:
+ param_groups["retrained"].append(param)
+ else:
+ param_groups["pretrained"].append(param)
+
+ LOGGER.info(
+ f"Parameter Groups:{{"
+ f"Pretrained: {len(param_groups['pretrained'])}, "
+ f"Fixed: {len(param_groups['fixed'])}, "
+ f"ReTrained: {len(param_groups['retrained'])}}}"
+ )
+ return param_groups
+
+
+class EffB1_ZoomNeXt(_ZoomNeXt_Base):
+ def __init__(self, pretrained, num_frames=1, input_norm=True, mid_dim=64, siu_groups=4, hmu_groups=6):
+ super().__init__()
+ self.set_backbone(pretrained)
+
+ self.tra_5 = SimpleASPP(self.embed_dims[4], out_dim=mid_dim)
+ self.siu_5 = MHSIU(mid_dim, siu_groups)
+ self.hmu_5 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.tra_4 = ConvBNReLU(self.embed_dims[3], mid_dim, 3, 1, 1)
+ self.siu_4 = MHSIU(mid_dim, siu_groups)
+ self.hmu_4 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.tra_3 = ConvBNReLU(self.embed_dims[2], mid_dim, 3, 1, 1)
+ self.siu_3 = MHSIU(mid_dim, siu_groups)
+ self.hmu_3 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.tra_2 = ConvBNReLU(self.embed_dims[1], mid_dim, 3, 1, 1)
+ self.siu_2 = MHSIU(mid_dim, siu_groups)
+ self.hmu_2 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.tra_1 = ConvBNReLU(self.embed_dims[0], mid_dim, 3, 1, 1)
+ self.siu_1 = MHSIU(mid_dim, siu_groups)
+ self.hmu_1 = RGPU(mid_dim, hmu_groups, num_frames=num_frames)
+
+ self.normalizer = PixelNormalizer() if input_norm else nn.Identity()
+ self.predictor = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
+ ConvBNReLU(64, 32, 3, 1, 1),
+ nn.Conv2d(32, 1, 1),
+ )
+
+ def set_backbone(self, pretrained):
+ self.encoder = EfficientNet.from_pretrained("efficientnet-b1", pretrained=pretrained)
+ self.embed_dims = [16, 24, 40, 112, 320]
+
+ def normalize_encoder(self, x):
+ x = self.normalizer(x)
+ features = self.encoder.extract_endpoints(x)
+ c1 = features["reduction_1"]
+ c2 = features["reduction_2"]
+ c3 = features["reduction_3"]
+ c4 = features["reduction_4"]
+ c5 = features["reduction_5"]
+ return c1, c2, c3, c4, c5
+
+ def body(self, data):
+ l_trans_feats = self.normalize_encoder(data["image_l"])
+ m_trans_feats = self.normalize_encoder(data["image_m"])
+ s_trans_feats = self.normalize_encoder(data["image_s"])
+
+ l, m, s = self.tra_5(l_trans_feats[4]), self.tra_5(m_trans_feats[4]), self.tra_5(s_trans_feats[4])
+ lms = self.siu_5(l=l, m=m, s=s)
+ x = self.hmu_5(lms)
+
+ l, m, s = self.tra_4(l_trans_feats[3]), self.tra_4(m_trans_feats[3]), self.tra_4(s_trans_feats[3])
+ lms = self.siu_4(l=l, m=m, s=s)
+ x = self.hmu_4(lms + resize_to(x, tgt_hw=lms.shape[-2:]))
+
+ l, m, s = self.tra_3(l_trans_feats[2]), self.tra_3(m_trans_feats[2]), self.tra_3(s_trans_feats[2])
+ lms = self.siu_3(l=l, m=m, s=s)
+ x = self.hmu_3(lms + resize_to(x, tgt_hw=lms.shape[-2:]))
+
+ l, m, s = self.tra_2(l_trans_feats[1]), self.tra_2(m_trans_feats[1]), self.tra_2(s_trans_feats[1])
+ lms = self.siu_2(l=l, m=m, s=s)
+ x = self.hmu_2(lms + resize_to(x, tgt_hw=lms.shape[-2:]))
+
+ l, m, s = self.tra_1(l_trans_feats[0]), self.tra_1(m_trans_feats[0]), self.tra_1(s_trans_feats[0])
+ lms = self.siu_1(l=l, m=m, s=s)
+ x = self.hmu_1(lms + resize_to(x, tgt_hw=lms.shape[-2:]))
+
+ return self.predictor(x)
+
+
+class EffB4_ZoomNeXt(EffB1_ZoomNeXt):
+ def set_backbone(self, pretrained):
+ self.encoder = EfficientNet.from_pretrained("efficientnet-b4", pretrained=pretrained)
+ self.embed_dims = [24, 32, 56, 160, 448]
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..536c539
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,72 @@
+# https://github.com/LongTengDao/TOML/
+
+[tool.isort]
+# https://pycqa.github.io/isort/docs/configuration/options/
+profile = "black"
+multi_line_output = 3
+filter_files = true
+supported_extensions = "py"
+
+[tool.black]
+line-length = 119
+include = '\.pyi?$'
+exclude = '''
+/(
+ \.eggs
+ | \.git
+ | \.idea
+ | \.vscode
+ | \.hg
+ | \.mypy_cache
+ | \.tox
+ | \.venv
+ | _build
+ | buck-out
+ | build
+ | dist
+ | output
+)/
+'''
+
+[tool.ruff]
+# Same as Black.
+line-length = 119
+indent-width = 4
+# Exclude a variety of commonly ignored directories.
+exclude = [
+ ".bzr",
+ ".direnv",
+ ".eggs",
+ ".git",
+ ".git-rewrite",
+ ".hg",
+ ".ipynb_checkpoints",
+ ".mypy_cache",
+ ".nox",
+ ".pants.d",
+ ".pyenv",
+ ".pytest_cache",
+ ".pytype",
+ ".ruff_cache",
+ ".svn",
+ ".tox",
+ ".venv",
+ ".vscode",
+ "__pypackages__",
+ "_build",
+ "buck-out",
+ "build",
+ "dist",
+ "node_modules",
+ "site-packages",
+ "venv",
+]
+[tool.ruff.format]
+# Like Black, use double quotes for strings.
+quote-style = "double"
+# Like Black, indent with spaces, rather than tabs.
+indent-style = "space"
+# Like Black, respect magic trailing commas.
+skip-magic-trailing-comma = false
+# Like Black, automatically detect the appropriate line ending.
+line-ending = "auto"
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..700894b
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,15 @@
+# Automatically generated by https://github.com/damnever/pigar.
+
+adjustText==0.8
+albumentations==1.3.1
+colorlog==6.8.0
+einops==0.7.0
+matplotlib==3.8.2
+mmengine==0.10.2
+numpy==1.26.2
+opencv-python-headless==4.8.1.78
+pysodmetrics==1.4.2
+PyYAML==6.0
+scipy==1.11.4
+timm==0.9.12
+tqdm==4.66.1
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/utils/io/__init__.py b/utils/io/__init__.py
new file mode 100644
index 0000000..7b33111
--- /dev/null
+++ b/utils/io/__init__.py
@@ -0,0 +1,7 @@
+# -*- coding: utf-8 -*-
+# @Time : 2021/5/17
+# @Author : Lart Pang
+# @GitHub : https://github.com/lartpang
+
+from .image import read_color_array, read_gray_array
+from .params import load_weight, save_weight
diff --git a/utils/io/image.py b/utils/io/image.py
new file mode 100644
index 0000000..2332ab4
--- /dev/null
+++ b/utils/io/image.py
@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+# @Time : 2021/5/17
+# @Author : Lart Pang
+# @GitHub : https://github.com/lartpang
+import cv2
+import numpy as np
+
+from utils.ops import minmax
+
+
+def read_gray_array(path, div_255=False, to_normalize=False, thr=-1, dtype=np.float32) -> np.ndarray:
+ """
+ 1. read the binary image with the suffix `.jpg` or `.png`
+ into a grayscale ndarray
+ 2. (to_normalize=True) rescale the ndarray to [0, 1]
+ 3. (thr >= 0) binarize the ndarray with `thr`
+ 4. return a gray ndarray (np.float32)
+ """
+ assert path.endswith(".jpg") or path.endswith(".png"), path
+ assert not div_255 or not to_normalize, path
+ gray_array = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
+ assert gray_array is not None, f"Image Not Found: {path}"
+
+ if div_255:
+ gray_array = gray_array / 255
+
+ if to_normalize:
+ gray_array = minmax(gray_array, up_bound=255)
+
+ if thr >= 0:
+ gray_array = gray_array > thr
+
+ return gray_array.astype(dtype)
+
+
+def read_color_array(path: str):
+ assert path.endswith(".jpg") or path.endswith(".png")
+ bgr_array = cv2.imread(path, cv2.IMREAD_COLOR)
+ assert bgr_array is not None, f"Image Not Found: {path}"
+ rgb_array = cv2.cvtColor(bgr_array, cv2.COLOR_BGR2RGB)
+ return rgb_array
diff --git a/utils/io/params.py b/utils/io/params.py
new file mode 100644
index 0000000..65aabc1
--- /dev/null
+++ b/utils/io/params.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# @Time : 2020/12/19
+# @Author : Lart Pang
+# @GitHub : https://github.com/lartpang
+
+import os
+
+import torch
+
+
+def save_weight(save_path, model):
+ print(f"Saving weight '{save_path}'")
+ if isinstance(model, dict):
+ model_state = model
+ else:
+ model_state = model.module.state_dict() if hasattr(model, "module") else model.state_dict()
+ torch.save(model_state, save_path)
+ print(f"Saved weight '{save_path}' " f"(only contain the net's weight)")
+
+
+def load_weight(load_path, model, *, strict=True, skip_unmatched_shape=False):
+ assert os.path.exists(load_path), load_path
+ model_params = model.state_dict()
+ for k, v in torch.load(load_path, map_location="cpu").items():
+ if k.endswith("module."):
+ k = k[7:]
+ if skip_unmatched_shape and v.shape != model_params[k].shape:
+ continue
+ model_params[k] = v
+ model.load_state_dict(model_params, strict=strict)
diff --git a/utils/ops/__init__.py b/utils/ops/__init__.py
new file mode 100644
index 0000000..491a5a6
--- /dev/null
+++ b/utils/ops/__init__.py
@@ -0,0 +1,8 @@
+# -*- coding: utf-8 -*-
+# @Time : 2020/12/19
+# @Author : Lart Pang
+# @GitHub : https://github.com/lartpang
+
+
+from .array_ops import *
+from .tensor_ops import *
diff --git a/utils/ops/array_ops.py b/utils/ops/array_ops.py
new file mode 100644
index 0000000..1ffc1db
--- /dev/null
+++ b/utils/ops/array_ops.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+import os
+
+import cv2
+import numpy as np
+
+
+def minmax(data_array: np.ndarray, up_bound: float = None) -> np.ndarray:
+ """
+ ::
+
+ data_array = (data_array / up_bound)
+ if min_value != max_value:
+ data_array = (data_array - min_value) / (max_value - min_value)
+
+ :param data_array:
+ :param up_bound: if is not None, data_array will devided by it before the minmax ops.
+ :return:
+ """
+ if up_bound is not None:
+ data_array = data_array / up_bound
+ max_value = data_array.max()
+ min_value = data_array.min()
+ if max_value != min_value:
+ data_array = (data_array - min_value) / (max_value - min_value)
+ return data_array
+
+
+def clip_to_normalize(data_array: np.ndarray, clip_range: tuple = None) -> np.ndarray:
+ clip_range = sorted(clip_range)
+ if len(clip_range) == 3:
+ clip_min, clip_mid, clip_max = clip_range
+ assert 0 <= clip_min < clip_mid < clip_max <= 1, clip_range
+ lower_array = data_array[data_array < clip_mid]
+ higher_array = data_array[data_array > clip_mid]
+ if lower_array.size > 0:
+ lower_array = np.clip(lower_array, a_min=clip_min, a_max=1)
+ max_lower = lower_array.max()
+ lower_array = minmax(lower_array) * max_lower
+ data_array[data_array < clip_mid] = lower_array
+ if higher_array.size > 0:
+ higher_array = np.clip(higher_array, a_min=0, a_max=clip_max)
+ min_lower = higher_array.min()
+ higher_array = minmax(higher_array) * (1 - min_lower) + min_lower
+ data_array[data_array > clip_mid] = higher_array
+ elif len(clip_range) == 2:
+ clip_min, clip_max = clip_range
+ assert 0 <= clip_min < clip_max <= 1, clip_range
+ if clip_min != 0 and clip_max != 1:
+ data_array = np.clip(data_array, a_min=clip_min, a_max=clip_max)
+ data_array = minmax(data_array)
+ elif clip_range is None:
+ data_array = minmax(data_array)
+ else:
+ raise NotImplementedError
+ return data_array
+
+
+def save_array_as_image(data_array: np.ndarray, save_name: str, save_dir: str, to_minmax: bool = False):
+ """
+ save the ndarray as a image
+
+ Args:
+ data_array: np.float32 the max value is less than or equal to 1
+ save_name: with special suffix
+ save_dir: the dirname of the image path
+ to_minmax: minmax the array
+ """
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ save_path = os.path.join(save_dir, save_name)
+ if data_array.dtype != np.uint8:
+ if data_array.max() > 1:
+ raise Exception("the range of data_array has smoe errors")
+ data_array = (data_array * 255).astype(np.uint8)
+ if to_minmax:
+ data_array = minmax(data_array, up_bound=255)
+ data_array = (data_array * 255).astype(np.uint8)
+ cv2.imwrite(save_path, data_array)
+
+
+def resize(image_array: np.ndarray, height, width, interpolation=cv2.INTER_LINEAR):
+ h, w = image_array.shape[:2]
+ if h == height and w == width:
+ return image_array
+
+ resized_image_array = cv2.resize(image_array, dsize=(width, height), interpolation=interpolation)
+ return resized_image_array
+
+
+def ms_resize(img, scales, base_h=None, base_w=None, interpolation=cv2.INTER_LINEAR):
+ assert isinstance(scales, (list, tuple))
+ if base_h is None:
+ base_h = img.shape[0]
+ if base_w is None:
+ base_w = img.shape[1]
+ return [resize(img, height=int(base_h * s), width=int(base_w * s), interpolation=interpolation) for s in scales]
diff --git a/utils/ops/tensor_ops.py b/utils/ops/tensor_ops.py
new file mode 100644
index 0000000..18aed62
--- /dev/null
+++ b/utils/ops/tensor_ops.py
@@ -0,0 +1,31 @@
+# -*- coding: utf-8 -*-
+# @Time : 2020
+# @Author : Lart Pang
+# @GitHub : https://github.com/lartpang
+import torch
+import torch.nn.functional as F
+
+
+def rescale_2x(x: torch.Tensor, scale_factor=2):
+ return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
+
+
+def resize_to(x: torch.Tensor, tgt_hw: tuple):
+ return F.interpolate(x, size=tgt_hw, mode="bilinear", align_corners=False)
+
+
+def clip_grad(params, mode, clip_cfg: dict):
+ if mode == "norm":
+ if "max_norm" not in clip_cfg:
+ raise ValueError("`clip_cfg` must contain `max_norm`.")
+ torch.nn.utils.clip_grad_norm_(
+ params,
+ max_norm=clip_cfg.get("max_norm"),
+ norm_type=clip_cfg.get("norm_type", 2.0),
+ )
+ elif mode == "value":
+ if "clip_value" not in clip_cfg:
+ raise ValueError("`clip_cfg` must contain `clip_value`.")
+ torch.nn.utils.clip_grad_value_(params, clip_value=clip_cfg.get("clip_value"))
+ else:
+ raise NotImplementedError
diff --git a/utils/pipeline/__init__.py b/utils/pipeline/__init__.py
new file mode 100644
index 0000000..f6eb4bb
--- /dev/null
+++ b/utils/pipeline/__init__.py
@@ -0,0 +1,8 @@
+# -*- coding: utf-8 -*-
+# @Time : 2021/5/31
+# @Author : Lart Pang
+# @GitHub : https://github.com/lartpang
+
+from .optimizer import construct_optimizer
+from .scaler import Scaler
+from .scheduler import Scheduler
diff --git a/utils/pipeline/optimizer.py b/utils/pipeline/optimizer.py
new file mode 100644
index 0000000..3fe96fa
--- /dev/null
+++ b/utils/pipeline/optimizer.py
@@ -0,0 +1,169 @@
+# -*- coding: utf-8 -*-
+# @Time : 2020/12/19
+# @Author : Lart Pang
+# @GitHub : https://github.com/lartpang
+import types
+
+import torchvision.models
+from torch import nn
+from torch.optim import SGD, Adam, AdamW
+
+
+def get_optimizer(mode, params, initial_lr, optim_cfg):
+ if mode == "sgd":
+ optimizer = SGD(
+ params=params,
+ lr=initial_lr,
+ momentum=optim_cfg["momentum"],
+ weight_decay=optim_cfg["weight_decay"],
+ nesterov=optim_cfg.get("nesterov", False),
+ )
+ elif mode == "adamw":
+ optimizer = AdamW(
+ params=params,
+ lr=initial_lr,
+ betas=optim_cfg.get("betas", (0.9, 0.999)),
+ weight_decay=optim_cfg.get("weight_decay", 0),
+ amsgrad=optim_cfg.get("amsgrad", False),
+ )
+ elif mode == "adam":
+ optimizer = Adam(
+ params=params,
+ lr=initial_lr,
+ betas=optim_cfg.get("betas", (0.9, 0.999)),
+ weight_decay=optim_cfg.get("weight_decay", 0),
+ amsgrad=optim_cfg.get("amsgrad", False),
+ )
+ else:
+ raise NotImplementedError(mode)
+ return optimizer
+
+
+def group_params(model: nn.Module, group_mode: str, initial_lr: float, optim_cfg: dict):
+ if group_mode == "yolov5":
+ """
+ norm, weight, bias = [], [], [] # optimizer parameter groups
+ for k, v in model.named_modules():
+ if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
+ bias.append(v.bias) # biases
+ if isinstance(v, nn.BatchNorm2d):
+ norm.append(v.weight) # no decay
+ elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
+ weight.append(v.weight) # apply decay
+
+ if opt.adam:
+ optimizer = optim.Adam(norm, lr=hyp["lr0"], betas=(hyp["momentum"], 0.999)) # adjust beta1 to momentum
+ else:
+ optimizer = optim.SGD(norm, lr=hyp["lr0"], momentum=hyp["momentum"], nesterov=True)
+
+ optimizer.add_param_group({"params": weight, "weight_decay": hyp["weight_decay"]}) # add weight with weight_decay
+ optimizer.add_param_group({"params": bias}) # add bias (biases)
+ """
+ norm, weight, bias = [], [], [] # optimizer parameter groups
+ for k, v in model.named_modules():
+ if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
+ bias.append(v.bias) # conv bias and bn bias
+ if isinstance(v, nn.BatchNorm2d):
+ norm.append(v.weight) # bn weight
+ elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
+ weight.append(v.weight) # conv weight
+ params = [
+ {"params": filter(lambda p: p.requires_grad, bias), "weight_decay": 0.0},
+ {"params": filter(lambda p: p.requires_grad, norm), "weight_decay": 0.0},
+ {"params": filter(lambda p: p.requires_grad, weight)},
+ ]
+ elif group_mode == "r3":
+ params = [
+ # 不对bias参数执行weight decay操作,weight decay主要的作用就是通过对网络
+ # 层的参数(包括weight和bias)做约束(L2正则化会使得网络层的参数更加平滑)达
+ # 到减少模型过拟合的效果。
+ {
+ "params": [
+ param for name, param in model.named_parameters() if name[-4:] == "bias" and param.requires_grad
+ ],
+ "lr": 2 * initial_lr,
+ "weight_decay": 0,
+ },
+ {
+ "params": [
+ param for name, param in model.named_parameters() if name[-4:] != "bias" and param.requires_grad
+ ],
+ "lr": initial_lr,
+ "weight_decay": optim_cfg["weight_decay"],
+ },
+ ]
+ elif group_mode == "all":
+ params = model.parameters()
+ elif group_mode == "finetune":
+ if hasattr(model, "module"):
+ model = model.module
+ assert hasattr(model, "get_grouped_params"), "Cannot get the method get_grouped_params of the model."
+ params_groups = model.get_grouped_params()
+ params = [
+ {
+ "params": filter(lambda p: p.requires_grad, params_groups["pretrained"]),
+ "lr": optim_cfg.get("diff_factor", 0.1) * initial_lr,
+ },
+ {
+ "params": filter(lambda p: p.requires_grad, params_groups["retrained"]),
+ "lr": initial_lr,
+ },
+ ]
+ elif group_mode == "finetune2":
+ if hasattr(model, "module"):
+ model = model.module
+ assert hasattr(model, "get_grouped_params"), "Cannot get the method get_grouped_params of the model."
+ params_groups = model.get_grouped_params()
+ params = [
+ {
+ "params": filter(lambda p: p.requires_grad, params_groups["pretrained_backbone"]),
+ "lr": 0.1 * initial_lr,
+ },
+ {
+ "params": filter(lambda p: p.requires_grad, params_groups["pretrained_siamese"]),
+ "lr": 0.5 * initial_lr,
+ },
+ {
+ "params": filter(lambda p: p.requires_grad, params_groups["retrained"]),
+ "lr": initial_lr,
+ },
+ ]
+ else:
+ raise NotImplementedError
+ return params
+
+
+def construct_optimizer(model, initial_lr, mode, group_mode, cfg):
+ params = group_params(model, group_mode=group_mode, initial_lr=initial_lr, optim_cfg=cfg)
+ optimizer = get_optimizer(mode=mode, params=params, initial_lr=initial_lr, optim_cfg=cfg)
+ optimizer.lr_groups = types.MethodType(get_lr_groups, optimizer)
+ optimizer.lr_string = types.MethodType(get_lr_strings, optimizer)
+ return optimizer
+
+
+def get_lr_groups(self):
+ return [group["lr"] for group in self.param_groups]
+
+
+def get_lr_strings(self):
+ return ",".join([f"{group['lr']:.3e}" for group in self.param_groups])
+
+
+if __name__ == "__main__":
+ model = torchvision.models.vgg11_bn()
+ norm, weight, bias = [], [], [] # optimizer parameter groups
+ for k, v in model.named_modules():
+ if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
+ bias.append(v.bias) # biases
+ if isinstance(v, nn.BatchNorm2d):
+ norm.append(v.weight) # no decay
+ elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
+ weight.append(v.weight) # apply decay
+
+ optimizer = Adam(norm, lr=0.001, betas=(0.98, 0.999)) # adjust beta1 to momentum
+ # optimizer = optim.SGD(norm, lr=hyp["lr0"], momentum=hyp["momentum"], nesterov=True)
+
+ optimizer.add_param_group({"params": weight, "weight_decay": 1e-4}) # add weight with weight_decay
+ optimizer.add_param_group({"params": bias}) # add bias (biases)
+
+ print(optimizer)
diff --git a/utils/pipeline/scaler.py b/utils/pipeline/scaler.py
new file mode 100644
index 0000000..9a27cde
--- /dev/null
+++ b/utils/pipeline/scaler.py
@@ -0,0 +1,59 @@
+from functools import partial
+from itertools import chain
+
+from torch.cuda.amp import GradScaler, autocast
+
+from .. import ops
+
+
+class Scaler:
+ def __init__(
+ self, optimizer, use_fp16=False, *, set_to_none=False, clip_grad=False, clip_mode=None, clip_cfg=None
+ ) -> None:
+ self.optimizer = optimizer
+ self.set_to_none = set_to_none
+ self.autocast = autocast(enabled=use_fp16)
+ self.scaler = GradScaler(enabled=use_fp16)
+
+ if clip_grad:
+ self.grad_clip_ops = partial(ops.clip_grad, mode=clip_mode, clip_cfg=clip_cfg)
+ else:
+ self.grad_clip_ops = None
+
+ def calculate_grad(self, loss):
+ self.scaler.scale(loss).backward()
+ if self.grad_clip_ops is not None:
+ self.scaler.unscale_(self.optimizer)
+ self.grad_clip_ops(chain(*[group["params"] for group in self.optimizer.param_groups]))
+
+ def update_grad(self):
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ self.optimizer.zero_grad(set_to_none=self.set_to_none)
+
+ def state_dict(self):
+ r"""
+ Returns the state of the scaler as a :class:`dict`. It contains five entries:
+
+ * ``"scale"`` - a Python float containing the current scale
+ * ``"growth_factor"`` - a Python float containing the current growth factor
+ * ``"backoff_factor"`` - a Python float containing the current backoff factor
+ * ``"growth_interval"`` - a Python int containing the current growth interval
+ * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
+
+ If this instance is not enabled, returns an empty dict.
+
+ .. note::
+ If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
+ should be called after :meth:`update`.
+ """
+ return self.scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ r"""
+ Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
+
+ Args:
+ state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
+ """
+ self.scaler.load_state_dict(state_dict)
diff --git a/utils/pipeline/scheduler.py b/utils/pipeline/scheduler.py
new file mode 100644
index 0000000..663ece1
--- /dev/null
+++ b/utils/pipeline/scheduler.py
@@ -0,0 +1,410 @@
+# -*- coding: utf-8 -*-
+# @Time : 2020/12/19
+# @Author : Lart Pang
+# @GitHub : https://github.com/lartpang
+
+import copy
+import math
+import os.path
+import warnings
+from bisect import bisect_right
+
+import matplotlib
+import numpy as np
+import torch.optim
+from adjustText import adjust_text
+
+matplotlib.use("Agg")
+from matplotlib import pyplot as plt
+
+# helper function ----------------------------------------------------------------------
+
+
+def linear_increase(low_bound, up_bound, percentage):
+ """low_bound + [0, 1] * (up_bound - low_bound)"""
+ assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]"
+ return low_bound + (up_bound - low_bound) * percentage
+
+
+def cos_anneal(low_bound, up_bound, percentage):
+ assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]"
+ cos_percentage = (1 + math.cos(math.pi * percentage)) / 2.0
+ return linear_increase(low_bound, up_bound, percentage=cos_percentage)
+
+
+def poly_anneal(low_bound, up_bound, percentage, lr_decay):
+ assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]"
+ poly_percentage = pow((1 - percentage), lr_decay)
+ return linear_increase(low_bound, up_bound, percentage=poly_percentage)
+
+
+def linear_anneal(low_bound, up_bound, percentage):
+ assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]"
+ return linear_increase(low_bound, up_bound, percentage=1 - percentage)
+
+
+# coefficient function ----------------------------------------------------------------------
+
+
+def get_f3_coef_func(num_iters):
+ """
+ F3Net
+
+ :param num_iters: The number of iterations for the total process.
+ :return:
+ """
+
+ def get_f3_coef(curr_idx):
+ assert 0 <= curr_idx <= num_iters
+ return 1 - abs((curr_idx + 1) / (num_iters + 1) * 2 - 1)
+
+ return get_f3_coef
+
+
+def get_step_coef_func(gamma, milestones):
+ """
+ lr = baselr * gamma ** 0 if curr_idx < milestones[0]
+ lr = baselr * gamma ** 1 if milestones[0] <= epoch < milestones[1]
+ ...
+
+ :param gamma:
+ :param milestones:
+ :return: The function for generating the coefficient.
+ """
+ if isinstance(milestones, (tuple, list)):
+ milestones = list(sorted(milestones))
+ return lambda curr_idx: gamma ** bisect_right(milestones, curr_idx)
+ elif isinstance(milestones, int):
+ return lambda curr_idx: gamma ** ((curr_idx + 1) // milestones)
+ else:
+ raise ValueError(f"milestones only can be list/tuple/int, but now it is {type(milestones)}")
+
+
+def get_cos_coef_func(half_cycle, min_coef, max_coef=1):
+ """
+ :param half_cycle: The number of iterations in a half cycle.
+ :param min_coef: The minimum coefficient of the learning rate.
+ :param max_coef: The maximum coefficient of the learning rate.
+ :return: The function for generating the coefficient.
+ """
+
+ def get_cos_coef(curr_idx):
+ recomputed_idx = curr_idx % (half_cycle + 1)
+ # recomputed \in [0, half_cycle]
+ return cos_anneal(low_bound=min_coef, up_bound=max_coef, percentage=recomputed_idx / half_cycle)
+
+ return get_cos_coef
+
+
+def get_fatcos_coef_func(start_iter, half_cycle, min_coef, max_coef=1):
+ """
+ :param half_cycle: The number of iterations in a half cycle.
+ :param min_coef: The minimum coefficient of the learning rate.
+ :param max_coef: The maximum coefficient of the learning rate.
+ :return: The function for generating the coefficient.
+ """
+
+ def get_cos_coef(curr_idx):
+ curr_idx = max(0, curr_idx - start_iter)
+ recomputed_idx = curr_idx % (half_cycle + 1)
+ # recomputed \in [0, half_cycle]
+ return cos_anneal(low_bound=min_coef, up_bound=max_coef, percentage=recomputed_idx / half_cycle)
+
+ return get_cos_coef
+
+
+def get_poly_coef_func(num_iters, lr_decay, min_coef, max_coef=1):
+ """
+ :param num_iters: The number of iterations for the polynomial descent process.
+ :param lr_decay: The decay item of the polynomial descent process.
+ :param min_coef: The minimum coefficient of the learning rate.
+ :param max_coef: The maximum coefficient of the learning rate.
+ :return: The function for generating the coefficient.
+ """
+
+ def get_poly_coef(curr_idx):
+ assert 0 <= curr_idx <= num_iters, (curr_idx, num_iters)
+ return poly_anneal(low_bound=min_coef, up_bound=max_coef, percentage=curr_idx / num_iters, lr_decay=lr_decay)
+
+ return get_poly_coef
+
+
+# coefficient entry function ----------------------------------------------------------------------
+
+
+def get_scheduler_coef_func(mode, num_iters, cfg):
+ """
+ the region is a closed interval: [0, num_iters]
+ """
+ assert num_iters > 0
+ min_coef = cfg.get("min_coef", 1e-6)
+ if min_coef is None or min_coef == 0:
+ warnings.warn(f"The min_coef ({min_coef}) of the scheduler will be replaced with 1e-6")
+ min_coef = 1e-6
+
+ if mode == "step":
+ coef_func = get_step_coef_func(gamma=cfg["gamma"], milestones=cfg["milestones"])
+ elif mode == "cos":
+ if half_cycle := cfg.get("half_cycle"):
+ half_cycle -= 1
+ else:
+ half_cycle = num_iters
+ if (num_iters - half_cycle) % (half_cycle + 1) != 0:
+ # idx starts from 0
+ percentage = ((num_iters - half_cycle) % (half_cycle + 1)) / (half_cycle + 1) * 100
+ warnings.warn(
+ f"The final annealing process ({percentage:.3f}%) is not complete. "
+ f"Please pay attention to the generated 'lr_coef_curve.png'."
+ )
+ coef_func = get_cos_coef_func(half_cycle=half_cycle, min_coef=min_coef)
+ elif mode == "fatcos":
+ assert 0 <= cfg.start_percent < 1, cfg.start_percent
+ start_iter = int(cfg.start_percent * num_iters)
+
+ num_iters -= start_iter
+ if half_cycle := cfg.get("half_cycle"):
+ half_cycle -= 1
+ else:
+ half_cycle = num_iters
+ if (num_iters - half_cycle) % (half_cycle + 1) != 0:
+ # idx starts from 0
+ percentage = ((num_iters - half_cycle) % (half_cycle + 1)) / (half_cycle + 1) * 100
+ warnings.warn(
+ f"The final annealing process ({percentage:.3f}%) is not complete. "
+ f"Please pay attention to the generated 'lr_coef_curve.png'."
+ )
+ coef_func = get_fatcos_coef_func(start_iter=start_iter, half_cycle=half_cycle, min_coef=min_coef)
+ elif mode == "poly":
+ coef_func = get_poly_coef_func(num_iters=num_iters, lr_decay=cfg["lr_decay"], min_coef=min_coef)
+ elif mode == "constant":
+ coef_func = lambda x: cfg.get("coef", 1)
+ elif mode == "f3":
+ coef_func = get_f3_coef_func(num_iters=num_iters)
+ else:
+ raise NotImplementedError(f"{mode} must be in {Scheduler.supported_scheduler_modes}")
+ return coef_func
+
+
+def get_warmup_coef_func(num_iters, min_coef, max_coef=1, mode="linear"):
+ """
+ the region is a closed interval: [0, num_iters]
+ """
+ assert num_iters > 0
+ if mode == "cos":
+ anneal_func = cos_anneal
+ elif mode == "linear":
+ anneal_func = linear_anneal
+ else:
+ raise NotImplementedError(f"{mode} must be in {Scheduler.supported_warmup_modes}")
+
+ def get_warmup_coef(curr_idx):
+ return anneal_func(low_bound=min_coef, up_bound=max_coef, percentage=1 - curr_idx / num_iters)
+
+ return get_warmup_coef
+
+
+# main class ----------------------------------------------------------------------
+
+
+class Scheduler:
+ supported_scheduler_modes = ("step", "cos", "fatcos", "poly", "constant", "f3")
+ supported_warmup_modes = ("cos", "linear")
+
+ def __init__(self, optimizer, num_iters, epoch_length, scheduler_cfg, step_by_batch=True):
+ """A customized wrapper of the scheduler.
+
+ Args:
+ optimizer (): Optimizer.
+ num_iters (int): The total number of the iterations.
+ epoch_length (int): The number of the iterations of one epoch.
+ scheduler_cfg (dict): The config of the scheduler.
+ step_by_batch (bool, optional): The mode of updating the scheduler. Defaults to True.
+
+ Raises:
+ NotImplementedError:
+ """
+ self.optimizer = optimizer
+ self.num_iters = num_iters
+ self.epoch_length = epoch_length
+ self.step_by_batch = step_by_batch
+
+ self.scheduler_cfg = copy.deepcopy(scheduler_cfg)
+ self.mode = scheduler_cfg["mode"]
+ if self.mode not in self.supported_scheduler_modes:
+ raise NotImplementedError(
+ f"{self.mode} is not implemented. Has been supported: {self.supported_scheduler_modes}"
+ )
+ warmup_cfg = scheduler_cfg.get("warmup", None)
+
+ num_warmup_iters = 0
+ if warmup_cfg is not None and isinstance(warmup_cfg, dict):
+ num_warmup_iters = warmup_cfg["num_iters"]
+ if num_warmup_iters > 0:
+ print("Will using warmup")
+ self.warmup_coef_func = get_warmup_coef_func(
+ num_warmup_iters,
+ min_coef=warmup_cfg.get("initial_coef", 0.01),
+ mode=warmup_cfg.get("mode", "linear"),
+ )
+ self.num_warmup_iters = num_warmup_iters
+
+ if step_by_batch:
+ num_scheduler_iters = num_iters - num_warmup_iters
+ else:
+ num_scheduler_iters = (num_iters - num_warmup_iters) // epoch_length
+ # the region is a closed interval
+ self.lr_coef_func = get_scheduler_coef_func(
+ mode=self.mode, num_iters=num_scheduler_iters - 1, cfg=scheduler_cfg["cfg"]
+ )
+ self.num_scheduler_iters = num_scheduler_iters
+
+ self.last_lr_coef = 0
+ self.initial_lrs = None
+
+ def __repr__(self):
+ formatted_string = [
+ f"{self.__class__.__name__}: (\n",
+ f"num_iters: {self.num_iters}\n",
+ f"epoch_length: {self.epoch_length}\n",
+ f"warmup_iter: [0, {self.num_warmup_iters})\n",
+ f"scheduler_iter: [{self.num_warmup_iters}, {self.num_iters - 1}]\n",
+ f"mode: {self.mode}\n",
+ f"scheduler_cfg: {self.scheduler_cfg}\n",
+ f"initial_lrs: {self.initial_lrs}\n",
+ f"step_by_batch: {self.step_by_batch}\n)",
+ ]
+ return " ".join(formatted_string)
+
+ def record_lrs(self, param_groups):
+ self.initial_lrs = [g["lr"] for g in param_groups]
+
+ def update(self, coef: float):
+ assert self.initial_lrs is not None, "Please run .record_lrs(optimizer) first."
+ for curr_group, initial_lr in zip(self.optimizer.param_groups, self.initial_lrs):
+ curr_group["lr"] = coef * initial_lr
+
+ def step(self, curr_idx):
+ if curr_idx < self.num_warmup_iters:
+ # get maximum value (1.0) when curr_idx == self.num_warmup_iters
+ self.update(coef=self.get_lr_coef(curr_idx))
+ else:
+ # Start from a value lower than 1 (curr_idx == self.num_warmup_iters)
+ if self.step_by_batch:
+ self.update(coef=self.get_lr_coef(curr_idx))
+ else:
+ if curr_idx % self.epoch_length == 0:
+ self.update(coef=self.get_lr_coef(curr_idx))
+
+ def get_lr_coef(self, curr_idx):
+ coef = None
+ if curr_idx < self.num_warmup_iters:
+ coef = self.warmup_coef_func(curr_idx)
+ else:
+ # when curr_idx == self.num_warmup_iters, coef == 1.0
+ # down from the largest coef (1.0)
+ if self.step_by_batch:
+ coef = self.lr_coef_func(curr_idx - self.num_warmup_iters)
+ else:
+ if curr_idx % self.epoch_length == 0 or curr_idx == self.num_warmup_iters:
+ # warmup结束后尚未开始按照epoch进行调整的学习率调整,此时需要将系数调整为最大。
+ coef = self.lr_coef_func((curr_idx - self.num_warmup_iters) // self.epoch_length)
+ if coef is not None:
+ self.last_lr_coef = coef
+ return self.last_lr_coef
+
+ def plot_lr_coef_curve(self, save_path=""):
+ plt.rc("xtick", labelsize="small")
+ plt.rc("ytick", labelsize="small")
+
+ fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(8, 4), dpi=600)
+ # give plot a title
+ ax.set_title("Learning Rate Coefficient Curve")
+ # make axis labels
+ ax.set_xlabel("Iteration")
+ ax.set_ylabel("Coefficient")
+
+ x_data = np.arange(self.num_iters)
+ y_data = np.array([self.get_lr_coef(x) for x in x_data])
+
+ # set lim
+ x_min, x_max = 0, self.num_iters - 1
+ dx = self.num_iters * 0.1
+ ax.set_xlim(x_min - dx, x_max + 2 * dx)
+
+ y_min, y_max = y_data.min(), y_data.max()
+ dy = (y_data.max() - y_data.min()) * 0.1
+ ax.set_ylim((y_min - dy, y_max + dy))
+
+ if self.step_by_batch:
+ marker_on = [0, -1]
+ key_point_xs = [0, self.num_iters - 1]
+ for idx in range(1, len(y_data) - 1):
+ prev_y = y_data[idx - 1]
+ curr_y = y_data[idx]
+ next_y = y_data[idx + 1]
+ if (
+ (curr_y > prev_y and curr_y >= next_y)
+ or (curr_y >= prev_y and curr_y > next_y)
+ or (curr_y <= prev_y and curr_y < next_y)
+ or (curr_y < prev_y and curr_y <= next_y)
+ ):
+ marker_on.append(idx)
+ key_point_xs.append(idx)
+
+ marker_on = sorted(set(marker_on))
+ key_point_xs = sorted(set(key_point_xs))
+ key_point_ys = []
+
+ texts = []
+ for x in key_point_xs:
+ y = y_data[x]
+ key_point_ys.append(y)
+
+ texts.append(ax.text(x=x, y=y, s=f"({x:d},{y:.3e})"))
+ adjust_text(texts, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.3"))
+
+ # set ticks
+ ax.set_xticks(key_point_xs)
+ # ax.set_yticks(key_point_ys)
+
+ ax.plot(x_data, y_data, marker="o", markevery=marker_on)
+ else:
+ ax.plot(x_data, y_data)
+
+ ax.spines["right"].set_visible(False)
+ ax.spines["top"].set_visible(False)
+ ax.spines["left"].set_visible(True)
+ ax.spines["bottom"].set_visible(True)
+
+ plt.tight_layout()
+ if save_path:
+ fig.savefig(os.path.join(save_path, "lr_coef.png"))
+ plt.close()
+
+
+if __name__ == "__main__":
+ model = torch.nn.Conv2d(10, 10, 3, 1, 1)
+ sche = Scheduler(
+ optimizer=torch.optim.SGD(model.parameters(), lr=0.1),
+ num_iters=30300,
+ epoch_length=505,
+ scheduler_cfg=dict(
+ warmup=dict(
+ num_iters=6060,
+ initial_coef=0.01,
+ mode="cos",
+ ),
+ mode="cos",
+ cfg=dict(
+ half_cycle=6060,
+ lr_decay=0.9,
+ min_coef=0.001,
+ ),
+ ),
+ step_by_batch=True,
+ )
+ print(sche)
+ sche.plot_lr_coef_curve(
+ # save_path="/home/lart/Coding/SOD.torch",
+ show=True,
+ )
diff --git a/utils/pt_utils.py b/utils/pt_utils.py
new file mode 100644
index 0000000..65a2d0b
--- /dev/null
+++ b/utils/pt_utils.py
@@ -0,0 +1,66 @@
+import logging
+import os
+import random
+
+import numpy as np
+import torch
+from torch import nn
+from torch.backends import cudnn
+
+LOGGER = logging.getLogger("main")
+
+
+def customized_worker_init_fn(worker_id):
+ worker_seed = torch.initial_seed() % 2**32
+ np.random.seed(worker_seed)
+
+
+def set_seed_for_lib(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ # 为了禁止hash随机化,使得实验可复现。
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ torch.manual_seed(seed) # 为CPU设置随机种子
+ torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
+ torch.cuda.manual_seed_all(seed) # 为所有GPU设置随机种子
+
+
+def initialize_seed_cudnn(seed, deterministic):
+ assert isinstance(deterministic, bool) and isinstance(seed, int)
+ if seed >= 0:
+ LOGGER.info(f"We will use a fixed seed {seed}")
+ else:
+ seed = np.random.randint(2**32)
+ LOGGER.info(f"We will use a random seed {seed}")
+ set_seed_for_lib(seed)
+ if not deterministic:
+ LOGGER.info("We will use `torch.backends.cudnn.benchmark`")
+ else:
+ LOGGER.info("We will not use `torch.backends.cudnn.benchmark`")
+ cudnn.enabled = True
+ cudnn.benchmark = not deterministic
+ cudnn.deterministic = deterministic
+
+
+def to_device(data, device="cuda"):
+ if isinstance(data, (tuple, list)):
+ return [to_device(item, device) for item in data]
+ elif isinstance(data, dict):
+ return {name: to_device(item, device) for name, item in data.items()}
+ elif isinstance(data, torch.Tensor):
+ return data.to(device=device, non_blocking=True)
+ else:
+ raise TypeError(f"Unsupported type {type(data)}. Only support Tensor or tuple/list/dict containing Tensors.")
+
+
+def frozen_bn_stats(model, freeze_affine=False):
+ """
+ Set all the bn layers to eval mode.
+ Args:
+ model (model): model to set bn layers to eval mode.
+ """
+ for m in model.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
+ m.eval()
+ if freeze_affine:
+ m.requires_grad_(False)
diff --git a/utils/py_utils.py b/utils/py_utils.py
new file mode 100644
index 0000000..4844a1d
--- /dev/null
+++ b/utils/py_utils.py
@@ -0,0 +1,197 @@
+# -*- coding: utf-8 -*-
+import copy
+import logging
+import os
+import shutil
+from collections import OrderedDict, abc
+from datetime import datetime
+
+LOGGER = logging.getLogger("main")
+
+
+def construct_path(output_dir: str, exp_name: str) -> dict:
+ proj_root = os.path.join(output_dir, exp_name)
+ exp_idx = 0
+ exp_output_dir = os.path.join(proj_root, f"exp_{exp_idx}")
+ while os.path.exists(exp_output_dir):
+ exp_idx += 1
+ exp_output_dir = os.path.join(proj_root, f"exp_{exp_idx}")
+
+ tb_path = os.path.join(exp_output_dir, "tb")
+ save_path = os.path.join(exp_output_dir, "pre")
+ pth_path = os.path.join(exp_output_dir, "pth")
+
+ final_full_model_path = os.path.join(pth_path, "checkpoint_final.pth")
+ final_state_path = os.path.join(pth_path, "state_final.pth")
+
+ log_path = os.path.join(exp_output_dir, f"log_{str(datetime.now())[:10]}.txt")
+ cfg_copy_path = os.path.join(exp_output_dir, f"config.py")
+ trainer_copy_path = os.path.join(exp_output_dir, f"trainer.txt")
+ excel_path = os.path.join(exp_output_dir, f"results.xlsx")
+
+ path_config = {
+ "output_dir": output_dir,
+ "pth_log": exp_output_dir,
+ "tb": tb_path,
+ "save": save_path,
+ "pth": pth_path,
+ "final_full_net": final_full_model_path,
+ "final_state_net": final_state_path,
+ "log": log_path,
+ "cfg_copy": cfg_copy_path,
+ "excel": excel_path,
+ "trainer_copy": trainer_copy_path,
+ }
+
+ return path_config
+
+
+def construct_exp_name(model_name: str, cfg: dict):
+ # bs_16_lr_0.05_e30_noamp_2gpu_noms_352
+ focus_item = OrderedDict(
+ {
+ "train/batch_size": "bs",
+ "train/lr": "lr",
+ "train/num_epochs": "e",
+ "train/num_iters": "i",
+ "train/data/shape/h": "h",
+ "train/data/shape/w": "w",
+ "train/optimizer/mode": "opm",
+ "train/optimizer/group_mode": "opgm",
+ "train/scheduler/mode": "sc",
+ "train/scheduler/warmup/num_iters": "wu",
+ "train/use_amp": "amp",
+ }
+ )
+ config = copy.deepcopy(cfg)
+
+ def _format_item(_i):
+ if isinstance(_i, bool):
+ _i = "" if _i else "false"
+ elif isinstance(_i, (int, float)):
+ if _i == 0:
+ _i = "false"
+ elif isinstance(_i, (list, tuple)):
+ _i = "" if _i else "false" # 只是判断是否非空
+ elif isinstance(_i, str):
+ if "_" in _i:
+ _i = _i.replace("_", "").lower()
+ elif _i is None:
+ _i = "none"
+ # else: other types and values will be returned directly
+ return _i
+
+ if (epoch_based := config.train.get("epoch_based", None)) is not None and (not epoch_based):
+ focus_item.pop("train/num_epochs")
+ else:
+ # 默认基于epoch
+ focus_item.pop("train/num_iters")
+
+ exp_names = [model_name]
+ for key, alias in focus_item.items():
+ item = get_value_recurse(keys=key.split("/"), info=config)
+ formatted_item = _format_item(item)
+ if formatted_item == "false":
+ continue
+ exp_names.append(f"{alias.upper()}{formatted_item}")
+
+ info = config.get("info", None)
+ if info:
+ exp_names.append(f"INFO{info.lower()}")
+
+ return "_".join(exp_names)
+
+
+def pre_mkdir(path_config):
+ # 提前创建好记录文件,避免自动创建的时候触发文件创建事件
+ check_mkdir(path_config["pth_log"])
+ make_log(path_config["log"], f"=== log {datetime.now()} ===")
+
+ # 提前创建好存储预测结果和存放模型的文件夹
+ check_mkdir(path_config["save"])
+ check_mkdir(path_config["pth"])
+
+
+def check_mkdir(dir_name, delete_if_exists=False):
+ if not os.path.exists(dir_name):
+ os.makedirs(dir_name)
+ else:
+ if delete_if_exists:
+ print(f"{dir_name} will be re-created!!!")
+ shutil.rmtree(dir_name)
+ os.makedirs(dir_name)
+
+
+def make_log(path, context):
+ with open(path, "a") as log:
+ log.write(f"{context}\n")
+
+
+def iterate_nested_sequence(nested_sequence):
+ """
+ 当前支持list/tuple/int/float/range()的多层嵌套,注意不要嵌套的太深,小心超出python默认的最大递归深度
+
+ 例子
+ ::
+
+ for x in iterate_nested_sequence([[1, (2, 3)], range(3, 10), 0]):
+ print(x)
+
+ 1
+ 2
+ 3
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+ 0
+
+ :param nested_sequence: 多层嵌套的序列
+ :return: generator
+ """
+ for item in nested_sequence:
+ if isinstance(item, (int, float)):
+ yield item
+ elif isinstance(item, (list, tuple, range)):
+ yield from iterate_nested_sequence(item)
+ else:
+ raise NotImplementedError
+
+
+def get_value_recurse(keys: list, info: dict):
+ curr_key, sub_keys = keys[0], keys[1:]
+
+ if (sub_info := info.get(curr_key, "NoKey")) == "NoKey":
+ raise KeyError(f"{curr_key} must be contained in {info}")
+
+ if sub_keys:
+ return get_value_recurse(keys=sub_keys, info=sub_info)
+ else:
+ return sub_info
+
+
+def mapping_to_str(mapping: abc.Mapping, *, prefix: str = " ", lvl: int = 0, max_lvl: int = 1) -> str:
+ """
+ Print the structural information of the dict.
+ """
+ sub_lvl = lvl + 1
+ cur_prefix = prefix * lvl
+ sub_prefix = prefix * sub_lvl
+
+ if lvl == max_lvl:
+ sub_items = str(mapping)
+ else:
+ sub_items = ["{"]
+ for k, v in mapping.items():
+ sub_item = sub_prefix + k + ": "
+ if isinstance(v, abc.Mapping):
+ sub_item += mapping_to_str(v, prefix=prefix, lvl=sub_lvl, max_lvl=max_lvl)
+ else:
+ sub_item += str(v)
+ sub_items.append(sub_item)
+ sub_items.append(cur_prefix + "}")
+ sub_items = "\n".join(sub_items)
+ return sub_items
diff --git a/utils/recorder/__init__.py b/utils/recorder/__init__.py
new file mode 100644
index 0000000..f40f9be
--- /dev/null
+++ b/utils/recorder/__init__.py
@@ -0,0 +1,7 @@
+# -*- coding: utf-8 -*-
+
+from .counter import TrainingCounter
+from .group_metric_caller import GroupedMetricRecorder
+from .logger import TBLogger
+from .meter_recorder import AvgMeter, HistoryBuffer
+from .visualize_results import plot_results
diff --git a/utils/recorder/counter.py b/utils/recorder/counter.py
new file mode 100644
index 0000000..b139573
--- /dev/null
+++ b/utils/recorder/counter.py
@@ -0,0 +1,75 @@
+import math
+
+
+class TrainingCounter:
+ def __init__(self, epoch_length, epoch_based=True, *, num_epochs=None, num_total_iters=None) -> None:
+ self.num_inner_iters = epoch_length
+ self._iter_counter = 0
+ self._epoch_counter = 0
+
+ if epoch_based:
+ assert num_epochs is not None
+ self.num_epochs = num_epochs
+ self.num_total_iters = num_epochs * epoch_length
+ else:
+ assert num_total_iters is not None
+ self.num_total_iters = num_total_iters
+ self.num_epochs = math.ceil(num_total_iters / epoch_length)
+
+ def set_start_epoch(self, start_epoch):
+ self._epoch_counter = start_epoch
+ self._iter_counter = start_epoch * self.num_inner_iters
+
+ def set_start_iterations(self, start_iteration):
+ self._iter_counter = start_iteration
+ self._epoch_counter = start_iteration // self.num_inner_iters
+
+ def every_n_epochs(self, n: int) -> bool:
+ return (self._epoch_counter + 1) % n == 0 if n > 0 else False
+
+ def every_n_iters(self, n: int) -> bool:
+ return (self._iter_counter + 1) % n == 0 if n > 0 else False
+
+ def is_first_epoch(self) -> bool:
+ return self._epoch_counter == 0
+
+ def is_last_epoch(self) -> bool:
+ return self._epoch_counter == self.num_epochs - 1
+
+ def is_first_inner_iter(self) -> bool:
+ return self._iter_counter % self.num_inner_iters == 0
+
+ def is_last_inner_iter(self) -> bool:
+ return (self._iter_counter + 1) % self.num_inner_iters == 0
+
+ def is_first_total_iter(self) -> bool:
+ return self._iter_counter == 0
+
+ def is_last_total_iter(self) -> bool:
+ return self._iter_counter == self.num_total_iters - 1
+
+ def update_iter_counter(self):
+ self._iter_counter += 1
+
+ def update_epoch_counter(self):
+ self._epoch_counter += 1
+
+ def reset_iter_all_counter(self):
+ self._iter_counter = 0
+ self._epoch_counter = 0
+
+ @property
+ def curr_iter(self):
+ return self._iter_counter
+
+ @property
+ def next_iter(self):
+ return self._iter_counter + 1
+
+ @property
+ def curr_epoch(self):
+ return self._epoch_counter
+
+ @property
+ def curr_percent(self):
+ return self._iter_counter / self.num_total_iters
diff --git a/utils/recorder/group_metric_caller.py b/utils/recorder/group_metric_caller.py
new file mode 100644
index 0000000..62e9ec1
--- /dev/null
+++ b/utils/recorder/group_metric_caller.py
@@ -0,0 +1,205 @@
+# -*- coding: utf-8 -*-
+# @Time : 2021/1/4
+# @Author : Lart Pang
+# @GitHub : https://github.com/lartpang
+
+from collections import OrderedDict
+
+import numpy as np
+import py_sod_metrics
+
+
+def ndarray_to_basetype(data):
+ """
+ 将单独的ndarray,或者tuple,list或者dict中的ndarray转化为基本数据类型,
+ 即列表(.tolist())和python标量
+ """
+
+ def _to_list_or_scalar(item):
+ listed_item = item.tolist()
+ if isinstance(listed_item, list) and len(listed_item) == 1:
+ listed_item = listed_item[0]
+ return listed_item
+
+ if isinstance(data, (tuple, list)):
+ results = [_to_list_or_scalar(item) for item in data]
+ elif isinstance(data, dict):
+ results = {k: _to_list_or_scalar(item) for k, item in data.items()}
+ else:
+ assert isinstance(data, np.ndarray)
+ results = _to_list_or_scalar(data)
+ return results
+
+
+def round_w_zero_padding(x, bit_width):
+ x = str(round(x, bit_width))
+ x += "0" * (bit_width - len(x.split(".")[-1]))
+ return x
+
+
+INDIVADUAL_METRIC_MAPPING = {
+ "sm": py_sod_metrics.Smeasure,
+ "wfm": py_sod_metrics.WeightedFmeasure,
+ "mae": py_sod_metrics.MAE,
+ "em": py_sod_metrics.Emeasure,
+}
+BINARY_CLASSIFICATION_METRIC_MAPPING = {
+ "fmeasure": {
+ "handler": py_sod_metrics.FmeasureHandler,
+ "kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=False, beta=0.3),
+ },
+ "iou": {
+ "handler": py_sod_metrics.IOUHandler,
+ "kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=False),
+ },
+ "dice": {
+ "handler": py_sod_metrics.DICEHandler,
+ "kwargs": dict(with_dynamic=True, with_adaptive=False, with_binary=False),
+ },
+}
+
+
+class ImageMetricRecorder:
+ supported_metrics = sorted(INDIVADUAL_METRIC_MAPPING.keys()) + sorted(BINARY_CLASSIFICATION_METRIC_MAPPING.keys())
+
+ def __init__(self, metric_names=("sm", "wfm", "mae", "fmeasure", "em")):
+ """
+ 用于统计各种指标的类
+ """
+ if not metric_names:
+ metric_names = self.supported_metrics
+ assert all([m in self.supported_metrics for m in metric_names]), f"Only support: {self.supported_metrics}"
+
+ self.metric_objs = {}
+ has_existed = False
+ for metric_name in metric_names:
+ if metric_name in INDIVADUAL_METRIC_MAPPING:
+ self.metric_objs[metric_name] = INDIVADUAL_METRIC_MAPPING[metric_name]()
+ else: # metric_name in BINARY_CLASSIFICATION_METRIC_MAPPING
+ if not has_existed: # only init once
+ self.metric_objs["fmeasurev2"] = py_sod_metrics.FmeasureV2()
+ has_existed = True
+ metric_handler = BINARY_CLASSIFICATION_METRIC_MAPPING[metric_name]
+ self.metric_objs["fmeasurev2"].add_handler(
+ handler_name=metric_name, metric_handler=metric_handler["handler"](**metric_handler["kwargs"])
+ )
+
+ def step(self, pre: np.ndarray, gt: np.ndarray, gt_path: str):
+ assert pre.shape == gt.shape, (pre.shape, gt.shape, gt_path)
+ assert pre.dtype == gt.dtype == np.uint8, (pre.dtype, gt.dtype, gt_path)
+
+ for m_obj in self.metric_objs.values():
+ m_obj.step(pre, gt)
+
+ def get_all_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
+ sequential_results = {}
+ numerical_results = {}
+ for m_name, m_obj in self.metric_objs.items():
+ info = m_obj.get_results()
+ if m_name == "fmeasurev2":
+ for _name, results in info.items():
+ dynamic_results = results.get("dynamic")
+ adaptive_results = results.get("adaptive")
+ if dynamic_results is not None:
+ sequential_results[_name] = np.flip(dynamic_results)
+ numerical_results[f"max{_name}"] = dynamic_results.max()
+ numerical_results[f"avg{_name}"] = dynamic_results.mean()
+ if adaptive_results is not None:
+ numerical_results[f"adp{_name}"] = adaptive_results
+ else:
+ results = info[m_name]
+ if m_name in ("wfm", "sm", "mae"):
+ numerical_results[m_name] = results
+ elif m_name == "em":
+ sequential_results[m_name] = np.flip(results["curve"])
+ numerical_results.update(
+ {
+ "maxem": results["curve"].max(),
+ "avgem": results["curve"].mean(),
+ "adpem": results["adp"],
+ }
+ )
+ else:
+ raise NotImplementedError(m_name)
+
+ if num_bits is not None and isinstance(num_bits, int):
+ numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()}
+ if not return_ndarray:
+ sequential_results = ndarray_to_basetype(sequential_results)
+ numerical_results = ndarray_to_basetype(numerical_results)
+ return {"sequential": sequential_results, "numerical": numerical_results}
+
+ def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
+ return self.get_all_results(num_bits=num_bits, return_ndarray=return_ndarray)["numerical"]
+
+
+class GroupedMetricRecorder(object):
+ supported_metrics = ["mae", "em", "sm", "wfm"] + sorted(BINARY_CLASSIFICATION_METRIC_MAPPING.keys())
+
+ def __init__(self, group_names=None, metric_names=("sm", "wfm", "mae", "fmeasure", "em")):
+ self.group_names = group_names
+ self.metric_names = metric_names
+ self.zero()
+
+ def zero(self):
+ self.metric_recorders = {}
+ if self.group_names is not None:
+ self.metric_recorders.update(
+ {n: ImageMetricRecorder(metric_names=self.metric_names) for n in self.group_names}
+ )
+
+ def step(self, group_name: str, pre: np.ndarray, gt: np.ndarray, gt_path: str):
+ if group_name not in self.metric_recorders:
+ self.metric_recorders[group_name] = ImageMetricRecorder(metric_names=self.metric_names)
+ self.metric_recorders[group_name].step(pre, gt, gt_path)
+
+ def show(self, num_bits: int = 3, return_group: bool = False):
+ groups_metrics = {
+ n: r.get_all_results(num_bits=None, return_ndarray=True) for n, r in self.metric_recorders.items()
+ }
+
+ results = {}
+ for group_metrics in groups_metrics.values():
+ for metric_type, metric_group in group_metrics.items(): # sequential and numerical
+ results.setdefault(metric_type, {})
+ for metric_name, metric_array in metric_group.items():
+ results[metric_type].setdefault(metric_name, []).append(metric_array)
+
+ numerical_results = {}
+ for metric_type, metric_group in results.items():
+ for metric_name, metric_array in metric_group.items():
+ metric_array = np.mean(np.vstack(metric_array), axis=0) # average over all groups
+
+ if metric_name in BINARY_CLASSIFICATION_METRIC_MAPPING or metric_name == "em":
+ if metric_type == "sequential":
+ numerical_results[f"max{metric_name}"] = metric_array.max()
+ numerical_results[f"avg{metric_name}"] = metric_array.mean()
+ else:
+ if metric_type == "numerical":
+ if metric_name.startswith(("max", "avg")):
+ # these metrics (maxfm, avgfm, maxem, avgem) will be recomputed within the group
+ continue
+ numerical_results[metric_name] = metric_array
+
+ numerical_results = ndarray_to_basetype(numerical_results)
+ numerical_results = {k: round(v, num_bits) for k, v in numerical_results.items()}
+ numerical_results = self.sort_results(numerical_results)
+ if not return_group:
+ return numerical_results
+
+ group_numerical_results = {}
+ for group_name, group_metric in groups_metrics.items():
+ group_metric = {k: v.round(num_bits) for k, v in group_metric["numerical"].items()}
+ group_metric = ndarray_to_basetype(group_metric)
+ group_numerical_results[group_name] = self.sort_results(group_metric)
+ return numerical_results, group_numerical_results
+
+ def sort_results(self, results: dict) -> OrderedDict:
+ """for a single group of metrics"""
+ sorted_results = OrderedDict()
+ all_keys = sorted(results.keys(), key=lambda item: item[::-1])
+ for name in self.metric_names:
+ for key in all_keys:
+ if key.endswith(name):
+ sorted_results[key] = results[key]
+ return sorted_results
diff --git a/utils/recorder/logger.py b/utils/recorder/logger.py
new file mode 100644
index 0000000..eb411d0
--- /dev/null
+++ b/utils/recorder/logger.py
@@ -0,0 +1,23 @@
+from torch.utils.tensorboard import SummaryWriter
+
+
+class TBLogger:
+ def __init__(self, tb_root):
+ self.tb_root = tb_root
+ self.tb = None
+
+ def write_to_tb(self, name, data, curr_iter):
+ assert self.tb_root is not None
+
+ if self.tb is None:
+ self.tb = SummaryWriter(self.tb_root)
+
+ if not isinstance(data, (tuple, list)):
+ self.tb.add_scalar(f"data/{name}", data, curr_iter)
+ else:
+ for idx, data_item in enumerate(data):
+ self.tb.add_scalar(f"data/{name}_{idx}", data_item, curr_iter)
+
+ def close_tb(self):
+ if self.tb is not None:
+ self.tb.close()
diff --git a/utils/recorder/meter_recorder.py b/utils/recorder/meter_recorder.py
new file mode 100644
index 0000000..888ad1f
--- /dev/null
+++ b/utils/recorder/meter_recorder.py
@@ -0,0 +1,91 @@
+# -*- coding: utf-8 -*-
+from collections import deque
+
+
+class AvgMeter(object):
+ __slots__ = ["value", "sum", "count"]
+
+ def __init__(self):
+ self.value = 0
+ self.sum = 0
+ self.count = 0
+
+ def reset(self):
+ self.value = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, value, num=1):
+ self.value = value
+ self.sum += value * num
+ self.count += num
+
+ @property
+ def avg(self):
+ return self.sum / self.count
+
+ def __repr__(self) -> str:
+ return f"{self.avg:.5f}"
+
+
+class HistoryBuffer:
+ """The class tracks a series of values and provides access to the smoothed
+ value over a window or the global average / sum of the sequence.
+
+ Args:
+ window_size (int): The maximal number of values that can
+ be stored in the buffer. Defaults to 20.
+
+ Example::
+
+ >>> his_buf = HistoryBuffer()
+ >>> his_buf.update(0.1)
+ >>> his_buf.update(0.2)
+ >>> his_buf.avg
+ 0.15
+ """
+
+ def __init__(self, window_size: int = 20) -> None:
+ self._history = deque(maxlen=window_size)
+ self._count: int = 0
+ self._sum: float = 0
+ self.reset()
+
+ def reset(self):
+ self._history.clear()
+ self._count = 0
+ self._sum = 0
+
+ def update(self, value: float, num: int = 1) -> None:
+ """Add a new scalar value. If the length of queue exceeds ``window_size``,
+ the oldest element will be removed from the queue.
+ """
+ self._history.append(value)
+ self._count += num
+ self._sum += value * num
+
+ @property
+ def latest(self) -> float:
+ """The latest value of the queue."""
+ return self._history[-1]
+
+ @property
+ def avg(self) -> float:
+ """The average over the window."""
+ if len(self._history) == 0:
+ return 0
+ else:
+ return sum(self._history) / len(self._history)
+
+ @property
+ def global_avg(self) -> float:
+ """The global average of the queue."""
+ if self._count == 0:
+ return 0
+ else:
+ return self._sum / self._count
+
+ @property
+ def global_sum(self) -> float:
+ """The global sum of the queue."""
+ return self._sum
diff --git a/utils/recorder/visualize_results.py b/utils/recorder/visualize_results.py
new file mode 100644
index 0000000..167e6e5
--- /dev/null
+++ b/utils/recorder/visualize_results.py
@@ -0,0 +1,43 @@
+import os
+
+import cv2
+import matplotlib
+
+matplotlib.use("Agg")
+import numpy as np
+import torchvision.transforms.functional as tv_tf
+from torchvision.utils import make_grid
+
+
+def plot_results(data_container, save_path, base_size=256, is_rgb=True):
+ """Plot the results conresponding to the batched images based on the `make_grid` method from `torchvision`.
+
+ Args:
+ data_container (dict): Dict containing data you want to plot.
+ save_path (str): Path of the exported image.
+ """
+ font_cfg = dict(fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, thickness=2)
+
+ grids = []
+ for subplot_id, (name, data) in enumerate(data_container.items()):
+ if data.ndim == 3:
+ data = data.unsqueeze(1)
+
+ grid = make_grid(data, nrow=data.shape[0], padding=2, normalize=False)
+ grid = np.array(tv_tf.to_pil_image(grid.float()))
+ h, w = grid.shape[:2]
+ ratio = base_size / h
+ grid = cv2.resize(grid, dsize=None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
+
+ (text_w, text_h), baseline = cv2.getTextSize(text=name, **font_cfg)
+ text_xy = 20, 20 + text_h // 2 + baseline
+ cv2.putText(grid, text=name, org=text_xy, color=(255, 255, 255), **font_cfg)
+
+ grids.append(grid)
+ grids = np.concatenate(grids, axis=0) # H,W,C
+
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+ if is_rgb:
+ grids = cv2.cvtColor(grids, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(save_path, grids)