diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..eb796ac --- /dev/null +++ b/.gitignore @@ -0,0 +1,179 @@ +### JupyterNotebooks template +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +.idea/ +.vscode/ +weights/ +outputs/ +pretrained/ +dataset.yaml + +# Custom Scripts +*.ps1 +*.sh diff --git a/README.md b/README.md index ab7d4cf..150d7a2 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,12 @@ ```bibtex -@misc{ZoomNeXt, - title={ZoomNeXt: A Unified Collaborative Pyramid Network for Camouflaged Object Detection}, - author={Youwei Pang and Xiaoqi Zhao and Tian-Zhu Xiang and Lihe Zhang and Huchuan Lu}, - year={2023}, - eprint={2310.20208}, - archivePrefix={arXiv}, - primaryClass={cs.CV} +@ARTICLE {ZoomNeXt, + title = {ZoomNeXt: A Unified Collaborative Pyramid Network for Camouflaged Object Detection}, + author ={Youwei Pang and Xiaoqi Zhao and Tian-Zhu Xiang and Lihe Zhang and Huchuan Lu}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, + year = {2024}, + doi = {10.1109/TPAMI.2024.3417329}, } ``` @@ -29,19 +28,152 @@ ### Weights -| Backbone | CAMO-TE | | | CHAMELEON | | | COD10K-TE | | | NC4K | | | Links | -| ---------------- | ------- | -------------------- | ----- | --------- | -------------------- | ----- | --------- | -------------------- | ----- | ----- | -------------------- | ----- | ------------------------------------------------------------------------------------------------------- | -| | $S_m$ | $F^{\omega}_{\beta}$ | MAE | $S_m$ | $F^{\omega}_{\beta}$ | MAE | $S_m$ | $F^{\omega}_{\beta}$ | MAE | $S_m$ | $F^{\omega}_{\beta}$ | MAE | | -| ResNet-50 | 0.833 | 0.774 | 0.065 | 0.908 | 0.858 | 0.021 | 0.861 | 0.768 | 0.026 | 0.874 | 0.816 | 0.037 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/resnet50-zoomnext.pth) | -| EfficientNet-B4 | 0.867 | 0.824 | 0.046 | 0.911 | 0.865 | 0.020 | 0.875 | 0.797 | 0.021 | 0.884 | 0.837 | 0.032 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/eff-b4-zoomnext.pth) | -| PVTv2-B2 | 0.874 | 0.839 | 0.047 | 0.922 | 0.884 | 0.017 | 0.887 | 0.818 | 0.019 | 0.892 | 0.852 | 0.030 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b2-zoomnext.pth) | -| PVTv2-B3 | 0.885 | 0.854 | 0.042 | 0.927 | 0.898 | 0.017 | 0.895 | 0.829 | 0.018 | 0.900 | 0.861 | 0.028 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b3-zoomnext.pth) | -| PVTv2-B4 | 0.888 | 0.859 | 0.040 | 0.925 | 0.897 | 0.016 | 0.898 | 0.838 | 0.017 | 0.900 | 0.865 | 0.028 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b4-zoomnext.pth) | -| PVTv2-B5 | 0.889 | 0.857 | 0.041 | 0.924 | 0.885 | 0.018 | 0.898 | 0.827 | 0.018 | 0.903 | 0.863 | 0.028 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b5-zoomnext.pth) | -| EfficientNet-B1 | | | | | | | | | | | | | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/eff-b1-zoomnext.pth) | -| ConvNeXtV2-Large | | | | | | | | | | | | | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/convnextv2-l-zoomnext.pth) | - -| Backbone | CAD | | | | | MoCA-Mask-TE | | | | | Links | -| --------------- | ----- | -------------------- | ----- | ----- | ----- | ------------ | -------------------- | ----- | ----- | ----- | ---------------------------------------------------------------------------------------------------------- | -| | $S_m$ | $F^{\omega}_{\beta}$ | MAE | mDice | mIoU | $S_m$ | $F^{\omega}_{\beta}$ | MAE | mDice | mIoU | | -| PVTv2-B5 (T=5) | 0.757 | 0.593 | 0.020 | 0.599 | 0.510 | 0.734 | 0.476 | 0.010 | 0.497 | 0.422 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b5-5frame-zoomnext.pth) | +| Backbone | CAMO-TE | | | CHAMELEON | | | COD10K-TE | | | NC4K | | | Links | +| --------------- | ------- | -------------------- | ----- | --------- | -------------------- | ----- | --------- | -------------------- | ----- | ----- | -------------------- | ----- | --------------------------------------------------------------------------------------------------- | +| | $S_m$ | $F^{\omega}_{\beta}$ | MAE | $S_m$ | $F^{\omega}_{\beta}$ | MAE | $S_m$ | $F^{\omega}_{\beta}$ | MAE | $S_m$ | $F^{\omega}_{\beta}$ | MAE | | +| ResNet-50 | 0.833 | 0.774 | 0.065 | 0.908 | 0.858 | 0.021 | 0.861 | 0.768 | 0.026 | 0.874 | 0.816 | 0.037 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/resnet50-zoomnext.pth) | +| EfficientNet-B1 | 0.848 | 0.803 | 0.056 | 0.916 | 0.870 | 0.020 | 0.863 | 0.773 | 0.024 | 0.876 | 0.823 | 0.036 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/eff-b1-zoomnext.pth) | +| EfficientNet-B4 | 0.867 | 0.824 | 0.046 | 0.911 | 0.865 | 0.020 | 0.875 | 0.797 | 0.021 | 0.884 | 0.837 | 0.032 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/eff-b4-zoomnext.pth) | +| PVTv2-B2 | 0.874 | 0.839 | 0.047 | 0.922 | 0.884 | 0.017 | 0.887 | 0.818 | 0.019 | 0.892 | 0.852 | 0.030 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b2-zoomnext.pth) | +| PVTv2-B3 | 0.885 | 0.854 | 0.042 | 0.927 | 0.898 | 0.017 | 0.895 | 0.829 | 0.018 | 0.900 | 0.861 | 0.028 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b3-zoomnext.pth) | +| PVTv2-B4 | 0.888 | 0.859 | 0.040 | 0.925 | 0.897 | 0.016 | 0.898 | 0.838 | 0.017 | 0.900 | 0.865 | 0.028 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b4-zoomnext.pth) | +| PVTv2-B5 | 0.889 | 0.857 | 0.041 | 0.924 | 0.885 | 0.018 | 0.898 | 0.827 | 0.018 | 0.903 | 0.863 | 0.028 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b5-zoomnext.pth) | + +| Backbone | CAD | | | | | MoCA-Mask-TE | | | | | Links | +| -------------- | ----- | -------------------- | ----- | ----- | ----- | ------------ | -------------------- | ----- | ----- | ----- | ---------------------------------------------------------------------------------------------------------- | +| | $S_m$ | $F^{\omega}_{\beta}$ | MAE | mDice | mIoU | $S_m$ | $F^{\omega}_{\beta}$ | MAE | mDice | mIoU | | +| PVTv2-B5 (T=5) | 0.757 | 0.593 | 0.020 | 0.599 | 0.510 | 0.734 | 0.476 | 0.010 | 0.497 | 0.422 | [Weight](https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.2/pvtv2-b5-5frame-zoomnext.pth) | + + +## Prepare Data + +Set all dataset information to the `dataset.yaml` as follows. + +
+ +Example of the config file (dataset.yaml): + + +```yaml +# VCOD Datasets +moca_mask_tr: + { + root: "YOUR-VCOD-DATASETS-ROOT/MoCA-Mask/MoCA_Video/TrainDataset_per_sq", + image: { path: "*/Imgs", suffix: ".jpg" }, + mask: { path: "*/GT", suffix: ".png" }, + } +moca_mask_te: + { + root: "YOUR-VCOD-DATASETS-ROOT/MoCA-Mask/MoCA_Video/TestDataset_per_sq", + image: { path: "*/Imgs", suffix: ".jpg" }, + mask: { path: "*/GT", suffix: ".png" }, + } +cad: + { + root: "YOUR-VCOD-DATASETS-ROOT/CamouflagedAnimalDataset", + image: { path: "original_data/*/frames", suffix: ".png" }, + mask: { path: "converted_mask/*/groundtruth", suffix: ".png" }, + } + +# ICOD Datasets +cod10k_tr: + { + root: "YOUR-ICOD-DATASETS-ROOT/Train/COD10K-TR", + image: { path: "Image", suffix: ".jpg" }, + mask: { path: "Mask", suffix: ".png" }, + } +camo_tr: + { + root: "YOUR-ICOD-DATASETS-ROOT/Train/CAMO-TR", + image: { path: "Image", suffix: ".jpg" }, + mask: { path: "Mask", suffix: ".png" }, + } +cod10k_te: + { + root: "YOUR-ICOD-DATASETS-ROOT/Test/COD10K-TE", + image: { path: "Image", suffix: ".jpg" }, + mask: { path: "Mask", suffix: ".png" }, + } +camo_te: + { + root: "YOUR-ICOD-DATASETS-ROOT/Test/CAMO-TE", + image: { path: "Image", suffix: ".jpg" }, + mask: { path: "Mask", suffix: ".png" }, + } +chameleon: + { + root: "YOUR-ICOD-DATASETS-ROOT/Test/CHAMELEON", + image: { path: "Image", suffix: ".jpg" }, + mask: { path: "Mask", suffix: ".png" }, + } +nc4k: + { + root: "YOUR-ICOD-DATASETS-ROOT/Test/NC4K", + image: { path: "Imgs", suffix: ".jpg" }, + mask: { path: "GT", suffix: ".png" }, + } +``` +
+ +## Install Requirements + +- torch==2.1.2 +- torchvision==0.16.2 +- Others: `pip install -r requirements.txt` + +## Evaluation + +```shell +# ICOD +python main_for_image.py --config configs/icod_train.py --model-name --evaluate --load-from +# VCOD +python main_for_video.py --config configs/vcod_finetune.py --model-name --evaluate --load-from +``` + +> [!note] +> +> Evaluating performance on the VCOD dataset directly using training scripts is not consistent with the paper. +> This is because the evaluation approach in the paper continues the strategy of previous work [SLT-Net](https://github.com/XuelianCheng/SLT-Net), which adjusts the range of valid frames in the sequence. + +To get the results in our paper, you can use [PySODEvalToolkit](https://github.com/lartpang/PySODEvalToolkit) and use the similar command as: + +```shell +python ./eval.py ` + --dataset-json vcod-datasets.json ` + --method-json vcod-methods.json ` + --include-datasets CAD ` + --include-methods videoPvtV2B5_ZoomNeXt ` + --data-type video ` + --valid-frame-start "0" ` + --valid-frame-end "0" ` + --metric-names "sm" "wfm" "mae" "fmeasure" "em" "dice" "iou" + +python ./eval.py ` + --dataset-json vcod-datasets.json ` + --method-json vcod-methods.json ` + --include-datasets MOCA-MASK-TE ` + --include-methods videoPvtV2B5_ZoomNeXt ` + --data-type video ` + --valid-frame-start "0" ` + --valid-frame-end "-2" ` + --metric-names "sm" "wfm" "mae" "fmeasure" "em" "dice" "iou" +``` + +## Training + +### Image Camouflaged Object Detection + +```shell +python main_for_image.py --config configs/icod_train.py --pretrained --model-name EffB1_ZoomNeXt +python main_for_image.py --config configs/icod_train.py --pretrained --model-name EffB4_ZoomNeXt +python main_for_image.py --config configs/icod_train.py --pretrained --model-name PvtV2B2_ZoomNeXt +python main_for_image.py --config configs/icod_train.py --pretrained --model-name PvtV2B3_ZoomNeXt +python main_for_image.py --config configs/icod_train.py --pretrained --model-name PvtV2B4_ZoomNeXt +python main_for_image.py --config configs/icod_train.py --pretrained --model-name PvtV2B5_ZoomNeXt +python main_for_image.py --config configs/icod_train.py --pretrained --model-name RN50_ZoomNeXt +``` + +### Video Camouflaged Object Detection + +1. Pretrain on COD10K-TR: `python main_for_image.py --config configs/icod_pretrain.py --info pretrain --model-name PvtV2B5_ZoomNeXt --pretrained` +2. Finetune on MoCA-Mask-TR: `python main_for_video.py --config configs/vcod_finetune.py --info finetune --model-name videoPvtV2B5_ZoomNeXt --load-from ` diff --git a/configs/icod_pretrain.py b/configs/icod_pretrain.py new file mode 100644 index 0000000..1bb7337 --- /dev/null +++ b/configs/icod_pretrain.py @@ -0,0 +1,54 @@ +_base_ = ["icod_train.py"] + +has_test = False + +__BATCHSIZE = 4 +__NUM_EPOCHS = 150 +__NUM_TR_SAMPLES = 3040 +__ITER_PER_EPOCH = __NUM_TR_SAMPLES // __BATCHSIZE # drop_last is True +__NUM_ITERS = __NUM_EPOCHS * __ITER_PER_EPOCH + +train = dict( + batch_size=__BATCHSIZE, + use_amp=True, + num_epochs=__NUM_EPOCHS, + lr=0.0001, + optimizer=dict( + mode="adam", + set_to_none=False, + group_mode="finetune", + cfg=dict( + weight_decay=0, + diff_factor=0.1, + ), + ), + sche_usebatch=True, + scheduler=dict( + warmup=dict( + num_iters=0, + initial_coef=0.01, + mode="linear", + ), + mode="step", + cfg=dict( + milestones=int(__NUM_ITERS * 2 / 3), + gamma=0.1, + ), + ), + bn=dict( + freeze_status=True, + freeze_affine=True, + freeze_encoder=False, + ), + data=dict( + shape=dict(h=384, w=384), + names=["cod10k_tr"], + ), +) + +test = dict( + data=dict( + shape=dict(h=384, w=384), + names=[], + ), +) diff --git a/configs/icod_train.py b/configs/icod_train.py new file mode 100644 index 0000000..a8f3358 --- /dev/null +++ b/configs/icod_train.py @@ -0,0 +1,63 @@ +has_test = True +deterministic = True +use_custom_worker_init = True +log_interval = 20 +base_seed = 112358 + +__BATCHSIZE = 4 +__NUM_EPOCHS = 150 +__NUM_TR_SAMPLES = 3040 + 1000 +__ITER_PER_EPOCH = __NUM_TR_SAMPLES // __BATCHSIZE # drop_last is True +__NUM_ITERS = __NUM_EPOCHS * __ITER_PER_EPOCH + +train = dict( + batch_size=__BATCHSIZE, + num_workers=2, + use_amp=True, + num_epochs=__NUM_EPOCHS, + epoch_based=True, + num_iters=None, + lr=0.0001, + grad_acc_step=1, + optimizer=dict( + mode="adam", + set_to_none=False, + group_mode="finetune", + cfg=dict( + weight_decay=0, + diff_factor=0.1, + ), + ), + sche_usebatch=True, + scheduler=dict( + warmup=dict( + num_iters=0, + initial_coef=0.01, + mode="linear", + ), + mode="step", + cfg=dict( + milestones=int(__NUM_ITERS * 2 / 3), + gamma=0.1, + ), + ), + bn=dict( + freeze_status=True, + freeze_affine=True, + freeze_encoder=False, + ), + data=dict( + shape=dict(h=384, w=384), + names=["cod10k_tr", "camo_tr"], + ), +) + +test = dict( + batch_size=__BATCHSIZE, + num_workers=2, + clip_range=None, + data=dict( + shape=dict(h=384, w=384), + names=["chameleon", "camo_te", "cod10k_te", "nc4k"], + ), +) diff --git a/configs/vcod_finetune.py b/configs/vcod_finetune.py new file mode 100644 index 0000000..2c9fbd8 --- /dev/null +++ b/configs/vcod_finetune.py @@ -0,0 +1,44 @@ +_base_ = ["icod_train.py"] + +num_frames = 5 + +__BATCHSIZE = 2 + +train = dict( + batch_size=__BATCHSIZE, + use_amp=True, + num_epochs=10, + lr=0.0001, + optimizer=dict( + mode="adam", + set_to_none=False, + group_mode="finetune", + cfg=dict( + weight_decay=0, + diff_factor=0.1, + ), + ), + sche_usebatch=True, + scheduler=dict( + warmup=dict(num_iters=0), + mode="constant", + cfg=dict(coef=1), + ), + bn=dict( + freeze_status=True, + freeze_affine=True, + freeze_encoder=False, + ), + data=dict( + shape=dict(h=384, w=384), + names=["moca_mask_tr"], + ), +) + +test = dict( + batch_size=__BATCHSIZE, + data=dict( + shape=dict(h=384, w=384), + names=["cad", "moca_mask_te"], + ), +) diff --git a/main_for_image.py b/main_for_image.py new file mode 100644 index 0000000..329c610 --- /dev/null +++ b/main_for_image.py @@ -0,0 +1,404 @@ +# -*- coding: utf-8 -*- +import argparse +import datetime +import inspect +import logging +import os +import shutil +import time + +import albumentations as A +import colorlog +import cv2 +import numpy as np +import torch +import yaml +from mmengine import Config +from torch.utils import data +from tqdm import tqdm + +import methods as model_zoo +from utils import io, ops, pipeline, pt_utils, py_utils, recorder + +LOGGER = logging.getLogger("main") +LOGGER.propagate = False +LOGGER.setLevel(level=logging.DEBUG) +stream_handler = logging.StreamHandler() +stream_handler.setLevel(logging.DEBUG) +stream_handler.setFormatter(colorlog.ColoredFormatter("%(log_color)s[%(filename)s] %(reset)s%(message)s")) +LOGGER.addHandler(stream_handler) + + +class ImageTestDataset(data.Dataset): + def __init__(self, dataset_info: dict, shape: dict): + super().__init__() + self.shape = shape + + image_path = os.path.join(dataset_info["root"], dataset_info["image"]["path"]) + image_suffix = dataset_info["image"]["suffix"] + mask_path = os.path.join(dataset_info["root"], dataset_info["mask"]["path"]) + mask_suffix = dataset_info["mask"]["suffix"] + + image_names = [p[: -len(image_suffix)] for p in sorted(os.listdir(image_path)) if p.endswith(image_suffix)] + mask_names = [p[: -len(mask_suffix)] for p in sorted(os.listdir(mask_path)) if p.endswith(mask_suffix)] + valid_names = sorted(set(image_names).intersection(mask_names)) + self.total_data_paths = [ + (os.path.join(image_path, n) + image_suffix, os.path.join(mask_path, n) + mask_suffix) for n in valid_names + ] + + def __getitem__(self, index): + image_path, mask_path = self.total_data_paths[index] + image = io.read_color_array(image_path) + + base_h = self.shape["h"] + base_w = self.shape["w"] + + images = ops.ms_resize(image, scales=(0.5, 1.0, 1.5), base_h=base_h, base_w=base_w) + image_s = torch.from_numpy(images[0]).div(255).permute(2, 0, 1) + image_m = torch.from_numpy(images[1]).div(255).permute(2, 0, 1) + image_l = torch.from_numpy(images[2]).div(255).permute(2, 0, 1) + + return dict( + data={"image_s": image_s, "image_m": image_m, "image_l": image_l}, + info=dict(mask_path=mask_path, group_name="image"), + ) + + def __len__(self): + return len(self.total_data_paths) + + +class ImageTrainDataset(data.Dataset): + def __init__(self, dataset_infos: dict, shape: dict): + super().__init__() + self.shape = shape + + self.total_data_paths = [] + for dataset_name, dataset_info in dataset_infos.items(): + image_path = os.path.join(dataset_info["root"], dataset_info["image"]["path"]) + image_suffix = dataset_info["image"]["suffix"] + mask_path = os.path.join(dataset_info["root"], dataset_info["mask"]["path"]) + mask_suffix = dataset_info["mask"]["suffix"] + + image_names = [p[: -len(image_suffix)] for p in sorted(os.listdir(image_path)) if p.endswith(image_suffix)] + mask_names = [p[: -len(mask_suffix)] for p in sorted(os.listdir(mask_path)) if p.endswith(mask_suffix)] + valid_names = sorted(set(image_names).intersection(mask_names)) + data_paths = [ + (os.path.join(image_path, n) + image_suffix, os.path.join(mask_path, n) + mask_suffix) + for n in valid_names + ] + LOGGER.info(f"Length of {dataset_name}: {len(data_paths)}") + self.total_data_paths.extend(data_paths) + + self.trains = A.Compose( + [ + A.HorizontalFlip(p=0.5), + A.Rotate(limit=90, p=0.5, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REPLICATE), + A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=10, p=0.5), + ] + ) + + def __getitem__(self, index): + image_path, mask_path = self.total_data_paths[index] + image = io.read_color_array(image_path) + mask = io.read_gray_array(mask_path, thr=0) + if image.shape[:2] != mask.shape: + h, w = mask.shape + image = ops.resize(image, height=h, width=w) + + transformed = self.trains(image=image, mask=mask) + image = transformed["image"] + mask = transformed["mask"] + + base_h = self.shape["h"] + base_w = self.shape["w"] + + images = ops.ms_resize(image, scales=(0.5, 1.0, 1.5), base_h=base_h, base_w=base_w) + image_s = torch.from_numpy(images[0]).div(255).permute(2, 0, 1) + image_m = torch.from_numpy(images[1]).div(255).permute(2, 0, 1) + image_l = torch.from_numpy(images[2]).div(255).permute(2, 0, 1) + + mask = ops.resize(mask, height=base_h, width=base_w) + mask = torch.from_numpy(mask).unsqueeze(0) + + return dict( + data={ + "image_s": image_s, + "image_m": image_m, + "image_l": image_l, + "mask": mask, + } + ) + + def __len__(self): + return len(self.total_data_paths) + + +class Evaluator: + def __init__(self, device, metric_names, clip_range=None): + self.device = device + self.clip_range = clip_range + self.metric_names = metric_names + + @torch.no_grad() + def eval(self, model, data_loader, save_path=""): + model.eval() + all_metrics = recorder.GroupedMetricRecorder(metric_names=self.metric_names) + + for batch in tqdm(data_loader, total=len(data_loader), ncols=79, desc="[EVAL]"): + batch_images = pt_utils.to_device(batch["data"], device=self.device) + logits = model(data=batch_images) # B,1,H,W + probs = logits.sigmoid().squeeze(1).cpu().detach().numpy() + probs = probs - probs.min() + probs = probs / (probs.max() + 1e-8) + + mask_paths = batch["info"]["mask_path"] + group_names = batch["info"]["group_name"] + for pred_idx, pred in enumerate(probs): + mask_path = mask_paths[pred_idx] + mask_array = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + mask_array[mask_array > 0] = 255 + mask_h, mask_w = mask_array.shape + pred = ops.resize(pred, height=mask_h, width=mask_w) + + if self.clip_range is not None: + pred = ops.clip_to_normalize(pred, clip_range=self.clip_range) + + group_name = group_names[pred_idx] + if save_path: # 这里的save_path包含了数据集名字 + ops.save_array_as_image( + data_array=pred, + save_name=os.path.basename(mask_path), + save_dir=os.path.join(save_path, group_name), + ) + + pred = (pred * 255).astype(np.uint8) + all_metrics.step(group_name=group_name, pre=pred, gt=mask_array, gt_path=mask_path) + return all_metrics.show() + + +def test(model, cfg): + test_wrapper = Evaluator(device=cfg.device, metric_names=cfg.metric_names, clip_range=cfg.test.clip_range) + + for te_name in cfg.test.data.names: + te_info = cfg.dataset_infos[te_name] + te_dataset = ImageTestDataset(dataset_info=te_info, shape=cfg.test.data.shape) + te_loader = data.DataLoader( + dataset=te_dataset, batch_size=cfg.test.batch_size, num_workers=cfg.test.num_workers, pin_memory=True + ) + LOGGER.info(f"Testing with testset: {te_name}: {len(te_dataset)}") + + if cfg.save_results: + save_path = os.path.join(cfg.path.save, te_name) + LOGGER.info(f"Results will be saved into {save_path}") + else: + save_path = "" + + seg_results = test_wrapper.eval(model=model, data_loader=te_loader, save_path=save_path) + seg_results_str = ", ".join([f"{k}: {v:.03f}" for k, v in seg_results.items()]) + LOGGER.info(f"({te_name}): {py_utils.mapping_to_str(te_info)}\n{seg_results_str}") + + +def train(model, cfg): + tr_dataset = ImageTrainDataset( + dataset_infos={data_name: cfg.dataset_infos[data_name] for data_name in cfg.train.data.names}, + shape=cfg.train.data.shape, + ) + LOGGER.info(f"Total Length of Image Trainset: {len(tr_dataset)}") + + tr_loader = data.DataLoader( + dataset=tr_dataset, + batch_size=cfg.train.batch_size, + num_workers=cfg.train.num_workers, + shuffle=True, + drop_last=True, + pin_memory=True, + worker_init_fn=pt_utils.customized_worker_init_fn if cfg.use_custom_worker_init else None, + ) + + counter = recorder.TrainingCounter( + epoch_length=len(tr_loader), + epoch_based=cfg.train.epoch_based, + num_epochs=cfg.train.num_epochs, + num_total_iters=cfg.train.num_iters, + ) + optimizer = pipeline.construct_optimizer( + model=model, + initial_lr=cfg.train.lr, + mode=cfg.train.optimizer.mode, + group_mode=cfg.train.optimizer.group_mode, + cfg=cfg.train.optimizer.cfg, + ) + scheduler = pipeline.Scheduler( + optimizer=optimizer, + num_iters=counter.num_total_iters, + epoch_length=counter.num_inner_iters, + scheduler_cfg=cfg.train.scheduler, + step_by_batch=cfg.train.sche_usebatch, + ) + scheduler.record_lrs(param_groups=optimizer.param_groups) + scheduler.plot_lr_coef_curve(save_path=cfg.path.pth_log) + scaler = pipeline.Scaler(optimizer, cfg.train.use_amp, set_to_none=cfg.train.optimizer.set_to_none) + + LOGGER.info(f"Scheduler:\n{scheduler}\nOptimizer:\n{optimizer}") + + loss_recorder = recorder.HistoryBuffer() + iter_time_recorder = recorder.HistoryBuffer() + + LOGGER.info(f"Image Mean: {model.normalizer.mean.flatten()}, Image Std: {model.normalizer.std.flatten()}") + if cfg.train.bn.freeze_encoder: + LOGGER.info(" >>> Freeze Backbone !!! <<< ") + model.encoder.requires_grad_(False) + + train_start_time = time.perf_counter() + for _ in range(counter.num_epochs): + LOGGER.info(f"Exp_Name: {cfg.exp_name}") + + model.train() + if cfg.train.bn.freeze_status: + pt_utils.frozen_bn_stats(model.encoder, freeze_affine=cfg.train.bn.freeze_affine) + + # an epoch starts + for batch_idx, batch in enumerate(tr_loader): + iter_start_time = time.perf_counter() + scheduler.step(curr_idx=counter.curr_iter) # update learning rate + + data_batch = pt_utils.to_device(data=batch["data"], device=cfg.device) + with torch.cuda.amp.autocast(enabled=cfg.train.use_amp): + outputs = model(data=data_batch, iter_percentage=counter.curr_percent) + + loss = outputs["loss"] + loss_str = outputs["loss_str"] + loss = loss / cfg.train.grad_acc_step + scaler.calculate_grad(loss=loss) + if counter.every_n_iters(cfg.train.grad_acc_step): # Accumulates scaled gradients. + scaler.update_grad() + + item_loss = loss.item() + data_shape = tuple(data_batch["mask"].shape) + loss_recorder.update(value=item_loss, num=data_shape[0]) + + if cfg.log_interval > 0 and ( + counter.every_n_iters(cfg.log_interval) + or counter.is_first_inner_iter() + or counter.is_last_inner_iter() + or counter.is_last_total_iter() + ): + gpu_mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB) + eta_seconds = iter_time_recorder.avg * (counter.num_total_iters - counter.curr_iter - 1) + eta_string = f"ETA: {datetime.timedelta(seconds=int(eta_seconds))}" + progress = ( + f"{counter.curr_iter}:{counter.num_total_iters} " + f"{batch_idx}/{counter.num_inner_iters} " + f"{counter.curr_epoch}/{counter.num_epochs}" + ) + loss_info = f"{loss_str} (M:{loss_recorder.global_avg:.5f}/C:{item_loss:.5f})" + lr_info = f"LR: {optimizer.lr_string()}" + LOGGER.info(f"{eta_string}({gpu_mem}) | {progress} | {lr_info} | {loss_info} | {data_shape}") + cfg.tb_logger.write_to_tb("lr", optimizer.lr_groups(), counter.curr_iter) + cfg.tb_logger.write_to_tb("iter_loss", item_loss, counter.curr_iter) + cfg.tb_logger.write_to_tb("avg_loss", loss_recorder.global_avg, counter.curr_iter) + + if counter.curr_iter < 3: # plot some batches of the training phase + recorder.plot_results( + dict(img=data_batch["image_m"], msk=data_batch["mask"], **outputs["vis"]), + save_path=os.path.join(cfg.path.pth_log, "img", f"iter_{counter.curr_iter}.png"), + ) + + iter_time_recorder.update(value=time.perf_counter() - iter_start_time) + if counter.is_last_total_iter(): + break + counter.update_iter_counter() + + # an epoch ends + recorder.plot_results( + dict(img=data_batch["image_m"], msk=data_batch["mask"], **outputs["vis"]), + save_path=os.path.join(cfg.path.pth_log, "img", f"epoch_{counter.curr_epoch}.png"), + ) + io.save_weight(model=model, save_path=cfg.path.final_state_net) + counter.update_epoch_counter() + + cfg.tb_logger.close_tb() + io.save_weight(model=model, save_path=cfg.path.final_state_net) + + total_train_time = time.perf_counter() - train_start_time + total_other_time = datetime.timedelta(seconds=int(total_train_time - iter_time_recorder.global_sum)) + LOGGER.info( + f"Total Training Time: {datetime.timedelta(seconds=int(total_train_time))} ({total_other_time} on others)" + ) + + +def parse_cfg(): + parser = argparse.ArgumentParser("Training and evaluation script") + parser.add_argument("--config", required=True, type=str) + parser.add_argument("--data-cfg", type=str, default="./dataset.yaml") + parser.add_argument("--model-name", type=str, choices=model_zoo.__dict__.keys()) + parser.add_argument("--output-dir", type=str, default="outputs") + parser.add_argument("--load-from", type=str) + parser.add_argument("--pretrained", action="store_true") + parser.add_argument( + "--metric-names", + nargs="+", + type=str, + default=["sm", "wfm", "mae", "em", "fmeasure"], + choices=recorder.GroupedMetricRecorder.supported_metrics, + ) + parser.add_argument("--evaluate", action="store_true") + parser.add_argument("--save-results", action="store_true") + parser.add_argument("--info", type=str) + args = parser.parse_args() + + cfg = Config.fromfile(args.config) + cfg.merge_from_dict(vars(args)) + + with open(cfg.data_cfg, mode="r") as f: + cfg.dataset_infos = yaml.safe_load(f) + + cfg.proj_root = os.path.dirname(os.path.abspath(__file__)) + cfg.exp_name = py_utils.construct_exp_name(model_name=cfg.model_name, cfg=cfg) + cfg.output_dir = os.path.join(cfg.proj_root, cfg.output_dir) + cfg.path = py_utils.construct_path(output_dir=cfg.output_dir, exp_name=cfg.exp_name) + cfg.device = "cuda:0" + + py_utils.pre_mkdir(cfg.path) + with open(cfg.path.cfg_copy, encoding="utf-8", mode="w") as f: + f.write(cfg.pretty_text) + shutil.copy(__file__, cfg.path.trainer_copy) + + file_handler = logging.FileHandler(cfg.path.log) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter("[%(filename)s] %(message)s")) + LOGGER.addHandler(file_handler) + LOGGER.info(cfg.pretty_text) + + cfg.tb_logger = recorder.TBLogger(tb_root=cfg.path.tb) + return cfg + + +def main(): + cfg = parse_cfg() + pt_utils.initialize_seed_cudnn(seed=cfg.base_seed, deterministic=cfg.deterministic) + + model_class = model_zoo.__dict__.get(cfg.model_name) + assert model_class is not None, "Please check your --model-name" + model_code = inspect.getsource(model_class) + model = model_class(num_frames=1, pretrained=cfg.pretrained) + LOGGER.info(model_code) + model.to(cfg.device) + + if cfg.load_from: + io.load_weight(model=model, load_path=cfg.load_from, strict=True) + + LOGGER.info(f"Number of Parameters: {sum((v.numel() for v in model.parameters(recurse=True)))}") + if not cfg.evaluate: + train(model=model, cfg=cfg) + + if cfg.evaluate or cfg.has_test: + io.save_weight(model=model, save_path=cfg.path.final_state_net) + test(model=model, cfg=cfg) + + LOGGER.info("End training...") + + +if __name__ == "__main__": + main() diff --git a/main_for_video.py b/main_for_video.py new file mode 100644 index 0000000..eab29a8 --- /dev/null +++ b/main_for_video.py @@ -0,0 +1,526 @@ +# -*- coding: utf-8 -*- +import argparse +import datetime +import glob +import inspect +import logging +import os +import re +import shutil +import time + +import albumentations as A +import colorlog +import cv2 +import numpy as np +import torch +import yaml +from mmengine import Config +from torch.utils import data +from tqdm import tqdm + +import methods as model_zoo +from utils import io, ops, pipeline, pt_utils, py_utils, recorder + +LOGGER = logging.getLogger("main") +LOGGER.propagate = False +LOGGER.setLevel(level=logging.DEBUG) +stream_handler = logging.StreamHandler() +stream_handler.setLevel(logging.DEBUG) +stream_handler.setFormatter(colorlog.ColoredFormatter("%(log_color)s[%(filename)s] %(reset)s%(message)s")) +LOGGER.addHandler(stream_handler) + + +def construct_frame_transform(): + return A.Compose( + [ + A.Rotate(limit=90, p=0.5, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT101), + A.RandomBrightnessContrast(brightness_limit=0.02, contrast_limit=0.02, p=0.5), + ] + ) + + +def construct_video_transform(): + return A.ReplayCompose( + [ + A.HorizontalFlip(p=0.5), + A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=10, p=0.5), + ] + ) + + +def get_number_from_tail(string): + return int(re.findall(pattern="\d+$", string=string)[0]) + + +class VideoTestDataset(data.Dataset): + def __init__(self, dataset_info: dict, shape: dict, num_frames: int = 1, overlap: int = 0): + super().__init__() + self.shape = shape + self.num_frames = num_frames + assert 0 <= overlap < num_frames + + image_path = os.path.join(dataset_info["root"], dataset_info["image"]["path"]) + image_suffix = dataset_info["image"]["suffix"] + mask_path = os.path.join(dataset_info["root"], dataset_info["mask"]["path"]) + mask_suffix = dataset_info["mask"]["suffix"] + + image_group_paths = sorted(glob.glob(image_path)) + mask_group_paths = sorted(glob.glob(mask_path)) + group_name_place = image_path.find("*") + + self.total_data_paths = [] + for image_group_path, mask_group_path in zip(image_group_paths, mask_group_paths): + group_name = image_group_path[group_name_place:].split("/")[0] + image_names = [ + p[: -len(image_suffix)] for p in sorted(os.listdir(image_group_path)) if p.endswith(image_suffix) + ] + mask_names = [ + p[: -len(mask_suffix)] for p in sorted(os.listdir(mask_group_path)) if p.endswith(mask_suffix) + ] + valid_names = sorted( + set(image_names).intersection(mask_names), key=lambda item: get_number_from_tail(item) + ) + assert len(valid_names) >= num_frames, image_group_path + + i = 0 + while i < len(valid_names): + if i + num_frames > len(valid_names): + i = len(valid_names) - num_frames + clip_info = [ + ( + os.path.join(image_group_path, n) + image_suffix, + os.path.join(mask_group_path, n) + mask_suffix, + group_name, + ) + for n in valid_names[i : i + num_frames] + ] + if len(clip_info) < num_frames: + times, last = divmod(num_frames, len(clip_info)) + clip_info.extend(clip_info * (times - 1)) + clip_info.extend(clip_info[:last]) + self.total_data_paths.append(clip_info) + + i += num_frames - overlap + + def __getitem__(self, index): + clip_info = self.total_data_paths[index] + + base_h = self.shape["h"] + base_w = self.shape["w"] + image_ss = [] + image_ms = [] + image_ls = [] + mask_paths = [] + for image_path, mask_path, group_name in clip_info: + image = io.read_color_array(image_path) + mask_paths.append(mask_path) + images = ops.ms_resize(image, scales=(0.5, 1.0, 1.5), base_h=base_h, base_w=base_w) + image_ss.append(torch.from_numpy(images[0]).div(255).permute(2, 0, 1)) + image_ms.append(torch.from_numpy(images[1]).div(255).permute(2, 0, 1)) + image_ls.append(torch.from_numpy(images[2]).div(255).permute(2, 0, 1)) + + return dict( + data={ + "image_s": torch.stack(image_ss, dim=0), + "image_m": torch.stack(image_ms, dim=0), + "image_l": torch.stack(image_ls, dim=0), + }, + info=dict(mask_path={f"frame_{i}": p for i, p in enumerate(mask_paths)}, group_name=group_name), + ) + + def __len__(self): + return len(self.total_data_paths) + + +class VideoTrainDataset(data.Dataset): + def __init__(self, dataset_infos: dict, shape: dict, num_frames: int = 1): + super().__init__() + self.shape = shape + self.num_frames = num_frames + self.stride = num_frames - 1 if num_frames > 1 else 1 + + self.total_data_paths = [] + for dataset_name, dataset_info in dataset_infos.items(): + image_path = os.path.join(dataset_info["root"], dataset_info["image"]["path"]) + image_suffix = dataset_info["image"]["suffix"] + mask_path = os.path.join(dataset_info["root"], dataset_info["mask"]["path"]) + mask_suffix = dataset_info["mask"]["suffix"] + + image_group_paths = sorted(glob.glob(image_path)) + mask_group_paths = sorted(glob.glob(mask_path)) + group_name_place = image_path.find("*") + + data_paths = [] + for image_group_path, mask_group_path in zip(image_group_paths, mask_group_paths): + group_name = image_group_path[group_name_place:].split("/")[0] + image_names = [ + p[: -len(image_suffix)] for p in sorted(os.listdir(image_group_path)) if p.endswith(image_suffix) + ] + mask_names = [ + p[: -len(mask_suffix)] for p in sorted(os.listdir(mask_group_path)) if p.endswith(mask_suffix) + ] + valid_names = sorted( + set(image_names).intersection(mask_names), key=lambda item: get_number_from_tail(item) + ) + + length_of_sequence = len(valid_names) + for clip_idx in range(0, length_of_sequence, self.stride): + start_idx = clip_idx * self.stride + if start_idx + num_frames > length_of_sequence: + start_idx = length_of_sequence - num_frames + + clip_info = [] + for i, n in enumerate(valid_names[start_idx : start_idx + num_frames]): + clip_info.append( + ( + os.path.join(image_group_path, n) + image_suffix, + os.path.join(mask_group_path, n) + mask_suffix, + group_name, + i, + clip_idx, + ) + ) + if len(clip_info) < num_frames: + times, last = divmod(num_frames, len(clip_info)) + clip_info.extend(clip_info * (times - 1)) + clip_info.extend(clip_info[:last]) + data_paths.append(clip_info) + + LOGGER.info(f"Length of {dataset_name}: {len(data_paths)}") + self.total_data_paths.extend(data_paths) + + self.frame_specific_transformation = construct_frame_transform() + self.frame_share_transformation = construct_video_transform() + + def __getitem__(self, index): + base_h = self.shape["h"] + base_w = self.shape["w"] + image_ss = [] + image_ms = [] + image_ls = [] + masks = [] + for image_path, mask_path, _, idx_in_group, _ in self.total_data_paths[index]: + image = io.read_color_array(image_path) + mask = io.read_gray_array(mask_path, thr=0) + if image.shape[:2] != mask.shape: + h, w = mask.shape + image = ops.resize(image, height=h, width=w) + + if idx_in_group == 0: + shared_transformed = self.frame_share_transformation(image=image, mask=mask) + else: + shared_transformed = A.ReplayCompose.replay( + saved_augmentations=shared_transformed["replay"], image=image, mask=mask + ) + specific_transformed = self.frame_specific_transformation( + image=shared_transformed["image"], mask=shared_transformed["mask"] + ) + image = specific_transformed["image"] + mask = specific_transformed["mask"] + + images = ops.ms_resize(image, scales=(0.5, 1.0, 1.5), base_h=base_h, base_w=base_w) + mask = ops.resize(mask, height=base_h, width=base_w) + + image_ss.append(torch.from_numpy(images[0]).div(255).permute(2, 0, 1)) + image_ms.append(torch.from_numpy(images[1]).div(255).permute(2, 0, 1)) + image_ls.append(torch.from_numpy(images[2]).div(255).permute(2, 0, 1)) + masks.append(torch.from_numpy(mask).unsqueeze(0)) + + return dict( + data={ + "image_s": torch.stack(image_ss, dim=0), + "image_m": torch.stack(image_ms, dim=0), + "image_l": torch.stack(image_ls, dim=0), + "mask": torch.stack(masks, dim=0), + } + ) + + def __len__(self): + return len(self.total_data_paths) + + +class Evaluator: + def __init__(self, device, metric_names, clip_range=None): + self.device = device + self.clip_range = clip_range + self.metric_names = metric_names + + @torch.no_grad() + def eval(self, model, data_loader, save_path=""): + model.eval() + all_metrics = recorder.GroupedMetricRecorder(metric_names=self.metric_names) + + num_frames = data_loader.dataset.num_frames + for batch in tqdm(data_loader, total=len(data_loader), ncols=79, desc="[EVAL]"): + batch_images = pt_utils.to_device(batch["data"], device=self.device) + batch_images = {k: v.flatten(0, 1) for k, v in batch_images.items()} # B_T,C,H,W + logits = model(data=batch_images) # BT,1,H,W + probs = logits.sigmoid().squeeze(1).cpu().detach() + probs = probs - probs.min() + probs = probs / (probs.max() + 1e-8) + probs = torch.unflatten(probs, dim=0, sizes=(-1, num_frames)) + probs = probs.numpy() + + mask_paths = batch["info"]["mask_path"] + group_names = batch["info"]["group_name"] + for clip_idx, clip_pred in enumerate(probs): + for frame_idx, pred in enumerate(clip_pred): + mask_path = mask_paths[f"frame_{frame_idx}"][clip_idx] + mask_array = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + mask_array[mask_array > 0] = 255 + mask_h, mask_w = mask_array.shape + pred = ops.resize(pred, height=mask_h, width=mask_w) + + if self.clip_range is not None: + pred = ops.clip_to_normalize(pred, clip_range=self.clip_range) + + group_name = group_names[clip_idx] + if save_path: # 这里的save_path包含了数据集名字 + ops.save_array_as_image( + data_array=pred, + save_name=os.path.basename(mask_path), + save_dir=os.path.join(save_path, group_name), + ) + + pred = (pred * 255).astype(np.uint8) + all_metrics.step(group_name=group_name, pre=pred, gt=mask_array, gt_path=mask_path) + seg_results, group_seg_results = all_metrics.show(return_group=True) + return seg_results, group_seg_results + + +def test(model, cfg): + test_wrapper = Evaluator(device=cfg.device, metric_names=cfg.metric_names, clip_range=cfg.test.clip_range) + + for te_name in cfg.test.data.names: + te_info = cfg.dataset_infos[te_name] + te_dataset = VideoTestDataset(dataset_info=te_info, shape=cfg.test.data.shape, num_frames=cfg.num_frames) + te_loader = data.DataLoader( + dataset=te_dataset, batch_size=cfg.test.batch_size, num_workers=cfg.test.num_workers, pin_memory=True + ) + LOGGER.info(f"Testing with testset: {te_name}: {len(te_dataset)}") + + if cfg.save_results: + save_path = os.path.join(cfg.path.save, te_name) + LOGGER.info(f"Results will be saved into {save_path}") + else: + save_path = "" + + seg_results, group_seg_results = test_wrapper.eval(model=model, data_loader=te_loader, save_path=save_path) + seg_results_str = ", ".join([f"{k}: {v:.03f}" for k, v in seg_results.items()]) + LOGGER.info(f"({te_name}): {py_utils.mapping_to_str(te_info)}\n{seg_results_str}") + + max_group_name_length = max([len(n) for n in group_seg_results.keys()]) + for group_name, group_results in group_seg_results.items(): + metric_str = "" + for metric_name, metric_value in group_results.items(): + metric_str += "|" + metric_name + " " + str(metric_value).ljust(6) + LOGGER.info(f"{group_name.rjust(max_group_name_length)}: {metric_str}") + + +def train(model, cfg): + tr_dataset = VideoTrainDataset( + dataset_infos={data_name: cfg.dataset_infos[data_name] for data_name in cfg.train.data.names}, + shape=cfg.train.data.shape, + num_frames=cfg.num_frames, + ) + LOGGER.info(f"Total Length of Video Trainset: {len(tr_dataset)}") + + tr_loader = data.DataLoader( + dataset=tr_dataset, + batch_size=cfg.train.batch_size, + num_workers=cfg.train.num_workers, + shuffle=True, + drop_last=True, + pin_memory=True, + worker_init_fn=pt_utils.customized_worker_init_fn if cfg.use_custom_worker_init else None, + ) + + counter = recorder.TrainingCounter( + epoch_length=len(tr_loader), + epoch_based=cfg.train.epoch_based, + num_epochs=cfg.train.num_epochs, + num_total_iters=cfg.train.num_iters, + ) + optimizer = pipeline.construct_optimizer( + model=model, + initial_lr=cfg.train.lr, + mode=cfg.train.optimizer.mode, + group_mode=cfg.train.optimizer.group_mode, + cfg=cfg.train.optimizer.cfg, + ) + scheduler = pipeline.Scheduler( + optimizer=optimizer, + num_iters=counter.num_total_iters, + epoch_length=counter.num_inner_iters, + scheduler_cfg=cfg.train.scheduler, + step_by_batch=cfg.train.sche_usebatch, + ) + scheduler.record_lrs(param_groups=optimizer.param_groups) + scheduler.plot_lr_coef_curve(save_path=cfg.path.pth_log) + scaler = pipeline.Scaler(optimizer, cfg.train.use_amp, set_to_none=cfg.train.optimizer.set_to_none) + + LOGGER.info(f"Scheduler:\n{scheduler}\nOptimizer:\n{optimizer}") + + loss_recorder = recorder.HistoryBuffer() + iter_time_recorder = recorder.HistoryBuffer() + + LOGGER.info(f"Image Mean: {model.normalizer.mean.flatten()}, Image Std: {model.normalizer.std.flatten()}") + if cfg.train.bn.freeze_encoder: + LOGGER.info(" >>> Freeze Backbone !!! <<< ") + model.encoder.requires_grad_(False) + + train_start_time = time.perf_counter() + for _ in range(counter.num_epochs): + LOGGER.info(f"Exp_Name: {cfg.exp_name}") + + model.train() + if cfg.train.bn.freeze_status: + pt_utils.frozen_bn_stats(model.encoder, freeze_affine=cfg.train.bn.freeze_affine) + + # an epoch starts + for batch_idx, batch in enumerate(tr_loader): + iter_start_time = time.perf_counter() + scheduler.step(curr_idx=counter.curr_iter) # update learning rate + + data_batch = pt_utils.to_device(data=batch["data"], device=cfg.device) + data_batch = {k: v.flatten(0, 1) for k, v in data_batch.items()} + with torch.cuda.amp.autocast(enabled=cfg.train.use_amp): + outputs = model(data=data_batch, iter_percentage=counter.curr_percent) + + loss = outputs["loss"] + loss_str = outputs["loss_str"] + loss = loss / cfg.train.grad_acc_step + scaler.calculate_grad(loss=loss) + if counter.every_n_iters(cfg.train.grad_acc_step): # Accumulates scaled gradients. + scaler.update_grad() + + item_loss = loss.item() + data_shape = tuple(data_batch["mask"].shape) + loss_recorder.update(value=item_loss, num=data_shape[0]) + + if cfg.log_interval > 0 and ( + counter.every_n_iters(cfg.log_interval) + or counter.is_first_inner_iter() + or counter.is_last_inner_iter() + or counter.is_last_total_iter() + ): + gpu_mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB) + eta_seconds = iter_time_recorder.avg * (counter.num_total_iters - counter.curr_iter - 1) + eta_string = f"ETA: {datetime.timedelta(seconds=int(eta_seconds))}" + progress = ( + f"{counter.curr_iter}:{counter.num_total_iters} " + f"{batch_idx}/{counter.num_inner_iters} " + f"{counter.curr_epoch}/{counter.num_epochs}" + ) + loss_info = f"{loss_str} (M:{loss_recorder.global_avg:.5f}/C:{item_loss:.5f})" + lr_info = f"LR: {optimizer.lr_string()}" + LOGGER.info(f"{eta_string}({gpu_mem}) | {progress} | {lr_info} | {loss_info} | {data_shape}") + cfg.tb_logger.write_to_tb("lr", optimizer.lr_groups(), counter.curr_iter) + cfg.tb_logger.write_to_tb("iter_loss", item_loss, counter.curr_iter) + cfg.tb_logger.write_to_tb("avg_loss", loss_recorder.global_avg, counter.curr_iter) + + if counter.curr_iter < 3: # plot some batches of the training phase + recorder.plot_results( + dict(img=data_batch["image_m"], msk=data_batch["mask"], **outputs["vis"]), + save_path=os.path.join(cfg.path.pth_log, "img", f"iter_{counter.curr_iter}.png"), + ) + + iter_time_recorder.update(value=time.perf_counter() - iter_start_time) + if counter.is_last_total_iter(): + break + counter.update_iter_counter() + + # an epoch ends + recorder.plot_results( + dict(img=data_batch["image_m"], msk=data_batch["mask"], **outputs["vis"]), + save_path=os.path.join(cfg.path.pth_log, "img", f"epoch_{counter.curr_epoch}.png"), + ) + io.save_weight(model=model, save_path=cfg.path.final_state_net) + counter.update_epoch_counter() + + cfg.tb_logger.close_tb() + io.save_weight(model=model, save_path=cfg.path.final_state_net) + + total_train_time = time.perf_counter() - train_start_time + total_other_time = datetime.timedelta(seconds=int(total_train_time - iter_time_recorder.global_sum)) + LOGGER.info( + f"Total Training Time: {datetime.timedelta(seconds=int(total_train_time))} ({total_other_time} on others)" + ) + + +def parse_cfg(): + parser = argparse.ArgumentParser("Training and evaluation script") + parser.add_argument("--config", required=True, type=str) + parser.add_argument("--data-cfg", type=str, default="./dataset.yaml") + parser.add_argument("--model-name", type=str, choices=model_zoo.__dict__.keys()) + parser.add_argument("--output-dir", type=str, default="outputs") + parser.add_argument("--load-from", type=str) + parser.add_argument("--pretrained", action="store_true") + parser.add_argument( + "--metric-names", + nargs="+", + type=str, + default=["sm", "wfm", "mae", "em", "fmeasure", "iou", "dice"], + choices=recorder.GroupedMetricRecorder.supported_metrics, + ) + parser.add_argument("--evaluate", action="store_true") + parser.add_argument("--save-results", action="store_true") + parser.add_argument("--info", type=str) + args = parser.parse_args() + + cfg = Config.fromfile(args.config) + cfg.merge_from_dict(vars(args)) + + with open(cfg.data_cfg, mode="r") as f: + cfg.dataset_infos = yaml.safe_load(f) + + cfg.proj_root = os.path.dirname(os.path.abspath(__file__)) + cfg.exp_name = py_utils.construct_exp_name(model_name=cfg.model_name, cfg=cfg) + cfg.output_dir = os.path.join(cfg.proj_root, cfg.output_dir) + cfg.path = py_utils.construct_path(output_dir=cfg.output_dir, exp_name=cfg.exp_name) + cfg.device = "cuda:0" + + py_utils.pre_mkdir(cfg.path) + with open(cfg.path.cfg_copy, encoding="utf-8", mode="w") as f: + f.write(cfg.pretty_text) + shutil.copy(__file__, cfg.path.trainer_copy) + + file_handler = logging.FileHandler(cfg.path.log) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter("[%(filename)s] %(message)s")) + LOGGER.addHandler(file_handler) + LOGGER.info(cfg.pretty_text) + + cfg.tb_logger = recorder.TBLogger(tb_root=cfg.path.tb) + return cfg + + +def main(): + cfg = parse_cfg() + pt_utils.initialize_seed_cudnn(seed=cfg.base_seed, deterministic=cfg.deterministic) + + model_class = model_zoo.__dict__.get(cfg.model_name) + assert model_class is not None, "Please check your --model-name" + model_code = inspect.getsource(model_class) + model = model_class(num_frames=cfg.num_frames, pretrained=cfg.pretrained) + LOGGER.info(model_code) + model.to(cfg.device) + + if cfg.load_from: + io.load_weight(model=model, load_path=cfg.load_from, strict=True) + + LOGGER.info(f"Number of Parameters: {sum((v.numel() for v in model.parameters(recurse=True)))}") + if not cfg.evaluate: + train(model=model, cfg=cfg) + + if cfg.evaluate or cfg.has_test: + io.save_weight(model=model, save_path=cfg.path.final_state_net) + test(model=model, cfg=cfg) + + LOGGER.info("End training...") + + +if __name__ == "__main__": + main() diff --git a/methods/__init__.py b/methods/__init__.py new file mode 100644 index 0000000..fa7e6ba --- /dev/null +++ b/methods/__init__.py @@ -0,0 +1,10 @@ +from .zoomnext.zoomnext import ( + EffB1_ZoomNeXt, + EffB4_ZoomNeXt, + PvtV2B2_ZoomNeXt, + PvtV2B3_ZoomNeXt, + PvtV2B4_ZoomNeXt, + PvtV2B5_ZoomNeXt, + RN50_ZoomNeXt, + videoPvtV2B5_ZoomNeXt, +) diff --git a/methods/backbone/__init__.py b/methods/backbone/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/methods/backbone/efficientnet.py b/methods/backbone/efficientnet.py new file mode 100644 index 0000000..8e60b6b --- /dev/null +++ b/methods/backbone/efficientnet.py @@ -0,0 +1,436 @@ +"""model.py - Model and module class for EfficientNet. +They are built to mirror those in the official TensorFlow implementation. +""" + +# Author: lukemelas (github username) +# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch +# With adjustments and added comments by workingcoder (github username). + +import torch +from torch import nn +from torch.nn import functional as F + +from .efficientnet_utils import ( + MemoryEfficientSwish, + Swish, + calculate_output_image_size, + drop_connect, + efficientnet_params, + get_model_params, + get_same_padding_conv2d, + load_pretrained_weights, + round_filters, + round_repeats, +) + +VALID_MODELS = ( + "efficientnet-b0", + "efficientnet-b1", + "efficientnet-b2", + "efficientnet-b3", + "efficientnet-b4", + "efficientnet-b5", + "efficientnet-b6", + "efficientnet-b7", + "efficientnet-b8", + # Support the construction of 'efficientnet-l2' without pretrained weights + "efficientnet-l2", +) + + +class MBConvBlock(nn.Module): + """Mobile Inverted Residual Bottleneck Block. + + Args: + block_args (namedtuple): BlockArgs, defined in utils.py. + global_params (namedtuple): GlobalParam, defined in utils.py. + image_size (tuple or list): [image_height, image_width]. + + References: + [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) + [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) + [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) + """ + + def __init__(self, block_args, global_params, image_size=None): + super().__init__() + self._block_args = block_args + self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow + self._bn_eps = global_params.batch_norm_epsilon + self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) + self.id_skip = block_args.id_skip # whether to use skip connection and drop connect + + # Expansion phase (Inverted Bottleneck) + inp = self._block_args.input_filters # number of input channels + oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels + if self._block_args.expand_ratio != 1: + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) + self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size + + # Depthwise convolution phase + k = self._block_args.kernel_size + s = self._block_args.stride + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._depthwise_conv = Conv2d( + in_channels=oup, + out_channels=oup, + groups=oup, # groups makes it depthwise + kernel_size=k, + stride=s, + bias=False, + ) + self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + image_size = calculate_output_image_size(image_size, s) + + # Squeeze and Excitation layer, if desired + if self.has_se: + Conv2d = get_same_padding_conv2d(image_size=(1, 1)) + num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) + self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) + self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) + + # Pointwise convolution phase + final_oup = self._block_args.output_filters + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) + self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) + self._swish = MemoryEfficientSwish() + + def forward(self, inputs, drop_connect_rate=None): + """MBConvBlock's forward function. + + Args: + inputs (tensor): Input tensor. + drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). + + Returns: + Output of this block after processing. + """ + + # Expansion and Depthwise Convolution + x = inputs + if self._block_args.expand_ratio != 1: + x = self._expand_conv(inputs) + x = self._bn0(x) + x = self._swish(x) + + x = self._depthwise_conv(x) + x = self._bn1(x) + x = self._swish(x) + + # Squeeze and Excitation + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._se_reduce(x_squeezed) + x_squeezed = self._swish(x_squeezed) + x_squeezed = self._se_expand(x_squeezed) + x = torch.sigmoid(x_squeezed) * x + + # Pointwise Convolution + x = self._project_conv(x) + x = self._bn2(x) + + # Skip connection and drop connect + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: + # The combination of skip connection and drop connect brings about stochastic depth. + if drop_connect_rate: + x = drop_connect(x, p=drop_connect_rate, training=self.training) + x = x + inputs # skip connection + return x + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + + +class EfficientNet(nn.Module): + """EfficientNet model. + Most easily loaded with the .from_name or .from_pretrained methods. + + Args: + blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. + global_params (namedtuple): A set of GlobalParams shared between blocks. + + References: + [1] https://arxiv.org/abs/1905.11946 (EfficientNet) + + Example: + >>> import torch + >>> from efficientnet.model1 import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> model.eval() + >>> outputs = model(inputs) + """ + + def __init__(self, blocks_args=None, global_params=None): + super().__init__() + assert isinstance(blocks_args, list), "blocks_args should be a list" + assert len(blocks_args) > 0, "block args must be greater than 0" + self._global_params = global_params + self._blocks_args = blocks_args + + # Batch norm parameters + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + + # Get stem static or dynamic convolution depending on image size + image_size = global_params.image_size + Conv2d = get_same_padding_conv2d(image_size=image_size) + + # Stem + in_channels = 3 # rgb + out_channels = round_filters(32, self._global_params) # number of output channels + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + image_size = calculate_output_image_size(image_size, 2) + + # Build blocks + self._blocks = nn.ModuleList([]) + for block_args in self._blocks_args: + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, self._global_params), + output_filters=round_filters(block_args.output_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params), + ) + + # The first block needs to take care of stride and filter size increase. + self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) + image_size = calculate_output_image_size(image_size, block_args.stride) + if block_args.num_repeat > 1: # modify block_args to keep same output size + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) + # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 + + # Head + in_channels = block_args.output_filters # output of final block + out_channels = round_filters(1280, self._global_params) + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + # Final linear layer + self._avg_pooling = nn.AdaptiveAvgPool2d(1) + # if self._global_params.include_top: + # self._dropout = nn.Dropout(self._global_params.dropout_rate) + # self._fc = nn.Linear(out_channels, self._global_params.num_classes) + + # set activation to memory efficient swish by default + self._swish = MemoryEfficientSwish() + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + def extract_endpoints(self, inputs): + """Use convolution layer to extract features + from reduction levels i in [1, 2, 3, 4, 5]. + + Args: + inputs (tensor): Input tensor. + + Returns: + Dictionary of last intermediate features + with reduction levels i in [1, 2, 3, 4, 5]. + Example: + >>> import torch + >>> from efficientnet.model1 import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> endpoints = model.extract_endpoints(inputs) + >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) + >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) + >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) + >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) + >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7]) + >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7]) + """ + endpoints = dict() + + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + prev_x = x + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + if prev_x.size(2) > x.size(2): + endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x + elif idx == len(self._blocks) - 1: + endpoints["reduction_{}".format(len(endpoints) + 1)] = x + prev_x = x + + # Head + # x = self._swish(self._bn1(self._conv_head(x))) + # endpoints['reduction_{}'.format(len(endpoints) + 1)] = x + + return endpoints + + def extract_features(self, inputs): + """use convolution layer to extract feature . + + Args: + inputs (tensor): Input tensor. + + Returns: + Output of the final convolution + layer in the efficientnet model. + """ + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + + # Head + x = self._swish(self._bn1(self._conv_head(x))) + + return x + + def forward(self, inputs): + """EfficientNet's forward function. + Calls extract_features to extract features, applies final linear layer, and returns logits. + + Args: + inputs (tensor): Input tensor. + + Returns: + Output of this model after processing. + """ + # Convolution layers + x = self.extract_features(inputs) + # Pooling and final linear layer + x = self._avg_pooling(x) + # if self._global_params.include_top: + # x = x.flatten(start_dim=1) + # x = self._dropout(x) + # x = self._fc(x) + return x + + @classmethod + def from_name(cls, model_name, in_channels=3, **override_params): + """Create an efficientnet model according to name. + + Args: + model_name (str): Name for efficientnet. + in_channels (int): Input data's channel number. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'num_classes', 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + + Returns: + An efficientnet model. + """ + cls._check_model_name_is_valid(model_name) + blocks_args, global_params = get_model_params(model_name, override_params) + model = cls(blocks_args, global_params) + model._change_in_channels(in_channels) + return model + + @classmethod + def from_pretrained( + cls, + model_name, + pretrained=True, + weights_path=None, + advprop=False, + in_channels=3, + num_classes=1000, + **override_params, + ): + """Create an efficientnet model according to name. + + Args: + model_name (str): Name for efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + advprop (bool): + Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + in_channels (int): Input data's channel number. + num_classes (int): + Number of categories for classification. + It controls the output size for final linear layer. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + + Returns: + A pretrained efficientnet model. + """ + model = cls.from_name(model_name, num_classes=num_classes, **override_params) + if pretrained: + load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=False, advprop=advprop) + model._change_in_channels(in_channels) + return model + + @classmethod + def get_image_size(cls, model_name): + """Get the input image size for a given efficientnet model. + + Args: + model_name (str): Name for efficientnet. + + Returns: + Input image size (resolution). + """ + cls._check_model_name_is_valid(model_name) + _, _, res, _ = efficientnet_params(model_name) + return res + + @classmethod + def _check_model_name_is_valid(cls, model_name): + """Validates model name. + + Args: + model_name (str): Name for efficientnet. + + Returns: + bool: Is a valid name or not. + """ + if model_name not in VALID_MODELS: + raise ValueError("model_name should be one of: " + ", ".join(VALID_MODELS)) + + def _change_in_channels(self, in_channels): + """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. + + Args: + in_channels (int): Input data's channel number. + """ + if in_channels != 3: + Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) + out_channels = round_filters(32, self._global_params) + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) diff --git a/methods/backbone/efficientnet_utils.py b/methods/backbone/efficientnet_utils.py new file mode 100644 index 0000000..f59e54a --- /dev/null +++ b/methods/backbone/efficientnet_utils.py @@ -0,0 +1,616 @@ +"""utils.py - Helper functions for building the model and for loading model parameters. + These helper functions are built to mirror those in the official TensorFlow implementation. +""" + +# Author: lukemelas (github username) +# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch +# With adjustments and added comments by workingcoder (github username). + +import re +import math +import collections +from functools import partial +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import model_zoo + + +################################################################################ +# Help functions for model architecture +################################################################################ + +# GlobalParams and BlockArgs: Two namedtuples +# Swish and MemoryEfficientSwish: Two implementations of the method +# round_filters and round_repeats: +# Functions to calculate params for scaling model width and depth ! ! ! +# get_width_and_height_from_size and calculate_output_image_size +# drop_connect: A structural design +# get_same_padding_conv2d: +# Conv2dDynamicSamePadding +# Conv2dStaticSamePadding +# get_same_padding_maxPool2d: +# MaxPool2dDynamicSamePadding +# MaxPool2dStaticSamePadding +# It's an additional function, not used in EfficientNet, +# but can be used in other model (such as EfficientDet). + +# Parameters for the entire model (stem, all blocks, and head) +GlobalParams = collections.namedtuple('GlobalParams', [ + 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate', + 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon', + 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top']) + +# Parameters for an individual model block +BlockArgs = collections.namedtuple('BlockArgs', [ + 'num_repeat', 'kernel_size', 'stride', 'expand_ratio', + 'input_filters', 'output_filters', 'se_ratio', 'id_skip']) + +# Set GlobalParams and BlockArgs's defaults +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) + +# Swish activation function +if hasattr(nn, 'SiLU'): + Swish = nn.SiLU +else: + # For compatibility with old PyTorch versions + class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +# A memory-efficient implementation of Swish function +class SwishImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_tensors[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class MemoryEfficientSwish(nn.Module): + def forward(self, x): + return SwishImplementation.apply(x) + + +def round_filters(filters, global_params): + """Calculate and round number of filters based on width multiplier. + Use width_coefficient, depth_divisor and min_depth of global_params. + + Args: + filters (int): Filters number to be calculated. + global_params (namedtuple): Global params of the model. + + Returns: + new_filters: New filters number after calculating. + """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + # TODO: modify the params names. + # maybe the names (width_divisor,min_width) + # are more suitable than (depth_divisor,min_depth). + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor # pay attention to this line when using min_depth + # follow the formula transferred from official TensorFlow implementation + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """Calculate module's repeat number of a block based on depth multiplier. + Use depth_coefficient of global_params. + + Args: + repeats (int): num_repeat to be calculated. + global_params (namedtuple): Global params of the model. + + Returns: + new repeat: New repeat number after calculating. + """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + # follow the formula transferred from official TensorFlow implementation + return int(math.ceil(multiplier * repeats)) + + +def drop_connect(inputs, p, training): + """Drop connect. + + Args: + input (tensor: BCWH): Input of this structure. + p (float: 0.0~1.0): Probability of drop connection. + training (bool): The running mode. + + Returns: + output: Output after drop connection. + """ + assert 0 <= p <= 1, 'p must be in range of [0,1]' + + if not training: + return inputs + + batch_size = inputs.shape[0] + keep_prob = 1 - p + + # generate binary_tensor mask according to probability (p for 0, 1-p for 1) + random_tensor = keep_prob + random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) + binary_tensor = torch.floor(random_tensor) + + output = inputs / keep_prob * binary_tensor + return output + + +def get_width_and_height_from_size(x): + """Obtain height and width from x. + + Args: + x (int, tuple or list): Data size. + + Returns: + size: A tuple or list (H,W). + """ + if isinstance(x, int): + return x, x + if isinstance(x, list) or isinstance(x, tuple): + return x + else: + raise TypeError() + + +def calculate_output_image_size(input_image_size, stride): + """Calculates the output image size when using Conv2dSamePadding with a stride. + Necessary for static padding. Thanks to mannatsingh for pointing this out. + + Args: + input_image_size (int, tuple or list): Size of input image. + stride (int, tuple or list): Conv2d operation's stride. + + Returns: + output_image_size: A list [H,W]. + """ + if input_image_size is None: + return None + image_height, image_width = get_width_and_height_from_size(input_image_size) + stride = stride if isinstance(stride, int) else stride[0] + image_height = int(math.ceil(image_height / stride)) + image_width = int(math.ceil(image_width / stride)) + return [image_height, image_width] + + +# Note: +# The following 'SamePadding' functions make output size equal ceil(input size/stride). +# Only when stride equals 1, can the output size be the same as input size. +# Don't be confused by their function names ! ! ! + +def get_same_padding_conv2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + + Args: + image_size (int or tuple): Size of the image. + + Returns: + Conv2dDynamicSamePadding or Conv2dStaticSamePadding. + """ + if image_size is None: + return Conv2dDynamicSamePadding + else: + return partial(Conv2dStaticSamePadding, image_size=image_size) + + +class Conv2dDynamicSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow, for a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + # Tips for 'SAME' mode padding. + # Given the following: + # i: width or height + # s: stride + # k: kernel size + # d: dilation + # p: padding + # Output after Conv2d: + # o = floor((i+p-((k-1)*d+1))/s+1) + # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1), + # => p = (i-1)*s+((k-1)*d+1)-i + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): + super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! ! + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class Conv2dStaticSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + # With the same calculation as Conv2dDynamicSamePadding + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs): + super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, + pad_h // 2, pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +def get_same_padding_maxPool2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + + Args: + image_size (int or tuple): Size of the image. + + Returns: + MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. + """ + if image_size is None: + return MaxPool2dDynamicSamePadding + else: + return partial(MaxPool2dStaticSamePadding, image_size=image_size) + + +class MaxPool2dDynamicSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False): + super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) + self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.max_pool2d(x, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, self.return_indices) + + +class MaxPool2dStaticSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + def __init__(self, kernel_size, stride, image_size=None, **kwargs): + super().__init__(kernel_size, stride, **kwargs) + self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, self.return_indices) + return x + + +################################################################################ +# Helper functions for loading model params +################################################################################ + +# BlockDecoder: A Class for encoding and decoding BlockArgs +# efficientnet_params: A function to query compound coefficient +# get_model_params and efficientnet: +# Functions to get BlockArgs and GlobalParams for efficientnet +# url_map and url_map_advprop: Dicts of url_map for pretrained weights +# load_pretrained_weights: A function to load pretrained weights + +class BlockDecoder(object): + """Block Decoder for readability, + straight from the official TensorFlow repository. + """ + + @staticmethod + def _decode_block_string(block_string): + """Get a block through a string notation of arguments. + + Args: + block_string (str): A string notation of arguments. + Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. + + Returns: + BlockArgs: The namedtuple defined at the top of this file. + """ + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert (('s' in options and len(options['s']) == 1) or + (len(options['s']) == 2 and options['s'][0] == options['s'][1])) + + return BlockArgs( + num_repeat=int(options['r']), + kernel_size=int(options['k']), + stride=[int(options['s'][0])], + expand_ratio=int(options['e']), + input_filters=int(options['i']), + output_filters=int(options['o']), + se_ratio=float(options['se']) if 'se' in options else None, + id_skip=('noskip' not in block_string)) + + @staticmethod + def _encode_block_string(block): + """Encode a block to a string. + + Args: + block (namedtuple): A BlockArgs type argument. + + Returns: + block_string: A String form of BlockArgs. + """ + args = [ + 'r%d' % block.num_repeat, + 'k%d' % block.kernel_size, + 's%d%d' % (block.strides[0], block.strides[1]), + 'e%s' % block.expand_ratio, + 'i%d' % block.input_filters, + 'o%d' % block.output_filters + ] + if 0 < block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + if block.id_skip is False: + args.append('noskip') + return '_'.join(args) + + @staticmethod + def decode(string_list): + """Decode a list of string notations to specify blocks inside the network. + + Args: + string_list (list[str]): A list of strings, each string is a notation of block. + + Returns: + blocks_args: A list of BlockArgs namedtuples of block args. + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """Encode a list of BlockArgs to a list of strings. + + Args: + blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. + + Returns: + block_strings: A list of strings, each string is a notation of block. + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def efficientnet_params(model_name): + """Map EfficientNet model name to parameter coefficients. + + Args: + model_name (str): Model name to be queried. + + Returns: + params_dict[model_name]: A (width,depth,res,dropout) tuple. + """ + params_dict = { + # Coefficients: width,depth,res,dropout + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + } + return params_dict[model_name] + + +def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None, + dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True): + """Create BlockArgs and GlobalParams for efficientnet model. + + Args: + width_coefficient (float) + depth_coefficient (float) + image_size (int) + dropout_rate (float) + drop_connect_rate (float) + num_classes (int) + + Meaning as the name suggests. + + Returns: + blocks_args, global_params. + """ + + # Blocks args for the whole model(efficientnet-b0 by default) + # It will be modified in the construction of EfficientNet Class according to model + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25', + 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', + 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', + 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + image_size=image_size, + dropout_rate=dropout_rate, + + num_classes=num_classes, + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + drop_connect_rate=drop_connect_rate, + depth_divisor=8, + min_depth=None, + include_top=include_top, + ) + + return blocks_args, global_params + + +def get_model_params(model_name, override_params): + """Get the block args and global params for a given model name. + + Args: + model_name (str): Model's name. + override_params (dict): A dict to modify global_params. + + Returns: + blocks_args, global_params + """ + if model_name.startswith('efficientnet'): + w, d, s, p = efficientnet_params(model_name) + # note: all models have drop connect rate = 0.2 + blocks_args, global_params = efficientnet( + width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) + else: + raise NotImplementedError('model name is not pre-defined: {}'.format(model_name)) + if override_params: + # ValueError will be raised here if override_params has fields not included in global_params. + global_params = global_params._replace(**override_params) + return blocks_args, global_params + + +# train with Standard methods +# check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks) +url_map = { + 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth', + 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth', + 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth', + 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth', + 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth', + 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth', + 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth', + 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth', +} + +# train with Adversarial Examples(AdvProp) +# check more details in paper(Adversarial Examples Improve Image Recognition) +url_map_advprop = { + 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth', + 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth', + 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth', + 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth', + 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth', + 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth', + 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth', + 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth', + 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth', +} + +# TODO: add the petrained weights url map of 'efficientnet-l2' + + +def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True): + """Loads pretrained weights from weights path or download using url. + + Args: + model (Module): The whole model of efficientnet. + model_name (str): Model name of efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. + advprop (bool): Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + """ + if isinstance(weights_path, str): + state_dict = torch.load(weights_path) + else: + # AutoAugment or Advprop (different preprocessing) + url_map_ = url_map_advprop if advprop else url_map + state_dict = model_zoo.load_url(url_map_[model_name]) + + if load_fc: + ret = model.load_state_dict(state_dict, strict=False) + assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) + else: + state_dict.pop('_fc.weight') + state_dict.pop('_fc.bias') + ret = model.load_state_dict(state_dict, strict=False) + # assert set(ret.missing_keys) == set( + # ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) + assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) + + if verbose: + print('Loaded pretrained weights for {}'.format(model_name)) diff --git a/methods/backbone/pvt_v2_eff.py b/methods/backbone/pvt_v2_eff.py new file mode 100644 index 0000000..6c4eb47 --- /dev/null +++ b/methods/backbone/pvt_v2_eff.py @@ -0,0 +1,552 @@ +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.vision_transformer import _cfg +from torch.backends import cuda +from torch.hub import load_state_dict_from_url +from torch.utils.checkpoint import checkpoint + + +class Mlp(nn.Module): + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, linear=False + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.linear = linear + if self.linear: + self.relu = nn.ReLU(inplace=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + if self.linear: + x = self.relu(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, sr_ratio=1, linear=False + ): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.linear = linear + self.sr_ratio = sr_ratio + if not linear: + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.pool = nn.AdaptiveAvgPool2d(7) + self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) + self.norm = nn.LayerNorm(dim) + self.act = nn.GELU() + self.apply(self._init_weights) + + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) + if device_properties.major == 8 and device_properties.minor == 0: + # print("A100 GPU detected, using flash attention if input tensor is on cuda") + self.cuda_config = {"enable_flash": True, "enable_math": False, "enable_mem_efficient": False} + else: + # print("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda") + self.cuda_config = {"enable_flash": False, "enable_math": True, "enable_mem_efficient": True} + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if not self.linear: + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + x_ = self.act(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + # attn = (q @ k.transpose(-2, -1)) * self.scale + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) + # x = attn @ v + with cuda.sdp_kernel(**self.cuda_config): + # q: bs,nh,l,hd, k: bs,nh,s,hd, v: bs,nh,s,hd + # same as: (((q @ k.transpose(-1, -2)) * q.shape[-1] ** -0.5).softmax(dim=-1) @ v) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False) # built-in scale + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1, + linear=False, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + linear=linear, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + return x + + +class OverlapPatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + assert max(patch_size) > stride, "Set larger patch_size than stride" + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // stride, img_size[1] // stride + self.num_patches = self.H * self.W + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2), + ) + self.norm = nn.LayerNorm(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + + return x, H, W + + +class PyramidVisionTransformerV2(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + num_stages=4, + linear=False, + use_checkpoint=False, + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.num_stages = num_stages + self.embed_dims = embed_dims + self.use_checkpoint = use_checkpoint + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + + for i in range(num_stages): + patch_embed = OverlapPatchEmbed( + img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), + patch_size=7 if i == 0 else 3, + stride=4 if i == 0 else 2, + in_chans=in_chans if i == 0 else embed_dims[i - 1], + embed_dim=embed_dims[i], + ) + + block = nn.ModuleList( + [ + Block( + dim=embed_dims[i], + num_heads=num_heads[i], + mlp_ratio=mlp_ratios[i], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + j], + norm_layer=norm_layer, + sr_ratio=sr_ratios[i], + linear=linear, + ) + for j in range(depths[i]) + ] + ) + norm = norm_layer(embed_dims[i]) + cur += depths[i] + + setattr(self, f"patch_embed{i + 1}", patch_embed) + setattr(self, f"block{i + 1}", block) + setattr(self, f"norm{i + 1}", norm) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed1", "pos_embed2", "pos_embed3", "pos_embed4", "cls_token"} # has pos_embed may be better + + def extract_endpoints(self, x): + B = x.shape[0] + endpoints = dict() + for i in range(self.num_stages): + patch_embed = getattr(self, f"patch_embed{i + 1}") + block = getattr(self, f"block{i + 1}") + norm = getattr(self, f"norm{i + 1}") + x, H, W = patch_embed(x) + for blk in block: + if self.use_checkpoint: + x = checkpoint(blk, x, H, W) + else: + x = blk(x, H, W) + x = norm(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + # print(i + 2, x.shape) + endpoints["reduction_{}".format(i + 2)] = x + return endpoints + + def forward(self, x): + endpoints = self.extract_endpoints(x) + return endpoints + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + + return x + + +def _conv_filter(state_dict, patch_size=16): + """convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if "patch_embed.proj.weight" in k: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + + return out_dict + + +def pvt_v2_eff_b0(pretrained=False, **kwargs): + model = PyramidVisionTransformerV2( + patch_size=4, + embed_dims=[32, 64, 160, 256], + num_heads=[1, 2, 5, 8], + mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + **kwargs, + ) + model.default_cfg = _cfg() + if pretrained: + state_dict = load_state_dict_from_url( + "https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b0.pth", progress=True + ) + state_dict.pop("head.weight") + state_dict.pop("head.bias") + model.load_state_dict(state_dict) + return model + + +def pvt_v2_eff_b1(pretrained=False, **kwargs): + model = PyramidVisionTransformerV2( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + **kwargs, + ) + model.default_cfg = _cfg() + if pretrained: + state_dict = load_state_dict_from_url( + "https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b1.pth", progress=True + ) + state_dict.pop("head.weight") + state_dict.pop("head.bias") + model.load_state_dict(state_dict) + return model + + +def pvt_v2_eff_b2(pretrained=False, **kwargs): + model = PyramidVisionTransformerV2( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + **kwargs, + ) + model.default_cfg = _cfg() + if pretrained: + state_dict = load_state_dict_from_url( + "https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2.pth", progress=True + ) + state_dict.pop("head.weight") + state_dict.pop("head.bias") + model.load_state_dict(state_dict) + return model + + +def pvt_v2_eff_b3(pretrained=False, **kwargs): + model = PyramidVisionTransformerV2( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 4, 18, 3], + sr_ratios=[8, 4, 2, 1], + **kwargs, + ) + model.default_cfg = _cfg() + if pretrained: + state_dict = load_state_dict_from_url( + "https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b3.pth", progress=True + ) + state_dict.pop("head.weight") + state_dict.pop("head.bias") + model.load_state_dict(state_dict) + + return model + + +def pvt_v2_eff_b4(pretrained=False, **kwargs): + model = PyramidVisionTransformerV2( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 8, 27, 3], + sr_ratios=[8, 4, 2, 1], + **kwargs, + ) + model.default_cfg = _cfg() + if pretrained: + state_dict = load_state_dict_from_url( + "https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b4.pth", progress=True + ) + state_dict.pop("head.weight") + state_dict.pop("head.bias") + model.load_state_dict(state_dict) + return model + + +def pvt_v2_eff_b5(pretrained=False, **kwargs): + model = PyramidVisionTransformerV2( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 6, 40, 3], + sr_ratios=[8, 4, 2, 1], + **kwargs, + ) + model.default_cfg = _cfg() + if pretrained: + state_dict = load_state_dict_from_url( + "https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b5.pth", progress=True + ) + state_dict.pop("head.weight") + state_dict.pop("head.bias") + model.load_state_dict(state_dict) + return model + + +def pvt_v2_eff_b2_li(pretrained=False, **kwargs): + model = PyramidVisionTransformerV2( + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + linear=True, + **kwargs, + ) + model.default_cfg = _cfg() + + return model diff --git a/methods/zoomnext/__init__.py b/methods/zoomnext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/methods/zoomnext/layers.py b/methods/zoomnext/layers.py new file mode 100644 index 0000000..c232ebe --- /dev/null +++ b/methods/zoomnext/layers.py @@ -0,0 +1,182 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from .ops import ConvBNReLU, resize_to + + +class SimpleASPP(nn.Module): + def __init__(self, in_dim, out_dim, dilation=3): + """A simple ASPP variant. + + Args: + in_dim (int): Input channels. + out_dim (int): Output channels. + dilation (int, optional): Dilation of the convolution operation. Defaults to 3. + """ + super().__init__() + self.conv1x1_1 = ConvBNReLU(in_dim, 2 * out_dim, 1) + self.conv1x1_2 = ConvBNReLU(out_dim, out_dim, 1) + self.conv3x3_1 = ConvBNReLU(out_dim, out_dim, 3, dilation=dilation, padding=dilation) + self.conv3x3_2 = ConvBNReLU(out_dim, out_dim, 3, dilation=dilation, padding=dilation) + self.conv3x3_3 = ConvBNReLU(out_dim, out_dim, 3, dilation=dilation, padding=dilation) + self.fuse = nn.Sequential(ConvBNReLU(5 * out_dim, out_dim, 1), ConvBNReLU(out_dim, out_dim, 3, 1, 1)) + + def forward(self, x): + y = self.conv1x1_1(x) + y1, y5 = y.chunk(2, dim=1) + + # dilation branch + y2 = self.conv3x3_1(y1) + y3 = self.conv3x3_2(y2) + y4 = self.conv3x3_3(y3) + + # global branch + y0 = torch.mean(y5, dim=(2, 3), keepdim=True) + y0 = self.conv1x1_2(y0) + y0 = resize_to(y0, tgt_hw=x.shape[-2:]) + return self.fuse(torch.cat([y0, y1, y2, y3, y4], dim=1)) + + +class DifferenceAwareOps(nn.Module): + def __init__(self, num_frames): + super().__init__() + self.num_frames = num_frames + + self.temperal_proj_norm = nn.LayerNorm(num_frames, elementwise_affine=False) + self.temperal_proj_kv = nn.Linear(num_frames, 2 * num_frames, bias=False) + self.temperal_proj = nn.Sequential( + nn.Conv2d(num_frames, num_frames, 3, 1, 1, bias=False), + nn.ReLU(True), + nn.Conv2d(num_frames, num_frames, 3, 1, 1, bias=False), + ) + for t in self.parameters(): + nn.init.zeros_(t) + + def forward(self, x): + if self.num_frames == 1: + return x + + unshifted_x_tmp = rearrange(x, "(b t) c h w -> b c h w t", t=self.num_frames) + B, C, H, W, T = unshifted_x_tmp.shape + shifted_x_tmp = torch.roll(unshifted_x_tmp, shifts=1, dims=-1) + diff_q = shifted_x_tmp - unshifted_x_tmp # B,C,H,W,T + diff_q = self.temperal_proj_norm(diff_q) # normalization along the time + + # merge all channels + diff_k, diff_v = self.temperal_proj_kv(diff_q).chunk(2, dim=-1) + diff_qk = torch.einsum("bxhwt, byhwt -> bxyt", diff_q, diff_k) * (H * W) ** -0.5 + temperal_diff = torch.einsum("bxyt, byhwt -> bxhwt", diff_qk.softmax(dim=2), diff_v) + + temperal_diff = rearrange(temperal_diff, "b c h w t -> (b c) t h w") + shifted_x_tmp = self.temperal_proj(temperal_diff) # combine different time step + shifted_x_tmp = rearrange(shifted_x_tmp, "(b c) t h w -> (b t) c h w", c=x.shape[1]) + return x + shifted_x_tmp + + +class RGPU(nn.Module): + def __init__(self, in_c, num_groups=6, hidden_dim=None, num_frames=1): + super().__init__() + self.num_groups = num_groups + + hidden_dim = hidden_dim or in_c // 2 + expand_dim = hidden_dim * num_groups + self.expand_conv = ConvBNReLU(in_c, expand_dim, 1) + self.gate_genator = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d(num_groups * hidden_dim, hidden_dim, 1), + nn.ReLU(True), + nn.Conv2d(hidden_dim, num_groups * hidden_dim, 1), + nn.Softmax(dim=1), + ) + + self.interact = nn.ModuleDict() + self.interact["0"] = ConvBNReLU(hidden_dim, 3 * hidden_dim, 3, 1, 1) + for group_id in range(1, num_groups - 1): + self.interact[str(group_id)] = ConvBNReLU(2 * hidden_dim, 3 * hidden_dim, 3, 1, 1) + self.interact[str(num_groups - 1)] = ConvBNReLU(2 * hidden_dim, 2 * hidden_dim, 3, 1, 1) + + self.fuse = nn.Sequential( + DifferenceAwareOps(num_frames=num_frames), + ConvBNReLU(num_groups * hidden_dim, in_c, 3, 1, 1, act_name=None), + ) + self.final_relu = nn.ReLU(True) + + def forward(self, x): + xs = self.expand_conv(x).chunk(self.num_groups, dim=1) + + outs = [] + gates = [] + + group_id = 0 + curr_x = xs[group_id] + branch_out = self.interact[str(group_id)](curr_x) + curr_out, curr_fork, curr_gate = branch_out.chunk(3, dim=1) + outs.append(curr_out) + gates.append(curr_gate) + + for group_id in range(1, self.num_groups - 1): + curr_x = torch.cat([xs[group_id], curr_fork], dim=1) + branch_out = self.interact[str(group_id)](curr_x) + curr_out, curr_fork, curr_gate = branch_out.chunk(3, dim=1) + outs.append(curr_out) + gates.append(curr_gate) + + group_id = self.num_groups - 1 + curr_x = torch.cat([xs[group_id], curr_fork], dim=1) + branch_out = self.interact[str(group_id)](curr_x) + curr_out, curr_gate = branch_out.chunk(2, dim=1) + outs.append(curr_out) + gates.append(curr_gate) + + out = torch.cat(outs, dim=1) + gate = self.gate_genator(torch.cat(gates, dim=1)) + out = self.fuse(out * gate) + return self.final_relu(out + x) + + +class MHSIU(nn.Module): + def __init__(self, in_dim, num_groups=4): + super().__init__() + self.conv_l_pre = ConvBNReLU(in_dim, in_dim, 3, 1, 1) + self.conv_s_pre = ConvBNReLU(in_dim, in_dim, 3, 1, 1) + + self.conv_l = ConvBNReLU(in_dim, in_dim, 3, 1, 1) # intra-branch + self.conv_m = ConvBNReLU(in_dim, in_dim, 3, 1, 1) # intra-branch + self.conv_s = ConvBNReLU(in_dim, in_dim, 3, 1, 1) # intra-branch + + self.conv_lms = ConvBNReLU(3 * in_dim, 3 * in_dim, 1) # inter-branch + self.initial_merge = ConvBNReLU(3 * in_dim, 3 * in_dim, 1) # inter-branch + + self.num_groups = num_groups + self.trans = nn.Sequential( + ConvBNReLU(3 * in_dim // num_groups, in_dim // num_groups, 1), + ConvBNReLU(in_dim // num_groups, in_dim // num_groups, 3, 1, 1), + nn.Conv2d(in_dim // num_groups, 3, 1), + nn.Softmax(dim=1), + ) + + def forward(self, l, m, s): + tgt_size = m.shape[2:] + + l = self.conv_l_pre(l) + l = F.adaptive_max_pool2d(l, tgt_size) + F.adaptive_avg_pool2d(l, tgt_size) + s = self.conv_s_pre(s) + s = resize_to(s, tgt_hw=m.shape[2:]) + + l = self.conv_l(l) + m = self.conv_m(m) + s = self.conv_s(s) + lms = torch.cat([l, m, s], dim=1) # BT,3C,H,W + + attn = self.conv_lms(lms) # BT,3C,H,W + attn = rearrange(attn, "bt (nb ng d) h w -> (bt ng) (nb d) h w", nb=3, ng=self.num_groups) + attn = self.trans(attn) # BTG,3,H,W + attn = attn.unsqueeze(dim=2) # BTG,3,1,H,W + + x = self.initial_merge(lms) + x = rearrange(x, "bt (nb ng d) h w -> (bt ng) nb d h w", nb=3, ng=self.num_groups) + x = (attn * x).sum(dim=1) + x = rearrange(x, "(bt ng) d h w -> bt (ng d) h w", ng=self.num_groups) + return x diff --git a/methods/zoomnext/ops.py b/methods/zoomnext/ops.py new file mode 100644 index 0000000..d97cbf1 --- /dev/null +++ b/methods/zoomnext/ops.py @@ -0,0 +1,222 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import to_2tuple + + +def rescale_2x(x: torch.Tensor, scale_factor=2): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def resize_to(x: torch.Tensor, tgt_hw: tuple): + return F.interpolate(x, size=tgt_hw, mode="bilinear", align_corners=False) + + +def global_avgpool(x: torch.Tensor): + return x.mean((-1, -2), keepdim=True) + + +def _get_act_fn(act_name, inplace=True): + if act_name == "relu": + return nn.ReLU(inplace=inplace) + elif act_name == "leaklyrelu": + return nn.LeakyReLU(negative_slope=0.1, inplace=inplace) + elif act_name == "gelu": + return nn.GELU() + elif act_name == "sigmoid": + return nn.Sigmoid() + else: + raise NotImplementedError + + +class ConvBN(nn.Module): + def __init__(self, in_dim, out_dim, k, s=1, p=0, d=1, g=1, bias=True): + super(ConvBN, self).__init__() + self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p, dilation=d, groups=g, bias=bias) + self.bn = nn.BatchNorm2d(out_dim) + + def forward(self, x): + return self.bn(self.conv(x)) + + +class CBR(nn.Module): + def __init__(self, in_dim, out_dim, k, s=1, p=0, d=1, bias=True): + super().__init__() + self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p, dilation=d, bias=bias) + self.bn = nn.BatchNorm2d(out_dim) + self.relu = nn.ReLU(True) + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) + + +class ConvBNReLU(nn.Sequential): + def __init__( + self, + in_planes, + out_planes, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + act_name="relu", + is_transposed=False, + ): + """ + Convolution-BatchNormalization-ActivationLayer + + :param in_planes: + :param out_planes: + :param kernel_size: + :param stride: + :param padding: + :param dilation: + :param groups: + :param bias: + :param act_name: None denote it doesn't use the activation layer. + :param is_transposed: True -> nn.ConvTranspose2d, False -> nn.Conv2d + """ + super().__init__() + self.in_planes = in_planes + self.out_planes = out_planes + + if is_transposed: + conv_module = nn.ConvTranspose2d + else: + conv_module = nn.Conv2d + self.add_module( + name="conv", + module=conv_module( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=to_2tuple(stride), + padding=to_2tuple(padding), + dilation=to_2tuple(dilation), + groups=groups, + bias=bias, + ), + ) + self.add_module(name="bn", module=nn.BatchNorm2d(out_planes)) + if act_name is not None: + self.add_module(name=act_name, module=_get_act_fn(act_name=act_name)) + + +class ConvGNReLU(nn.Sequential): + def __init__( + self, + in_planes, + out_planes, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + gn_groups=8, + bias=False, + act_name="relu", + inplace=True, + ): + """ + 执行流程Conv2d => GroupNormalization [=> Activation] + + Args: + in_planes: 模块输入通道数 + out_planes: 模块输出通道数 + kernel_size: 内部卷积操作的卷积核大小 + stride: 卷积步长 + padding: 卷积padding + dilation: 卷积的扩张率 + groups: 卷积分组数,需满足pytorch自身要求 + gn_groups: GroupNormalization的分组数,默认为4 + bias: 是否启用卷积的偏置,默认为False + act_name: 使用的激活函数,默认为relu,设置为None的时候则不使用激活函数 + inplace: 设置激活函数的inplace参数 + """ + super().__init__() + self.in_planes = in_planes + self.out_planes = out_planes + + self.add_module( + name="conv", + module=nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=to_2tuple(stride), + padding=to_2tuple(padding), + dilation=to_2tuple(dilation), + groups=groups, + bias=bias, + ), + ) + self.add_module(name="gn", module=nn.GroupNorm(num_groups=gn_groups, num_channels=out_planes)) + if act_name is not None: + self.add_module(name=act_name, module=_get_act_fn(act_name=act_name, inplace=inplace)) + + +class PixelNormalizer(nn.Module): + def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + """Divide pixel values by 255 = 2**8 - 1, subtract mean per channel and divide by std per channel. + + Args: + mean (tuple, optional): the mean value. Defaults to (0.485, 0.456, 0.406). + std (tuple, optional): the std value. Defaults to (0.229, 0.224, 0.225). + """ + super().__init__() + # self.norm = A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + self.register_buffer(name="mean", tensor=torch.Tensor(mean).reshape(3, 1, 1)) + self.register_buffer(name="std", tensor=torch.Tensor(std).reshape(3, 1, 1)) + + def __repr__(self): + return self.__class__.__name__ + f"(mean={self.mean.flatten()}, std={self.std.flatten()})" + + def forward(self, x): + """normalize x by the mean and std values + + Args: + x (torch.Tensor): input tensor + + Returns: + torch.Tensor: output tensor + + Albumentations: + + ``` + mean = np.array(mean, dtype=np.float32) + mean *= max_pixel_value + std = np.array(std, dtype=np.float32) + std *= max_pixel_value + denominator = np.reciprocal(std, dtype=np.float32) + + img = img.astype(np.float32) + img -= mean + img *= denominator + ``` + """ + x = x.sub(self.mean) + x = x.div(self.std) + return x + + +class LayerNorm2d(nn.Module): + """ + From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py + Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 + """ + + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/methods/zoomnext/zoomnext.py b/methods/zoomnext/zoomnext.py new file mode 100644 index 0000000..f048665 --- /dev/null +++ b/methods/zoomnext/zoomnext.py @@ -0,0 +1,376 @@ +import abc +import logging + +import numpy as np +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..backbone.efficientnet import EfficientNet +from ..backbone.pvt_v2_eff import pvt_v2_eff_b2, pvt_v2_eff_b3, pvt_v2_eff_b4, pvt_v2_eff_b5 +from .layers import MHSIU, RGPU, SimpleASPP +from .ops import ConvBNReLU, PixelNormalizer, resize_to + +LOGGER = logging.getLogger("main") + + +class _ZoomNeXt_Base(nn.Module): + @staticmethod + def get_coef(iter_percentage=1, method="cos", milestones=(0, 1)): + min_point, max_point = min(milestones), max(milestones) + min_coef, max_coef = 0, 1 + + ual_coef = 1.0 + if iter_percentage < min_point: + ual_coef = min_coef + elif iter_percentage > max_point: + ual_coef = max_coef + else: + if method == "linear": + ratio = (max_coef - min_coef) / (max_point - min_point) + ual_coef = ratio * (iter_percentage - min_point) + elif method == "cos": + perc = (iter_percentage - min_point) / (max_point - min_point) + normalized_coef = (1 - np.cos(perc * np.pi)) / 2 + ual_coef = normalized_coef * (max_coef - min_coef) + min_coef + return ual_coef + + @abc.abstractmethod + def body(self): + pass + + def forward(self, data, iter_percentage=1, **kwargs): + logits = self.body(data=data) + + if self.training: + mask = data["mask"] + prob = logits.sigmoid() + + losses = [] + loss_str = [] + + sod_loss = F.binary_cross_entropy_with_logits(input=logits, target=mask, reduction="mean") + losses.append(sod_loss) + loss_str.append(f"bce: {sod_loss.item():.5f}") + + ual_coef = self.get_coef(iter_percentage=iter_percentage, method="cos", milestones=(0, 1)) + ual_loss = ual_coef * (1 - (2 * prob - 1).abs().pow(2)).mean() + losses.append(ual_loss) + loss_str.append(f"powual_{ual_coef:.5f}: {ual_loss.item():.5f}") + return dict(vis=dict(sal=prob), loss=sum(losses), loss_str=" ".join(loss_str)) + else: + return logits + + def get_grouped_params(self): + param_groups = {"pretrained": [], "fixed": [], "retrained": []} + for name, param in self.named_parameters(): + if name.startswith("encoder.patch_embed1."): + param.requires_grad = False + param_groups["fixed"].append(param) + elif name.startswith("encoder."): + param_groups["pretrained"].append(param) + else: + if "clip." in name: + param.requires_grad = False + param_groups["fixed"].append(param) + else: + param_groups["retrained"].append(param) + LOGGER.info( + f"Parameter Groups:{{" + f"Pretrained: {len(param_groups['pretrained'])}, " + f"Fixed: {len(param_groups['fixed'])}, " + f"ReTrained: {len(param_groups['retrained'])}}}" + ) + return param_groups + + +class RN50_ZoomNeXt(_ZoomNeXt_Base): + def __init__(self, pretrained=True, num_frames=1, input_norm=True, mid_dim=64, siu_groups=4, hmu_groups=6): + super().__init__() + self.encoder = timm.create_model( + model_name="resnet50", features_only=True, out_indices=range(5), pretrained=False + ) + if pretrained: + params = torch.hub.load_state_dict_from_url( + url="https://github.com/lartpang/ZoomNeXt/releases/download/weights-v0.1/resnet50-timm.pth", + map_location="cpu", + ) + self.encoder.load_state_dict(params, strict=False) + + self.tra_5 = SimpleASPP(in_dim=2048, out_dim=mid_dim) + self.siu_5 = MHSIU(mid_dim, siu_groups) + self.hmu_5 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.tra_4 = ConvBNReLU(1024, mid_dim, 3, 1, 1) + self.siu_4 = MHSIU(mid_dim, siu_groups) + self.hmu_4 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.tra_3 = ConvBNReLU(512, mid_dim, 3, 1, 1) + self.siu_3 = MHSIU(mid_dim, siu_groups) + self.hmu_3 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.tra_2 = ConvBNReLU(256, mid_dim, 3, 1, 1) + self.siu_2 = MHSIU(mid_dim, siu_groups) + self.hmu_2 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.tra_1 = ConvBNReLU(64, mid_dim, 3, 1, 1) + self.siu_1 = MHSIU(mid_dim, siu_groups) + self.hmu_1 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.normalizer = PixelNormalizer() if input_norm else nn.Identity() + self.predictor = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + ConvBNReLU(64, 32, 3, 1, 1), + nn.Conv2d(32, 1, 1), + ) + + def normalize_encoder(self, x): + x = self.normalizer(x) + c1, c2, c3, c4, c5 = self.encoder(x) + return c1, c2, c3, c4, c5 + + def body(self, data): + l_trans_feats = self.normalize_encoder(data["image_l"]) + m_trans_feats = self.normalize_encoder(data["image_m"]) + s_trans_feats = self.normalize_encoder(data["image_s"]) + + l, m, s = ( + self.tra_5(l_trans_feats[4]), + self.tra_5(m_trans_feats[4]), + self.tra_5(s_trans_feats[4]), + ) + lms = self.siu_5(l=l, m=m, s=s) + x = self.hmu_5(lms) + + l, m, s = ( + self.tra_4(l_trans_feats[3]), + self.tra_4(m_trans_feats[3]), + self.tra_4(s_trans_feats[3]), + ) + lms = self.siu_4(l=l, m=m, s=s) + x = self.hmu_4(lms + resize_to(x, tgt_hw=lms.shape[-2:])) + + l, m, s = ( + self.tra_3(l_trans_feats[2]), + self.tra_3(m_trans_feats[2]), + self.tra_3(s_trans_feats[2]), + ) + lms = self.siu_3(l=l, m=m, s=s) + x = self.hmu_3(lms + resize_to(x, tgt_hw=lms.shape[-2:])) + + l, m, s = ( + self.tra_2(l_trans_feats[1]), + self.tra_2(m_trans_feats[1]), + self.tra_2(s_trans_feats[1]), + ) + lms = self.siu_2(l=l, m=m, s=s) + x = self.hmu_2(lms + resize_to(x, tgt_hw=lms.shape[-2:])) + + l, m, s = ( + self.tra_1(l_trans_feats[0]), + self.tra_1(m_trans_feats[0]), + self.tra_1(s_trans_feats[0]), + ) + lms = self.siu_1(l=l, m=m, s=s) + x = self.hmu_1(lms + resize_to(x, tgt_hw=lms.shape[-2:])) + + return self.predictor(x) + + +class PvtV2B2_ZoomNeXt(_ZoomNeXt_Base): + def __init__( + self, + pretrained=True, + num_frames=1, + input_norm=True, + mid_dim=64, + siu_groups=4, + hmu_groups=6, + use_checkpoint=False, + ): + super().__init__() + self.set_backbone(pretrained=pretrained, use_checkpoint=use_checkpoint) + + self.embed_dims = self.encoder.embed_dims + self.tra_5 = SimpleASPP(self.embed_dims[3], out_dim=mid_dim) + self.siu_5 = MHSIU(mid_dim, siu_groups) + self.hmu_5 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.tra_4 = ConvBNReLU(self.embed_dims[2], mid_dim, 3, 1, 1) + self.siu_4 = MHSIU(mid_dim, siu_groups) + self.hmu_4 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.tra_3 = ConvBNReLU(self.embed_dims[1], mid_dim, 3, 1, 1) + self.siu_3 = MHSIU(mid_dim, siu_groups) + self.hmu_3 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.tra_2 = ConvBNReLU(self.embed_dims[0], mid_dim, 3, 1, 1) + self.siu_2 = MHSIU(mid_dim, siu_groups) + self.hmu_2 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.tra_1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), ConvBNReLU(64, mid_dim, 3, 1, 1) + ) + + self.normalizer = PixelNormalizer() if input_norm else nn.Identity() + self.predictor = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + ConvBNReLU(64, 32, 3, 1, 1), + nn.Conv2d(32, 1, 1), + ) + + def set_backbone(self, pretrained: bool, use_checkpoint: bool): + self.encoder = pvt_v2_eff_b2(pretrained=pretrained, use_checkpoint=use_checkpoint) + + def normalize_encoder(self, x): + x = self.normalizer(x) + features = self.encoder(x) + c2 = features["reduction_2"] + c3 = features["reduction_3"] + c4 = features["reduction_4"] + c5 = features["reduction_5"] + return c2, c3, c4, c5 + + def body(self, data): + l_trans_feats = self.normalize_encoder(data["image_l"]) + m_trans_feats = self.normalize_encoder(data["image_m"]) + s_trans_feats = self.normalize_encoder(data["image_s"]) + + l, m, s = self.tra_5(l_trans_feats[3]), self.tra_5(m_trans_feats[3]), self.tra_5(s_trans_feats[3]) + lms = self.siu_5(l=l, m=m, s=s) + x = self.hmu_5(lms) + + l, m, s = self.tra_4(l_trans_feats[2]), self.tra_4(m_trans_feats[2]), self.tra_4(s_trans_feats[2]) + lms = self.siu_4(l=l, m=m, s=s) + x = self.hmu_4(lms + resize_to(x, tgt_hw=lms.shape[-2:])) + + l, m, s = self.tra_3(l_trans_feats[1]), self.tra_3(m_trans_feats[1]), self.tra_3(s_trans_feats[1]) + lms = self.siu_3(l=l, m=m, s=s) + x = self.hmu_3(lms + resize_to(x, tgt_hw=lms.shape[-2:])) + + l, m, s = self.tra_2(l_trans_feats[0]), self.tra_2(m_trans_feats[0]), self.tra_2(s_trans_feats[0]) + lms = self.siu_2(l=l, m=m, s=s) + x = self.hmu_2(lms + resize_to(x, tgt_hw=lms.shape[-2:])) + + x = self.tra_1(x) + return self.predictor(x) + + +class PvtV2B3_ZoomNeXt(PvtV2B2_ZoomNeXt): + def set_backbone(self, pretrained: bool, use_checkpoint: bool): + self.encoder = pvt_v2_eff_b3(pretrained=pretrained, use_checkpoint=use_checkpoint) + + +class PvtV2B4_ZoomNeXt(PvtV2B2_ZoomNeXt): + def set_backbone(self, pretrained: bool, use_checkpoint: bool): + self.encoder = pvt_v2_eff_b4(pretrained=pretrained, use_checkpoint=use_checkpoint) + + +class PvtV2B5_ZoomNeXt(PvtV2B2_ZoomNeXt): + def set_backbone(self, pretrained: bool, use_checkpoint: bool): + self.encoder = pvt_v2_eff_b5(pretrained=pretrained, use_checkpoint=use_checkpoint) + + +class videoPvtV2B5_ZoomNeXt(PvtV2B5_ZoomNeXt): + def get_grouped_params(self): + param_groups = {"pretrained": [], "fixed": [], "retrained": []} + for name, param in self.named_parameters(): + if name.startswith("encoder.patch_embed1."): + param.requires_grad = False + param_groups["fixed"].append(param) + elif name.startswith("encoder."): + param_groups["pretrained"].append(param) + else: + if "temperal_proj" in name: + param_groups["retrained"].append(param) + else: + param_groups["pretrained"].append(param) + + LOGGER.info( + f"Parameter Groups:{{" + f"Pretrained: {len(param_groups['pretrained'])}, " + f"Fixed: {len(param_groups['fixed'])}, " + f"ReTrained: {len(param_groups['retrained'])}}}" + ) + return param_groups + + +class EffB1_ZoomNeXt(_ZoomNeXt_Base): + def __init__(self, pretrained, num_frames=1, input_norm=True, mid_dim=64, siu_groups=4, hmu_groups=6): + super().__init__() + self.set_backbone(pretrained) + + self.tra_5 = SimpleASPP(self.embed_dims[4], out_dim=mid_dim) + self.siu_5 = MHSIU(mid_dim, siu_groups) + self.hmu_5 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.tra_4 = ConvBNReLU(self.embed_dims[3], mid_dim, 3, 1, 1) + self.siu_4 = MHSIU(mid_dim, siu_groups) + self.hmu_4 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.tra_3 = ConvBNReLU(self.embed_dims[2], mid_dim, 3, 1, 1) + self.siu_3 = MHSIU(mid_dim, siu_groups) + self.hmu_3 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.tra_2 = ConvBNReLU(self.embed_dims[1], mid_dim, 3, 1, 1) + self.siu_2 = MHSIU(mid_dim, siu_groups) + self.hmu_2 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.tra_1 = ConvBNReLU(self.embed_dims[0], mid_dim, 3, 1, 1) + self.siu_1 = MHSIU(mid_dim, siu_groups) + self.hmu_1 = RGPU(mid_dim, hmu_groups, num_frames=num_frames) + + self.normalizer = PixelNormalizer() if input_norm else nn.Identity() + self.predictor = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + ConvBNReLU(64, 32, 3, 1, 1), + nn.Conv2d(32, 1, 1), + ) + + def set_backbone(self, pretrained): + self.encoder = EfficientNet.from_pretrained("efficientnet-b1", pretrained=pretrained) + self.embed_dims = [16, 24, 40, 112, 320] + + def normalize_encoder(self, x): + x = self.normalizer(x) + features = self.encoder.extract_endpoints(x) + c1 = features["reduction_1"] + c2 = features["reduction_2"] + c3 = features["reduction_3"] + c4 = features["reduction_4"] + c5 = features["reduction_5"] + return c1, c2, c3, c4, c5 + + def body(self, data): + l_trans_feats = self.normalize_encoder(data["image_l"]) + m_trans_feats = self.normalize_encoder(data["image_m"]) + s_trans_feats = self.normalize_encoder(data["image_s"]) + + l, m, s = self.tra_5(l_trans_feats[4]), self.tra_5(m_trans_feats[4]), self.tra_5(s_trans_feats[4]) + lms = self.siu_5(l=l, m=m, s=s) + x = self.hmu_5(lms) + + l, m, s = self.tra_4(l_trans_feats[3]), self.tra_4(m_trans_feats[3]), self.tra_4(s_trans_feats[3]) + lms = self.siu_4(l=l, m=m, s=s) + x = self.hmu_4(lms + resize_to(x, tgt_hw=lms.shape[-2:])) + + l, m, s = self.tra_3(l_trans_feats[2]), self.tra_3(m_trans_feats[2]), self.tra_3(s_trans_feats[2]) + lms = self.siu_3(l=l, m=m, s=s) + x = self.hmu_3(lms + resize_to(x, tgt_hw=lms.shape[-2:])) + + l, m, s = self.tra_2(l_trans_feats[1]), self.tra_2(m_trans_feats[1]), self.tra_2(s_trans_feats[1]) + lms = self.siu_2(l=l, m=m, s=s) + x = self.hmu_2(lms + resize_to(x, tgt_hw=lms.shape[-2:])) + + l, m, s = self.tra_1(l_trans_feats[0]), self.tra_1(m_trans_feats[0]), self.tra_1(s_trans_feats[0]) + lms = self.siu_1(l=l, m=m, s=s) + x = self.hmu_1(lms + resize_to(x, tgt_hw=lms.shape[-2:])) + + return self.predictor(x) + + +class EffB4_ZoomNeXt(EffB1_ZoomNeXt): + def set_backbone(self, pretrained): + self.encoder = EfficientNet.from_pretrained("efficientnet-b4", pretrained=pretrained) + self.embed_dims = [24, 32, 56, 160, 448] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..536c539 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,72 @@ +# https://github.com/LongTengDao/TOML/ + +[tool.isort] +# https://pycqa.github.io/isort/docs/configuration/options/ +profile = "black" +multi_line_output = 3 +filter_files = true +supported_extensions = "py" + +[tool.black] +line-length = 119 +include = '\.pyi?$' +exclude = ''' +/( + \.eggs + | \.git + | \.idea + | \.vscode + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + | output +)/ +''' + +[tool.ruff] +# Same as Black. +line-length = 119 +indent-width = 4 +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..700894b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +# Automatically generated by https://github.com/damnever/pigar. + +adjustText==0.8 +albumentations==1.3.1 +colorlog==6.8.0 +einops==0.7.0 +matplotlib==3.8.2 +mmengine==0.10.2 +numpy==1.26.2 +opencv-python-headless==4.8.1.78 +pysodmetrics==1.4.2 +PyYAML==6.0 +scipy==1.11.4 +timm==0.9.12 +tqdm==4.66.1 diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/io/__init__.py b/utils/io/__init__.py new file mode 100644 index 0000000..7b33111 --- /dev/null +++ b/utils/io/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +# @Time : 2021/5/17 +# @Author : Lart Pang +# @GitHub : https://github.com/lartpang + +from .image import read_color_array, read_gray_array +from .params import load_weight, save_weight diff --git a/utils/io/image.py b/utils/io/image.py new file mode 100644 index 0000000..2332ab4 --- /dev/null +++ b/utils/io/image.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# @Time : 2021/5/17 +# @Author : Lart Pang +# @GitHub : https://github.com/lartpang +import cv2 +import numpy as np + +from utils.ops import minmax + + +def read_gray_array(path, div_255=False, to_normalize=False, thr=-1, dtype=np.float32) -> np.ndarray: + """ + 1. read the binary image with the suffix `.jpg` or `.png` + into a grayscale ndarray + 2. (to_normalize=True) rescale the ndarray to [0, 1] + 3. (thr >= 0) binarize the ndarray with `thr` + 4. return a gray ndarray (np.float32) + """ + assert path.endswith(".jpg") or path.endswith(".png"), path + assert not div_255 or not to_normalize, path + gray_array = cv2.imread(path, cv2.IMREAD_GRAYSCALE) + assert gray_array is not None, f"Image Not Found: {path}" + + if div_255: + gray_array = gray_array / 255 + + if to_normalize: + gray_array = minmax(gray_array, up_bound=255) + + if thr >= 0: + gray_array = gray_array > thr + + return gray_array.astype(dtype) + + +def read_color_array(path: str): + assert path.endswith(".jpg") or path.endswith(".png") + bgr_array = cv2.imread(path, cv2.IMREAD_COLOR) + assert bgr_array is not None, f"Image Not Found: {path}" + rgb_array = cv2.cvtColor(bgr_array, cv2.COLOR_BGR2RGB) + return rgb_array diff --git a/utils/io/params.py b/utils/io/params.py new file mode 100644 index 0000000..65aabc1 --- /dev/null +++ b/utils/io/params.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# @Time : 2020/12/19 +# @Author : Lart Pang +# @GitHub : https://github.com/lartpang + +import os + +import torch + + +def save_weight(save_path, model): + print(f"Saving weight '{save_path}'") + if isinstance(model, dict): + model_state = model + else: + model_state = model.module.state_dict() if hasattr(model, "module") else model.state_dict() + torch.save(model_state, save_path) + print(f"Saved weight '{save_path}' " f"(only contain the net's weight)") + + +def load_weight(load_path, model, *, strict=True, skip_unmatched_shape=False): + assert os.path.exists(load_path), load_path + model_params = model.state_dict() + for k, v in torch.load(load_path, map_location="cpu").items(): + if k.endswith("module."): + k = k[7:] + if skip_unmatched_shape and v.shape != model_params[k].shape: + continue + model_params[k] = v + model.load_state_dict(model_params, strict=strict) diff --git a/utils/ops/__init__.py b/utils/ops/__init__.py new file mode 100644 index 0000000..491a5a6 --- /dev/null +++ b/utils/ops/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +# @Time : 2020/12/19 +# @Author : Lart Pang +# @GitHub : https://github.com/lartpang + + +from .array_ops import * +from .tensor_ops import * diff --git a/utils/ops/array_ops.py b/utils/ops/array_ops.py new file mode 100644 index 0000000..1ffc1db --- /dev/null +++ b/utils/ops/array_ops.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +import os + +import cv2 +import numpy as np + + +def minmax(data_array: np.ndarray, up_bound: float = None) -> np.ndarray: + """ + :: + + data_array = (data_array / up_bound) + if min_value != max_value: + data_array = (data_array - min_value) / (max_value - min_value) + + :param data_array: + :param up_bound: if is not None, data_array will devided by it before the minmax ops. + :return: + """ + if up_bound is not None: + data_array = data_array / up_bound + max_value = data_array.max() + min_value = data_array.min() + if max_value != min_value: + data_array = (data_array - min_value) / (max_value - min_value) + return data_array + + +def clip_to_normalize(data_array: np.ndarray, clip_range: tuple = None) -> np.ndarray: + clip_range = sorted(clip_range) + if len(clip_range) == 3: + clip_min, clip_mid, clip_max = clip_range + assert 0 <= clip_min < clip_mid < clip_max <= 1, clip_range + lower_array = data_array[data_array < clip_mid] + higher_array = data_array[data_array > clip_mid] + if lower_array.size > 0: + lower_array = np.clip(lower_array, a_min=clip_min, a_max=1) + max_lower = lower_array.max() + lower_array = minmax(lower_array) * max_lower + data_array[data_array < clip_mid] = lower_array + if higher_array.size > 0: + higher_array = np.clip(higher_array, a_min=0, a_max=clip_max) + min_lower = higher_array.min() + higher_array = minmax(higher_array) * (1 - min_lower) + min_lower + data_array[data_array > clip_mid] = higher_array + elif len(clip_range) == 2: + clip_min, clip_max = clip_range + assert 0 <= clip_min < clip_max <= 1, clip_range + if clip_min != 0 and clip_max != 1: + data_array = np.clip(data_array, a_min=clip_min, a_max=clip_max) + data_array = minmax(data_array) + elif clip_range is None: + data_array = minmax(data_array) + else: + raise NotImplementedError + return data_array + + +def save_array_as_image(data_array: np.ndarray, save_name: str, save_dir: str, to_minmax: bool = False): + """ + save the ndarray as a image + + Args: + data_array: np.float32 the max value is less than or equal to 1 + save_name: with special suffix + save_dir: the dirname of the image path + to_minmax: minmax the array + """ + if not os.path.exists(save_dir): + os.makedirs(save_dir) + save_path = os.path.join(save_dir, save_name) + if data_array.dtype != np.uint8: + if data_array.max() > 1: + raise Exception("the range of data_array has smoe errors") + data_array = (data_array * 255).astype(np.uint8) + if to_minmax: + data_array = minmax(data_array, up_bound=255) + data_array = (data_array * 255).astype(np.uint8) + cv2.imwrite(save_path, data_array) + + +def resize(image_array: np.ndarray, height, width, interpolation=cv2.INTER_LINEAR): + h, w = image_array.shape[:2] + if h == height and w == width: + return image_array + + resized_image_array = cv2.resize(image_array, dsize=(width, height), interpolation=interpolation) + return resized_image_array + + +def ms_resize(img, scales, base_h=None, base_w=None, interpolation=cv2.INTER_LINEAR): + assert isinstance(scales, (list, tuple)) + if base_h is None: + base_h = img.shape[0] + if base_w is None: + base_w = img.shape[1] + return [resize(img, height=int(base_h * s), width=int(base_w * s), interpolation=interpolation) for s in scales] diff --git a/utils/ops/tensor_ops.py b/utils/ops/tensor_ops.py new file mode 100644 index 0000000..18aed62 --- /dev/null +++ b/utils/ops/tensor_ops.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# @Time : 2020 +# @Author : Lart Pang +# @GitHub : https://github.com/lartpang +import torch +import torch.nn.functional as F + + +def rescale_2x(x: torch.Tensor, scale_factor=2): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def resize_to(x: torch.Tensor, tgt_hw: tuple): + return F.interpolate(x, size=tgt_hw, mode="bilinear", align_corners=False) + + +def clip_grad(params, mode, clip_cfg: dict): + if mode == "norm": + if "max_norm" not in clip_cfg: + raise ValueError("`clip_cfg` must contain `max_norm`.") + torch.nn.utils.clip_grad_norm_( + params, + max_norm=clip_cfg.get("max_norm"), + norm_type=clip_cfg.get("norm_type", 2.0), + ) + elif mode == "value": + if "clip_value" not in clip_cfg: + raise ValueError("`clip_cfg` must contain `clip_value`.") + torch.nn.utils.clip_grad_value_(params, clip_value=clip_cfg.get("clip_value")) + else: + raise NotImplementedError diff --git a/utils/pipeline/__init__.py b/utils/pipeline/__init__.py new file mode 100644 index 0000000..f6eb4bb --- /dev/null +++ b/utils/pipeline/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +# @Time : 2021/5/31 +# @Author : Lart Pang +# @GitHub : https://github.com/lartpang + +from .optimizer import construct_optimizer +from .scaler import Scaler +from .scheduler import Scheduler diff --git a/utils/pipeline/optimizer.py b/utils/pipeline/optimizer.py new file mode 100644 index 0000000..3fe96fa --- /dev/null +++ b/utils/pipeline/optimizer.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- +# @Time : 2020/12/19 +# @Author : Lart Pang +# @GitHub : https://github.com/lartpang +import types + +import torchvision.models +from torch import nn +from torch.optim import SGD, Adam, AdamW + + +def get_optimizer(mode, params, initial_lr, optim_cfg): + if mode == "sgd": + optimizer = SGD( + params=params, + lr=initial_lr, + momentum=optim_cfg["momentum"], + weight_decay=optim_cfg["weight_decay"], + nesterov=optim_cfg.get("nesterov", False), + ) + elif mode == "adamw": + optimizer = AdamW( + params=params, + lr=initial_lr, + betas=optim_cfg.get("betas", (0.9, 0.999)), + weight_decay=optim_cfg.get("weight_decay", 0), + amsgrad=optim_cfg.get("amsgrad", False), + ) + elif mode == "adam": + optimizer = Adam( + params=params, + lr=initial_lr, + betas=optim_cfg.get("betas", (0.9, 0.999)), + weight_decay=optim_cfg.get("weight_decay", 0), + amsgrad=optim_cfg.get("amsgrad", False), + ) + else: + raise NotImplementedError(mode) + return optimizer + + +def group_params(model: nn.Module, group_mode: str, initial_lr: float, optim_cfg: dict): + if group_mode == "yolov5": + """ + norm, weight, bias = [], [], [] # optimizer parameter groups + for k, v in model.named_modules(): + if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter): + bias.append(v.bias) # biases + if isinstance(v, nn.BatchNorm2d): + norm.append(v.weight) # no decay + elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter): + weight.append(v.weight) # apply decay + + if opt.adam: + optimizer = optim.Adam(norm, lr=hyp["lr0"], betas=(hyp["momentum"], 0.999)) # adjust beta1 to momentum + else: + optimizer = optim.SGD(norm, lr=hyp["lr0"], momentum=hyp["momentum"], nesterov=True) + + optimizer.add_param_group({"params": weight, "weight_decay": hyp["weight_decay"]}) # add weight with weight_decay + optimizer.add_param_group({"params": bias}) # add bias (biases) + """ + norm, weight, bias = [], [], [] # optimizer parameter groups + for k, v in model.named_modules(): + if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter): + bias.append(v.bias) # conv bias and bn bias + if isinstance(v, nn.BatchNorm2d): + norm.append(v.weight) # bn weight + elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter): + weight.append(v.weight) # conv weight + params = [ + {"params": filter(lambda p: p.requires_grad, bias), "weight_decay": 0.0}, + {"params": filter(lambda p: p.requires_grad, norm), "weight_decay": 0.0}, + {"params": filter(lambda p: p.requires_grad, weight)}, + ] + elif group_mode == "r3": + params = [ + # 不对bias参数执行weight decay操作,weight decay主要的作用就是通过对网络 + # 层的参数(包括weight和bias)做约束(L2正则化会使得网络层的参数更加平滑)达 + # 到减少模型过拟合的效果。 + { + "params": [ + param for name, param in model.named_parameters() if name[-4:] == "bias" and param.requires_grad + ], + "lr": 2 * initial_lr, + "weight_decay": 0, + }, + { + "params": [ + param for name, param in model.named_parameters() if name[-4:] != "bias" and param.requires_grad + ], + "lr": initial_lr, + "weight_decay": optim_cfg["weight_decay"], + }, + ] + elif group_mode == "all": + params = model.parameters() + elif group_mode == "finetune": + if hasattr(model, "module"): + model = model.module + assert hasattr(model, "get_grouped_params"), "Cannot get the method get_grouped_params of the model." + params_groups = model.get_grouped_params() + params = [ + { + "params": filter(lambda p: p.requires_grad, params_groups["pretrained"]), + "lr": optim_cfg.get("diff_factor", 0.1) * initial_lr, + }, + { + "params": filter(lambda p: p.requires_grad, params_groups["retrained"]), + "lr": initial_lr, + }, + ] + elif group_mode == "finetune2": + if hasattr(model, "module"): + model = model.module + assert hasattr(model, "get_grouped_params"), "Cannot get the method get_grouped_params of the model." + params_groups = model.get_grouped_params() + params = [ + { + "params": filter(lambda p: p.requires_grad, params_groups["pretrained_backbone"]), + "lr": 0.1 * initial_lr, + }, + { + "params": filter(lambda p: p.requires_grad, params_groups["pretrained_siamese"]), + "lr": 0.5 * initial_lr, + }, + { + "params": filter(lambda p: p.requires_grad, params_groups["retrained"]), + "lr": initial_lr, + }, + ] + else: + raise NotImplementedError + return params + + +def construct_optimizer(model, initial_lr, mode, group_mode, cfg): + params = group_params(model, group_mode=group_mode, initial_lr=initial_lr, optim_cfg=cfg) + optimizer = get_optimizer(mode=mode, params=params, initial_lr=initial_lr, optim_cfg=cfg) + optimizer.lr_groups = types.MethodType(get_lr_groups, optimizer) + optimizer.lr_string = types.MethodType(get_lr_strings, optimizer) + return optimizer + + +def get_lr_groups(self): + return [group["lr"] for group in self.param_groups] + + +def get_lr_strings(self): + return ",".join([f"{group['lr']:.3e}" for group in self.param_groups]) + + +if __name__ == "__main__": + model = torchvision.models.vgg11_bn() + norm, weight, bias = [], [], [] # optimizer parameter groups + for k, v in model.named_modules(): + if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter): + bias.append(v.bias) # biases + if isinstance(v, nn.BatchNorm2d): + norm.append(v.weight) # no decay + elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter): + weight.append(v.weight) # apply decay + + optimizer = Adam(norm, lr=0.001, betas=(0.98, 0.999)) # adjust beta1 to momentum + # optimizer = optim.SGD(norm, lr=hyp["lr0"], momentum=hyp["momentum"], nesterov=True) + + optimizer.add_param_group({"params": weight, "weight_decay": 1e-4}) # add weight with weight_decay + optimizer.add_param_group({"params": bias}) # add bias (biases) + + print(optimizer) diff --git a/utils/pipeline/scaler.py b/utils/pipeline/scaler.py new file mode 100644 index 0000000..9a27cde --- /dev/null +++ b/utils/pipeline/scaler.py @@ -0,0 +1,59 @@ +from functools import partial +from itertools import chain + +from torch.cuda.amp import GradScaler, autocast + +from .. import ops + + +class Scaler: + def __init__( + self, optimizer, use_fp16=False, *, set_to_none=False, clip_grad=False, clip_mode=None, clip_cfg=None + ) -> None: + self.optimizer = optimizer + self.set_to_none = set_to_none + self.autocast = autocast(enabled=use_fp16) + self.scaler = GradScaler(enabled=use_fp16) + + if clip_grad: + self.grad_clip_ops = partial(ops.clip_grad, mode=clip_mode, clip_cfg=clip_cfg) + else: + self.grad_clip_ops = None + + def calculate_grad(self, loss): + self.scaler.scale(loss).backward() + if self.grad_clip_ops is not None: + self.scaler.unscale_(self.optimizer) + self.grad_clip_ops(chain(*[group["params"] for group in self.optimizer.param_groups])) + + def update_grad(self): + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad(set_to_none=self.set_to_none) + + def state_dict(self): + r""" + Returns the state of the scaler as a :class:`dict`. It contains five entries: + + * ``"scale"`` - a Python float containing the current scale + * ``"growth_factor"`` - a Python float containing the current growth factor + * ``"backoff_factor"`` - a Python float containing the current backoff factor + * ``"growth_interval"`` - a Python int containing the current growth interval + * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. + + If this instance is not enabled, returns an empty dict. + + .. note:: + If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` + should be called after :meth:`update`. + """ + return self.scaler.state_dict() + + def load_state_dict(self, state_dict): + r""" + Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. + """ + self.scaler.load_state_dict(state_dict) diff --git a/utils/pipeline/scheduler.py b/utils/pipeline/scheduler.py new file mode 100644 index 0000000..663ece1 --- /dev/null +++ b/utils/pipeline/scheduler.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- +# @Time : 2020/12/19 +# @Author : Lart Pang +# @GitHub : https://github.com/lartpang + +import copy +import math +import os.path +import warnings +from bisect import bisect_right + +import matplotlib +import numpy as np +import torch.optim +from adjustText import adjust_text + +matplotlib.use("Agg") +from matplotlib import pyplot as plt + +# helper function ---------------------------------------------------------------------- + + +def linear_increase(low_bound, up_bound, percentage): + """low_bound + [0, 1] * (up_bound - low_bound)""" + assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]" + return low_bound + (up_bound - low_bound) * percentage + + +def cos_anneal(low_bound, up_bound, percentage): + assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]" + cos_percentage = (1 + math.cos(math.pi * percentage)) / 2.0 + return linear_increase(low_bound, up_bound, percentage=cos_percentage) + + +def poly_anneal(low_bound, up_bound, percentage, lr_decay): + assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]" + poly_percentage = pow((1 - percentage), lr_decay) + return linear_increase(low_bound, up_bound, percentage=poly_percentage) + + +def linear_anneal(low_bound, up_bound, percentage): + assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]" + return linear_increase(low_bound, up_bound, percentage=1 - percentage) + + +# coefficient function ---------------------------------------------------------------------- + + +def get_f3_coef_func(num_iters): + """ + F3Net + + :param num_iters: The number of iterations for the total process. + :return: + """ + + def get_f3_coef(curr_idx): + assert 0 <= curr_idx <= num_iters + return 1 - abs((curr_idx + 1) / (num_iters + 1) * 2 - 1) + + return get_f3_coef + + +def get_step_coef_func(gamma, milestones): + """ + lr = baselr * gamma ** 0 if curr_idx < milestones[0] + lr = baselr * gamma ** 1 if milestones[0] <= epoch < milestones[1] + ... + + :param gamma: + :param milestones: + :return: The function for generating the coefficient. + """ + if isinstance(milestones, (tuple, list)): + milestones = list(sorted(milestones)) + return lambda curr_idx: gamma ** bisect_right(milestones, curr_idx) + elif isinstance(milestones, int): + return lambda curr_idx: gamma ** ((curr_idx + 1) // milestones) + else: + raise ValueError(f"milestones only can be list/tuple/int, but now it is {type(milestones)}") + + +def get_cos_coef_func(half_cycle, min_coef, max_coef=1): + """ + :param half_cycle: The number of iterations in a half cycle. + :param min_coef: The minimum coefficient of the learning rate. + :param max_coef: The maximum coefficient of the learning rate. + :return: The function for generating the coefficient. + """ + + def get_cos_coef(curr_idx): + recomputed_idx = curr_idx % (half_cycle + 1) + # recomputed \in [0, half_cycle] + return cos_anneal(low_bound=min_coef, up_bound=max_coef, percentage=recomputed_idx / half_cycle) + + return get_cos_coef + + +def get_fatcos_coef_func(start_iter, half_cycle, min_coef, max_coef=1): + """ + :param half_cycle: The number of iterations in a half cycle. + :param min_coef: The minimum coefficient of the learning rate. + :param max_coef: The maximum coefficient of the learning rate. + :return: The function for generating the coefficient. + """ + + def get_cos_coef(curr_idx): + curr_idx = max(0, curr_idx - start_iter) + recomputed_idx = curr_idx % (half_cycle + 1) + # recomputed \in [0, half_cycle] + return cos_anneal(low_bound=min_coef, up_bound=max_coef, percentage=recomputed_idx / half_cycle) + + return get_cos_coef + + +def get_poly_coef_func(num_iters, lr_decay, min_coef, max_coef=1): + """ + :param num_iters: The number of iterations for the polynomial descent process. + :param lr_decay: The decay item of the polynomial descent process. + :param min_coef: The minimum coefficient of the learning rate. + :param max_coef: The maximum coefficient of the learning rate. + :return: The function for generating the coefficient. + """ + + def get_poly_coef(curr_idx): + assert 0 <= curr_idx <= num_iters, (curr_idx, num_iters) + return poly_anneal(low_bound=min_coef, up_bound=max_coef, percentage=curr_idx / num_iters, lr_decay=lr_decay) + + return get_poly_coef + + +# coefficient entry function ---------------------------------------------------------------------- + + +def get_scheduler_coef_func(mode, num_iters, cfg): + """ + the region is a closed interval: [0, num_iters] + """ + assert num_iters > 0 + min_coef = cfg.get("min_coef", 1e-6) + if min_coef is None or min_coef == 0: + warnings.warn(f"The min_coef ({min_coef}) of the scheduler will be replaced with 1e-6") + min_coef = 1e-6 + + if mode == "step": + coef_func = get_step_coef_func(gamma=cfg["gamma"], milestones=cfg["milestones"]) + elif mode == "cos": + if half_cycle := cfg.get("half_cycle"): + half_cycle -= 1 + else: + half_cycle = num_iters + if (num_iters - half_cycle) % (half_cycle + 1) != 0: + # idx starts from 0 + percentage = ((num_iters - half_cycle) % (half_cycle + 1)) / (half_cycle + 1) * 100 + warnings.warn( + f"The final annealing process ({percentage:.3f}%) is not complete. " + f"Please pay attention to the generated 'lr_coef_curve.png'." + ) + coef_func = get_cos_coef_func(half_cycle=half_cycle, min_coef=min_coef) + elif mode == "fatcos": + assert 0 <= cfg.start_percent < 1, cfg.start_percent + start_iter = int(cfg.start_percent * num_iters) + + num_iters -= start_iter + if half_cycle := cfg.get("half_cycle"): + half_cycle -= 1 + else: + half_cycle = num_iters + if (num_iters - half_cycle) % (half_cycle + 1) != 0: + # idx starts from 0 + percentage = ((num_iters - half_cycle) % (half_cycle + 1)) / (half_cycle + 1) * 100 + warnings.warn( + f"The final annealing process ({percentage:.3f}%) is not complete. " + f"Please pay attention to the generated 'lr_coef_curve.png'." + ) + coef_func = get_fatcos_coef_func(start_iter=start_iter, half_cycle=half_cycle, min_coef=min_coef) + elif mode == "poly": + coef_func = get_poly_coef_func(num_iters=num_iters, lr_decay=cfg["lr_decay"], min_coef=min_coef) + elif mode == "constant": + coef_func = lambda x: cfg.get("coef", 1) + elif mode == "f3": + coef_func = get_f3_coef_func(num_iters=num_iters) + else: + raise NotImplementedError(f"{mode} must be in {Scheduler.supported_scheduler_modes}") + return coef_func + + +def get_warmup_coef_func(num_iters, min_coef, max_coef=1, mode="linear"): + """ + the region is a closed interval: [0, num_iters] + """ + assert num_iters > 0 + if mode == "cos": + anneal_func = cos_anneal + elif mode == "linear": + anneal_func = linear_anneal + else: + raise NotImplementedError(f"{mode} must be in {Scheduler.supported_warmup_modes}") + + def get_warmup_coef(curr_idx): + return anneal_func(low_bound=min_coef, up_bound=max_coef, percentage=1 - curr_idx / num_iters) + + return get_warmup_coef + + +# main class ---------------------------------------------------------------------- + + +class Scheduler: + supported_scheduler_modes = ("step", "cos", "fatcos", "poly", "constant", "f3") + supported_warmup_modes = ("cos", "linear") + + def __init__(self, optimizer, num_iters, epoch_length, scheduler_cfg, step_by_batch=True): + """A customized wrapper of the scheduler. + + Args: + optimizer (): Optimizer. + num_iters (int): The total number of the iterations. + epoch_length (int): The number of the iterations of one epoch. + scheduler_cfg (dict): The config of the scheduler. + step_by_batch (bool, optional): The mode of updating the scheduler. Defaults to True. + + Raises: + NotImplementedError: + """ + self.optimizer = optimizer + self.num_iters = num_iters + self.epoch_length = epoch_length + self.step_by_batch = step_by_batch + + self.scheduler_cfg = copy.deepcopy(scheduler_cfg) + self.mode = scheduler_cfg["mode"] + if self.mode not in self.supported_scheduler_modes: + raise NotImplementedError( + f"{self.mode} is not implemented. Has been supported: {self.supported_scheduler_modes}" + ) + warmup_cfg = scheduler_cfg.get("warmup", None) + + num_warmup_iters = 0 + if warmup_cfg is not None and isinstance(warmup_cfg, dict): + num_warmup_iters = warmup_cfg["num_iters"] + if num_warmup_iters > 0: + print("Will using warmup") + self.warmup_coef_func = get_warmup_coef_func( + num_warmup_iters, + min_coef=warmup_cfg.get("initial_coef", 0.01), + mode=warmup_cfg.get("mode", "linear"), + ) + self.num_warmup_iters = num_warmup_iters + + if step_by_batch: + num_scheduler_iters = num_iters - num_warmup_iters + else: + num_scheduler_iters = (num_iters - num_warmup_iters) // epoch_length + # the region is a closed interval + self.lr_coef_func = get_scheduler_coef_func( + mode=self.mode, num_iters=num_scheduler_iters - 1, cfg=scheduler_cfg["cfg"] + ) + self.num_scheduler_iters = num_scheduler_iters + + self.last_lr_coef = 0 + self.initial_lrs = None + + def __repr__(self): + formatted_string = [ + f"{self.__class__.__name__}: (\n", + f"num_iters: {self.num_iters}\n", + f"epoch_length: {self.epoch_length}\n", + f"warmup_iter: [0, {self.num_warmup_iters})\n", + f"scheduler_iter: [{self.num_warmup_iters}, {self.num_iters - 1}]\n", + f"mode: {self.mode}\n", + f"scheduler_cfg: {self.scheduler_cfg}\n", + f"initial_lrs: {self.initial_lrs}\n", + f"step_by_batch: {self.step_by_batch}\n)", + ] + return " ".join(formatted_string) + + def record_lrs(self, param_groups): + self.initial_lrs = [g["lr"] for g in param_groups] + + def update(self, coef: float): + assert self.initial_lrs is not None, "Please run .record_lrs(optimizer) first." + for curr_group, initial_lr in zip(self.optimizer.param_groups, self.initial_lrs): + curr_group["lr"] = coef * initial_lr + + def step(self, curr_idx): + if curr_idx < self.num_warmup_iters: + # get maximum value (1.0) when curr_idx == self.num_warmup_iters + self.update(coef=self.get_lr_coef(curr_idx)) + else: + # Start from a value lower than 1 (curr_idx == self.num_warmup_iters) + if self.step_by_batch: + self.update(coef=self.get_lr_coef(curr_idx)) + else: + if curr_idx % self.epoch_length == 0: + self.update(coef=self.get_lr_coef(curr_idx)) + + def get_lr_coef(self, curr_idx): + coef = None + if curr_idx < self.num_warmup_iters: + coef = self.warmup_coef_func(curr_idx) + else: + # when curr_idx == self.num_warmup_iters, coef == 1.0 + # down from the largest coef (1.0) + if self.step_by_batch: + coef = self.lr_coef_func(curr_idx - self.num_warmup_iters) + else: + if curr_idx % self.epoch_length == 0 or curr_idx == self.num_warmup_iters: + # warmup结束后尚未开始按照epoch进行调整的学习率调整,此时需要将系数调整为最大。 + coef = self.lr_coef_func((curr_idx - self.num_warmup_iters) // self.epoch_length) + if coef is not None: + self.last_lr_coef = coef + return self.last_lr_coef + + def plot_lr_coef_curve(self, save_path=""): + plt.rc("xtick", labelsize="small") + plt.rc("ytick", labelsize="small") + + fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(8, 4), dpi=600) + # give plot a title + ax.set_title("Learning Rate Coefficient Curve") + # make axis labels + ax.set_xlabel("Iteration") + ax.set_ylabel("Coefficient") + + x_data = np.arange(self.num_iters) + y_data = np.array([self.get_lr_coef(x) for x in x_data]) + + # set lim + x_min, x_max = 0, self.num_iters - 1 + dx = self.num_iters * 0.1 + ax.set_xlim(x_min - dx, x_max + 2 * dx) + + y_min, y_max = y_data.min(), y_data.max() + dy = (y_data.max() - y_data.min()) * 0.1 + ax.set_ylim((y_min - dy, y_max + dy)) + + if self.step_by_batch: + marker_on = [0, -1] + key_point_xs = [0, self.num_iters - 1] + for idx in range(1, len(y_data) - 1): + prev_y = y_data[idx - 1] + curr_y = y_data[idx] + next_y = y_data[idx + 1] + if ( + (curr_y > prev_y and curr_y >= next_y) + or (curr_y >= prev_y and curr_y > next_y) + or (curr_y <= prev_y and curr_y < next_y) + or (curr_y < prev_y and curr_y <= next_y) + ): + marker_on.append(idx) + key_point_xs.append(idx) + + marker_on = sorted(set(marker_on)) + key_point_xs = sorted(set(key_point_xs)) + key_point_ys = [] + + texts = [] + for x in key_point_xs: + y = y_data[x] + key_point_ys.append(y) + + texts.append(ax.text(x=x, y=y, s=f"({x:d},{y:.3e})")) + adjust_text(texts, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.3")) + + # set ticks + ax.set_xticks(key_point_xs) + # ax.set_yticks(key_point_ys) + + ax.plot(x_data, y_data, marker="o", markevery=marker_on) + else: + ax.plot(x_data, y_data) + + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["left"].set_visible(True) + ax.spines["bottom"].set_visible(True) + + plt.tight_layout() + if save_path: + fig.savefig(os.path.join(save_path, "lr_coef.png")) + plt.close() + + +if __name__ == "__main__": + model = torch.nn.Conv2d(10, 10, 3, 1, 1) + sche = Scheduler( + optimizer=torch.optim.SGD(model.parameters(), lr=0.1), + num_iters=30300, + epoch_length=505, + scheduler_cfg=dict( + warmup=dict( + num_iters=6060, + initial_coef=0.01, + mode="cos", + ), + mode="cos", + cfg=dict( + half_cycle=6060, + lr_decay=0.9, + min_coef=0.001, + ), + ), + step_by_batch=True, + ) + print(sche) + sche.plot_lr_coef_curve( + # save_path="/home/lart/Coding/SOD.torch", + show=True, + ) diff --git a/utils/pt_utils.py b/utils/pt_utils.py new file mode 100644 index 0000000..65a2d0b --- /dev/null +++ b/utils/pt_utils.py @@ -0,0 +1,66 @@ +import logging +import os +import random + +import numpy as np +import torch +from torch import nn +from torch.backends import cudnn + +LOGGER = logging.getLogger("main") + + +def customized_worker_init_fn(worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + + +def set_seed_for_lib(seed): + random.seed(seed) + np.random.seed(seed) + # 为了禁止hash随机化,使得实验可复现。 + os.environ["PYTHONHASHSEED"] = str(seed) + torch.manual_seed(seed) # 为CPU设置随机种子 + torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子 + torch.cuda.manual_seed_all(seed) # 为所有GPU设置随机种子 + + +def initialize_seed_cudnn(seed, deterministic): + assert isinstance(deterministic, bool) and isinstance(seed, int) + if seed >= 0: + LOGGER.info(f"We will use a fixed seed {seed}") + else: + seed = np.random.randint(2**32) + LOGGER.info(f"We will use a random seed {seed}") + set_seed_for_lib(seed) + if not deterministic: + LOGGER.info("We will use `torch.backends.cudnn.benchmark`") + else: + LOGGER.info("We will not use `torch.backends.cudnn.benchmark`") + cudnn.enabled = True + cudnn.benchmark = not deterministic + cudnn.deterministic = deterministic + + +def to_device(data, device="cuda"): + if isinstance(data, (tuple, list)): + return [to_device(item, device) for item in data] + elif isinstance(data, dict): + return {name: to_device(item, device) for name, item in data.items()} + elif isinstance(data, torch.Tensor): + return data.to(device=device, non_blocking=True) + else: + raise TypeError(f"Unsupported type {type(data)}. Only support Tensor or tuple/list/dict containing Tensors.") + + +def frozen_bn_stats(model, freeze_affine=False): + """ + Set all the bn layers to eval mode. + Args: + model (model): model to set bn layers to eval mode. + """ + for m in model.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + m.eval() + if freeze_affine: + m.requires_grad_(False) diff --git a/utils/py_utils.py b/utils/py_utils.py new file mode 100644 index 0000000..4844a1d --- /dev/null +++ b/utils/py_utils.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +import copy +import logging +import os +import shutil +from collections import OrderedDict, abc +from datetime import datetime + +LOGGER = logging.getLogger("main") + + +def construct_path(output_dir: str, exp_name: str) -> dict: + proj_root = os.path.join(output_dir, exp_name) + exp_idx = 0 + exp_output_dir = os.path.join(proj_root, f"exp_{exp_idx}") + while os.path.exists(exp_output_dir): + exp_idx += 1 + exp_output_dir = os.path.join(proj_root, f"exp_{exp_idx}") + + tb_path = os.path.join(exp_output_dir, "tb") + save_path = os.path.join(exp_output_dir, "pre") + pth_path = os.path.join(exp_output_dir, "pth") + + final_full_model_path = os.path.join(pth_path, "checkpoint_final.pth") + final_state_path = os.path.join(pth_path, "state_final.pth") + + log_path = os.path.join(exp_output_dir, f"log_{str(datetime.now())[:10]}.txt") + cfg_copy_path = os.path.join(exp_output_dir, f"config.py") + trainer_copy_path = os.path.join(exp_output_dir, f"trainer.txt") + excel_path = os.path.join(exp_output_dir, f"results.xlsx") + + path_config = { + "output_dir": output_dir, + "pth_log": exp_output_dir, + "tb": tb_path, + "save": save_path, + "pth": pth_path, + "final_full_net": final_full_model_path, + "final_state_net": final_state_path, + "log": log_path, + "cfg_copy": cfg_copy_path, + "excel": excel_path, + "trainer_copy": trainer_copy_path, + } + + return path_config + + +def construct_exp_name(model_name: str, cfg: dict): + # bs_16_lr_0.05_e30_noamp_2gpu_noms_352 + focus_item = OrderedDict( + { + "train/batch_size": "bs", + "train/lr": "lr", + "train/num_epochs": "e", + "train/num_iters": "i", + "train/data/shape/h": "h", + "train/data/shape/w": "w", + "train/optimizer/mode": "opm", + "train/optimizer/group_mode": "opgm", + "train/scheduler/mode": "sc", + "train/scheduler/warmup/num_iters": "wu", + "train/use_amp": "amp", + } + ) + config = copy.deepcopy(cfg) + + def _format_item(_i): + if isinstance(_i, bool): + _i = "" if _i else "false" + elif isinstance(_i, (int, float)): + if _i == 0: + _i = "false" + elif isinstance(_i, (list, tuple)): + _i = "" if _i else "false" # 只是判断是否非空 + elif isinstance(_i, str): + if "_" in _i: + _i = _i.replace("_", "").lower() + elif _i is None: + _i = "none" + # else: other types and values will be returned directly + return _i + + if (epoch_based := config.train.get("epoch_based", None)) is not None and (not epoch_based): + focus_item.pop("train/num_epochs") + else: + # 默认基于epoch + focus_item.pop("train/num_iters") + + exp_names = [model_name] + for key, alias in focus_item.items(): + item = get_value_recurse(keys=key.split("/"), info=config) + formatted_item = _format_item(item) + if formatted_item == "false": + continue + exp_names.append(f"{alias.upper()}{formatted_item}") + + info = config.get("info", None) + if info: + exp_names.append(f"INFO{info.lower()}") + + return "_".join(exp_names) + + +def pre_mkdir(path_config): + # 提前创建好记录文件,避免自动创建的时候触发文件创建事件 + check_mkdir(path_config["pth_log"]) + make_log(path_config["log"], f"=== log {datetime.now()} ===") + + # 提前创建好存储预测结果和存放模型的文件夹 + check_mkdir(path_config["save"]) + check_mkdir(path_config["pth"]) + + +def check_mkdir(dir_name, delete_if_exists=False): + if not os.path.exists(dir_name): + os.makedirs(dir_name) + else: + if delete_if_exists: + print(f"{dir_name} will be re-created!!!") + shutil.rmtree(dir_name) + os.makedirs(dir_name) + + +def make_log(path, context): + with open(path, "a") as log: + log.write(f"{context}\n") + + +def iterate_nested_sequence(nested_sequence): + """ + 当前支持list/tuple/int/float/range()的多层嵌套,注意不要嵌套的太深,小心超出python默认的最大递归深度 + + 例子 + :: + + for x in iterate_nested_sequence([[1, (2, 3)], range(3, 10), 0]): + print(x) + + 1 + 2 + 3 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 0 + + :param nested_sequence: 多层嵌套的序列 + :return: generator + """ + for item in nested_sequence: + if isinstance(item, (int, float)): + yield item + elif isinstance(item, (list, tuple, range)): + yield from iterate_nested_sequence(item) + else: + raise NotImplementedError + + +def get_value_recurse(keys: list, info: dict): + curr_key, sub_keys = keys[0], keys[1:] + + if (sub_info := info.get(curr_key, "NoKey")) == "NoKey": + raise KeyError(f"{curr_key} must be contained in {info}") + + if sub_keys: + return get_value_recurse(keys=sub_keys, info=sub_info) + else: + return sub_info + + +def mapping_to_str(mapping: abc.Mapping, *, prefix: str = " ", lvl: int = 0, max_lvl: int = 1) -> str: + """ + Print the structural information of the dict. + """ + sub_lvl = lvl + 1 + cur_prefix = prefix * lvl + sub_prefix = prefix * sub_lvl + + if lvl == max_lvl: + sub_items = str(mapping) + else: + sub_items = ["{"] + for k, v in mapping.items(): + sub_item = sub_prefix + k + ": " + if isinstance(v, abc.Mapping): + sub_item += mapping_to_str(v, prefix=prefix, lvl=sub_lvl, max_lvl=max_lvl) + else: + sub_item += str(v) + sub_items.append(sub_item) + sub_items.append(cur_prefix + "}") + sub_items = "\n".join(sub_items) + return sub_items diff --git a/utils/recorder/__init__.py b/utils/recorder/__init__.py new file mode 100644 index 0000000..f40f9be --- /dev/null +++ b/utils/recorder/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .counter import TrainingCounter +from .group_metric_caller import GroupedMetricRecorder +from .logger import TBLogger +from .meter_recorder import AvgMeter, HistoryBuffer +from .visualize_results import plot_results diff --git a/utils/recorder/counter.py b/utils/recorder/counter.py new file mode 100644 index 0000000..b139573 --- /dev/null +++ b/utils/recorder/counter.py @@ -0,0 +1,75 @@ +import math + + +class TrainingCounter: + def __init__(self, epoch_length, epoch_based=True, *, num_epochs=None, num_total_iters=None) -> None: + self.num_inner_iters = epoch_length + self._iter_counter = 0 + self._epoch_counter = 0 + + if epoch_based: + assert num_epochs is not None + self.num_epochs = num_epochs + self.num_total_iters = num_epochs * epoch_length + else: + assert num_total_iters is not None + self.num_total_iters = num_total_iters + self.num_epochs = math.ceil(num_total_iters / epoch_length) + + def set_start_epoch(self, start_epoch): + self._epoch_counter = start_epoch + self._iter_counter = start_epoch * self.num_inner_iters + + def set_start_iterations(self, start_iteration): + self._iter_counter = start_iteration + self._epoch_counter = start_iteration // self.num_inner_iters + + def every_n_epochs(self, n: int) -> bool: + return (self._epoch_counter + 1) % n == 0 if n > 0 else False + + def every_n_iters(self, n: int) -> bool: + return (self._iter_counter + 1) % n == 0 if n > 0 else False + + def is_first_epoch(self) -> bool: + return self._epoch_counter == 0 + + def is_last_epoch(self) -> bool: + return self._epoch_counter == self.num_epochs - 1 + + def is_first_inner_iter(self) -> bool: + return self._iter_counter % self.num_inner_iters == 0 + + def is_last_inner_iter(self) -> bool: + return (self._iter_counter + 1) % self.num_inner_iters == 0 + + def is_first_total_iter(self) -> bool: + return self._iter_counter == 0 + + def is_last_total_iter(self) -> bool: + return self._iter_counter == self.num_total_iters - 1 + + def update_iter_counter(self): + self._iter_counter += 1 + + def update_epoch_counter(self): + self._epoch_counter += 1 + + def reset_iter_all_counter(self): + self._iter_counter = 0 + self._epoch_counter = 0 + + @property + def curr_iter(self): + return self._iter_counter + + @property + def next_iter(self): + return self._iter_counter + 1 + + @property + def curr_epoch(self): + return self._epoch_counter + + @property + def curr_percent(self): + return self._iter_counter / self.num_total_iters diff --git a/utils/recorder/group_metric_caller.py b/utils/recorder/group_metric_caller.py new file mode 100644 index 0000000..62e9ec1 --- /dev/null +++ b/utils/recorder/group_metric_caller.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +# @Time : 2021/1/4 +# @Author : Lart Pang +# @GitHub : https://github.com/lartpang + +from collections import OrderedDict + +import numpy as np +import py_sod_metrics + + +def ndarray_to_basetype(data): + """ + 将单独的ndarray,或者tuple,list或者dict中的ndarray转化为基本数据类型, + 即列表(.tolist())和python标量 + """ + + def _to_list_or_scalar(item): + listed_item = item.tolist() + if isinstance(listed_item, list) and len(listed_item) == 1: + listed_item = listed_item[0] + return listed_item + + if isinstance(data, (tuple, list)): + results = [_to_list_or_scalar(item) for item in data] + elif isinstance(data, dict): + results = {k: _to_list_or_scalar(item) for k, item in data.items()} + else: + assert isinstance(data, np.ndarray) + results = _to_list_or_scalar(data) + return results + + +def round_w_zero_padding(x, bit_width): + x = str(round(x, bit_width)) + x += "0" * (bit_width - len(x.split(".")[-1])) + return x + + +INDIVADUAL_METRIC_MAPPING = { + "sm": py_sod_metrics.Smeasure, + "wfm": py_sod_metrics.WeightedFmeasure, + "mae": py_sod_metrics.MAE, + "em": py_sod_metrics.Emeasure, +} +BINARY_CLASSIFICATION_METRIC_MAPPING = { + "fmeasure": { + "handler": py_sod_metrics.FmeasureHandler, + "kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=False, beta=0.3), + }, + "iou": { + "handler": py_sod_metrics.IOUHandler, + "kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=False), + }, + "dice": { + "handler": py_sod_metrics.DICEHandler, + "kwargs": dict(with_dynamic=True, with_adaptive=False, with_binary=False), + }, +} + + +class ImageMetricRecorder: + supported_metrics = sorted(INDIVADUAL_METRIC_MAPPING.keys()) + sorted(BINARY_CLASSIFICATION_METRIC_MAPPING.keys()) + + def __init__(self, metric_names=("sm", "wfm", "mae", "fmeasure", "em")): + """ + 用于统计各种指标的类 + """ + if not metric_names: + metric_names = self.supported_metrics + assert all([m in self.supported_metrics for m in metric_names]), f"Only support: {self.supported_metrics}" + + self.metric_objs = {} + has_existed = False + for metric_name in metric_names: + if metric_name in INDIVADUAL_METRIC_MAPPING: + self.metric_objs[metric_name] = INDIVADUAL_METRIC_MAPPING[metric_name]() + else: # metric_name in BINARY_CLASSIFICATION_METRIC_MAPPING + if not has_existed: # only init once + self.metric_objs["fmeasurev2"] = py_sod_metrics.FmeasureV2() + has_existed = True + metric_handler = BINARY_CLASSIFICATION_METRIC_MAPPING[metric_name] + self.metric_objs["fmeasurev2"].add_handler( + handler_name=metric_name, metric_handler=metric_handler["handler"](**metric_handler["kwargs"]) + ) + + def step(self, pre: np.ndarray, gt: np.ndarray, gt_path: str): + assert pre.shape == gt.shape, (pre.shape, gt.shape, gt_path) + assert pre.dtype == gt.dtype == np.uint8, (pre.dtype, gt.dtype, gt_path) + + for m_obj in self.metric_objs.values(): + m_obj.step(pre, gt) + + def get_all_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict: + sequential_results = {} + numerical_results = {} + for m_name, m_obj in self.metric_objs.items(): + info = m_obj.get_results() + if m_name == "fmeasurev2": + for _name, results in info.items(): + dynamic_results = results.get("dynamic") + adaptive_results = results.get("adaptive") + if dynamic_results is not None: + sequential_results[_name] = np.flip(dynamic_results) + numerical_results[f"max{_name}"] = dynamic_results.max() + numerical_results[f"avg{_name}"] = dynamic_results.mean() + if adaptive_results is not None: + numerical_results[f"adp{_name}"] = adaptive_results + else: + results = info[m_name] + if m_name in ("wfm", "sm", "mae"): + numerical_results[m_name] = results + elif m_name == "em": + sequential_results[m_name] = np.flip(results["curve"]) + numerical_results.update( + { + "maxem": results["curve"].max(), + "avgem": results["curve"].mean(), + "adpem": results["adp"], + } + ) + else: + raise NotImplementedError(m_name) + + if num_bits is not None and isinstance(num_bits, int): + numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()} + if not return_ndarray: + sequential_results = ndarray_to_basetype(sequential_results) + numerical_results = ndarray_to_basetype(numerical_results) + return {"sequential": sequential_results, "numerical": numerical_results} + + def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict: + return self.get_all_results(num_bits=num_bits, return_ndarray=return_ndarray)["numerical"] + + +class GroupedMetricRecorder(object): + supported_metrics = ["mae", "em", "sm", "wfm"] + sorted(BINARY_CLASSIFICATION_METRIC_MAPPING.keys()) + + def __init__(self, group_names=None, metric_names=("sm", "wfm", "mae", "fmeasure", "em")): + self.group_names = group_names + self.metric_names = metric_names + self.zero() + + def zero(self): + self.metric_recorders = {} + if self.group_names is not None: + self.metric_recorders.update( + {n: ImageMetricRecorder(metric_names=self.metric_names) for n in self.group_names} + ) + + def step(self, group_name: str, pre: np.ndarray, gt: np.ndarray, gt_path: str): + if group_name not in self.metric_recorders: + self.metric_recorders[group_name] = ImageMetricRecorder(metric_names=self.metric_names) + self.metric_recorders[group_name].step(pre, gt, gt_path) + + def show(self, num_bits: int = 3, return_group: bool = False): + groups_metrics = { + n: r.get_all_results(num_bits=None, return_ndarray=True) for n, r in self.metric_recorders.items() + } + + results = {} + for group_metrics in groups_metrics.values(): + for metric_type, metric_group in group_metrics.items(): # sequential and numerical + results.setdefault(metric_type, {}) + for metric_name, metric_array in metric_group.items(): + results[metric_type].setdefault(metric_name, []).append(metric_array) + + numerical_results = {} + for metric_type, metric_group in results.items(): + for metric_name, metric_array in metric_group.items(): + metric_array = np.mean(np.vstack(metric_array), axis=0) # average over all groups + + if metric_name in BINARY_CLASSIFICATION_METRIC_MAPPING or metric_name == "em": + if metric_type == "sequential": + numerical_results[f"max{metric_name}"] = metric_array.max() + numerical_results[f"avg{metric_name}"] = metric_array.mean() + else: + if metric_type == "numerical": + if metric_name.startswith(("max", "avg")): + # these metrics (maxfm, avgfm, maxem, avgem) will be recomputed within the group + continue + numerical_results[metric_name] = metric_array + + numerical_results = ndarray_to_basetype(numerical_results) + numerical_results = {k: round(v, num_bits) for k, v in numerical_results.items()} + numerical_results = self.sort_results(numerical_results) + if not return_group: + return numerical_results + + group_numerical_results = {} + for group_name, group_metric in groups_metrics.items(): + group_metric = {k: v.round(num_bits) for k, v in group_metric["numerical"].items()} + group_metric = ndarray_to_basetype(group_metric) + group_numerical_results[group_name] = self.sort_results(group_metric) + return numerical_results, group_numerical_results + + def sort_results(self, results: dict) -> OrderedDict: + """for a single group of metrics""" + sorted_results = OrderedDict() + all_keys = sorted(results.keys(), key=lambda item: item[::-1]) + for name in self.metric_names: + for key in all_keys: + if key.endswith(name): + sorted_results[key] = results[key] + return sorted_results diff --git a/utils/recorder/logger.py b/utils/recorder/logger.py new file mode 100644 index 0000000..eb411d0 --- /dev/null +++ b/utils/recorder/logger.py @@ -0,0 +1,23 @@ +from torch.utils.tensorboard import SummaryWriter + + +class TBLogger: + def __init__(self, tb_root): + self.tb_root = tb_root + self.tb = None + + def write_to_tb(self, name, data, curr_iter): + assert self.tb_root is not None + + if self.tb is None: + self.tb = SummaryWriter(self.tb_root) + + if not isinstance(data, (tuple, list)): + self.tb.add_scalar(f"data/{name}", data, curr_iter) + else: + for idx, data_item in enumerate(data): + self.tb.add_scalar(f"data/{name}_{idx}", data_item, curr_iter) + + def close_tb(self): + if self.tb is not None: + self.tb.close() diff --git a/utils/recorder/meter_recorder.py b/utils/recorder/meter_recorder.py new file mode 100644 index 0000000..888ad1f --- /dev/null +++ b/utils/recorder/meter_recorder.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +from collections import deque + + +class AvgMeter(object): + __slots__ = ["value", "sum", "count"] + + def __init__(self): + self.value = 0 + self.sum = 0 + self.count = 0 + + def reset(self): + self.value = 0 + self.sum = 0 + self.count = 0 + + def update(self, value, num=1): + self.value = value + self.sum += value * num + self.count += num + + @property + def avg(self): + return self.sum / self.count + + def __repr__(self) -> str: + return f"{self.avg:.5f}" + + +class HistoryBuffer: + """The class tracks a series of values and provides access to the smoothed + value over a window or the global average / sum of the sequence. + + Args: + window_size (int): The maximal number of values that can + be stored in the buffer. Defaults to 20. + + Example:: + + >>> his_buf = HistoryBuffer() + >>> his_buf.update(0.1) + >>> his_buf.update(0.2) + >>> his_buf.avg + 0.15 + """ + + def __init__(self, window_size: int = 20) -> None: + self._history = deque(maxlen=window_size) + self._count: int = 0 + self._sum: float = 0 + self.reset() + + def reset(self): + self._history.clear() + self._count = 0 + self._sum = 0 + + def update(self, value: float, num: int = 1) -> None: + """Add a new scalar value. If the length of queue exceeds ``window_size``, + the oldest element will be removed from the queue. + """ + self._history.append(value) + self._count += num + self._sum += value * num + + @property + def latest(self) -> float: + """The latest value of the queue.""" + return self._history[-1] + + @property + def avg(self) -> float: + """The average over the window.""" + if len(self._history) == 0: + return 0 + else: + return sum(self._history) / len(self._history) + + @property + def global_avg(self) -> float: + """The global average of the queue.""" + if self._count == 0: + return 0 + else: + return self._sum / self._count + + @property + def global_sum(self) -> float: + """The global sum of the queue.""" + return self._sum diff --git a/utils/recorder/visualize_results.py b/utils/recorder/visualize_results.py new file mode 100644 index 0000000..167e6e5 --- /dev/null +++ b/utils/recorder/visualize_results.py @@ -0,0 +1,43 @@ +import os + +import cv2 +import matplotlib + +matplotlib.use("Agg") +import numpy as np +import torchvision.transforms.functional as tv_tf +from torchvision.utils import make_grid + + +def plot_results(data_container, save_path, base_size=256, is_rgb=True): + """Plot the results conresponding to the batched images based on the `make_grid` method from `torchvision`. + + Args: + data_container (dict): Dict containing data you want to plot. + save_path (str): Path of the exported image. + """ + font_cfg = dict(fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, thickness=2) + + grids = [] + for subplot_id, (name, data) in enumerate(data_container.items()): + if data.ndim == 3: + data = data.unsqueeze(1) + + grid = make_grid(data, nrow=data.shape[0], padding=2, normalize=False) + grid = np.array(tv_tf.to_pil_image(grid.float())) + h, w = grid.shape[:2] + ratio = base_size / h + grid = cv2.resize(grid, dsize=None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) + + (text_w, text_h), baseline = cv2.getTextSize(text=name, **font_cfg) + text_xy = 20, 20 + text_h // 2 + baseline + cv2.putText(grid, text=name, org=text_xy, color=(255, 255, 255), **font_cfg) + + grids.append(grid) + grids = np.concatenate(grids, axis=0) # H,W,C + + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + if is_rgb: + grids = cv2.cvtColor(grids, cv2.COLOR_RGB2BGR) + cv2.imwrite(save_path, grids)