diff --git a/.circleci/config.yml b/.circleci/config.yml
index ac6568d996..e7650e63ea 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -26,6 +26,7 @@ workflows:
tools/.* lint_only false
configs/.* lint_only false
.circleci/.* lint_only false
+ .dev_scripts/.* lint_only true
base-revision: 1.x
# this is the path of the configuration we should trigger once
# path filtering and pipeline parameter value updates are
diff --git a/.circleci/test.yml b/.circleci/test.yml
index d1a3489802..b6974c6456 100644
--- a/.circleci/test.yml
+++ b/.circleci/test.yml
@@ -62,6 +62,7 @@ jobs:
pip install git+https://github.com/open-mmlab/mmengine.git@main
pip install -U openmim
mim install 'mmcv >= 2.0.0rc1'
+ mim install 'mmdet >= 3.0.0rc2'
pip install -r requirements/tests.txt
- run:
name: Build and install
@@ -98,13 +99,14 @@ jobs:
name: Build Docker image
command: |
docker build .circleci/docker -t mmedit:gpu --build-arg PYTORCH=<< parameters.torch >> --build-arg CUDA=<< parameters.cuda >> --build-arg CUDNN=<< parameters.cudnn >>
- docker run --gpus all -t -d -v /home/circleci/project:/mmedit -v /home/circleci/mmengine:/mmengine -w /mmedit --name mmedit mmedit:gpu
+ docker run --gpus all -t -d -v /home/circleci/project:/mmedit -v /home/circleci/mmengine:/mmengine -v /home/circleci/mmdetection:/mmdetection -w /mmedit --name mmedit mmedit:gpu
- run:
name: Install mmedit dependencies
command: |
docker exec mmedit pip install -e /mmengine
docker exec mmedit pip install -U openmim
docker exec mmedit mim install 'mmcv >= 2.0.0rc1'
+ docker exec mmedit mim install 'mmdet >= 3.0.0rc2'
docker exec mmedit pip install -r requirements/tests.txt
- run:
name: Build and install
diff --git a/.dev_scripts/README.md b/.dev_scripts/README.md
index 7ba0f9fe12..deb066b41f 100644
--- a/.dev_scripts/README.md
+++ b/.dev_scripts/README.md
@@ -9,6 +9,7 @@
- [4. Monitor your training](#4-monitor-your-training)
- [5. Train with a list of models](#5-train-with-a-list-of-models)
- [6. Train with skipping a list of models](#6-train-with-skipping-a-list-of-models)
+- [7. Automatically check links](#automatically-check-links)
## 1. Check UT
@@ -128,7 +129,7 @@ python .dev_scripts/train_benchmark.py mm_lol \
--quotatype=auto
```
-# 4. Monitor your training
+## 4. Monitor your training
After you submitting jobs following [3-Train-all-the-models](#3-train-all-the-models), you will find a `xxx.log` file.
This log file list all the job name of job id you have submitted. With this log file, you can monitor your training by running `.dev_scripts/job_watcher.py`.
@@ -141,7 +142,7 @@ python .dev_scripts/job_watcher.py --work-dir work_dirs/benchmark_fp32/ --log 20
Then, you will find `20220923-140317.csv`, which reports the status and recent log of each job.
-# 5. Train with a list of models
+## 5. Train with a list of models
If you only need to run some of the models, you can list all the models' name in a file, and specify the models when using `train_benchmark.py`.
@@ -162,7 +163,7 @@ python .dev_scripts/train_benchmark.py mm_lol \
Specifically, you need to enable `--rerun`, and specify the list of models to rerun by `--rerun-list`
-# 6. Train with skipping a list of models
+## 6. Train with skipping a list of models
If you want to train all the models while skipping some models, you can also list all the models' name in a file, and specify the models when running `train_benchmark.py`.
@@ -182,3 +183,53 @@ python .dev_scripts/train_benchmark.py mm_lol \
```
Specifically, you need to enable `--skip`, and specify the list of models to skip by `--skip-list`
+
+## 7. Train failed or canceled jobs
+
+If you want to rerun failed or canceled jobs in the last run, you can combine `--rerun` flag with `--rerun-failure` and `--rerun-cancel` flags.
+
+For example, the log file of the last run is `train-20221009-211904.log`, and now you want to rerun the failed jobs. You can use the following command:
+
+```bash
+python .dev_scripts/train_benchmark.py mm_lol \
+ --job-name RERUN \
+ --rerun train-20221009-211904.log \
+ --rerun-fail \
+ --run
+```
+
+We can combine `--rerun-fail` and `--rerun-cancel` with flag `---models` to rerun a **subset** of failed or canceled model.
+
+```bash
+python .dev_scripts/train_benchmark.py mm_lol \
+ --job-name RERUN \
+ --rerun train-20221009-211904.log \
+ --rerun-fail \
+ --models sagan \ # only rerun 'sagan' models in all failed tasks
+ --run
+```
+
+Specifically, `--rerun-fail` and `--rerun-cancel` can be used together to rerun both failed and cancaled jobs.
+
+## 8. `deterministic` training
+
+Set `torch.backends.cudnn.deterministic = True` and `torch.backends.cudnn.benchmark = False` can remove randomness operation in Pytorch training. You can add `--deterministic` flag when start your benchmark training to remove the influence of randomness operation.
+
+```shell
+python .dev_scripts/train_benchmark.py mm_lol --job-name xzn --models pix2pix --cpus-per-job 16 --run --deterministic
+```
+
+## 9. Automatically check links
+
+Use the following script to check whether the links in documentations are valid:
+
+```shell
+python3 .github/scripts/doc_link_checker.py --target docs/zh_cn
+python3 .github/scripts/doc_link_checker.py --target README_zh-CN.md
+python3 .github/scripts/doc_link_checker.py --target docs/en
+python3 .github/scripts/doc_link_checker.py --target README.md
+```
+
+You can specify the `--target` by a file or a directory.
+
+**Notes:** DO NOT use it in CI, because requiring too many http requirements by CI will cause 503 and CI will propabaly fail.
diff --git a/.dev_scripts/create_ceph_configs.py b/.dev_scripts/create_ceph_configs.py
index b324b10f63..3c68d988d3 100644
--- a/.dev_scripts/create_ceph_configs.py
+++ b/.dev_scripts/create_ceph_configs.py
@@ -40,7 +40,7 @@ def convert_data_config(data_cfg):
dataset: dict = data_cfg['dataset']
dataset_type: str = dataset['type']
- if 'mmcls' in dataset_type:
+ if dataset_type in ['ImageNet', 'CIFAR10']:
repo_name = 'classification'
else:
repo_name = 'editing'
@@ -112,8 +112,6 @@ def convert_data_config(data_cfg):
bg_dir_path = bg_dir_path.replace(dataroot_prefix,
ceph_dataroot_prefix)
bg_dir_path = bg_dir_path.replace(repo_name, 'detection')
- bg_dir_path = bg_dir_path.replace('openmmlab:',
- 'sproject:')
pipeline['bg_dir'] = bg_dir_path
elif type_ == 'CompositeFg':
fg_dir_path = pipeline['fg_dirs']
@@ -188,9 +186,10 @@ def update_ceph_config(filename, args, dry_run=False):
# 2. change visualizer
if hasattr(config, 'vis_backends'):
- for vis_cfg in config['vis_backends']:
- if vis_cfg['type'] == 'GenVisBackend':
- vis_cfg['ceph_path'] = work_dir
+ # TODO: support upload to ceph
+ # for vis_cfg in config['vis_backends']:
+ # if vis_cfg['type'] == 'GenVisBackend':
+ # vis_cfg['ceph_path'] = work_dir
# add pavi config
if args.add_pavi:
diff --git a/.dev_scripts/doc_link_checker.py b/.dev_scripts/doc_link_checker.py
new file mode 100644
index 0000000000..f9fdd4e42e
--- /dev/null
+++ b/.dev_scripts/doc_link_checker.py
@@ -0,0 +1,85 @@
+# Copyright (c) MegFlow. All rights reserved.
+# /bin/python3
+
+import argparse
+import os
+import re
+
+
+def make_parser():
+ parser = argparse.ArgumentParser('Doc link checker')
+ parser.add_argument(
+ '--http', default=False, type=bool, help='check http or not ')
+ parser.add_argument(
+ '--target',
+ default='./docs',
+ type=str,
+ help='the directory or file to check')
+ return parser
+
+
+pattern = re.compile(r'\[.*?\]\(.*?\)')
+
+
+def analyze_doc(home, path):
+ print('analyze {}'.format(path))
+ problem_list = []
+ code_block = 0
+ with open(path) as f:
+ lines = f.readlines()
+ for line in lines:
+ line = line.strip()
+ if line.startswith('```'):
+ code_block = 1 - code_block
+
+ if code_block > 0:
+ continue
+
+ if '[' in line and ']' in line and '(' in line and ')' in line:
+ all = pattern.findall(line)
+ for item in all:
+ # skip ![]()
+ if item.find('[') == item.find(']') - 1:
+ continue
+
+ # process the case [text()]()
+ offset = item.find('](')
+ if offset == -1:
+ continue
+ item = item[offset:]
+ start = item.find('(')
+ end = item.find(')')
+ ref = item[start + 1:end]
+
+ if ref.startswith('http') or ref.startswith('#'):
+ continue
+ if '.md#' in ref:
+ ref = ref[ref.find('#'):]
+ fullpath = os.path.join(home, ref)
+ if not os.path.exists(fullpath):
+ problem_list.append(ref)
+ else:
+ continue
+ if len(problem_list) > 0:
+ print(f'{path}:')
+ for item in problem_list:
+ print(f'\t {item}')
+ print('\n')
+ raise Exception('found link error')
+
+
+def traverse(target):
+ if os.path.isfile(target):
+ analyze_doc(os.path.dirname(target), target)
+ return
+ for home, dirs, files in os.walk(target):
+ for filename in files:
+ if filename.endswith('.md'):
+ path = os.path.join(home, filename)
+ if os.path.islink(path) is False:
+ analyze_doc(home, path)
+
+
+if __name__ == '__main__':
+ args = make_parser().parse_args()
+ traverse(args.target)
diff --git a/.dev_scripts/download_models.py b/.dev_scripts/download_models.py
index ead0796c62..f3349aaf15 100644
--- a/.dev_scripts/download_models.py
+++ b/.dev_scripts/download_models.py
@@ -76,6 +76,7 @@ def download(args):
http_prefix_long = 'https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/' # noqa
http_prefix_short = 'https://download.openmmlab.com/mmediting/'
+ http_prefix_gen = 'https://download.openmmlab.com/mmgen/'
# load model list
if args.model_list:
@@ -112,6 +113,11 @@ def download(args):
model_name = model_weight_url[len(http_prefix_long):]
elif model_weight_url.startswith(http_prefix_short):
model_name = model_weight_url[len(http_prefix_short):]
+ elif model_weight_url.startswith(http_prefix_gen):
+ model_name = model_weight_url[len(http_prefix_gen):]
+ elif model_weight_url == '':
+ print(f'{model_info.Name} weight is missing')
+ return None
else:
raise ValueError(f'Unknown url prefix. \'{model_weight_url}\'')
diff --git a/.dev_scripts/job_watcher.py b/.dev_scripts/job_watcher.py
index eaf4d3cfba..2ceaff5e1c 100644
--- a/.dev_scripts/job_watcher.py
+++ b/.dev_scripts/job_watcher.py
@@ -9,7 +9,7 @@
from pygments.util import ClassNotFound
from simple_term_menu import TerminalMenu
-CACHE_DIR = '~/.task_watcher'
+CACHE_DIR = osp.join(osp.abspath('~'), '.task_watcher')
def show_job_out(name, root, job_name_list):
diff --git a/.dev_scripts/metric_mapping.py b/.dev_scripts/metric_mapping.py
index 37094059ef..8759c47103 100644
--- a/.dev_scripts/metric_mapping.py
+++ b/.dev_scripts/metric_mapping.py
@@ -1,5 +1,10 @@
# key-in-metafile: key-in-results.pkl
METRICS_MAPPING = {
+ 'FID': {
+ 'keys': ['FID-Full-50k/fid'],
+ 'tolerance': 0.5,
+ 'rule': 'less'
+ },
'PSNR': {
'keys': ['PSNR'],
'tolerance': 0.1,
diff --git a/.dev_scripts/test_benchmark.py b/.dev_scripts/test_benchmark.py
index 50bc913390..ae36afb13b 100644
--- a/.dev_scripts/test_benchmark.py
+++ b/.dev_scripts/test_benchmark.py
@@ -100,12 +100,18 @@ def create_test_job_batch(commands, model_info, args, port, script_name):
http_prefix_short = 'https://download.openmmlab.com/mmediting/'
http_prefix_long = 'https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/' # noqa
+ http_prefix_gen = 'https://download.openmmlab.com/mmgen/'
model_weight_url = model_info.weights
if model_weight_url.startswith(http_prefix_long):
model_name = model_weight_url[len(http_prefix_long):]
elif model_weight_url.startswith(http_prefix_short):
model_name = model_weight_url[len(http_prefix_short):]
+ elif model_weight_url.startswith(http_prefix_gen):
+ model_name = model_weight_url[len(http_prefix_gen):]
+ elif model_weight_url == '':
+ print(f'{fname} weight is missing')
+ return None
else:
raise ValueError(f'Unknown url prefix. \'{model_weight_url}\'')
diff --git a/.dev_scripts/train_benchmark.py b/.dev_scripts/train_benchmark.py
index d8f5eca71a..acb900cc15 100644
--- a/.dev_scripts/train_benchmark.py
+++ b/.dev_scripts/train_benchmark.py
@@ -13,6 +13,7 @@
from rich.syntax import Syntax
from rich.table import Table
from tqdm import tqdm
+from utils import filter_jobs, parse_job_list_from_file
console = Console()
MMEDIT_ROOT = Path(__file__).absolute().parents[1]
@@ -91,8 +92,13 @@ def parse_args():
parser.add_argument('--skip', type=str, default=None)
parser.add_argument('--skip-list', default=None)
parser.add_argument('--rerun', type=str, default=None)
+ parser.add_argument(
+ '--rerun-fail', action='store_true', help='only rerun failed tasks')
+ parser.add_argument(
+ '--rerun-cancel', action='store_true', help='only rerun cancel tasks')
parser.add_argument('--rerun-list', default=None)
parser.add_argument('--gpus-per-job', type=int, default=None)
+ parser.add_argument('--cpus-per-job', type=int, default=16)
parser.add_argument(
'--amp', action='store_true', help='Whether to use amp.')
parser.add_argument(
@@ -111,6 +117,10 @@ def parse_args():
'--work-dir',
default='work_dirs/benchmark_train',
help='the dir to save metric')
+ parser.add_argument(
+ '--deterministic',
+ action='store_true',
+ help='Whether set `deterministic` during training.')
parser.add_argument(
'--run', action='store_true', help='run script directly')
parser.add_argument(
@@ -145,11 +155,22 @@ def parse_args():
args.skip_list = skip_list
print('skip_list: ', args.skip_list)
elif args.rerun is not None:
- with open(args.rerun, 'r') as fp:
- rerun_list = fp.readlines()
- rerun_list = [j.split('\n')[0] for j in rerun_list]
- args.rerun_list = rerun_list
- print('rerun_list: ', args.rerun_list)
+ job_id_list_full, job_name_list_full = parse_job_list_from_file(
+ args.rerun)
+ filter_target = []
+
+ if args.rerun_fail:
+ filter_target += ['FAILED']
+ if args.rerun_cancel:
+ filter_target += ['CANCELLED']
+
+ _, job_name_list = filter_jobs(
+ job_id_list_full,
+ job_name_list_full,
+ filter_target,
+ show_table=True,
+ table_name='Rerun List')
+ args.rerun_list = job_name_list
return args
@@ -171,14 +192,19 @@ def create_train_job_batch(commands, model_info, args, port, script_name):
config = Path(config)
assert config.exists(), f'{fname}: {config} not found.'
- # get n gpus
try:
n_gpus = int(model_info.metadata.data['GPUs'].split()[0])
except Exception:
if 'official' in model_info.config:
return None
else:
- n_gpus = 1
+ pattern = r'\d+xb\d+'
+ parse_res = re.search(pattern, config.name)
+ if not parse_res:
+ # defaults to use 1 gpu
+ n_gpus = 1
+ else:
+ n_gpus = int(parse_res.group().split('x')[0])
if args.gpus_per_job is not None:
n_gpus = min(args.gpus_per_job, n_gpus)
@@ -217,10 +243,14 @@ def create_train_job_batch(commands, model_info, args, port, script_name):
job_script += (f'#SBATCH --gres=gpu:{n_gpus}\n'
f'#SBATCH --ntasks-per-node={min(n_gpus, 8)}\n'
f'#SBATCH --ntasks={n_gpus}\n'
- f'#SBATCH --cpus-per-task=5\n\n')
+ f'#SBATCH --cpus-per-task={args.cpus_per_job}\n'
+ f'#SBATCH --kill-on-bad-exit=1\n\n')
else:
job_script += '\n\n' + 'export CUDA_VISIBLE_DEVICES=-1\n'
+ if args.deterministic:
+ job_script += 'export CUBLAS_WORKSPACE_CONFIG=:4096:8\n'
+
job_script += (f'export MASTER_PORT={port}\n'
f'{runner} -u {script_name} {config} '
f'--work-dir={work_dir} '
@@ -232,6 +262,9 @@ def create_train_job_batch(commands, model_info, args, port, script_name):
if args.amp:
job_script += ' --amp '
+ if args.deterministic:
+ job_script += ' --cfg-options randomness.deterministic=True'
+
job_script += '\n'
with open(work_dir / 'job.sh', 'w') as f:
diff --git a/.dev_scripts/utils.py b/.dev_scripts/utils.py
new file mode 100644
index 0000000000..42198a5475
--- /dev/null
+++ b/.dev_scripts/utils.py
@@ -0,0 +1,118 @@
+import os
+import os.path as osp
+from typing import Tuple
+
+from rich import print as pprint
+from rich.table import Table
+
+
+def parse_job_list(job_list) -> Tuple[list, list]:
+ """Parse task name and job id from list. All elements in `job_list` must.
+
+ be formatted as `JOBID @ JOBNAME`.
+
+ Args:
+ job_list (list[str]): Job list.
+
+ Returns:
+ Tuple[list, list]: Job ID list and Job name list.
+ """
+ assert all([
+ ' @ ' in job for job in job_list
+ ]), ('Each line of job list must be formatted like \'JOBID @ JOBNAME\'.')
+ job_id_list, job_name_list = [], []
+ for job_info in job_list:
+ job_id, job_name = job_info.split(' @ ')
+ job_id_list.append(job_id)
+ job_name_list.append(job_name)
+ return job_id_list, job_name_list
+
+
+def parse_job_list_from_file(job_list_file: str) -> Tuple[list, list]:
+ """Parse job list from file and return a tuple contains list of job id and
+ job name.
+
+ Args:
+ job_list_file (str): The path to the file list.
+
+ Returns:
+ Tuple[list, list]: A tuple contains list of job id and job name.
+ """
+ if not osp.exists(job_list_file):
+ return False
+ with open(job_list_file, 'r') as file:
+ job_list = [job.strip() for job in file.readlines()]
+ return parse_job_list(job_list)
+
+
+def get_info_from_id(job_id: str) -> dict:
+ """Get the basic information of a job id with `swatch examine` command.
+
+ Args:
+ job_id (str): The ID of the job.
+
+ Returns:
+ dict: A dict contains information of the corresponding job id.
+ """
+ # NOTE: do not have exception handling here
+ info_stream = os.popen(f'swatch examine {job_id}')
+ info_str = [line.strip() for line in info_stream.readlines()]
+ status_info = info_str[2].split()
+ try:
+ status_dict = {
+ 'JobID': status_info[0],
+ 'JobName': status_info[1],
+ 'Partition': status_info[2],
+ 'NNodes': status_info[3],
+ 'AllocCPUS': status_info[4],
+ 'State': status_info[5]
+ }
+ except Exception:
+ print(job_id)
+ print(info_str)
+ return status_dict
+
+
+def filter_jobs(job_id_list: list,
+ job_name_list: list,
+ select: list = ['FAILED'],
+ show_table: bool = False,
+ table_name: str = 'Filter Results') -> Tuple[list, list]:
+ """Filter the job which status not belong to :attr:`select`.
+
+ Args:
+ job_id_list (list): The list of job ids.
+ job_name_list (list): The list of job names.
+ select (list, optional): Which kind of jobs will be selected.
+ Defaults to ['FAILED'].
+ show_table (bool, optional): Whether display the filter result in a
+ table. Defaults to False.
+ table_name (str, optional): The name of the table. Defaults to
+ 'Filter Results'.
+
+ Returns:
+ Tuple[list]: A tuple contains selected job ids and job names.
+ """
+ # if ignore is not passed, return the original id list and name list
+ if not select:
+ return job_id_list, job_name_list
+ filtered_id_list, filtered_name_list = [], []
+ job_info_list = []
+ for id_, name_ in zip(job_id_list, job_name_list):
+ info = get_info_from_id(id_)
+ job_info_list.append(info)
+ if info['State'] in select:
+ filtered_id_list.append(id_)
+ filtered_name_list.append(name_)
+
+ if show_table:
+ filter_table = Table(title=table_name)
+ for field in ['Name', 'ID', 'State', 'Is Selected']:
+ filter_table.add_column(field)
+ for id_, name_, info_ in zip(job_id_list, job_name_list,
+ job_info_list):
+ selected = '[green]True' \
+ if info_['State'] in select else '[red]False'
+ filter_table.add_row(name_, id_, info_['State'], selected)
+ pprint(filter_table)
+ return filtered_id_list, filtered_name_list
diff --git a/.github/workflows/merge_stage_test.yml b/.github/workflows/merge_stage_test.yml
index d85c59596f..f5bf7d8c0d 100644
--- a/.github/workflows/merge_stage_test.yml
+++ b/.github/workflows/merge_stage_test.yml
@@ -8,6 +8,8 @@ on:
- 'docs/**'
- '.dev_scripts/**'
- '.circleci/**'
+ - 'configs/**'
+
branches:
- dev-1.x
- test-1.x
@@ -39,10 +41,11 @@ jobs:
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install MMEngine
run: pip install git+https://github.com/open-mmlab/mmengine.git@main
- - name: Install MMCV
+ - name: Install MMCV and MMDet
run: |
pip install -U openmim
mim install 'mmcv >= 2.0.0rc1'
+ mim install 'mmdet >= 3.0.0rc2'
- name: Install other dependencies
run: pip install -r requirements/tests.txt
- name: Build and install
@@ -86,10 +89,11 @@ jobs:
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install MMEngine
run: pip install git+https://github.com/open-mmlab/mmengine.git@main
- - name: Install MMCV
+ - name: Install MMCV and MMDet
run: |
pip install -U openmim
mim install 'mmcv >= 2.0.0rc1'
+ mim install 'mmdet >= 3.0.0rc2'
- name: Install other dependencies
run: pip install -r requirements/tests.txt
- name: Build and install
@@ -165,20 +169,21 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
- run: pip install pip --upgrade
+ run: python -m pip install pip --upgrade
- name: Install lmdb
- run: pip install lmdb
+ run: python -m pip install lmdb
- name: Install PyTorch
- run: pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
+ run: python -m pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
- name: Install mmediting dependencies
run: |
- pip install git+https://github.com/open-mmlab/mmengine.git@main
- pip install -U openmim
+ python -m pip install git+https://github.com/open-mmlab/mmengine.git@main
+ python -m pip install -U openmim
mim install 'mmcv >= 2.0.0rc1'
- pip install -r requirements/tests.txt
+ mim install 'mmdet >= 3.0.0rc2'
+ python -m pip install -r requirements/tests.txt
- name: Build and install
run: |
- pip install -e .
+ python -m pip install -e .
- name: Run unittests and generate coverage report
run: |
pytest tests/
diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml
index e99ed3b190..e27703b379 100644
--- a/.github/workflows/pr_stage_test.yml
+++ b/.github/workflows/pr_stage_test.yml
@@ -8,6 +8,7 @@ on:
- 'docs/**'
- '.dev_scripts/**'
- '.circleci/**'
+ - 'configs/**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
@@ -34,10 +35,11 @@ jobs:
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install MMEngine
run: pip install git+https://github.com/open-mmlab/mmengine.git@main
- - name: Install MMCV
+ - name: Install MMCV and MMDet
run: |
pip install -U openmim
mim install 'mmcv >= 2.0.0rc1'
+ mim install 'mmdet >= 3.0.0rc2'
- name: Install other dependencies
run: pip install -r requirements/tests.txt
- name: Build and install
@@ -93,6 +95,7 @@ jobs:
pip install git+https://github.com/open-mmlab/mmengine.git@main
pip install -U openmim
mim install 'mmcv >= 2.0.0rc1'
+ mim install 'mmdet >= 3.0.0rc2'
pip install -r requirements/tests.txt
- name: Build and install
run: |
@@ -116,20 +119,21 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
- run: pip install pip --upgrade
+ run: python -m pip install pip --upgrade
- name: Install lmdb
- run: pip install lmdb
+ run: python -m pip install lmdb
- name: Install PyTorch
- run: pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
+ run: python -m pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
- name: Install mmedit dependencies
run: |
- pip install git+https://github.com/open-mmlab/mmengine.git@main
- pip install -U openmim
+ python -m pip install git+https://github.com/open-mmlab/mmengine.git@main
+ python -m pip install -U openmim
mim install 'mmcv >= 2.0.0rc1'
- pip install -r requirements/tests.txt
+ mim install 'mmdet >= 3.0.0rc2'
+ python -m pip install -r requirements/tests.txt
- name: Build and install
run: |
- pip install -e .
+ python -m pip install -e .
- name: Run unittests and generate coverage report
run: |
pytest tests/
diff --git a/README.md b/README.md
index f79f5b4551..9ace05994b 100644
--- a/README.md
+++ b/README.md
@@ -104,22 +104,26 @@ hope MMEditing could provide better experience.
## What's New
-- \[2022-09-13\] 🎉[MMGeneration](https://github.com/open-mmlab/mmgeneration/tree/1.x) was merged into MMEditing! And we are calling for your [suggestion](https://github.com/open-mmlab/mmediting/discussions/1108)!
-- \[2022-08-31\] v1.0.0rc0 was released. This release introduced a brand new and flexible training & test engine, but it's still in progress. Welcome
- to try according to [the documentation](https://mmediting.readthedocs.io/en/1.x/).
-- \[2022-06-01\] v0.15.0 was released.
- - Support FLAVR
- - Support AOT-GAN
- - Support CAIN with ReduceLROnPlateau Scheduler
-- \[2022-04-01\] v0.14.0 was released.
- - Support TOFlow in video frame interpolation
-- \[2022-03-01\] v0.13.0 was released.
- - Support CAIN
- - Support EDVR-L
- - Support running in Windows
-- \[2022-02-11\] Switch to **PyTorch 1.5+**. The compatibility to earlier versions of PyTorch will no longer be guaranteed.
-
-Please refer to [changelog.md](docs/en/notes/3_changelog.md) for details and release history.
+### 🌟 Preview of 1.x version
+
+A brand new version of [**MMEditing v1.0.0rc2**](https://github.com/open-mmlab/mmediting/releases/tag/v1.0.0rc2) was released in 02/11/2022:
+
+- Support all the tasks, models, metrics, and losses in [MMGeneration](https://github.com/open-mmlab/mmgeneration) 😍。
+- Unifies interfaces of all components based on [MMEngine](https://github.com/open-mmlab/mmengine).
+- Support patch-based and slider-based image and video comparison viewer.
+- Support image colorization.
+
+Find more new features in [1.x branch](https://github.com/open-mmlab/mmediting/tree/1.x). Issues and PRs are welcome!
+
+### 💎 Stable version
+
+**0.16.0** was released in 31/10/2022:
+
+- `VisualizationHook` is deprecated. Users should use `MMEditVisualizationHook` instead.
+- Fix FLAVR register.
+- Fix the number of channels in RDB.
+
+Please refer to [changelog.md](docs/en/changelog.md) for details and release history.
## Installation
@@ -215,6 +219,13 @@ Supported algorithms:
+
+Image Colorization
+
+- ✅ [InstColorization](configs/inst_colorization/README.md) (CVPR'2020)
+
+
+
Unconditional GANs
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 7f820b5456..64fca3222b 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -37,7 +37,7 @@
[English](/README.md) | 简体中文
-## Introduction
+## 介绍
MMEditing 是基于 PyTorch 的图像&视频编辑和生成开源工具箱。是 [OpenMMLab](https://openmmlab.com/) 项目的成员之一。
@@ -101,24 +101,28 @@ https://user-images.githubusercontent.com/12756472/158972813-d8d0f19c-f49c-4618-
需要注意的是 **MMSR** 已作为 MMEditing 的一部分并入本仓库。
MMEditing 缜密地设计新的框架并将其精心实现,希望能够为您带来更好的体验。
-## 最新消息
-
-- \[2022-09-13\] 🎉 [MMGeneration](<(https://github.com/open-mmlab/mmgeneration/tree/1.x)>) 合入 MMEditing! 对于该合入计划,我们期待您的 [建议](https://github.com/open-mmlab/mmediting/discussions/1108)!
-- \[2022-08-31\] v1.0.0rc0 版本发布
- 这个版本引入一个全新的,可扩展性强的训练和测试引擎,但目前仍在开发中。欢迎根据[文档](https://mmediting.readthedocs.io/en/1.x/)进行试用。
-- \[2022-06-01\] v0.15.0 版本发布
- - 支持 FLAVR
- - 支持 AOT-GAN
- - 新版 CAIN,支持 ReduceLROnPlateau 策略
-- \[2022-04-01\] v0.14.0 版本发布
- - 支持视频插帧算法 TOFlow
-- \[2022-03-01\] v0.13.0 版本发布
- - 支持 CAIN
- - 支持 EDVR-L
- - 支持在 Windows 系统中运行
-- \[2022-02-11\] 切换到 **PyTorch 1.5+**. 将不再保证与早期版本的 PyTorch 的兼容性
-
-请查看 [changelog.md](docs/zh_cn/notes/3_changelog.md) 以获取更多细节与发版记录
+## 最新进展
+
+### 🌟 1.x 预览版本
+
+全新的 [**MMEditing v1.0.0rc2**](https://github.com/open-mmlab/mmediting/releases/tag/v1.0.0rc2) 已经在 02/11/2022 发布:
+
+- 支持[MMGeneration](https://github.com/open-mmlab/mmgeneration)中的全量任务、模型、优化函数和评价指标 😍。
+- 基于[MMEngine](https://github.com/open-mmlab/mmengine)统一了各组件接口。
+- 支持基于图像子块以及滑动条的图像和视频比较可视化工具。
+- 支持图像上色任务。
+
+在[1.x 分支](https://github.com/open-mmlab/mmediting/tree/1.x)中发现更多特性!欢迎提 Issues 和 PRs!
+
+### 💎 稳定版本
+
+最新的 **0.16.0** 版本已经在 31/10/2022 发布:
+
+- `VisualizationHook` 将被启用,建议用户使用 `MMEditVisualizationHook`。
+- 修复 FLAVR 的注册问题。
+- 修正 RDB 模型中的通道数。
+
+如果像了解更多版本更新细节和历史信息,请阅读[更新日志](docs/en/changelog.md)。
## 安装
@@ -213,6 +217,13 @@ pip3 install -e .
+
+图像上色
+
+- ✅ [InstColorization](configs/inst_colorization/README.md) (CVPR'2020)
+
+
+
Unconditional GANs
diff --git a/configs/_base_/datasets/basicvsr_test_config.py b/configs/_base_/datasets/basicvsr_test_config.py
index 42e9cf7c91..7f16b094a3 100644
--- a/configs/_base_/datasets/basicvsr_test_config.py
+++ b/configs/_base_/datasets/basicvsr_test_config.py
@@ -109,7 +109,6 @@
metainfo=dict(dataset_type='udm10', task_name='vsr'),
data_root=udm10_data_root,
data_prefix=dict(img='BDx4', gt='GT'),
- num_input_frames=15,
pipeline=udm10_pipeline))
udm10_evaluator = [
@@ -138,8 +137,6 @@
data_prefix=dict(img='BDx4', gt='GT'),
ann_file='meta_info_Vid4_GT.txt',
depth=1,
- num_input_frames=7,
- fixed_seq_len=7,
pipeline=vid4_pipeline))
vid4_bi_dataloader = dict(
@@ -154,8 +151,6 @@
data_prefix=dict(img='BIx4', gt='GT'),
ann_file='meta_info_Vid4_GT.txt',
depth=1,
- num_input_frames=7,
- fixed_seq_len=7,
pipeline=vid4_pipeline))
vid4_bd_evaluator = [
diff --git a/configs/_base_/datasets/cifar10_nopad.py b/configs/_base_/datasets/cifar10_nopad.py
index 141e404ea5..16dfec7d0c 100644
--- a/configs/_base_/datasets/cifar10_nopad.py
+++ b/configs/_base_/datasets/cifar10_nopad.py
@@ -1,7 +1,5 @@
-# custom_imports = dict(
-# imports=['mmcls.datasets.transforms'], allow_failed_imports=False)
cifar_pipeline = [
- dict(type='Flip', flip_ratio=0.5, direction='horizontal'),
+ dict(type='Flip', keys=['img'], flip_ratio=0.5, direction='horizontal'),
dict(type='PackEditInputs')
]
cifar_dataset = dict(
diff --git a/configs/_base_/datasets/imagenet_128.py b/configs/_base_/datasets/imagenet_128.py
index 39310f4bae..88a15db1f0 100644
--- a/configs/_base_/datasets/imagenet_128.py
+++ b/configs/_base_/datasets/imagenet_128.py
@@ -7,7 +7,7 @@
dict(type='LoadImageFromFile', key='img'),
dict(type='RandomCropLongEdge', keys=['img']),
dict(type='Resize', scale=(128, 128), keys=['img'], backend='pillow'),
- dict(type='Flip', flip_ratio=0.5, direction='horizontal'),
+ dict(type='Flip', keys=['img'], flip_ratio=0.5, direction='horizontal'),
dict(type='PackEditInputs')
]
diff --git a/configs/_base_/datasets/liif_test_config.py b/configs/_base_/datasets/liif_test_config.py
index 6187899d49..9569c8e8f1 100644
--- a/configs/_base_/datasets/liif_test_config.py
+++ b/configs/_base_/datasets/liif_test_config.py
@@ -29,8 +29,8 @@
pipeline=test_pipeline)) for test_pipeline in test_pipelines
]
set5_evaluators = [[
- dict(type='PSNR', crop_border=2, prefix=f'Set5x{scale}'),
- dict(type='SSIM', crop_border=2, prefix=f'Set5x{scale}'),
+ dict(type='PSNR', crop_border=scale, prefix=f'Set5x{scale}'),
+ dict(type='SSIM', crop_border=scale, prefix=f'Set5x{scale}'),
] for scale in scale_test_list]
# test config for Set14
@@ -48,8 +48,8 @@
pipeline=test_pipeline)) for test_pipeline in test_pipelines
]
set14_evaluators = [[
- dict(type='PSNR', crop_border=2, prefix=f'Set14x{scale}'),
- dict(type='SSIM', crop_border=2, prefix=f'Set14x{scale}'),
+ dict(type='PSNR', crop_border=scale, prefix=f'Set14x{scale}'),
+ dict(type='SSIM', crop_border=scale, prefix=f'Set14x{scale}'),
] for scale in scale_test_list]
# test config for DIV2K
@@ -69,8 +69,8 @@
pipeline=test_pipeline)) for test_pipeline in test_pipelines
]
div2k_evaluators = [[
- dict(type='PSNR', crop_border=2, prefix=f'DIV2Kx{scale}'),
- dict(type='SSIM', crop_border=2, prefix=f'DIV2Kx{scale}'),
+ dict(type='PSNR', crop_border=scale, prefix=f'DIV2Kx{scale}'),
+ dict(type='SSIM', crop_border=scale, prefix=f'DIV2Kx{scale}'),
] for scale in scale_test_list]
# test config
diff --git a/configs/_base_/datasets/paired_imgs_256x256_crop.py b/configs/_base_/datasets/paired_imgs_256x256_crop.py
index 38a609fa66..6770420cf5 100644
--- a/configs/_base_/datasets/paired_imgs_256x256_crop.py
+++ b/configs/_base_/datasets/paired_imgs_256x256_crop.py
@@ -10,18 +10,11 @@
domain_b='B',
color_type='color'),
dict(
- type='TransformBroadcaster',
- mapping={'img': ['img_A', 'img_B']},
- auto_remap=True,
- share_random_params=True,
- transforms=[
- dict(
- type='Resize',
- keys='img',
- scale=(286, 286),
- interpolation='bicubic'),
- dict(type='FixedCrop', keys=['img'], crop_size=(256, 256))
- ]),
+ type='Resize',
+ keys=['img_A', 'img_B'],
+ scale=(286, 286),
+ interpolation='bicubic'),
+ dict(type='FixedCrop', keys=['img_A', 'img_B'], crop_size=(256, 256)),
dict(type='Flip', keys=['img_A', 'img_B'], direction='horizontal'),
# NOTE: users should implement their own keyMapper and Pack operation
# dict(
@@ -76,7 +69,7 @@
# `batch_size` and `data_root` need to be set.
train_dataloader = dict(
- batch_size=1,
+ batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
@@ -86,7 +79,7 @@
pipeline=train_pipeline))
val_dataloader = dict(
- batch_size=1,
+ batch_size=4,
num_workers=4,
dataset=dict(
type=dataset_type,
@@ -96,7 +89,7 @@
persistent_workers=True)
test_dataloader = dict(
- batch_size=1,
+ batch_size=4,
num_workers=4,
dataset=dict(
type=dataset_type,
diff --git a/configs/_base_/datasets/sisr_x2_test_config.py b/configs/_base_/datasets/sisr_x2_test_config.py
index 394ea23c4e..491c5327df 100644
--- a/configs/_base_/datasets/sisr_x2_test_config.py
+++ b/configs/_base_/datasets/sisr_x2_test_config.py
@@ -41,7 +41,7 @@
dataset=dict(
type='BasicImageDataset',
metainfo=dict(dataset_type='set14', task_name='sisr'),
- data_root=set5_data_root,
+ data_root=set14_data_root,
data_prefix=dict(img='LRbicx2', gt='GTmod12'),
pipeline=test_pipeline))
set14_evaluator = [
diff --git a/configs/_base_/datasets/sisr_x3_test_config.py b/configs/_base_/datasets/sisr_x3_test_config.py
index ecd62accf7..35de971440 100644
--- a/configs/_base_/datasets/sisr_x3_test_config.py
+++ b/configs/_base_/datasets/sisr_x3_test_config.py
@@ -28,8 +28,8 @@
data_prefix=dict(img='LRbicx3', gt='GTmod12'),
pipeline=test_pipeline))
set5_evaluator = [
- dict(type='PSNR', crop_border=2, prefix='Set5'),
- dict(type='SSIM', crop_border=2, prefix='Set5'),
+ dict(type='PSNR', crop_border=3, prefix='Set5'),
+ dict(type='SSIM', crop_border=3, prefix='Set5'),
]
set14_data_root = 'data/Set14'
@@ -41,12 +41,12 @@
dataset=dict(
type='BasicImageDataset',
metainfo=dict(dataset_type='set14', task_name='sisr'),
- data_root=set5_data_root,
+ data_root=set14_data_root,
data_prefix=dict(img='LRbicx3', gt='GTmod12'),
pipeline=test_pipeline))
set14_evaluator = [
- dict(type='PSNR', crop_border=2, prefix='Set14'),
- dict(type='SSIM', crop_border=2, prefix='Set14'),
+ dict(type='PSNR', crop_border=3, prefix='Set14'),
+ dict(type='SSIM', crop_border=3, prefix='Set14'),
]
# test config for DIV2K
@@ -65,8 +65,8 @@
img='DIV2K_train_LR_bicubic/X3_sub', gt='DIV2K_train_HR_sub'),
pipeline=test_pipeline))
div2k_evaluator = [
- dict(type='PSNR', crop_border=2, prefix='DIV2K'),
- dict(type='SSIM', crop_border=2, prefix='DIV2K'),
+ dict(type='PSNR', crop_border=3, prefix='DIV2K'),
+ dict(type='SSIM', crop_border=3, prefix='DIV2K'),
]
# test config
diff --git a/configs/_base_/datasets/sisr_x4_test_config.py b/configs/_base_/datasets/sisr_x4_test_config.py
index b040c6a465..bc637b7a64 100644
--- a/configs/_base_/datasets/sisr_x4_test_config.py
+++ b/configs/_base_/datasets/sisr_x4_test_config.py
@@ -28,8 +28,8 @@
data_prefix=dict(img='LRbicx4', gt='GTmod12'),
pipeline=test_pipeline))
set5_evaluator = [
- dict(type='PSNR', crop_border=2, prefix='Set5'),
- dict(type='SSIM', crop_border=2, prefix='Set5'),
+ dict(type='PSNR', crop_border=4, prefix='Set5'),
+ dict(type='SSIM', crop_border=4, prefix='Set5'),
]
set14_data_root = 'data/Set14'
@@ -41,12 +41,12 @@
dataset=dict(
type='BasicImageDataset',
metainfo=dict(dataset_type='set14', task_name='sisr'),
- data_root=set5_data_root,
+ data_root=set14_data_root,
data_prefix=dict(img='LRbicx4', gt='GTmod12'),
pipeline=test_pipeline))
set14_evaluator = [
- dict(type='PSNR', crop_border=2, prefix='Set14'),
- dict(type='SSIM', crop_border=2, prefix='Set14'),
+ dict(type='PSNR', crop_border=4, prefix='Set14'),
+ dict(type='SSIM', crop_border=4, prefix='Set14'),
]
# test config for DIV2K
@@ -66,8 +66,8 @@
# filename_tmpl=dict(img='{}_x4', gt='{}'),
pipeline=test_pipeline))
div2k_evaluator = [
- dict(type='PSNR', crop_border=2, prefix='DIV2K'),
- dict(type='SSIM', crop_border=2, prefix='DIV2K'),
+ dict(type='PSNR', crop_border=4, prefix='DIV2K'),
+ dict(type='SSIM', crop_border=4, prefix='DIV2K'),
]
# test config
diff --git a/configs/_base_/datasets/unpaired_imgs_256x256.py b/configs/_base_/datasets/unpaired_imgs_256x256.py
index 6c3624ce1f..1b55ccc51d 100644
--- a/configs/_base_/datasets/unpaired_imgs_256x256.py
+++ b/configs/_base_/datasets/unpaired_imgs_256x256.py
@@ -46,7 +46,6 @@
transforms=dict(
type='Resize', scale=(256, 256), interpolation='bicubic'),
),
-
# NOTE: users should implement their own keyMapper and Pack operation
# dict(
# type='KeyMapper',
@@ -65,7 +64,7 @@
# `batch_size` and `data_root` need to be set.
train_dataloader = dict(
- batch_size=1,
+ batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
@@ -75,21 +74,23 @@
pipeline=train_pipeline))
val_dataloader = dict(
- batch_size=1,
+ batch_size=4,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root=None, # set by user
+ test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True)
test_dataloader = dict(
- batch_size=1,
+ batch_size=4,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root=None, # set by user
+ test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True)
diff --git a/configs/_base_/gen_default_runtime.py b/configs/_base_/gen_default_runtime.py
index a13f62f701..64a510e42e 100644
--- a/configs/_base_/gen_default_runtime.py
+++ b/configs/_base_/gen_default_runtime.py
@@ -19,8 +19,9 @@
type='CheckpointHook',
interval=10000,
by_epoch=False,
- less_keys=['FID-Full-50k/fid'],
- greater_keys=['IS-50k/is'],
+ max_keep_ckpts=20,
+ less_keys=['FID-Full-50k/fid', 'swd/avg'],
+ greater_keys=['IS-50k/is', 'ms-ssim/avg'],
save_optimizer=True))
# config for environment
diff --git a/configs/_base_/matting_default_runtime.py b/configs/_base_/matting_default_runtime.py
index 9d9f09cd6c..e544dd8677 100644
--- a/configs/_base_/matting_default_runtime.py
+++ b/configs/_base_/matting_default_runtime.py
@@ -26,7 +26,7 @@
fn_key='trimap_path',
img_keys=['pred_alpha', 'trimap', 'gt_merged', 'gt_alpha'],
bgr2rgb=True)
-custom_hooks = [dict(type='BasicVisualizationHook', interval=1)]
+custom_hooks = [dict(type='BasicVisualizationHook', interval=2000)]
log_level = 'INFO'
log_processor = dict(type='LogProcessor', by_epoch=False)
diff --git a/configs/_base_/models/base_glean.py b/configs/_base_/models/base_glean.py
index d857ab5e67..50faaba19a 100644
--- a/configs/_base_/models/base_glean.py
+++ b/configs/_base_/models/base_glean.py
@@ -40,7 +40,8 @@
save_optimizer=True,
by_epoch=False,
out_dir=save_dir,
- ),
+ save_best=['MAE', 'PSNR', 'SSIM'],
+ rule=['less', 'greater', 'greater']),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
diff --git a/configs/_base_/models/base_pix2pix.py b/configs/_base_/models/base_pix2pix.py
index d64c275f8b..8fde5a1de5 100644
--- a/configs/_base_/models/base_pix2pix.py
+++ b/configs/_base_/models/base_pix2pix.py
@@ -20,6 +20,7 @@
num_conv=3,
norm_cfg=dict(type='BN'),
init_cfg=dict(type='normal', gain=0.02)),
+ loss_config=dict(pixel_loss_weight=100.0),
default_domain=target_domain,
reachable_domains=[target_domain],
related_domains=[target_domain, source_domain])
diff --git a/configs/basicvsr/basicvsr_2xb4_vimeo90k-bd.py b/configs/basicvsr/basicvsr_2xb4_vimeo90k-bd.py
index 2e75628ce1..8b9fb167d7 100644
--- a/configs/basicvsr/basicvsr_2xb4_vimeo90k-bd.py
+++ b/configs/basicvsr/basicvsr_2xb4_vimeo90k-bd.py
@@ -68,8 +68,6 @@
data_prefix=dict(img='BDx4', gt='GT'),
ann_file='meta_info_Vid4_GT.txt',
depth=1,
- num_input_frames=7,
- fixed_seq_len=7,
pipeline=val_pipeline))
val_evaluator = [
diff --git a/configs/basicvsr_pp/basicvsr-pp_c128n25_600k_ntire-decompress-track1.py b/configs/basicvsr_pp/basicvsr-pp_c128n25_600k_ntire-decompress-track1.py
index 0b15728127..daef0c8702 100644
--- a/configs/basicvsr_pp/basicvsr-pp_c128n25_600k_ntire-decompress-track1.py
+++ b/configs/basicvsr_pp/basicvsr-pp_c128n25_600k_ntire-decompress-track1.py
@@ -52,7 +52,6 @@
metainfo=dict(dataset_type='ntire21_track1', task_name='vsr'),
data_root='data/NTIRE21_decompression_track1',
data_prefix=dict(img='LQ', gt='GT'),
- num_input_frames=15,
pipeline=test_pipeline))
test_evaluator = [
diff --git a/configs/cyclegan/README.md b/configs/cyclegan/README.md
index 9e62e3c263..b72a0fefcf 100644
--- a/configs/cyclegan/README.md
+++ b/configs/cyclegan/README.md
@@ -27,18 +27,20 @@ Image-to-image translation is a class of vision and graphics problems where the
We use `FID` and `IS` metrics to evaluate the generation performance of CycleGAN.1
+https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_resnet_in_1x1_80k_facades_20210902_165905-5e2c0876.pth
+https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_in_1x1_80k_facades_20210902_165905-5e2c0876.pth
| Models | Dataset | FID | IS | Config | Download |
| :----: | :---------------: | :------: | :---: | :-------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------: |
-| Ours | facades | 124.8033 | 1.792 | [config](/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-80kiters_facades.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/base_cyclegan_in_1x1_80k_facades_20210902_165905-5e2c0876.pth) \| [log](https://download.openmmlab.com/mmgen/cyclegan/base_cyclegan_in_1x1_80k_facades_20210317_160938.log.json) 2 |
+| Ours | facades | 124.8033 | 1.792 | [config](/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-80kiters_facades.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_resnet_in_1x1_80k_facades_20210902_165905-5e2c0876.pth) \| [log](https://download.openmmlab.com/mmgen/cyclegan/base_cyclegan_in_1x1_80k_facades_20210317_160938.log.json) 2 |
| Ours | facades-id0 | 125.1694 | 1.905 | [config](/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-80kiters_facades.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_id0_resnet_in_1x1_80k_facades_convert-bgr_20210902_164411-d8e72b45.pth) |
-| Ours | summer2winter | 83.7177 | 2.771 | [config](/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-250kiters_summer2winter.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/base_cyclegan_in_1x1_246200_summer2winter_convert-bgr_20210902_165932-fcf08dc1.pth) |
+| Ours | summer2winter | 83.7177 | 2.771 | [config](/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-250kiters_summer2winter.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_resnet_in_1x1_246200_summer2winter_convert-bgr_20210902_165932-fcf08dc1.pth) |
| Ours | summer2winter-id0 | 83.1418 | 2.720 | [config](/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-250kiters_summer2winter.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_id0_resnet_in_1x1_246200_summer2winter_convert-bgr_20210902_165640-8b825581.pth) |
-| Ours | winter2summer | 72.8025 | 3.129 | [config](/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-250kiters_summer2winter.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/base_cyclegan_in_1x1_246200_summer2winter_convert-bgr_20210902_165932-fcf08dc1.pth) |
+| Ours | winter2summer | 72.8025 | 3.129 | [config](/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-250kiters_summer2winter.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_resnet_in_1x1_246200_summer2winter_convert-bgr_20210902_165932-fcf08dc1.pth) |
| Ours | winter2summer-id0 | 73.5001 | 3.107 | [config](/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-250kiters_summer2winter.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_id0_resnet_in_1x1_246200_summer2winter_convert-bgr_20210902_165640-8b825581.pth) |
-| Ours | horse2zebra | 64.5225 | 1.418 | [config](/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-270kiters_horse2zebra.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/base_cyclegan_in_1x1_266800_horse2zebra_convert-bgr_20210902_170004-a32c733a.pth) |
+| Ours | horse2zebra | 64.5225 | 1.418 | [config](/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-270kiters_horse2zebra.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_resnet_in_1x1_266800_horse2zebra_convert-bgr_20210902_170004-a32c733a.pth) |
| Ours | horse2zebra-id0 | 74.7770 | 1.542 | [config](/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-270kiters_horse2zebra.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_id0_resnet_in_1x1_266800_horse2zebra_convert-bgr_20210902_165724-77c9c806.pth) |
-| Ours | zebra2horse | 141.1517 | 3.154 | [config](/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-270kiters_horse2zebra.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/base_cyclegan_in_1x1_266800_horse2zebra_convert-bgr_20210902_170004-a32c733a.pth) |
+| Ours | zebra2horse | 141.1517 | 3.154 | [config](/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-270kiters_horse2zebra.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_resnet_in_1x1_266800_horse2zebra_convert-bgr_20210902_170004-a32c733a.pth) |
| Ours | zebra2horse-id0 | 134.3728 | 3.091 | [config](/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-270kiters_horse2zebra.py) | [model](https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_id0_resnet_in_1x1_266800_horse2zebra_convert-bgr_20210902_165724-77c9c806.pth) |
`FID` comparison with official:
diff --git a/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-250kiters_summer2winter.py b/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-250kiters_summer2winter.py
index 0ea3571ca7..cd131f24ca 100644
--- a/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-250kiters_summer2winter.py
+++ b/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-250kiters_summer2winter.py
@@ -67,21 +67,39 @@
]
# testA: 309, testB:238
-num_images = 238
+num_images_a = 309
+num_images_b = 238
metrics = [
dict(
type='TransIS',
- prefix='IS-Full',
- fake_nums=num_images,
- fake_key='fake_winter',
+ prefix=f'IS-{domain_a}-to-{domain_b}',
+ fake_nums=num_images_b,
+ fake_key=f'fake_{domain_b}',
+ use_pillow_resize=False,
+ resize_method='bilinear',
inception_style='PyTorch'),
+ dict(
+ type='TransIS',
+ prefix=f'IS-{domain_b}-to-{domain_a}',
+ fake_nums=num_images_a,
+ fake_key=f'fake_{domain_a}',
+ use_pillow_resize=False,
+ resize_method='bilinear',
+ inception_style='PyTorch'),
+ dict(
+ type='TransFID',
+ prefix=f'FID-{domain_a}-to-{domain_b}',
+ fake_nums=num_images_b,
+ inception_style='PyTorch',
+ real_key=f'img_{domain_b}',
+ fake_key=f'fake_{domain_b}'),
dict(
type='TransFID',
- prefix='FID-Full',
- fake_nums=num_images,
+ prefix=f'FID-{domain_b}-to-{domain_a}',
+ fake_nums=num_images_a,
inception_style='PyTorch',
- real_key='img_winter',
- fake_key='fake_winter')
+ real_key=f'img_{domain_a}',
+ fake_key=f'fake_{domain_a}')
]
val_evaluator = dict(metrics=metrics)
diff --git a/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-270kiters_horse2zebra.py b/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-270kiters_horse2zebra.py
index 05663ecd58..4bad2f6df1 100644
--- a/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-270kiters_horse2zebra.py
+++ b/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-270kiters_horse2zebra.py
@@ -66,21 +66,39 @@
])
]
-num_images = 140
+num_images_a = 120
+num_images_b = 140
metrics = [
dict(
type='TransIS',
- prefix='IS-Full',
- fake_nums=num_images,
- fake_key='fake_zebra',
+ prefix=f'IS-{domain_a}-to-{domain_b}',
+ fake_nums=num_images_b,
+ fake_key=f'fake_{domain_b}',
+ use_pillow_resize=False,
+ resize_method='bilinear',
inception_style='PyTorch'),
+ dict(
+ type='TransIS',
+ prefix=f'IS-{domain_b}-to-{domain_a}',
+ fake_nums=num_images_a,
+ fake_key=f'fake_{domain_a}',
+ use_pillow_resize=False,
+ resize_method='bilinear',
+ inception_style='PyTorch'),
+ dict(
+ type='TransFID',
+ prefix=f'FID-{domain_a}-to-{domain_b}',
+ fake_nums=num_images_b,
+ inception_style='PyTorch',
+ real_key=f'img_{domain_b}',
+ fake_key=f'fake_{domain_b}'),
dict(
type='TransFID',
- prefix='FID-Full',
- fake_nums=num_images,
+ prefix=f'FID-{domain_b}-to-{domain_a}',
+ fake_nums=num_images_a,
inception_style='PyTorch',
- real_key='img_zebra',
- fake_key='fake_zebra')
+ real_key=f'img_{domain_a}',
+ fake_key=f'fake_{domain_a}')
]
val_evaluator = dict(metrics=metrics)
diff --git a/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-80kiters_facades.py b/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-80kiters_facades.py
index 6bc8243171..c09373a8c4 100644
--- a/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-80kiters_facades.py
+++ b/configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-80kiters_facades.py
@@ -57,7 +57,7 @@
custom_hooks = [
dict(
type='GenVisualizationHook',
- interval=50,
+ interval=5000,
fixed_input=True,
vis_kwargs_list=[
dict(type='Translation', name='trans'),
@@ -65,7 +65,6 @@
])
]
-# learning policy
num_images = 106
metrics = [
dict(
@@ -73,6 +72,8 @@
prefix='IS-Full',
fake_nums=num_images,
fake_key=f'fake_{domain_a}',
+ use_pillow_resize=False,
+ resize_method='bilinear',
inception_style='PyTorch'),
dict(
type='TransFID',
diff --git a/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-250kiters_summer2winter.py b/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-250kiters_summer2winter.py
index c6503a5c6b..a410ad64ac 100644
--- a/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-250kiters_summer2winter.py
+++ b/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-250kiters_summer2winter.py
@@ -66,21 +66,39 @@
])
]
-num_images = 238
+num_images_a = 309
+num_images_b = 238
metrics = [
dict(
type='TransIS',
- prefix='IS-Full',
- fake_nums=num_images,
- fake_key='fake_winter',
+ prefix=f'IS-{domain_a}-to-{domain_b}',
+ fake_nums=num_images_b,
+ fake_key=f'fake_{domain_b}',
+ use_pillow_resize=False,
+ resize_method='bilinear',
inception_style='PyTorch'),
+ dict(
+ type='TransIS',
+ prefix=f'IS-{domain_b}-to-{domain_a}',
+ fake_nums=num_images_a,
+ fake_key=f'fake_{domain_a}',
+ use_pillow_resize=False,
+ resize_method='bilinear',
+ inception_style='PyTorch'),
+ dict(
+ type='TransFID',
+ prefix=f'FID-{domain_a}-to-{domain_b}',
+ fake_nums=num_images_b,
+ inception_style='PyTorch',
+ real_key=f'img_{domain_b}',
+ fake_key=f'fake_{domain_b}'),
dict(
type='TransFID',
- prefix='FID-Full',
- fake_nums=num_images,
+ prefix=f'FID-{domain_b}-to-{domain_a}',
+ fake_nums=num_images_a,
inception_style='PyTorch',
- real_key='img_winter',
- fake_key='fake_winter')
+ real_key=f'img_{domain_a}',
+ fake_key=f'fake_{domain_a}')
]
val_evaluator = dict(metrics=metrics)
diff --git a/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-270kiters_horse2zebra.py b/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-270kiters_horse2zebra.py
index afbb87a372..341e9e768a 100644
--- a/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-270kiters_horse2zebra.py
+++ b/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-270kiters_horse2zebra.py
@@ -7,6 +7,7 @@
domain_a = 'horse'
domain_b = 'zebra'
+
model = dict(
loss_config=dict(cycle_loss_weight=10., id_loss_weight=0.5),
default_domain=domain_b,
@@ -65,21 +66,39 @@
])
]
-num_images = 140
+num_images_a = 120
+num_images_b = 140
metrics = [
dict(
type='TransIS',
- prefix='IS-Full',
- fake_nums=num_images,
- fake_key='fake_zebra',
+ prefix=f'IS-{domain_a}-to-{domain_b}',
+ fake_nums=num_images_b,
+ fake_key=f'fake_{domain_b}',
+ use_pillow_resize=False,
+ resize_method='bilinear',
+ inception_style='PyTorch'),
+ dict(
+ type='TransIS',
+ prefix=f'IS-{domain_b}-to-{domain_a}',
+ fake_nums=num_images_a,
+ fake_key=f'fake_{domain_a}',
+ use_pillow_resize=False,
+ resize_method='bilinear',
inception_style='PyTorch'),
dict(
type='TransFID',
- prefix='FID-Full',
- fake_nums=num_images,
+ prefix=f'FID-{domain_a}-to-{domain_b}',
+ fake_nums=num_images_b,
+ inception_style='PyTorch',
+ real_key=f'img_{domain_b}',
+ fake_key=f'fake_{domain_b}'),
+ dict(
+ type='TransFID',
+ prefix=f'FID-{domain_b}-to-{domain_a}',
+ fake_nums=num_images_a,
inception_style='PyTorch',
- real_key='img_zebra',
- fake_key='fake_zebra')
+ real_key=f'img_{domain_a}',
+ fake_key=f'fake_{domain_a}')
]
val_evaluator = dict(metrics=metrics)
diff --git a/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-80kiters_facades.py b/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-80kiters_facades.py
index 5b6eed0879..3ebaa6d32f 100644
--- a/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-80kiters_facades.py
+++ b/configs/cyclegan/cyclegan_lsgan-resnet-in_1xb1-80kiters_facades.py
@@ -11,8 +11,17 @@
loss_config=dict(cycle_loss_weight=10., id_loss_weight=0.5),
default_domain=domain_a,
reachable_domains=[domain_a, domain_b],
- related_domains=[domain_a, domain_b],
-)
+ related_domains=[domain_a, domain_b])
+
+param_scheduler = dict(
+ type='LinearLrInterval',
+ interval=400,
+ by_epoch=False,
+ start_factor=0.0002,
+ end_factor=0,
+ begin=40000,
+ end=80000)
+
dataroot = './data/cyclegan/facades'
train_pipeline = _base_.train_dataloader.dataset.pipeline
val_pipeline = _base_.val_dataloader.dataset.pipeline
@@ -45,15 +54,6 @@
discriminators=dict(
optimizer=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))))
-param_scheduler = dict(
- type='LinearLrInterval',
- interval=400,
- by_epoch=False,
- start_factor=0.0002,
- end_factor=0,
- begin=40000,
- end=80000)
-
custom_hooks = [
dict(
type='GenVisualizationHook',
@@ -65,8 +65,6 @@
])
]
-total_iters = 80000
-
num_images = 106
metrics = [
dict(
@@ -74,6 +72,8 @@
prefix='IS-Full',
fake_nums=num_images,
fake_key=f'fake_{domain_a}',
+ use_pillow_resize=False,
+ resize_method='bilinear',
inception_style='PyTorch'),
dict(
type='TransFID',
diff --git a/configs/cyclegan/metafile.yml b/configs/cyclegan/metafile.yml
index 76f0241735..482fd9f12a 100644
--- a/configs/cyclegan/metafile.yml
+++ b/configs/cyclegan/metafile.yml
@@ -21,7 +21,7 @@ Models:
FID: 124.8033
IS: 1.792
Task: Cyclegan
- Weights: https://download.openmmlab.com/mmgen/cyclegan/refactor/base_cyclegan_in_1x1_80k_facades_20210902_165905-5e2c0876.pth
+ Weights: https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_resnet_in_1x1_80k_facades_20210902_165905-5e2c0876.pth
- Config: configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-80kiters_facades.py
In Collection: 'CycleGAN: Unpaired Image-to-Image Translation Using Cycle-Consistent
Adversarial Networks'
@@ -47,7 +47,7 @@ Models:
FID: 83.7177
IS: 2.771
Task: Cyclegan
- Weights: https://download.openmmlab.com/mmgen/cyclegan/refactor/base_cyclegan_in_1x1_246200_summer2winter_convert-bgr_20210902_165932-fcf08dc1.pth
+ Weights: https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_resnet_in_1x1_246200_summer2winter_convert-bgr_20210902_165932-fcf08dc1.pth
- Config: configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-250kiters_summer2winter.py
In Collection: 'CycleGAN: Unpaired Image-to-Image Translation Using Cycle-Consistent
Adversarial Networks'
@@ -73,7 +73,7 @@ Models:
FID: 72.8025
IS: 3.129
Task: Cyclegan
- Weights: https://download.openmmlab.com/mmgen/cyclegan/refactor/base_cyclegan_in_1x1_246200_summer2winter_convert-bgr_20210902_165932-fcf08dc1.pth
+ Weights: https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_resnet_in_1x1_246200_summer2winter_convert-bgr_20210902_165932-fcf08dc1.pth
- Config: configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-250kiters_summer2winter.py
In Collection: 'CycleGAN: Unpaired Image-to-Image Translation Using Cycle-Consistent
Adversarial Networks'
@@ -99,7 +99,7 @@ Models:
FID: 64.5225
IS: 1.418
Task: Cyclegan
- Weights: https://download.openmmlab.com/mmgen/cyclegan/refactor/base_cyclegan_in_1x1_266800_horse2zebra_convert-bgr_20210902_170004-a32c733a.pth
+ Weights: https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_resnet_in_1x1_266800_horse2zebra_convert-bgr_20210902_170004-a32c733a.pth
- Config: configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-270kiters_horse2zebra.py
In Collection: 'CycleGAN: Unpaired Image-to-Image Translation Using Cycle-Consistent
Adversarial Networks'
@@ -125,7 +125,7 @@ Models:
FID: 141.1517
IS: 3.154
Task: Cyclegan
- Weights: https://download.openmmlab.com/mmgen/cyclegan/refactor/base_cyclegan_in_1x1_266800_horse2zebra_convert-bgr_20210902_170004-a32c733a.pth
+ Weights: https://download.openmmlab.com/mmgen/cyclegan/refactor/cyclegan_lsgan_resnet_in_1x1_266800_horse2zebra_convert-bgr_20210902_170004-a32c733a.pth
- Config: configs/cyclegan/cyclegan_lsgan-id0-resnet-in_1xb1-270kiters_horse2zebra.py
In Collection: 'CycleGAN: Unpaired Image-to-Image Translation Using Cycle-Consistent
Adversarial Networks'
diff --git a/configs/dcgan/dcgan_Glr4e-4_Dlr1e-4_1xb128-5kiters_mnist-64x64.py b/configs/dcgan/dcgan_Glr4e-4_Dlr1e-4_1xb128-5kiters_mnist-64x64.py
index 78f7c33c36..3eb1b4ad99 100644
--- a/configs/dcgan/dcgan_Glr4e-4_Dlr1e-4_1xb128-5kiters_mnist-64x64.py
+++ b/configs/dcgan/dcgan_Glr4e-4_Dlr1e-4_1xb128-5kiters_mnist-64x64.py
@@ -30,7 +30,11 @@
test_dataloader = dict(
batch_size=batch_size, dataset=dict(data_root=data_root))
-default_hooks = dict(checkpoint=dict(interval=500))
+default_hooks = dict(
+ checkpoint=dict(
+ interval=500,
+ save_best=['swd/avg', 'ms-ssim/avg'],
+ rule=['less', 'greater']))
# VIS_HOOK
custom_hooks = [
diff --git a/configs/edsr/edsr_x2c64b16_1xb16-300k_div2k.py b/configs/edsr/edsr_x2c64b16_1xb16-300k_div2k.py
index 593753e342..ec5f7f4e09 100644
--- a/configs/edsr/edsr_x2c64b16_1xb16-300k_div2k.py
+++ b/configs/edsr/edsr_x2c64b16_1xb16-300k_div2k.py
@@ -73,7 +73,9 @@
data_root = 'data'
train_dataloader = dict(
- num_workers=4,
+ num_workers=8,
+ batch_size=16,
+ drop_last=True,
persistent_workers=False,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
@@ -112,7 +114,7 @@
optim_wrapper = dict(
constructor='DefaultOptimWrapperConstructor',
type='OptimWrapper',
- optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.99)))
+ optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)))
# learning policy
param_scheduler = dict(
diff --git a/configs/edsr/edsr_x3c64b16_1xb16-300k_div2k.py b/configs/edsr/edsr_x3c64b16_1xb16-300k_div2k.py
index be54ae082d..3ac6695d05 100644
--- a/configs/edsr/edsr_x3c64b16_1xb16-300k_div2k.py
+++ b/configs/edsr/edsr_x3c64b16_1xb16-300k_div2k.py
@@ -43,7 +43,7 @@
channel_order='rgb',
imdecode_backend='cv2'),
dict(type='SetValues', dictionary=dict(scale=scale)),
- dict(type='PairedRandomCrop', gt_patch_size=96),
+ dict(type='PairedRandomCrop', gt_patch_size=144),
dict(
type='Flip',
keys=['img', 'gt'],
@@ -75,7 +75,9 @@
data_root = 'data'
train_dataloader = dict(
- num_workers=4,
+ num_workers=8,
+ batch_size=16,
+ drop_last=True,
persistent_workers=False,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
@@ -114,7 +116,7 @@
optim_wrapper = dict(
constructor='DefaultOptimWrapperConstructor',
type='OptimWrapper',
- optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.99)))
+ optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)))
# learning policy
param_scheduler = dict(
diff --git a/configs/edsr/edsr_x4c64b16_1xb16-300k_div2k.py b/configs/edsr/edsr_x4c64b16_1xb16-300k_div2k.py
index b57c5797d9..1e878cc725 100644
--- a/configs/edsr/edsr_x4c64b16_1xb16-300k_div2k.py
+++ b/configs/edsr/edsr_x4c64b16_1xb16-300k_div2k.py
@@ -43,7 +43,7 @@
channel_order='rgb',
imdecode_backend='cv2'),
dict(type='SetValues', dictionary=dict(scale=scale)),
- dict(type='PairedRandomCrop', gt_patch_size=96),
+ dict(type='PairedRandomCrop', gt_patch_size=196),
dict(
type='Flip',
keys=['img', 'gt'],
@@ -75,7 +75,9 @@
data_root = 'data'
train_dataloader = dict(
- num_workers=4,
+ num_workers=8,
+ batch_size=16,
+ drop_last=True,
persistent_workers=False,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
@@ -114,7 +116,7 @@
optim_wrapper = dict(
constructor='DefaultOptimWrapperConstructor',
type='OptimWrapper',
- optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.99)))
+ optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)))
# learning policy
param_scheduler = dict(
diff --git a/configs/flavr/flavr_in4out1_8xb4_vimeo90k-septuplet.py b/configs/flavr/flavr_in4out1_8xb4_vimeo90k-septuplet.py
index 82fed4c8d5..54d0ff4b91 100644
--- a/configs/flavr/flavr_in4out1_8xb4_vimeo90k-septuplet.py
+++ b/configs/flavr/flavr_in4out1_8xb4_vimeo90k-septuplet.py
@@ -92,7 +92,7 @@
num_workers=16,
batch_size=4, # 8 gpu
persistent_workers=False,
- sampler=dict(type='InfiniteSampler', shuffle=True),
+ sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=train_dataset_type,
ann_file='txt/sep_trainlist.txt',
diff --git a/configs/ggan/README.md b/configs/ggan/README.md
index e719cfc6e9..d282f0c258 100644
--- a/configs/ggan/README.md
+++ b/configs/ggan/README.md
@@ -30,7 +30,7 @@ Generative Adversarial Nets (GANs) represent an important milestone for effectiv
| :----------: | :------------: | :-----------------------------: | :-----: | :-----: | :-------------------------------------------------------------: | :----------------------------------------------------------------: |
| GGAN 64x64 | CelebA-Cropped | 11.18, 12.21, 39.16/20.85 | 0.3318 | 20.1797 | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/ggan/ggan_dcgan-archi_lr1e-3-1xb128-12Mimgs_celeba-cropped-64x64.py) | [model](https://download.openmmlab.com/mmgen/ggan/ggan_celeba-cropped_dcgan-archi_lr-1e-3_64_b128x1_12m.pth) \| [log](https://download.openmmlab.com/mmgen/ggan/ggan_celeba-cropped_dcgan-archi_lr-1e-3_64_b128x1_12m_20210430_113839.log.json) |
| GGAN 128x128 | CelebA-Cropped | 9.81, 11.29, 19.22, 47.79/22.03 | 0.3149 | 18.7647 | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/ggan/ggan_dcgan-archi_lr1e-4-1xb64-10Mimgs_celeba-cropped-128x128.py) | [model](https://download.openmmlab.com/mmgen/ggan/ggan_celeba-cropped_dcgan-archi_lr-1e-4_128_b64x1_10m_20210430_143027-516423dc.pth) \| [log](https://download.openmmlab.com/mmgen/ggan/ggan_celeba-cropped_dcgan-archi_lr-1e-4_128_b64x1_10m_20210423_154258.log.json) |
-| GGAN 64x64 | LSUN-Bedroom | 9.1, 6.2, 12.27/9.19 | 0.0649 | 85.6629 | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/ggan/ggan_lsgan-archi_lr1e-4-1xb128-20Mimgs_lsun-bedroom-64x64.py) | [model](https://download.openmmlab.com/mmgen/ggan/ggan_lsun-bedroom_lsgan_archi_lr-1e-4_64_b128x1_20m_20210430_143114-5d99b76c.pth) \| [log](https://download.openmmlab.com/mmgen/ggan/ggan_lsun-bedroom_lsgan_archi_lr-1e-4_64_b128x1_20m_20210428_202027.log.json) |
+| GGAN 64x64 | LSUN-Bedroom | 9.1, 6.2, 12.27/9.19 | 0.0649 | 39.9261 | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/ggan/ggan_lsgan-archi_lr1e-4-1xb128-20Mimgs_lsun-bedroom-64x64.py) | [model](https://download.openmmlab.com/mmgen/ggan/ggan_lsun-bedroom_lsgan_archi_lr-1e-4_64_b128x1_20m_20210430_143114-5d99b76c.pth) \| [log](https://download.openmmlab.com/mmgen/ggan/ggan_lsun-bedroom_lsgan_archi_lr-1e-4_64_b128x1_20m_20210428_202027.log.json) |
Note: In the original implementation of [GGAN](https://github.com/lim0606/pytorch-geometric-gan), they set `G_iters` to 10. However our framework does not support `G_iters` currently, so we dropped the settings in the original implementation and conducted several experiments with our own settings. We have shown above the experiment results with the lowest `fid` score. \
Original settings and our settings:
diff --git a/configs/ggan/ggan_dcgan-archi_lr1e-3-1xb128-12Mimgs_celeba-cropped-64x64.py b/configs/ggan/ggan_dcgan-archi_lr1e-3-1xb128-12Mimgs_celeba-cropped-64x64.py
index c0b003b33d..a42716a2d9 100644
--- a/configs/ggan/ggan_dcgan-archi_lr1e-3-1xb128-12Mimgs_celeba-cropped-64x64.py
+++ b/configs/ggan/ggan_dcgan-archi_lr1e-3-1xb128-12Mimgs_celeba-cropped-64x64.py
@@ -30,6 +30,9 @@
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
+default_hooks = dict(
+ checkpoint=dict(
+ max_keep_ckpts=20, save_best='FID-Full-50k/fid', rule='less'))
# METRICS
metrics = [
diff --git a/configs/ggan/ggan_dcgan-archi_lr1e-4-1xb64-10Mimgs_celeba-cropped-128x128.py b/configs/ggan/ggan_dcgan-archi_lr1e-4-1xb64-10Mimgs_celeba-cropped-128x128.py
index d86badf7bc..4173260d8b 100644
--- a/configs/ggan/ggan_dcgan-archi_lr1e-4-1xb64-10Mimgs_celeba-cropped-128x128.py
+++ b/configs/ggan/ggan_dcgan-archi_lr1e-4-1xb64-10Mimgs_celeba-cropped-128x128.py
@@ -22,7 +22,11 @@
discriminator=dict(
optimizer=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99))))
-default_hooks = dict(checkpoint=dict(max_keep_ckpts=20))
+train_cfg = dict(max_iters=160000)
+
+default_hooks = dict(
+ checkpoint=dict(
+ max_keep_ckpts=20, save_best='FID-Full-50k/fid', rule='less'))
# VIS_HOOK
custom_hooks = [
@@ -33,8 +37,6 @@
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
-train_cfg = dict(max_iters=160000)
-
# METRICS
metrics = [
dict(
@@ -54,14 +56,5 @@
image_shape=(3, 128, 128))
]
-val_metrics = [
- dict(
- type='FrechetInceptionDistance',
- prefix='FID-Full-50k',
- fake_nums=50000,
- inception_style='StyleGAN',
- sample_model='orig'),
-]
-
-val_evaluator = dict(metrics=val_metrics)
+val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
diff --git a/configs/ggan/ggan_lsgan-archi_lr1e-4-1xb128-20Mimgs_lsun-bedroom-64x64.py b/configs/ggan/ggan_lsgan-archi_lr1e-4-1xb128-20Mimgs_lsun-bedroom-64x64.py
index 2820c554c7..876d2a0607 100644
--- a/configs/ggan/ggan_lsgan-archi_lr1e-4-1xb128-20Mimgs_lsun-bedroom-64x64.py
+++ b/configs/ggan/ggan_lsgan-archi_lr1e-4-1xb128-20Mimgs_lsun-bedroom-64x64.py
@@ -26,7 +26,11 @@
discriminator=dict(
optimizer=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99))))
-default_hooks = dict(checkpoint=dict(max_keep_ckpts=20))
+default_hooks = dict(
+ checkpoint=dict(
+ max_keep_ckpts=20,
+ save_best=['FID-Full-50k/fid', 'swd/avg', 'ms-ssim/avg'],
+ rule=['less', 'less', 'greater']))
# VIS_HOOK
custom_hooks = [
diff --git a/configs/ggan/metafile.yml b/configs/ggan/metafile.yml
index 1fc1966179..8db4dc0f3c 100644
--- a/configs/ggan/metafile.yml
+++ b/configs/ggan/metafile.yml
@@ -39,7 +39,7 @@ Models:
Results:
- Dataset: Others
Metrics:
- FID: 85.6629
+ FID: 39.9261
MS-SSIM: 0.0649
Task: Ggan
Weights: https://download.openmmlab.com/mmgen/ggan/ggan_lsun-bedroom_lsgan_archi_lr-1e-4_64_b128x1_20m_20210430_143114-5d99b76c.pth
diff --git a/configs/inst_colorization/README.md b/configs/inst_colorization/README.md
new file mode 100644
index 0000000000..fdfdfb9e50
--- /dev/null
+++ b/configs/inst_colorization/README.md
@@ -0,0 +1,55 @@
+# Instance-aware Image Colorization (CVPR'2020)
+
+> [Instance-Aware Image Colorization](https://openaccess.thecvf.com/content_CVPR_2020/html/Su_Instance-Aware_Image_Colorization_CVPR_2020_paper.html)
+
+> **Task**: Colorization
+
+
+
+## Abstract
+
+
+
+Image colorization is inherently an ill-posed problem with multi-modal uncertainty. Previous methods leverage the deep neural network to map input grayscale images to plausible color outputs directly. Although these learning-based methods have shown impressive performance, they usually fail on the input images that contain multiple objects. The leading cause is that existing models perform learning and colorization on the entire image. In the absence of a clear figure-ground separation, these models cannot effectively locate and learn meaningful object-level semantics. In this paper, we propose a method for achieving instance-aware colorization. Our network architecture leverages an off-the-shelf object detector to obtain cropped object images and uses an instance colorization network to extract object-level features. We use a similar network to extract the full-image features and apply a fusion module to full object-level and image-level features to predict the final colors. Both colorization networks and fusion modules are learned from a large-scale dataset. Experimental results show that our work outperforms existing methods on different quality metrics and achieves state-of-the-art performance on image colorization.
+
+
+
+
+
+
+
+## Results and models
+
+| Method | Download |
+| :-------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------: |
+| [instance_aware_colorization_officiial](/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth) |
+
+## Quick Start
+
+
+Colorization demo
+
+You can use the following commands to colorize an image.
+
+```shell
+
+python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth input.jpg output.jpg
+```
+
+For more demos, you can refer to [Tutorial 3: inference with pre-trained models](https://mmediting.readthedocs.io/en/1.x/user_guides/3_inference.html).
+
+
+
+
+Instance-aware Image Colorization (CVPR'2020)
+
+```bibtex
+@inproceedings{Su-CVPR-2020,
+ author = {Su, Jheng-Wei and Chu, Hung-Kuo and Huang, Jia-Bin},
+ title = {Instance-aware Image Colorization},
+ booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
+ year = {2020}
+}
+```
+
+
diff --git a/configs/inst_colorization/README_zh-CN.md b/configs/inst_colorization/README_zh-CN.md
new file mode 100644
index 0000000000..19e59c64fc
--- /dev/null
+++ b/configs/inst_colorization/README_zh-CN.md
@@ -0,0 +1,54 @@
+# Instance-aware Image Colorization (CVPR'2020)
+
+> [Instance-Aware Image Colorization](https://openaccess.thecvf.com/content_CVPR_2020/html/Su_Instance-Aware_Image_Colorization_CVPR_2020_paper.html)
+
+> **任务**: 图像上色
+
+
+
+## 摘要
+
+
+
+Image colorization is inherently an ill-posed problem with multi-modal uncertainty. Previous methods leverage the deep neural network to map input grayscale images to plausible color outputs directly. Although these learning-based methods have shown impressive performance, they usually fail on the input images that contain multiple objects. The leading cause is that existing models perform learning and colorization on the entire image. In the absence of a clear figure-ground separation, these models cannot effectively locate and learn meaningful object-level semantics. In this paper, we propose a method for achieving instance-aware colorization. Our network architecture leverages an off-the-shelf object detector to obtain cropped object images and uses an instance colorization network to extract object-level features. We use a similar network to extract the full-image features and apply a fusion module to full object-level and image-level features to predict the final colors. Both colorization networks and fusion modules are learned from a large-scale dataset. Experimental results show that our work outperforms existing methods on different quality metrics and achieves state-of-the-art performance on image colorization.
+
+
+
+
+
+
+
+## 结果和模型
+
+| Method | Download |
+| :-------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------: |
+| [instance_aware_colorization_officiial](/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth) |
+
+## 快速开始
+
+
+图像上色模型
+
+您可以使用以下命令来对一张图像进行上色。
+
+```shell
+python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth input.jpg output.jpg
+```
+
+更多细节可以参考 [Tutorial 3: inference with pre-trained models](https://mmediting.readthedocs.io/en/1.x/user_guides/3_inference.html)。
+
+
+
+
+Instance-aware Image Colorization (CVPR'2020)
+
+```bibtex
+@inproceedings{Su-CVPR-2020,
+ author = {Su, Jheng-Wei and Chu, Hung-Kuo and Huang, Jia-Bin},
+ title = {Instance-aware Image Colorization},
+ booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
+ year = {2020}
+}
+```
+
+
diff --git a/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py
new file mode 100644
index 0000000000..952bc74cda
--- /dev/null
+++ b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py
@@ -0,0 +1,59 @@
+_base_ = ['../_base_/default_runtime.py']
+
+experiment_name = 'inst-colorization_full_official_cocostuff_256x256'
+work_dir = f'./work_dirs/{experiment_name}'
+save_dir = './work_dirs/'
+
+stage = 'full'
+
+model = dict(
+ type='InstColorization',
+ data_preprocessor=dict(
+ type='EditDataPreprocessor',
+ mean=[127.5],
+ std=[127.5],
+ ),
+ image_model=dict(
+ type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'),
+ instance_model=dict(
+ type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'),
+ fusion_model=dict(
+ type='FusionNet', input_nc=4, output_nc=2, norm_type='batch'),
+ color_data_opt=dict(
+ ab_thresh=0,
+ p=1.0,
+ sample_PS=[
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ ],
+ ab_norm=110,
+ ab_max=110.,
+ ab_quant=10.,
+ l_norm=100.,
+ l_cent=50.,
+ mask_cent=0.5),
+ which_direction='AtoB',
+ loss=dict(type='HuberLoss', delta=.01))
+
+# yapf: disable
+test_pipeline = [
+ dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
+ dict(
+ type='InstanceCrop',
+ config_file='mmdet::mask_rcnn/mask-rcnn_x101-32x8d_fpn_ms-poly-3x_coco.py', # noqa
+ finesize=256,
+ box_num_upbound=5),
+ dict(
+ type='Resize',
+ keys=['img', 'cropped_img'],
+ scale=(256, 256),
+ keep_ratio=False),
+ dict(type='PackEditInputs'),
+]
diff --git a/configs/inst_colorization/metafile.yml b/configs/inst_colorization/metafile.yml
new file mode 100644
index 0000000000..c13dabfb11
--- /dev/null
+++ b/configs/inst_colorization/metafile.yml
@@ -0,0 +1,19 @@
+Collections:
+- Metadata:
+ Architecture:
+ - Instance-aware Image Colorization
+ Name: Instance-aware Image Colorization
+ Paper:
+ - https://openaccess.thecvf.com/content_CVPR_2020/html/Su_Instance-Aware_Image_Colorization_CVPR_2020_paper.html
+ README: configs/inst_colorization/README.md
+Models:
+- Config: configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py
+ In Collection: Instance-aware Image Colorization
+ Metadata:
+ Training Data: Others
+ Name: inst-colorizatioon_full_official_cocostuff-256x256
+ Results:
+ - Dataset: Others
+ Metrics: {}
+ Task: Inst_colorization
+ Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth
diff --git a/configs/lsgan/lsgan_dcgan-archi_lr1e-3-1xb128-12Mimgs_celeba-cropped-64x64.py b/configs/lsgan/lsgan_dcgan-archi_lr1e-3-1xb128-12Mimgs_celeba-cropped-64x64.py
index 8f7b1650de..4e5267d641 100644
--- a/configs/lsgan/lsgan_dcgan-archi_lr1e-3-1xb128-12Mimgs_celeba-cropped-64x64.py
+++ b/configs/lsgan/lsgan_dcgan-archi_lr1e-3-1xb128-12Mimgs_celeba-cropped-64x64.py
@@ -32,6 +32,9 @@
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
+default_hooks = dict(
+ checkpoint=dict(
+ save_best=['FID-Full-50k/fid', 'IS-50k/is'], rule=['less', 'greater']))
# METRICS
metrics = [
diff --git a/configs/lsgan/lsgan_dcgan-archi_lr1e-4-1xb128-12Mimgs_lsun-bedroom-64x64.py b/configs/lsgan/lsgan_dcgan-archi_lr1e-4-1xb128-12Mimgs_lsun-bedroom-64x64.py
index 218f8abb77..0c9fe0eeff 100644
--- a/configs/lsgan/lsgan_dcgan-archi_lr1e-4-1xb128-12Mimgs_lsun-bedroom-64x64.py
+++ b/configs/lsgan/lsgan_dcgan-archi_lr1e-4-1xb128-12Mimgs_lsun-bedroom-64x64.py
@@ -26,6 +26,8 @@
generator=dict(optimizer=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99))),
discriminator=dict(
optimizer=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99))))
+default_hooks = dict(
+ checkpoint=dict(save_best=['FID-Full-50k/fid'], rule=['less']))
# METRICS
metrics = [
diff --git a/configs/lsgan/lsgan_dcgan-archi_lr1e-4-1xb64-10Mimgs_celeba-cropped-128x128.py b/configs/lsgan/lsgan_dcgan-archi_lr1e-4-1xb64-10Mimgs_celeba-cropped-128x128.py
index 6926462e0a..f920ca4622 100644
--- a/configs/lsgan/lsgan_dcgan-archi_lr1e-4-1xb64-10Mimgs_celeba-cropped-128x128.py
+++ b/configs/lsgan/lsgan_dcgan-archi_lr1e-4-1xb64-10Mimgs_celeba-cropped-128x128.py
@@ -28,6 +28,10 @@
discriminator=dict(
optimizer=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99))))
+default_hooks = dict(
+ checkpoint=dict(
+ save_best=['FID-Full-50k/fid', 'IS-50k/is'], rule=['less', 'greater']))
+
# METRICS
metrics = [
dict(
diff --git a/configs/lsgan/lsgan_lsgan-archi_lr1e-4-1xb64-10Mimgs_lsun-bedroom-128x128.py b/configs/lsgan/lsgan_lsgan-archi_lr1e-4-1xb64-10Mimgs_lsun-bedroom-128x128.py
index c22214f75d..7bab2ddd13 100644
--- a/configs/lsgan/lsgan_lsgan-archi_lr1e-4-1xb64-10Mimgs_lsun-bedroom-128x128.py
+++ b/configs/lsgan/lsgan_lsgan-archi_lr1e-4-1xb64-10Mimgs_lsun-bedroom-128x128.py
@@ -35,7 +35,8 @@
generator=dict(optimizer=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99))),
discriminator=dict(
optimizer=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99))))
-
+default_hooks = dict(
+ checkpoint=dict(save_best=['FID-Full-50k/fid'], rule=['less']))
# adjust running config
# METRICS
metrics = [
diff --git a/configs/partial_conv/pconv_stage1_8xb1_celeba-256x256.py b/configs/partial_conv/pconv_stage1_8xb1_celeba-256x256.py
index 302f63d54f..8aed86a8a6 100644
--- a/configs/partial_conv/pconv_stage1_8xb1_celeba-256x256.py
+++ b/configs/partial_conv/pconv_stage1_8xb1_celeba-256x256.py
@@ -79,7 +79,6 @@
constructor='DefaultOptimWrapperConstructor',
type='OptimWrapper',
optimizer=dict(type='Adam', lr=0.0002))
-param_scheduler = dict(policy='Fixed', by_epoch=False)
checkpoint = dict(
type='CheckpointHook', interval=50000, by_epoch=False, out_dir=save_dir)
diff --git a/configs/pggan/pggan_8xb4-12Mimg_celeba-hq-1024x1024.py b/configs/pggan/pggan_8xb4-12Mimg_celeba-hq-1024x1024.py
index afda435661..93399eed21 100644
--- a/configs/pggan/pggan_8xb4-12Mimg_celeba-hq-1024x1024.py
+++ b/configs/pggan/pggan_8xb4-12Mimg_celeba-hq-1024x1024.py
@@ -102,6 +102,11 @@
target_keys=['ema', 'orig'])),
dict(type='PGGANFetchDataHook')
]
+default_hooks = dict(
+ checkpoint=dict(
+ max_keep_ckpts=20,
+ save_best=['swd/avg', 'ms-ssim/avg'],
+ rule=['less', 'greater']))
# METRICS
metrics = [
diff --git a/configs/pggan/pggan_8xb4-12Mimgs_celeba-cropped-128x128.py b/configs/pggan/pggan_8xb4-12Mimgs_celeba-cropped-128x128.py
index 4c29b1ea4e..bd73feb068 100644
--- a/configs/pggan/pggan_8xb4-12Mimgs_celeba-cropped-128x128.py
+++ b/configs/pggan/pggan_8xb4-12Mimgs_celeba-cropped-128x128.py
@@ -84,6 +84,11 @@
target_keys=['ema', 'orig'])),
dict(type='PGGANFetchDataHook')
]
+default_hooks = dict(
+ checkpoint=dict(
+ max_keep_ckpts=20,
+ save_best=['swd/avg', 'ms-ssim/avg'],
+ rule=['less', 'greater']))
# METRICS
metrics = [
diff --git a/configs/pix2pix/README.md b/configs/pix2pix/README.md
index d35ac59037..99636fa555 100644
--- a/configs/pix2pix/README.md
+++ b/configs/pix2pix/README.md
@@ -32,7 +32,7 @@ We use `FID` and `IS` metrics to evaluate the generation performance of pix2pix.
| Ours | facades | 124.9773 | 1.620 | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-80kiters_facades.py) | [model](https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_1x1_80k_facades_20210902_170442-c0958d50.pth) \| [log](https://download.openmmlab.com/mmgen/pix2pix/pix2pix_vanilla_unet_bn_1x1_80k_facades_20210317_172625.log.json)2 |
| Ours | aerial2maps | 122.5856 | 3.137 | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-220kiters_aerial2maps.py) | [model](https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_a2b_1x1_219200_maps_convert-bgr_20210902_170729-59a31517.pth) |
| Ours | maps2aerial | 88.4635 | 3.310 | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-220kiters_maps2aerial.py) | [model](https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_b2a_1x1_219200_maps_convert-bgr_20210902_170814-6d2eac4a.pth) |
-| Ours | edges2shoes | 84.3750 | 2.815 | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/pix2pix/pix2pix_vanilla-unet-bn_wo-jitter-flip-4xb1-190kiters_edges2shoes.py) | [model](https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_186840_edges2shoes_convert-bgr_20210902_170902-0c828552.pth) |
+| Ours | edges2shoes | 84.3750 | 2.815 | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/pix2pix/pix2pix_vanilla-unet-bn_wo-jitter-flip-1xb4-190kiters_edges2shoes.py) | [model](https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_186840_edges2shoes_convert-bgr_20210902_170902-0c828552.pth) |
`FID` comparison with official:
diff --git a/configs/pix2pix/metafile.yml b/configs/pix2pix/metafile.yml
index 84f0b43dc4..522bb5b86e 100644
--- a/configs/pix2pix/metafile.yml
+++ b/configs/pix2pix/metafile.yml
@@ -43,11 +43,11 @@ Models:
IS: 3.31
Task: Pix2pix
Weights: https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_b2a_1x1_219200_maps_convert-bgr_20210902_170814-6d2eac4a.pth
-- Config: https://github.com/open-mmlab/mmediting/tree/master/configs/pix2pix/pix2pix_vanilla-unet-bn_wo-jitter-flip-4xb1-190kiters_edges2shoes.py
+- Config: https://github.com/open-mmlab/mmediting/tree/master/configs/pix2pix/pix2pix_vanilla-unet-bn_wo-jitter-flip-1xb4-190kiters_edges2shoes.py
In Collection: Pix2Pix
Metadata:
Training Data: EDGES2SHOES
- Name: pix2pix_vanilla-unet-bn_wo-jitter-flip-4xb1-190kiters_edges2shoes
+ Name: pix2pix_vanilla-unet-bn_wo-jitter-flip-1xb4-190kiters_edges2shoes
Results:
- Dataset: EDGES2SHOES
Metrics:
diff --git a/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-220kiters_aerial2maps.py b/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-220kiters_aerial2maps.py
index 6c8b667105..5f800951e0 100644
--- a/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-220kiters_aerial2maps.py
+++ b/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-220kiters_aerial2maps.py
@@ -3,8 +3,11 @@
'../_base_/datasets/paired_imgs_256x256_crop.py',
'../_base_/gen_default_runtime.py'
]
-source_domain = domain_b = 'aerial'
-target_domain = domain_a = 'map'
+# deterministic training can improve the performance of Pix2Pix
+randomness = dict(deterministic=True)
+
+source_domain = domain_a = 'aerial'
+target_domain = domain_b = 'map'
# model settings
model = dict(
default_domain=target_domain,
@@ -61,6 +64,9 @@
])
]
+# save multi best checkpoints
+default_hooks = dict(checkpoint=dict(save_best='FID-Full/fid', rule='less'))
+
fake_nums = 1098
metrics = [
dict(
diff --git a/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-220kiters_maps2aerial.py b/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-220kiters_maps2aerial.py
index 864a71f267..120ef3f255 100644
--- a/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-220kiters_maps2aerial.py
+++ b/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-220kiters_maps2aerial.py
@@ -3,6 +3,9 @@
'../_base_/datasets/paired_imgs_256x256_crop.py',
'../_base_/gen_default_runtime.py'
]
+# deterministic training can improve the performance of Pix2Pix
+randomness = dict(deterministic=True)
+
source_domain = domain_b = 'map'
target_domain = domain_a = 'aerial'
# model settings
@@ -41,6 +44,15 @@
dataset=dict(data_root=dataroot, test_dir='val', test_mode=True))
test_dataloader = val_dataloader
+# optimizer
+optim_wrapper = dict(
+ generators=dict(
+ type='OptimWrapper',
+ optimizer=dict(type='Adam', lr=2e-4, betas=(0.5, 0.999))),
+ discriminators=dict(
+ type='OptimWrapper',
+ optimizer=dict(type='Adam', lr=2e-4, betas=(0.5, 0.999))))
+
custom_hooks = [
dict(
type='GenVisualizationHook',
@@ -52,14 +64,8 @@
])
]
-# optimizer
-optim_wrapper = dict(
- generators=dict(
- type='OptimWrapper',
- optimizer=dict(type='Adam', lr=2e-4, betas=(0.5, 0.999))),
- discriminators=dict(
- type='OptimWrapper',
- optimizer=dict(type='Adam', lr=2e-4, betas=(0.5, 0.999))))
+# save multi best checkpoints
+default_hooks = dict(checkpoint=dict(save_best='FID-Full/fid', rule='less'))
fake_nums = 1098
metrics = [
@@ -74,7 +80,7 @@
type='TransFID',
prefix='FID-Full',
fake_nums=fake_nums,
- inception_style='StyleGAN',
+ inception_style='PyTorch',
real_key=f'img_{target_domain}',
fake_key=f'fake_{target_domain}',
sample_model='orig')
diff --git a/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-80kiters_facades.py b/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-80kiters_facades.py
index 5788b7a5bd..4fbc787cb8 100644
--- a/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-80kiters_facades.py
+++ b/configs/pix2pix/pix2pix_vanilla-unet-bn_1xb1-80kiters_facades.py
@@ -3,6 +3,9 @@
'../_base_/datasets/paired_imgs_256x256_crop.py',
'../_base_/gen_default_runtime.py'
]
+# deterministic training can improve the performance of Pix2Pix
+randomness = dict(deterministic=True)
+
source_domain = domain_b = 'mask'
target_domain = domain_a = 'photo'
# model settings
@@ -11,7 +14,7 @@
reachable_domains=[target_domain],
related_domains=[target_domain, source_domain])
-train_cfg = dict(max_iters=80000, val_interval=100)
+train_cfg = dict(max_iters=80000)
# dataset settings
dataroot = 'data/pix2pix/facades'
@@ -52,7 +55,7 @@
custom_hooks = [
dict(
type='GenVisualizationHook',
- interval=100,
+ interval=5000,
fixed_input=True,
vis_kwargs_list=[
dict(type='Translation', name='trans'),
@@ -60,6 +63,9 @@
])
]
+# save multi best checkpoints
+default_hooks = dict(checkpoint=dict(save_best='FID-Full/fid', rule='less'))
+
fake_nums = 106
metrics = [
dict(
diff --git a/configs/pix2pix/pix2pix_vanilla-unet-bn_wo-jitter-flip-4xb1-190kiters_edges2shoes.py b/configs/pix2pix/pix2pix_vanilla-unet-bn_wo-jitter-flip-1xb4-190kiters_edges2shoes.py
similarity index 74%
rename from configs/pix2pix/pix2pix_vanilla-unet-bn_wo-jitter-flip-4xb1-190kiters_edges2shoes.py
rename to configs/pix2pix/pix2pix_vanilla-unet-bn_wo-jitter-flip-1xb4-190kiters_edges2shoes.py
index 48ab2e3953..8acb0efb0d 100644
--- a/configs/pix2pix/pix2pix_vanilla-unet-bn_wo-jitter-flip-4xb1-190kiters_edges2shoes.py
+++ b/configs/pix2pix/pix2pix_vanilla-unet-bn_wo-jitter-flip-1xb4-190kiters_edges2shoes.py
@@ -3,8 +3,11 @@
'../_base_/datasets/paired_imgs_256x256_crop.py',
'../_base_/gen_default_runtime.py',
]
-source_domain = domain_b = 'edges'
-target_domain = domain_a = 'photo'
+# deterministic training can improve the performance of Pix2Pix
+randomness = dict(deterministic=True)
+
+source_domain = domain_a = 'edges'
+target_domain = domain_b = 'photo'
# model settings
model = dict(
default_domain=target_domain,
@@ -15,6 +18,20 @@
# dataset settings
dataroot = './data/pix2pix/edges2shoes'
+# overwrite train pipeline since we do not use flip and crop
+_base_.train_dataloader.dataset.pipeline = [
+ dict(
+ type='LoadPairedImageFromFile',
+ key='pair',
+ domain_a='A',
+ domain_b='B',
+ color_type='color'),
+ dict(
+ type='Resize',
+ keys=['img_A', 'img_B'],
+ scale=(256, 256),
+ interpolation='bicubic'),
+]
train_pipeline = _base_.train_dataloader.dataset.pipeline
val_pipeline = _base_.val_dataloader.dataset.pipeline
test_pipeline = _base_.test_dataloader.dataset.pipeline
@@ -36,7 +53,8 @@
val_pipeline += [key_mapping, pack_input]
test_pipeline += [key_mapping, pack_input]
-train_dataloader = dict(dataset=dict(data_root=dataroot, test_dir='val'))
+train_dataloader = dict(
+ batch_size=4, dataset=dict(data_root=dataroot, test_dir='val'))
val_dataloader = dict(
dataset=dict(data_root=dataroot, test_dir='val', test_mode=True))
test_dataloader = val_dataloader
@@ -61,6 +79,9 @@
])
]
+# save multi best checkpoints
+default_hooks = dict(checkpoint=dict(save_best='FID-Full/fid', rule='less'))
+
fake_nums = 200
metrics = [
dict(
diff --git a/configs/positional_encoding_in_gans/mspie-stylegan2-config-c_c2_8xb3-1100kiters_ffhq-256-512.py b/configs/positional_encoding_in_gans/mspie-stylegan2-config-c_c2_8xb3-1100kiters_ffhq-256-512.py
index f01f24340a..5c62fdbc67 100644
--- a/configs/positional_encoding_in_gans/mspie-stylegan2-config-c_c2_8xb3-1100kiters_ffhq-256-512.py
+++ b/configs/positional_encoding_in_gans/mspie-stylegan2-config-c_c2_8xb3-1100kiters_ffhq-256-512.py
@@ -37,13 +37,19 @@
batch_size=batch_size, # set by user
dataset=dict(data_root=data_root))
+pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='Resize', scale=(256, 256)),
+ dict(type='PackEditInputs', keys=['img'])
+]
+
val_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
test_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
# define optimizer
d_reg_interval = 16
@@ -70,14 +76,16 @@
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
-
+default_hooks = dict(
+ checkpoint=dict(save_best=['FID-Full-50k/fid'], rule=['less']))
# METRICS
metrics = [
dict(
type='FrechetInceptionDistance',
- prefix='FID-Full-50k',
+ prefix='FID-50k',
fake_nums=50000,
- inception_style='StyleGAN',
+ real_nums=50000,
+ inception_style='pytorch',
sample_model='ema'),
dict(type='PrecisionAndRecall', fake_nums=10000, prefix='PR-10K')
]
diff --git a/configs/positional_encoding_in_gans/mspie-stylegan2-config-d_c2_8xb3-1100kiters_ffhq-256-512.py b/configs/positional_encoding_in_gans/mspie-stylegan2-config-d_c2_8xb3-1100kiters_ffhq-256-512.py
index a75161821b..5d89ec7c59 100644
--- a/configs/positional_encoding_in_gans/mspie-stylegan2-config-d_c2_8xb3-1100kiters_ffhq-256-512.py
+++ b/configs/positional_encoding_in_gans/mspie-stylegan2-config-d_c2_8xb3-1100kiters_ffhq-256-512.py
@@ -36,13 +36,19 @@
batch_size=batch_size, # set by user
dataset=dict(data_root=data_root))
+pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='Resize', scale=(256, 256)),
+ dict(type='PackEditInputs', keys=['img'])
+]
+
val_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
test_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
# define optimizer
d_reg_interval = 16
@@ -69,14 +75,16 @@
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
-
+default_hooks = dict(
+ checkpoint=dict(save_best=['FID-Full-50k/fid'], rule=['less']))
# METRICS
metrics = [
dict(
type='FrechetInceptionDistance',
- prefix='FID-Full-50k',
+ prefix='FID-50k',
fake_nums=50000,
- inception_style='StyleGAN',
+ real_nums=50000,
+ inception_style='pytorch',
sample_model='ema'),
dict(type='PrecisionAndRecall', fake_nums=10000, prefix='PR-10K')
]
diff --git a/configs/positional_encoding_in_gans/mspie-stylegan2-config-e_c2_8xb3-1100kiters_ffhq-256-512.py b/configs/positional_encoding_in_gans/mspie-stylegan2-config-e_c2_8xb3-1100kiters_ffhq-256-512.py
index be2bd63c96..05800d3260 100644
--- a/configs/positional_encoding_in_gans/mspie-stylegan2-config-e_c2_8xb3-1100kiters_ffhq-256-512.py
+++ b/configs/positional_encoding_in_gans/mspie-stylegan2-config-e_c2_8xb3-1100kiters_ffhq-256-512.py
@@ -42,13 +42,19 @@
batch_size=batch_size, # set by user
dataset=dict(data_root=data_root))
+pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='Resize', scale=(256, 256)),
+ dict(type='PackEditInputs', keys=['img'])
+]
+
val_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
test_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
# define optimizer
d_reg_interval = 16
@@ -75,14 +81,16 @@
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
-
+default_hooks = dict(
+ checkpoint=dict(save_best=['FID-Full-50k/fid'], rule=['less']))
# METRICS
metrics = [
dict(
type='FrechetInceptionDistance',
- prefix='FID-Full-50k',
+ prefix='FID-50k',
fake_nums=50000,
- inception_style='StyleGAN',
+ real_nums=50000,
+ inception_style='pytorch',
sample_model='ema'),
dict(type='PrecisionAndRecall', fake_nums=10000, prefix='PR-10K')
]
diff --git a/configs/positional_encoding_in_gans/mspie-stylegan2-config-f_c1_8xb2-1600kiters_ffhq-256-1024.py b/configs/positional_encoding_in_gans/mspie-stylegan2-config-f_c1_8xb2-1600kiters_ffhq-256-1024.py
index f0ffe871ee..fb766571f5 100644
--- a/configs/positional_encoding_in_gans/mspie-stylegan2-config-f_c1_8xb2-1600kiters_ffhq-256-1024.py
+++ b/configs/positional_encoding_in_gans/mspie-stylegan2-config-f_c1_8xb2-1600kiters_ffhq-256-1024.py
@@ -11,13 +11,19 @@
batch_size=batch_size, # set by user
dataset=dict(data_root=data_root))
+pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='Resize', scale=(256, 256)),
+ dict(type='PackEditInputs', keys=['img'])
+]
+
val_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
test_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
ema_half_life = 10.
ema_config = dict(
@@ -73,14 +79,16 @@
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
-
+default_hooks = dict(
+ checkpoint=dict(save_best=['FID-Full-50k/fid'], rule=['less']))
# METRICS
metrics = [
dict(
type='FrechetInceptionDistance',
- prefix='FID-Full-50k',
+ prefix='FID-50k',
fake_nums=50000,
- inception_style='StyleGAN',
+ real_nums=50000,
+ inception_style='pytorch',
sample_model='ema'),
dict(type='PrecisionAndRecall', fake_nums=10000, prefix='PR-10K')
]
diff --git a/configs/positional_encoding_in_gans/mspie-stylegan2-config-f_c2_8xb3-1100kiters_ffhq-256-512.py b/configs/positional_encoding_in_gans/mspie-stylegan2-config-f_c2_8xb3-1100kiters_ffhq-256-512.py
index ee813aae22..8727ababea 100644
--- a/configs/positional_encoding_in_gans/mspie-stylegan2-config-f_c2_8xb3-1100kiters_ffhq-256-512.py
+++ b/configs/positional_encoding_in_gans/mspie-stylegan2-config-f_c2_8xb3-1100kiters_ffhq-256-512.py
@@ -40,13 +40,19 @@
batch_size=batch_size, # set by user
dataset=dict(data_root=data_root))
+pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='Resize', scale=(256, 256)),
+ dict(type='PackEditInputs', keys=['img'])
+]
+
val_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
test_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
# define optimizer
d_reg_interval = 16
@@ -73,14 +79,16 @@
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
-
+default_hooks = dict(
+ checkpoint=dict(save_best=['FID-Full-50k/fid'], rule=['less']))
# METRICS
metrics = [
dict(
type='FrechetInceptionDistance',
- prefix='FID-Full-50k',
+ prefix='FID-50k',
fake_nums=50000,
- inception_style='StyleGAN',
+ real_nums=50000,
+ inception_style='pytorch',
sample_model='ema'),
dict(type='PrecisionAndRecall', fake_nums=10000, prefix='PR-10K')
]
diff --git a/configs/positional_encoding_in_gans/mspie-stylegan2-config-f_c2_8xb3-1100kiters_ffhq-256-896.py b/configs/positional_encoding_in_gans/mspie-stylegan2-config-f_c2_8xb3-1100kiters_ffhq-256-896.py
index 0dc416acc2..9a6129d253 100644
--- a/configs/positional_encoding_in_gans/mspie-stylegan2-config-f_c2_8xb3-1100kiters_ffhq-256-896.py
+++ b/configs/positional_encoding_in_gans/mspie-stylegan2-config-f_c2_8xb3-1100kiters_ffhq-256-896.py
@@ -40,13 +40,19 @@
batch_size=batch_size, # set by user
dataset=dict(data_root=data_root))
+pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='Resize', scale=(256, 256)),
+ dict(type='PackEditInputs', keys=['img'])
+]
+
val_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
test_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
# define optimizer
d_reg_interval = 16
@@ -73,14 +79,16 @@
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
-
+default_hooks = dict(
+ checkpoint=dict(save_best=['FID-Full-50k/fid'], rule=['less']))
# METRICS
metrics = [
dict(
type='FrechetInceptionDistance',
- prefix='FID-Full-50k',
+ prefix='FID-50k',
fake_nums=50000,
- inception_style='StyleGAN',
+ real_nums=50000,
+ inception_style='pytorch',
sample_model='ema'),
dict(type='PrecisionAndRecall', fake_nums=10000, prefix='PR-10K')
]
diff --git a/configs/positional_encoding_in_gans/mspie-stylegan2-config-g_c1_8xb3-1100kiters_ffhq-256-512.py b/configs/positional_encoding_in_gans/mspie-stylegan2-config-g_c1_8xb3-1100kiters_ffhq-256-512.py
index 9f908370df..51d8d304dd 100644
--- a/configs/positional_encoding_in_gans/mspie-stylegan2-config-g_c1_8xb3-1100kiters_ffhq-256-512.py
+++ b/configs/positional_encoding_in_gans/mspie-stylegan2-config-g_c1_8xb3-1100kiters_ffhq-256-512.py
@@ -44,13 +44,19 @@
batch_size=batch_size, # set by user
dataset=dict(data_root=data_root))
+pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='Resize', scale=(256, 256)),
+ dict(type='PackEditInputs', keys=['img'])
+]
+
val_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
test_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
# define optimizer
d_reg_interval = 16
@@ -77,14 +83,16 @@
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
-
+default_hooks = dict(
+ checkpoint=dict(save_best=['FID-Full-50k/fid'], rule=['less']))
# METRICS
metrics = [
dict(
type='FrechetInceptionDistance',
- prefix='FID-Full-50k',
+ prefix='FID-50k',
fake_nums=50000,
- inception_style='StyleGAN',
+ real_nums=50000,
+ inception_style='pytorch',
sample_model='ema'),
dict(type='PrecisionAndRecall', fake_nums=10000, prefix='PR-10K')
]
diff --git a/configs/positional_encoding_in_gans/mspie-stylegan2-config-h_c2_8xb3-1100kiters_ffhq-256-512.py b/configs/positional_encoding_in_gans/mspie-stylegan2-config-h_c2_8xb3-1100kiters_ffhq-256-512.py
index c65ef00b8b..407a420a9c 100644
--- a/configs/positional_encoding_in_gans/mspie-stylegan2-config-h_c2_8xb3-1100kiters_ffhq-256-512.py
+++ b/configs/positional_encoding_in_gans/mspie-stylegan2-config-h_c2_8xb3-1100kiters_ffhq-256-512.py
@@ -39,13 +39,19 @@
batch_size=batch_size, # set by user
dataset=dict(data_root=data_root))
+pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='Resize', scale=(256, 256)),
+ dict(type='PackEditInputs', keys=['img'])
+]
+
val_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
test_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
# define optimizer
d_reg_interval = 16
@@ -72,14 +78,16 @@
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
-
+default_hooks = dict(
+ checkpoint=dict(save_best=['FID-Full-50k/fid'], rule=['less']))
# METRICS
metrics = [
dict(
type='FrechetInceptionDistance',
- prefix='FID-Full-50k',
+ prefix='FID-50k',
fake_nums=50000,
- inception_style='StyleGAN',
+ real_nums=50000,
+ inception_style='pytorch',
sample_model='ema'),
dict(type='PrecisionAndRecall', fake_nums=10000, prefix='PR-10K')
]
diff --git a/configs/positional_encoding_in_gans/mspie-stylegan2-config-i_c2_8xb3-1100kiters_ffhq-256-512.py b/configs/positional_encoding_in_gans/mspie-stylegan2-config-i_c2_8xb3-1100kiters_ffhq-256-512.py
index 3dc1edaddd..664eb19cd8 100644
--- a/configs/positional_encoding_in_gans/mspie-stylegan2-config-i_c2_8xb3-1100kiters_ffhq-256-512.py
+++ b/configs/positional_encoding_in_gans/mspie-stylegan2-config-i_c2_8xb3-1100kiters_ffhq-256-512.py
@@ -38,13 +38,19 @@
batch_size=batch_size, # set by user
dataset=dict(data_root=data_root))
+pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='Resize', scale=(256, 256)),
+ dict(type='PackEditInputs', keys=['img'])
+]
+
val_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
test_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
# define optimizer
d_reg_interval = 16
@@ -71,14 +77,16 @@
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
-
+default_hooks = dict(
+ checkpoint=dict(save_best=['FID-Full-50k/fid'], rule=['less']))
# METRICS
metrics = [
dict(
type='FrechetInceptionDistance',
- prefix='FID-Full-50k',
+ prefix='FID-50k',
fake_nums=50000,
- inception_style='StyleGAN',
+ real_nums=50000,
+ inception_style='pytorch',
sample_model='ema'),
dict(type='PrecisionAndRecall', fake_nums=10000, prefix='PR-10K')
]
diff --git a/configs/positional_encoding_in_gans/mspie-stylegan2-config-j_c2_8xb3-1100kiters_ffhq-256-512.py b/configs/positional_encoding_in_gans/mspie-stylegan2-config-j_c2_8xb3-1100kiters_ffhq-256-512.py
index a16b1b78ad..1b620af8f4 100644
--- a/configs/positional_encoding_in_gans/mspie-stylegan2-config-j_c2_8xb3-1100kiters_ffhq-256-512.py
+++ b/configs/positional_encoding_in_gans/mspie-stylegan2-config-j_c2_8xb3-1100kiters_ffhq-256-512.py
@@ -44,13 +44,19 @@
batch_size=batch_size, # set by user
dataset=dict(data_root=data_root))
+pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='Resize', scale=(256, 256)),
+ dict(type='PackEditInputs', keys=['img'])
+]
+
val_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
test_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
# define optimizer
d_reg_interval = 16
@@ -77,14 +83,16 @@
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
-
+default_hooks = dict(
+ checkpoint=dict(save_best=['FID-Full-50k/fid'], rule=['less']))
# METRICS
metrics = [
dict(
type='FrechetInceptionDistance',
- prefix='FID-Full-50k',
+ prefix='FID-50k',
fake_nums=50000,
- inception_style='StyleGAN',
+ real_nums=50000,
+ inception_style='pytorch',
sample_model='ema'),
dict(type='PrecisionAndRecall', fake_nums=10000, prefix='PR-10K')
]
diff --git a/configs/positional_encoding_in_gans/mspie-stylegan2-config-k_c2_8xb3-1100kiters_ffhq-256-512.py b/configs/positional_encoding_in_gans/mspie-stylegan2-config-k_c2_8xb3-1100kiters_ffhq-256-512.py
index 7c6ffe01e8..281afb77aa 100644
--- a/configs/positional_encoding_in_gans/mspie-stylegan2-config-k_c2_8xb3-1100kiters_ffhq-256-512.py
+++ b/configs/positional_encoding_in_gans/mspie-stylegan2-config-k_c2_8xb3-1100kiters_ffhq-256-512.py
@@ -43,13 +43,19 @@
batch_size=batch_size, # set by user
dataset=dict(data_root=data_root))
+pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='Resize', scale=(256, 256)),
+ dict(type='PackEditInputs', keys=['img'])
+]
+
val_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
test_dataloader = dict(
batch_size=batch_size, # set by user
- dataset=dict(data_root=data_root))
+ dataset=dict(data_root=data_root, pipeline=pipeline))
# define optimizer
d_reg_interval = 16
@@ -76,14 +82,16 @@
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
-
+default_hooks = dict(
+ checkpoint=dict(save_best=['FID-Full-50k/fid'], rule=['less']))
# METRICS
metrics = [
dict(
type='FrechetInceptionDistance',
- prefix='FID-Full-50k',
+ prefix='FID-50k',
fake_nums=50000,
- inception_style='StyleGAN',
+ real_nums=50000,
+ inception_style='pytorch',
sample_model='ema'),
dict(type='PrecisionAndRecall', fake_nums=10000, prefix='PR-10K')
]
diff --git a/configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py b/configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py
index c5213dc154..d7c058baa2 100644
--- a/configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py
+++ b/configs/rdn/rdn_x2c64b16_1xb16-1000k_div2k.py
@@ -71,6 +71,7 @@
train_dataloader = dict(
num_workers=4,
+ batch_size=16,
persistent_workers=False,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
@@ -109,7 +110,7 @@
optim_wrapper = dict(
constructor='DefaultOptimWrapperConstructor',
type='OptimWrapper',
- optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.99)))
+ optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)))
# learning policy
param_scheduler = dict(
diff --git a/configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py b/configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py
index 42c9a42a06..dcbcc06bdc 100644
--- a/configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py
+++ b/configs/rdn/rdn_x3c64b16_1xb16-1000k_div2k.py
@@ -71,6 +71,7 @@
train_dataloader = dict(
num_workers=4,
+ batch_size=16,
persistent_workers=False,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
@@ -109,7 +110,7 @@
optim_wrapper = dict(
constructor='DefaultOptimWrapperConstructor',
type='OptimWrapper',
- optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.99)))
+ optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)))
# learning policy
param_scheduler = dict(
diff --git a/configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py b/configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py
index f7566975ff..d6ca1345a0 100644
--- a/configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py
+++ b/configs/rdn/rdn_x4c64b16_1xb16-1000k_div2k.py
@@ -71,6 +71,7 @@
train_dataloader = dict(
num_workers=4,
+ batch_size=16,
persistent_workers=False,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
@@ -110,7 +111,7 @@
optim_wrapper = dict(
constructor='DefaultOptimWrapperConstructor',
type='OptimWrapper',
- optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.99)))
+ optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)))
# learning policy
param_scheduler = dict(
diff --git a/configs/real_basicvsr/realbasicvsr_c64b20-1x30x8_8xb1-lr5e-5-150k_reds.py b/configs/real_basicvsr/realbasicvsr_c64b20-1x30x8_8xb1-lr5e-5-150k_reds.py
index 617b193d61..faa44890be 100644
--- a/configs/real_basicvsr/realbasicvsr_c64b20-1x30x8_8xb1-lr5e-5-150k_reds.py
+++ b/configs/real_basicvsr/realbasicvsr_c64b20-1x30x8_8xb1-lr5e-5-150k_reds.py
@@ -50,10 +50,11 @@
is_use_sharpened_gt_in_pixel=True,
is_use_sharpened_gt_in_percep=True,
is_use_sharpened_gt_in_gan=False,
+ is_use_ema=True,
data_preprocessor=dict(
type='EditDataPreprocessor',
- mean=[0, 0, 0],
- std=[1, 1, 1],
+ mean=[0., 0., 0.],
+ std=[255., 255., 255.],
input_view=(1, -1, 1, 1),
output_view=(1, -1, 1, 1),
))
diff --git a/configs/real_basicvsr/realbasicvsr_wogan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py b/configs/real_basicvsr/realbasicvsr_wogan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py
index 7e12c1e425..cba60a8dcf 100644
--- a/configs/real_basicvsr/realbasicvsr_wogan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py
+++ b/configs/real_basicvsr/realbasicvsr_wogan-c64b20-2x30x8_8xb2-lr1e-4-300k_reds.py
@@ -22,10 +22,11 @@
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'),
cleaning_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'),
is_use_sharpened_gt_in_pixel=True,
+ is_use_ema=True,
data_preprocessor=dict(
type='EditDataPreprocessor',
mean=[0., 0., 0.],
- std=[1., 1., 1.],
+ std=[255., 255., 255.],
input_view=(1, -1, 1, 1),
output_view=(1, -1, 1, 1),
))
@@ -203,7 +204,6 @@
filename_tmpl='{:04d}.png'),
dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
- dict(type='RescaleToZeroOne', keys=['img', 'gt']),
dict(type='PackEditInputs')
]
@@ -214,7 +214,6 @@
filename_tmpl='{:08d}.png'),
dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
- dict(type='RescaleToZeroOne', keys=['img']),
dict(type='PackEditInputs')
]
@@ -244,7 +243,6 @@
metainfo=dict(dataset_type='udm10', task_name='vsr'),
data_root=f'{data_root}/UDM10',
data_prefix=dict(img='BIx4', gt='GT'),
- num_input_frames=15,
pipeline=val_pipeline))
test_dataloader = dict(
@@ -257,14 +255,13 @@
metainfo=dict(dataset_type='video_lq', task_name='vsr'),
data_root=f'{data_root}/VideoLQ',
data_prefix=dict(img='', gt=''),
- num_input_frames=15,
pipeline=test_pipeline))
val_evaluator = [
dict(type='PSNR'),
dict(type='SSIM'),
]
-# test_evaluator = [dict(type='NIQE', convert_to='Y')]
+
test_evaluator = [dict(type='NIQE', input_order='CHW', convert_to='Y')]
# test_evaluator = val_evaluator
@@ -275,9 +272,10 @@
# optimizer
optim_wrapper = dict(
- constructor='DefaultOptimWrapperConstructor',
- type='OptimWrapper',
- optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.99)))
+ constructor='MultiOptimWrapperConstructor',
+ generator=dict(
+ type='OptimWrapper',
+ optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.99))))
# NO learning policy
@@ -297,6 +295,16 @@
sampler_seed=dict(type='DistSamplerSeedHook'),
)
+custom_hooks = [
+ dict(type='BasicVisualizationHook', interval=5),
+ dict(
+ type='ExponentialMovingAverageHook',
+ module_keys=('generator_ema'),
+ interval=1,
+ interp_cfg=dict(momentum=0.999),
+ )
+]
+
model_wrapper_cfg = dict(
type='MMSeparateDistributedDataParallel',
broadcast_buffers=False,
diff --git a/configs/real_esrgan/realesrgan_c64b23g32_4xb12-lr1e-4-400k_df2k-ost.py b/configs/real_esrgan/realesrgan_c64b23g32_4xb12-lr1e-4-400k_df2k-ost.py
index 4999e6f5d7..b27a9adf79 100644
--- a/configs/real_esrgan/realesrgan_c64b23g32_4xb12-lr1e-4-400k_df2k-ost.py
+++ b/configs/real_esrgan/realesrgan_c64b23g32_4xb12-lr1e-4-400k_df2k-ost.py
@@ -47,12 +47,13 @@
is_use_sharpened_gt_in_pixel=True,
is_use_sharpened_gt_in_percep=True,
is_use_sharpened_gt_in_gan=False,
+ is_use_ema=True,
train_cfg=dict(start_iter=1000000),
test_cfg=dict(),
data_preprocessor=dict(
type='EditDataPreprocessor',
- mean=[0, 0, 0],
- std=[1, 1, 1],
+ mean=[0., 0., 0.],
+ std=[255., 255., 255.],
))
train_cfg = dict(
diff --git a/configs/real_esrgan/realesrnet_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py b/configs/real_esrgan/realesrnet_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py
index 46b72fd757..ed8eb9381b 100644
--- a/configs/real_esrgan/realesrnet_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py
+++ b/configs/real_esrgan/realesrnet_c64b23g32_4xb12-lr2e-4-1000k_df2k-ost.py
@@ -23,12 +23,13 @@
upscale_factor=scale),
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'),
is_use_sharpened_gt_in_pixel=True,
+ is_use_ema=True,
train_cfg=dict(),
test_cfg=dict(),
data_preprocessor=dict(
type='EditDataPreprocessor',
- mean=[0, 0, 0],
- std=[1, 1, 1],
+ mean=[0., 0., 0.],
+ std=[255., 255., 255.],
))
train_pipeline = [
@@ -177,7 +178,6 @@
val_pipeline = [
dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
- dict(type='RescaleToZeroOne', keys=['img', 'gt']),
dict(type='PackEditInputs')
]
@@ -205,14 +205,13 @@
dataset=dict(
type=dataset_type,
metainfo=dict(dataset_type='set5', task_name='real_sr'),
- data_root='data/set5',
- data_prefix=dict(gt='HR', img='bicLRx4'),
+ data_root='data/Set5',
+ data_prefix=dict(gt='GTmod12', img='LRbicx4'),
pipeline=val_pipeline))
test_dataloader = val_dataloader
val_evaluator = [
- dict(type='MAE'),
dict(type='PSNR'),
dict(type='SSIM'),
]
@@ -253,7 +252,7 @@
vis_backends=vis_backends,
fn_key='gt_path',
img_keys=['gt_img', 'input', 'pred_img'],
- bgr2rgb=True)
+ bgr2rgb=False)
custom_hooks = [
dict(type='BasicVisualizationHook', interval=1),
dict(
diff --git a/configs/singan/singan_fish.py b/configs/singan/singan_fish.py
index df2f946408..1ffbf734eb 100644
--- a/configs/singan/singan_fish.py
+++ b/configs/singan/singan_fish.py
@@ -41,7 +41,7 @@
dataset_type = 'SinGANDataset'
data_root = './data/singan/fish-crop.jpg'
-pipeline = [dict(type='PackEditInputs')]
+pipeline = [dict(type='PackEditInputs', pack_all=True)]
dataset = dict(
type=dataset_type,
data_root=data_root,
diff --git a/configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py b/configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py
index e2032424a5..cf39811509 100644
--- a/configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py
+++ b/configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py
@@ -71,6 +71,7 @@
train_dataloader = dict(
num_workers=4,
+ batch_size=16,
persistent_workers=False,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
@@ -109,7 +110,7 @@
optim_wrapper = dict(
constructor='DefaultOptimWrapperConstructor',
type='OptimWrapper',
- optimizer=dict(type='Adam', lr=2e-4, betas=(0.9, 0.99)))
+ optimizer=dict(type='Adam', lr=2e-4, betas=(0.9, 0.999)))
# learning policy
param_scheduler = dict(
diff --git a/configs/styleganv3/README.md b/configs/styleganv3/README.md
index fb5636b7bf..bcb1c677d4 100644
--- a/configs/styleganv3/README.md
+++ b/configs/styleganv3/README.md
@@ -43,8 +43,6 @@ For user convenience, we also offer the converted version of official weights.
| stylegan3-t | ffhq 1024x1024 | 490000 | 3.37\* | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/styleganv3/stylegan3-t_gamma32.8_8xb4-fp16-noaug_ffhq-1024x1024.py) | [log](https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_noaug_fp16_gamma32.8_ffhq_1024_b4x8_20220322_090417.log.json) | [model](https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_noaug_fp16_gamma32.8_ffhq_1024_b4x8_best_fid_iter_490000_20220401_120733-4ff83434.pth) |
| stylegan3-t-ada | metface 1024x1024 | 130000 | 15.09 | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/styleganv3/stylegan3-t_ada-gamma6.6_8xb4-fp16_metfaces-1024x1024.py) | [log](https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8_20220328_142211.log.json) | [model](https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8_best_fid_iter_130000_20220401_115101-f2ef498e.pth) |
-Note\*: This setting still needs a few days to run through, we put out currently the best checkpoint, and we will update the results the first time on the end of the experiment.
-
### Experimental Settings
| Model | Dataset | Iter | FID50k | Config | Log | Download |
diff --git a/configs/ttsr/README.md b/configs/ttsr/README.md
index 2eadc57530..bf9c63bd8f 100644
--- a/configs/ttsr/README.md
+++ b/configs/ttsr/README.md
@@ -23,7 +23,7 @@ We study on image super-resolution (SR), which aims to recover realistic texture
Evaluated on CUFED dataset (RGB channels), `scale` pixels in each border are cropped before evaluation.
The metrics are `PSNR and SSIM` .
-| Method | scale | SSIM | SSIM | GPU Info | Download |
+| Method | scale | PSNR | SSIM | GPU Info | Download |
| :----------------------------------------------------------------------------------: | :---: | :-----: | :----: | :----------: | :------------------------------------------------------------------------------------: |
| [ttsr-rec_x4_c64b16_g1_200k_CUFED](/configs/ttsr/ttsr-rec_x4c64b16_1xb9-200k_CUFED.py) | x4 | 25.2433 | 0.7491 | 1 (TITAN Xp) | [model](https://download.openmmlab.com/mmediting/restorers/ttsr/ttsr-rec_x4_c64b16_g1_200k_CUFED_20210525-b0dba584.pth) \| [log](https://download.openmmlab.com/mmediting/restorers/ttsr/ttsr-rec_x4_c64b16_g1_200k_CUFED_20210525-b0dba584.log.json) |
| [ttsr-gan_x4_c64b16_g1_500k_CUFED](/configs/ttsr/ttsr-gan_x4c64b16_1xb9-500k_CUFED.py) | x4 | 24.6075 | 0.7234 | 1 (TITAN Xp) | [model](https://download.openmmlab.com/mmediting/restorers/ttsr/ttsr-gan_x4_c64b16_g1_500k_CUFED_20210626-2ab28ca0.pth) \| [log](https://download.openmmlab.com/mmediting/restorers/ttsr/ttsr-gan_x4_c64b16_g1_500k_CUFED_20210626-2ab28ca0.log.json) |
diff --git a/configs/ttsr/metafile.yml b/configs/ttsr/metafile.yml
index a5dbdaf1c4..4450bef023 100644
--- a/configs/ttsr/metafile.yml
+++ b/configs/ttsr/metafile.yml
@@ -16,6 +16,7 @@ Models:
Results:
- Dataset: CUFED
Metrics:
+ PSNR: 25.2433
SSIM: 0.7491
Task: Ttsr
Weights: https://download.openmmlab.com/mmediting/restorers/ttsr/ttsr-rec_x4_c64b16_g1_200k_CUFED_20210525-b0dba584.pth
@@ -28,6 +29,7 @@ Models:
Results:
- Dataset: CUFED
Metrics:
+ PSNR: 24.6075
SSIM: 0.7234
Task: Ttsr
Weights: https://download.openmmlab.com/mmediting/restorers/ttsr/ttsr-gan_x4_c64b16_g1_500k_CUFED_20210626-2ab28ca0.pth
diff --git a/configs/ttsr/ttsr-gan_x4c64b16_1xb9-500k_CUFED.py b/configs/ttsr/ttsr-gan_x4c64b16_1xb9-500k_CUFED.py
index 68f3e17bbd..659edd82ec 100644
--- a/configs/ttsr/ttsr-gan_x4c64b16_1xb9-500k_CUFED.py
+++ b/configs/ttsr/ttsr-gan_x4c64b16_1xb9-500k_CUFED.py
@@ -61,7 +61,7 @@
optimizer=dict(type='Adam', lr=1e-5, betas=(0.9, 0.999))),
discriminator=dict(
type='OptimWrapper',
- optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.999))))
+ optimizer=dict(type='Adam', lr=1e-5, betas=(0.9, 0.999))))
# learning policy
param_scheduler = dict(
diff --git a/configs/wgan-gp/wgangp_GN_1xb64-160kiters_celeba-cropped-128x128.py b/configs/wgan-gp/wgangp_GN_1xb64-160kiters_celeba-cropped-128x128.py
index 9e3a41c200..3888b4516b 100644
--- a/configs/wgan-gp/wgangp_GN_1xb64-160kiters_celeba-cropped-128x128.py
+++ b/configs/wgan-gp/wgangp_GN_1xb64-160kiters_celeba-cropped-128x128.py
@@ -26,7 +26,7 @@
loss_config=loss_config)
# `batch_size` and `data_root` need to be set.
-batch_size = 4
+batch_size = 64
data_root = './data/celeba-cropped/cropped_images_aligned_png/'
train_dataloader = dict(
batch_size=batch_size, dataset=dict(data_root=data_root))
@@ -47,7 +47,7 @@
custom_hooks = [
dict(
type='GenVisualizationHook',
- interval=1000,
+ interval=5000,
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]
@@ -65,5 +65,8 @@
image_shape=(3, 128, 128))
]
+# save multi best checkpoints
+default_hooks = dict(checkpoint=dict(save_best='swd/avg'))
+
val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
diff --git a/demo/colorization_demo.py b/demo/colorization_demo.py
new file mode 100644
index 0000000000..926515f637
--- /dev/null
+++ b/demo/colorization_demo.py
@@ -0,0 +1,43 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+
+import mmcv
+import torch
+
+from mmedit.apis import colorization_inference, init_model
+from mmedit.utils import modify_args, tensor2img
+
+
+def parse_args():
+ modify_args()
+ parser = argparse.ArgumentParser(description='Colorzation demo')
+ parser.add_argument('config', help='test config file path')
+ parser.add_argument('checkpoints', help='checkpoints file path')
+ parser.add_argument('img_path', help='path to input image file')
+ parser.add_argument('save_path', help='path to save generation result')
+ parser.add_argument(
+ '--imshow', action='store_true', help='whether show image with opencv')
+ parser.add_argument('--device', type=int, default=0, help='CUDA device id')
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if args.device < 0 or not torch.cuda.is_available():
+ device = torch.device('cpu')
+ else:
+ device = torch.device('cuda', args.device)
+
+ model = init_model(args.config, args.checkpoints, device=device)
+ output = colorization_inference(model, args.img_path)
+ result = tensor2img(output)
+ mmcv.imwrite(result, args.save_path)
+
+ if args.imshow:
+ mmcv.imshow(output, 'predicted generation result')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/demo/singan_demo.py b/demo/singan_demo.py
new file mode 100644
index 0000000000..9633adfe1a
--- /dev/null
+++ b/demo/singan_demo.py
@@ -0,0 +1,113 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import os
+import sys
+
+import mmcv
+import torch
+from mmengine import Config, print_log
+from mmengine.logging import MMLogger
+from mmengine.runner import load_checkpoint, set_random_seed
+
+# yapf: disable
+sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa
+
+from mmedit.engine import * # isort:skip # noqa: F401,F403,E402
+from mmedit.datasets import * # isort:skip # noqa: F401,F403,E402
+from mmedit.models import * # isort:skip # noqa: F401,F403,E402
+
+from mmedit.registry import MODELS # isort:skip # noqa
+
+# yapf: enable
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Evaluate a GAN model')
+ parser.add_argument('config', help='evaluation config file path')
+ parser.add_argument('checkpoint', help='checkpoint file')
+ parser.add_argument('--seed', type=int, default=2021, help='random seed')
+ parser.add_argument(
+ '--deterministic',
+ action='store_true',
+ help='whether to set deterministic options for CUDNN backend.')
+ parser.add_argument(
+ '--samples-path',
+ type=str,
+ default='./',
+ help='path to store images. If not given, remove it after evaluation\
+ finished')
+ parser.add_argument(
+ '--save-prev-res',
+ action='store_true',
+ help='whether to store the results from previous stages')
+ parser.add_argument(
+ '--num-samples',
+ type=int,
+ default=10,
+ help='the number of synthesized samples')
+ args = parser.parse_args()
+ return args
+
+
+def _tensor2img(img):
+ img = img.permute(1, 2, 0)
+ img = ((img + 1) / 2 * 255).clamp(0, 255).to(torch.uint8)
+
+ return img.cpu().numpy()
+
+
+@torch.no_grad()
+def main():
+ MMLogger.get_instance('mmedit')
+
+ args = parse_args()
+ cfg = Config.fromfile(args.config)
+ # set cudnn_benchmark
+ if cfg.get('cudnn_benchmark', False):
+ torch.backends.cudnn.benchmark = True
+
+ # set random seeds
+ if args.seed is not None:
+ set_random_seed(args.seed, deterministic=args.deterministic)
+
+ # set scope manually
+ cfg.model['_scope_'] = 'mmedit'
+ # build the model and load checkpoint
+ model = MODELS.build(cfg.model)
+
+ model.eval()
+
+ # load ckpt
+ print_log(f'Loading ckpt from {args.checkpoint}')
+ _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
+
+ # add dp wrapper
+ if torch.cuda.is_available():
+ model = model.cuda()
+
+ for sample_iter in range(args.num_samples):
+ outputs = model.test_step(
+ dict(inputs=dict(num_batches=1, get_prev_res=args.save_prev_res)))
+
+ # store results from previous stages
+ if args.save_prev_res:
+ fake_img = outputs[0].fake_img.data
+ prev_res_list = outputs[0].prev_res_list
+ prev_res_list.append(fake_img)
+ for i, img in enumerate(prev_res_list):
+ img = _tensor2img(img)
+ mmcv.imwrite(
+ img,
+ os.path.join(args.samples_path, f'stage{i}',
+ f'rand_sample_{sample_iter}.png'))
+ # just store the final result
+ else:
+ img = _tensor2img(outputs[0].fake_img.data)
+ mmcv.imwrite(
+ img,
+ os.path.join(args.samples_path,
+ f'rand_sample_{sample_iter}.png'))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/docs/en/.dev_scripts/update_dataset_zoo.sh b/docs/en/.dev_scripts/update_dataset_zoo.sh
index 5f8c02a5e5..ae087ca36e 100644
--- a/docs/en/.dev_scripts/update_dataset_zoo.sh
+++ b/docs/en/.dev_scripts/update_dataset_zoo.sh
@@ -1,39 +1,39 @@
# generate all tasks dataset_zoo
-cat ../../../tools/dataset_converters/super-resolution/README.md > ../dataset_zoo/1_super_resolution_datasets.md
-cat ../../../tools/dataset_converters/inpainting/README.md > ../dataset_zoo/2_inpainting_datasets.md
-cat ../../../tools/dataset_converters/matting/README.md > ../dataset_zoo/3_matting_datasets.md
-cat ../../../tools/dataset_converters/video-interpolation/README.md > ../dataset_zoo/4_video_interpolation_datasets.md
-cat ../../../tools/dataset_converters/unconditional_gans/README.md > ../dataset_zoo/5_unconditional_gans_datasets.md
-cat ../../../tools/dataset_converters/image_translation/README.md > ../dataset_zoo/6_image_translation_datasets.md
+cat ../../../tools/dataset_converters/super-resolution/README.md > dataset_zoo/1_super_resolution_datasets.md
+cat ../../../tools/dataset_converters/inpainting/README.md > dataset_zoo/2_inpainting_datasets.md
+cat ../../../tools/dataset_converters/matting/README.md > dataset_zoo/3_matting_datasets.md
+cat ../../../tools/dataset_converters/video-interpolation/README.md > dataset_zoo/4_video_interpolation_datasets.md
+cat ../../../tools/dataset_converters/unconditional_gans/README.md > dataset_zoo/5_unconditional_gans_datasets.md
+cat ../../../tools/dataset_converters/image_translation/README.md > dataset_zoo/6_image_translation_datasets.md
# generate markdown TOC
-sed -i -e 's/](comp1k\(\/README.md)\)/](composition-1k\1/g' ../dataset_zoo/3_matting_datasets.md
+sed -i -e 's/](comp1k\(\/README.md)\)/](composition-1k\1/g' dataset_zoo/3_matting_datasets.md
-sed -i -e 's/](\(.*\)\/README.md)/](#\1-dataset)/g' ../dataset_zoo/1_super_resolution_datasets.md
-sed -i -e 's/](\(.*\)\/README.md)/](#\1-dataset)/g' ../dataset_zoo/2_inpainting_datasets.md
-sed -i -e 's/](\(.*\)\/README.md)/](#\1-dataset)/g' ../dataset_zoo/3_matting_datasets.md
-sed -i -e 's/](\(.*\)\/README.md)/](#\1-dataset)/g' ../dataset_zoo/4_video_interpolation_datasets.md
-sed -i -e 's/](\(.*\)\/README.md)/](#\1-dataset)/g' ../dataset_zoo/5_unconditional_gans_datasets.md
-sed -i -e 's/](\(.*\)\/README.md)/](#\1-dataset)/g' ../dataset_zoo/6_image_translation_datasets.md
+sed -i -e 's/](\(.*\)\/README.md)/](#\1-dataset)/g' dataset_zoo/1_super_resolution_datasets.md
+sed -i -e 's/](\(.*\)\/README.md)/](#\1-dataset)/g' dataset_zoo/2_inpainting_datasets.md
+sed -i -e 's/](\(.*\)\/README.md)/](#\1-dataset)/g' dataset_zoo/3_matting_datasets.md
+sed -i -e 's/](\(.*\)\/README.md)/](#\1-dataset)/g' dataset_zoo/4_video_interpolation_datasets.md
+sed -i -e 's/](\(.*\)\/README.md)/](#\1-dataset)/g' dataset_zoo/5_unconditional_gans_datasets.md
+sed -i -e 's/](\(.*\)\/README.md)/](#\1-dataset)/g' dataset_zoo/6_image_translation_datasets.md
# gather all datasets
-cat ../../../tools/dataset_converters/super-resolution/*/README.md | sed 's/# Preparing /\n# /g' | sed "s/#/#&/" >> ../dataset_zoo/1_super_resolution_datasets.md
-cat ../../../tools/dataset_converters/inpainting/*/README.md | sed 's/# Preparing /\n# /g' | sed "s/#/#&/" >> ../dataset_zoo/2_inpainting_datasets.md
-cat ../../../tools/dataset_converters/matting/*/README.md | sed 's/# Preparing /\n# /g' | sed "s/#/#&/" >> ../dataset_zoo/3_matting_datasets.md
-cat ../../../tools/dataset_converters/video-interpolation/*/README.md | sed 's/# Preparing /\n# /g' | sed "s/#/#&/" >> ../dataset_zoo/4_video_interpolation_datasets.md
-cat ../../../tools/dataset_converters/unconditional_gans/*/README.md | sed 's/# Preparing /\n# /g' | sed "s/#/#&/" >> ../dataset_zoo/5_unconditional_gans_datasets.md
-cat ../../../tools/dataset_converters/image_translation/*/README.md | sed 's/# Preparing /\n# /g' | sed "s/#/#&/" >> ../dataset_zoo/6_image_translation_datasets.md
+cat ../../../tools/dataset_converters/super-resolution/*/README.md | sed 's/# Preparing /\n# /g' | sed "s/#/#&/" >> dataset_zoo/1_super_resolution_datasets.md
+cat ../../../tools/dataset_converters/inpainting/*/README.md | sed 's/# Preparing /\n# /g' | sed "s/#/#&/" >> dataset_zoo/2_inpainting_datasets.md
+cat ../../../tools/dataset_converters/matting/*/README.md | sed 's/# Preparing /\n# /g' | sed "s/#/#&/" >> dataset_zoo/3_matting_datasets.md
+cat ../../../tools/dataset_converters/video-interpolation/*/README.md | sed 's/# Preparing /\n# /g' | sed "s/#/#&/" >> dataset_zoo/4_video_interpolation_datasets.md
+cat ../../../tools/dataset_converters/unconditional_gans/*/README.md | sed 's/# Preparing /\n# /g' | sed "s/#/#&/" >> dataset_zoo/5_unconditional_gans_datasets.md
+cat ../../../tools/dataset_converters/image_translation/*/README.md | sed 's/# Preparing /\n# /g' | sed "s/#/#&/" >> dataset_zoo/6_image_translation_datasets.md
-echo '# Overview' > ../dataset_zoo/0_overview.md
-echo '\n- [Prepare Super-Resolution Datasets](./1_super_resolution_datasets.md)' >> ../dataset_zoo/0_overview.md
-cat ../dataset_zoo/1_super_resolution_datasets.md | grep -oP '(- \[.*-dataset.*)' | sed 's/- \[/ - \[/g' | sed 's/(#/(.\/1_super_resolution_datasets.md#/g' >> ../dataset_zoo/0_overview.md
-echo '\n- [Prepare Inpainting Datasets](./2_inpainting_datasets.md)' >> ../dataset_zoo/0_overview.md
-cat ../dataset_zoo/2_inpainting_datasets.md | grep -oP '(- \[.*-dataset.*)' | sed 's/- \[/ - \[/g' | sed 's/(#/(.\/2_inpainting_datasets.md#/g' >> ../dataset_zoo/0_overview.md
-echo '\n- [Prepare Matting Datasets](./3_matting_datasets.md)\n' >> ../dataset_zoo/0_overview.md
-cat ../dataset_zoo/3_matting_datasets.md | grep -oP '(- \[.*-dataset.*)' | sed 's/- \[/ - \[/g' | sed 's/(#/(.\/3_matting_datasets.md#/g' >> ../dataset_zoo/0_overview.md
-echo '\n- [Prepare Video Frame Interpolation Datasets](./4_video_interpolation_datasets.md)' >> ../dataset_zoo/0_overview.md
-cat ../dataset_zoo/4_video_interpolation_datasets.md | grep -oP '(- \[.*-dataset.*)' | sed 's/- \[/ - \[/g' | sed 's/(#/(.\/4_video_interpolation_datasets.md#/g' >> ../dataset_zoo/0_overview.md
-echo '\n- [Prepare Unconditional GANs Datasets](./5_unconditional_gans_datasets.md)' >> ../dataset_zoo/0_overview.md
-cat ../dataset_zoo/5_unconditional_gans_datasets.md | grep -oP '(- \[.*-dataset.*)' | sed 's/- \[/ - \[/g' | sed 's/(#/(.\/5_unconditional_gans_datasets.md#/g' >> ../dataset_zoo/0_overview.md
-echo '\n- [Prepare Image Translation Datasets](./6_image_translation_datasets.md)' >> ../dataset_zoo/0_overview.md
-cat ../dataset_zoo/6_image_translation_datasets.md | grep -oP '(- \[.*-dataset.*)' | sed '$a\n' |sed 's/- \[/ - \[/g' | sed 's/(#/(.\/6_image_translation_datasets.md#/g' >> ../dataset_zoo/0_overview.md
+echo '# Overview' > dataset_zoo/0_overview.md
+echo '\n- [Prepare Super-Resolution Datasets](./1_super_resolution_datasets.md)' >> dataset_zoo/0_overview.md
+cat dataset_zoo/1_super_resolution_datasets.md | grep -oP '(- \[.*-dataset.*)' | sed 's/- \[/ - \[/g' | sed 's/(#/(.\/1_super_resolution_datasets.md#/g' >> dataset_zoo/0_overview.md
+echo '\n- [Prepare Inpainting Datasets](./2_inpainting_datasets.md)' >> dataset_zoo/0_overview.md
+cat dataset_zoo/2_inpainting_datasets.md | grep -oP '(- \[.*-dataset.*)' | sed 's/- \[/ - \[/g' | sed 's/(#/(.\/2_inpainting_datasets.md#/g' >> dataset_zoo/0_overview.md
+echo '\n- [Prepare Matting Datasets](./3_matting_datasets.md)\n' >> dataset_zoo/0_overview.md
+cat dataset_zoo/3_matting_datasets.md | grep -oP '(- \[.*-dataset.*)' | sed 's/- \[/ - \[/g' | sed 's/(#/(.\/3_matting_datasets.md#/g' >> dataset_zoo/0_overview.md
+echo '\n- [Prepare Video Frame Interpolation Datasets](./4_video_interpolation_datasets.md)' >> dataset_zoo/0_overview.md
+cat dataset_zoo/4_video_interpolation_datasets.md | grep -oP '(- \[.*-dataset.*)' | sed 's/- \[/ - \[/g' | sed 's/(#/(.\/4_video_interpolation_datasets.md#/g' >> dataset_zoo/0_overview.md
+echo '\n- [Prepare Unconditional GANs Datasets](./5_unconditional_gans_datasets.md)' >> dataset_zoo/0_overview.md
+cat dataset_zoo/5_unconditional_gans_datasets.md | grep -oP '(- \[.*-dataset.*)' | sed 's/- \[/ - \[/g' | sed 's/(#/(.\/5_unconditional_gans_datasets.md#/g' >> dataset_zoo/0_overview.md
+echo '\n- [Prepare Image Translation Datasets](./6_image_translation_datasets.md)' >> dataset_zoo/0_overview.md
+cat dataset_zoo/6_image_translation_datasets.md | grep -oP '(- \[.*-dataset.*)' | sed '$a\n' |sed 's/- \[/ - \[/g' | sed 's/(#/(.\/6_image_translation_datasets.md#/g' >> dataset_zoo/0_overview.md
diff --git a/docs/en/3_model_zoo.md b/docs/en/3_model_zoo.md
index eba6085eaa..af3547569a 100644
--- a/docs/en/3_model_zoo.md
+++ b/docs/en/3_model_zoo.md
@@ -1,19 +1,20 @@
# Overview
-- Number of checkpoints: 168
-- Number of configs: 168
-- Number of papers: 41
- - ALGORITHM: 42
+- Number of checkpoints: 169
+- Number of configs: 169
+- Number of papers: 42
+ - ALGORITHM: 43
- Tasks:
- - video super-resolution
- image2image translation
- video interpolation
+ - unconditional gans
+ - image super-resolution
+ - internal learning
- conditional gans
- inpainting
- - image super-resolution
+ - video super-resolution
+ - colorization
- matting
- - unconditional gans
- - internal learning
For supported datasets, see [datasets overview](dataset_zoo/0_overview.md).
@@ -185,6 +186,14 @@ For supported datasets, see [datasets overview](dataset_zoo/0_overview.md).
- Number of papers: 1
- \[ALGORITHM\] Indices Matter: Learning to Index for Deep Image Matting ([⇨](https://github.com/open-mmlab/mmediting/blob/1.x/configs/indexnet/README.md#citation))
+## Instance-aware Image Colorization (CVPR'2020)
+
+- Tasks: colorization
+- Number of checkpoints: 1
+- Number of configs: 1
+- Number of papers: 1
+ - \[ALGORITHM\] Instance-Aware Image Colorization ([⇨](https://github.com/open-mmlab/mmediting/blob/1.x/configs/inst_colorization/README.md#quick-start))
+
## LIIF (CVPR'2021)
- Tasks: image super-resolution
diff --git a/docs/en/advanced_guides/3_transforms.md b/docs/en/advanced_guides/3_transforms.md
index 4b1fcf9617..565a3499ab 100644
--- a/docs/en/advanced_guides/3_transforms.md
+++ b/docs/en/advanced_guides/3_transforms.md
@@ -45,10 +45,10 @@ The input and output types of transformations are both dict.
dict_keys(['pair_path', 'pair', 'pair_ori_shape', 'img_mask', 'img_photo', 'img_mask_path', 'img_photo_path', 'img_mask_ori_shape', 'img_photo_ori_shape'])
```
-Generally, the last step of the transforms pipeline must be `PackGenInputs`.
-`PackGenInputs` will pack the processed data into a dict containing two fields: `inputs` and `data_samples`.
+Generally, the last step of the transforms pipeline must be `PackEditInputs`.
+`PackEditInputs` will pack the processed data into a dict containing two fields: `inputs` and `data_samples`.
`inputs` is the variable you want to use as the model's input, which can be the type of `torch.Tensor`, dict of `torch.Tensor`, or any type you want.
-`data_samples` is a list of `GenDataSample`. Each `GenDataSample` contains groundtruth and necessary information for corresponding input.
+`data_samples` is a list of `EditDataSample`. Each `EditDataSample` contains groundtruth and necessary information for corresponding input.
### An example of BasicVSR
@@ -121,15 +121,8 @@ pipeline = [
keys=[f'img_{domain_a}', f'img_{domain_b}'],
direction='horizontal'),
dict(
- type='PackGenInputs',
- keys=[f'img_{domain_a}', f'img_{domain_b}', 'pair'],
- meta_keys=[
- 'pair_path', 'sample_idx', 'pair_ori_shape',
- f'img_{domain_a}_path', f'img_{domain_b}_path',
- f'img_{domain_a}_ori_shape', f'img_{domain_b}_ori_shape', 'flip',
- 'flip_direction'
- ])
-]
+ type='PackEditInputs',
+ keys=[f'img_{domain_a}', f'img_{domain_b}', 'pair'])
```
## Supported transforms in MMEditing
diff --git a/docs/en/conf.py b/docs/en/conf.py
index 5878769da6..05455c4dcc 100644
--- a/docs/en/conf.py
+++ b/docs/en/conf.py
@@ -77,6 +77,24 @@
'name': 'GitHub',
'url': 'https://github.com/open-mmlab/mmediting',
},
+ {
+ 'name':
+ 'Version',
+ 'children': [
+ {
+ 'name': 'MMEditing 0.x',
+ 'url': 'https://mmediting.readthedocs.io/en/latest/',
+ 'description': 'Main branch'
+ },
+ {
+ 'name': 'MMEditing 1.x',
+ 'url': 'https://mmediting.readthedocs.io/en/1.x/',
+ 'description': '1.x branch'
+ },
+ ],
+ 'active':
+ True,
+ },
],
'menu_lang':
'en'
@@ -98,7 +116,7 @@
def builder_inited_handler(app):
- subprocess.run(['python', './.dev_scripts/update_datasest_zoo.py'])
+ subprocess.run(['bash', './.dev_scripts/update_dataset_zoo.sh'])
subprocess.run(['python', './.dev_scripts/update_model_zoo.py'])
diff --git a/docs/en/dataset_zoo/0_overview.md b/docs/en/dataset_zoo/0_overview.md
index db700e20bd..f5c6f41be7 100644
--- a/docs/en/dataset_zoo/0_overview.md
+++ b/docs/en/dataset_zoo/0_overview.md
@@ -1,27 +1,8 @@
# Overview
-- [Prepare Super-Resolution Datasets](./1_super_resolution_datasets.md)
-
- - [DF2K_OST](./1_super_resolution_datasets.md#df2k_ost-dataset) \[ [Homepage](https://github.com/xinntao/Real-ESRGAN/blob/master/docs/Training.md) \]
- - [DIV2K](./1_super_resolution_datasets.md#div2k-dataset) \[ [Homepage](https://data.vision.ee.ethz.ch/cvl/DIV2K/) \]
- - [REDS](./1_super_resolution_datasets.md#reds-dataset) \[ [Homepage](https://seungjunnah.github.io/Datasets/reds.html) \]
- - [Vid4](./1_super_resolution_datasets.md#vid4-dataset) \[ [Homepage](https://drive.google.com/file/d/1ZuvNNLgR85TV_whJoHM7uVb-XW1y70DW/view) \]
- - [Vimeo90K](./1_super_resolution_datasets.md#vimeo90k-dataset) \[ [Homepage](http://toflow.csail.mit.edu) \]
-
-- [Prepare Inpainting Datasets](./2_inpainting_datasets.md)
-
- - [CelebA-HQ](./2_inpainting_datasets.md#celeba-hq-dataset) \[ [Homepage](https://github.com/tkarras/progressive_growing_of_gans#preparing-datasets-for-training) \]
- - [Paris Street View](./2_inpainting_datasets.md#paris-street-view-dataset) \[ [Homepage](https://github.com/pathak22/context-encoder/issues/24) \]
- - [Places365](./2_inpainting_datasets.md#places365-dataset) \[ [Homepage](http://places2.csail.mit.edu/) \]
-
-- [Prepare Matting Datasets](./3_matting_datasets.md)
-
- - [Composition-1k](./3_matting_datasets.md#composition-1k-dataset) \[ [Homepage](https://sites.google.com/view/deepimagematting) \]
-
-- [Prepare Video Frame Interpolation Datasets](./4_video_interpolation_datasets.md)
-
- - [Vimeo90K-triplet](./4_video_interpolation_datasets.md#vimeo90k-triplet-dataset) \[ [Homepage](http://toflow.csail.mit.edu) \]
-
-- [Prepare Unconditional GANs Datasets](./5_unconditional_gans_datasets.md)
-
-- [Prepare Image Translation Datasets](./6_image_translation_datasets.md)
+\\n- [Prepare Super-Resolution Datasets](./1_super_resolution_datasets.md)
+\\n- [Prepare Inpainting Datasets](./2_inpainting_datasets.md)
+\\n- [Prepare Matting Datasets](./3_matting_datasets.md)\\n
+\\n- [Prepare Video Frame Interpolation Datasets](./4_video_interpolation_datasets.md)
+\\n- [Prepare Unconditional GANs Datasets](./5_unconditional_gans_datasets.md)
+\\n- [Prepare Image Translation Datasets](./6_image_translation_datasets.md)
diff --git a/docs/en/dataset_zoo/1_super_resolution_datasets.md b/docs/en/dataset_zoo/1_super_resolution_datasets.md
index e6afa8ff6d..e69de29bb2 100644
--- a/docs/en/dataset_zoo/1_super_resolution_datasets.md
+++ b/docs/en/dataset_zoo/1_super_resolution_datasets.md
@@ -1,356 +0,0 @@
-# Super-Resolution Datasets
-
-It is recommended to symlink the dataset root to `$MMEDITING/data`. If your folder structure is different, you may need to change the corresponding paths in config files.
-
-MMEditing supported super-resolution datasets:
-
-- Image Super-Resolution
- - [DF2K_OST](#df2k_ost-dataset) \[ [Homepage](https://github.com/xinntao/Real-ESRGAN/blob/master/docs/Training.md) \]
- - [DIV2K](#div2k-dataset) \[ [Homepage](https://data.vision.ee.ethz.ch/cvl/DIV2K/) \]
-- Video Super-Resolution
- - [REDS](#reds-dataset) \[ [Homepage](https://seungjunnah.github.io/Datasets/reds.html) \]
- - [Vid4](#vid4-dataset) \[ [Homepage](https://drive.google.com/file/d/1ZuvNNLgR85TV_whJoHM7uVb-XW1y70DW/view) \]
- - [Vimeo90K](#vimeo90k-dataset) \[ [Homepage](http://toflow.csail.mit.edu) \]
-
-## DF2K_OST Dataset
-
-
-
-```bibtex
-@inproceedings{wang2021real,
- title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data},
- author={Wang, Xintao and Xie, Liangbin and Dong, Chao and Shan, Ying},
- booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
- pages={1905--1914},
- year={2021}
-}
-```
-
-- The DIV2K dataset can be downloaded from [here](https://data.vision.ee.ethz.ch/cvl/DIV2K/) (We use the training set only).
-- The Flickr2K dataset can be downloaded [here](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) (We use the training set only).
-- The OST dataset can be downloaded [here](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/datasets/OST_dataset.zip) (We use the training set only).
-
-Please first put all the images into the `GT` folder (naming does not need to be in order):
-
-```text
-mmediting
-├── mmedit
-├── tools
-├── configs
-├── data
-│ ├── df2k_ost
-│ │ ├── GT
-│ │ │ ├── 0001.png
-│ │ │ ├── 0002.png
-│ │ │ ├── ...
-...
-```
-
-### Crop sub-images
-
-For faster IO, we recommend to crop the images to sub-images. We provide such a script:
-
-```shell
-python tools/dataset_converters/super-resolution/df2k_ost/preprocess_df2k_ost_dataset.py --data-root ./data/df2k_ost
-```
-
-The generated data is stored under `df2k_ost` and the data structure is as follows, where `_sub` indicates the sub-images.
-
-```text
-mmediting
-├── mmedit
-├── tools
-├── configs
-├── data
-│ ├── df2k_ost
-│ │ ├── GT
-│ │ ├── GT_sub
-...
-```
-
-### Prepare LMDB dataset for DF2K_OST
-
-If you want to use LMDB datasets for faster IO speed, you can make LMDB files by:
-
-```shell
-python tools/dataset_converters/super-resolution/df2k_ost/preprocess_df2k_ost_dataset.py --data-root ./data/df2k_ost --make-lmdb
-```
-
-## DIV2K Dataset
-
-
-
-```bibtex
-@InProceedings{Agustsson_2017_CVPR_Workshops,
- author = {Agustsson, Eirikur and Timofte, Radu},
- title = {NTIRE 2017 Challenge on Single Image Super-Resolution: Dataset and Study},
- booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
- month = {July},
- year = {2017}
-}
-```
-
-- Training dataset: [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/).
-- Validation dataset: Set5 and Set14.
-
-```text
-mmediting
-├── mmedit
-├── tools
-├── configs
-├── data
-│ ├── DIV2K
-│ │ ├── DIV2K_train_HR
-│ │ ├── DIV2K_train_LR_bicubic
-│ │ │ ├── X2
-│ │ │ ├── X3
-│ │ │ ├── X4
-│ │ ├── DIV2K_valid_HR
-│ │ ├── DIV2K_valid_LR_bicubic
-│ │ │ ├── X2
-│ │ │ ├── X3
-│ │ │ ├── X4
-│ ├── Set5
-│ │ ├── GTmod12
-│ │ ├── LRbicx2
-│ │ ├── LRbicx3
-│ │ ├── LRbicx4
-│ ├── Set14
-│ │ ├── GTmod12
-│ │ ├── LRbicx2
-│ │ ├── LRbicx3
-│ │ ├── LRbicx4
-```
-
-### Crop sub-images
-
-For faster IO, we recommend to crop the DIV2K images to sub-images. We provide such a script:
-
-```shell
-python tools/dataset_converters/super-resolution/div2k/preprocess_div2k_dataset.py --data-root ./data/DIV2K
-```
-
-The generated data is stored under `DIV2K` and the data structure is as follows, where `_sub` indicates the sub-images.
-
-```text
-mmediting
-├── mmedit
-├── tools
-├── configs
-├── data
-│ ├── DIV2K
-│ │ ├── DIV2K_train_HR
-│ │ ├── DIV2K_train_HR_sub
-│ │ ├── DIV2K_train_LR_bicubic
-│ │ │ ├── X2
-│ │ │ ├── X3
-│ │ │ ├── X4
-│ │ │ ├── X2_sub
-│ │ │ ├── X3_sub
-│ │ │ ├── X4_sub
-│ │ ├── DIV2K_valid_HR
-│ │ ├── ...
-...
-```
-
-### Prepare annotation list
-
-If you use the annotation mode for the dataset, you first need to prepare a specific `txt` file.
-
-Each line in the annotation file contains the image names and image shape (usually for the ground-truth images), separated by a white space.
-
-Example of an annotation file:
-
-```text
-0001_s001.png (480,480,3)
-0001_s002.png (480,480,3)
-```
-
-### Prepare LMDB dataset for DIV2K
-
-If you want to use LMDB datasets for faster IO speed, you can make LMDB files by:
-
-```shell
-python tools/dataset_converters/super-resolution/div2k/preprocess_div2k_dataset.py --data-root ./data/DIV2K --make-lmdb
-```
-
-## REDS Dataset
-
-
-
-```bibtex
-@InProceedings{Nah_2019_CVPR_Workshops_REDS,
- author = {Nah, Seungjun and Baik, Sungyong and Hong, Seokil and Moon, Gyeongsik and Son, Sanghyun and Timofte, Radu and Lee, Kyoung Mu},
- title = {NTIRE 2019 Challenge on Video Deblurring and Super-Resolution: Dataset and Study},
- booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
- month = {June},
- year = {2019}
-}
-```
-
-- Training dataset: [REDS dataset](https://seungjunnah.github.io/Datasets/reds.html).
-- Validation dataset: [REDS dataset](https://seungjunnah.github.io/Datasets/reds.html) and Vid4.
-
-Note that we merge train and val datasets in REDS for easy switching between REDS4 partition (used in EDVR) and the official validation partition.
-The original val dataset (clip names from 000 to 029) are modified to avoid conflicts with training dataset (total 240 clips). Specifically, the clip names are changed to 240, 241, ... 269.
-
-You can prepare the REDS dataset by running:
-
-```shell
-python tools/dataset_converters/super-resolution/reds/preprocess_reds_dataset.py --root-path ./data/REDS
-```
-
-```text
-mmediting
-├── mmedit
-├── tools
-├── configs
-├── data
-│ ├── REDS
-│ │ ├── train_sharp
-│ │ │ ├── 000
-│ │ │ ├── 001
-│ │ │ ├── ...
-│ │ ├── train_sharp_bicubic
-│ │ │ ├── 000
-│ │ │ ├── 001
-│ │ │ ├── ...
-│ ├── REDS4
-│ │ ├── GT
-│ │ ├── sharp_bicubic
-```
-
-### Prepare LMDB dataset for REDS
-
-If you want to use LMDB datasets for faster IO speed, you can make LMDB files by:
-
-```shell
-python tools/dataset_converters/super-resolution/reds/preprocess_reds_dataset.py --root-path ./data/REDS --make-lmdb
-```
-
-### Crop to sub-images
-
-MMEditing also support cropping REDS images to sub-images for faster IO. We provide such a script:
-
-```shell
-python tools/dataset_converters/super-resolution/reds/crop_sub_images.py --data-root ./data/REDS -scales 4
-```
-
-The generated data is stored under `REDS` and the data structure is as follows, where `_sub` indicates the sub-images.
-
-```text
-mmediting
-├── mmedit
-├── tools
-├── configs
-├── data
-│ ├── REDS
-│ │ ├── train_sharp
-│ │ │ ├── 000
-│ │ │ ├── 001
-│ │ │ ├── ...
-│ │ ├── train_sharp_sub
-│ │ │ ├── 000_s001
-│ │ │ ├── 000_s002
-│ │ │ ├── ...
-│ │ │ ├── 001_s001
-│ │ │ ├── ...
-│ │ ├── train_sharp_bicubic
-│ │ │ ├── X4
-│ │ │ │ ├── 000
-│ │ │ │ ├── 001
-│ │ │ │ ├── ...
-│ │ │ ├── X4_sub
-│ │ │ ├── 000_s001
-│ │ │ ├── 000_s002
-│ │ │ ├── ...
-│ │ │ ├── 001_s001
-│ │ │ ├── ...
-```
-
-Note that by default `preprocess_reds_dataset.py` does not make lmdb and annotation file for the cropped dataset. You may need to modify the scripts a little bit for such operations.
-
-## Vid4 Dataset
-
-
-
-```bibtex
-@article{xue2019video,
- title={On Bayesian adaptive video super resolution},
- author={Liu, Ce and Sun, Deqing},
- journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
- volume={36},
- number={2},
- pages={346--360},
- year={2013},
- publisher={IEEE}
-}
-```
-
-The Vid4 dataset can be downloaded from [here](https://drive.google.com/file/d/1ZuvNNLgR85TV_whJoHM7uVb-XW1y70DW/view?usp=sharing). There are two degradations in the dataset.
-
-1. BIx4 contains images downsampled by bicubic interpolation
-2. BDx4 contains images blurred by Gaussian kernel with σ=1.6, followed by a subsampling every four pixels.
-
-## Vimeo90K Dataset
-
-
-
-```bibtex
-@article{xue2019video,
- title={Video Enhancement with Task-Oriented Flow},
- author={Xue, Tianfan and Chen, Baian and Wu, Jiajun and Wei, Donglai and Freeman, William T},
- journal={International Journal of Computer Vision (IJCV)},
- volume={127},
- number={8},
- pages={1106--1125},
- year={2019},
- publisher={Springer}
-}
-```
-
-The training and test datasets can be download from [here](http://toflow.csail.mit.edu/).
-
-The Vimeo90K dataset has a `clip/sequence/img` folder structure:
-
-```text
-mmediting
-├── mmedit
-├── tools
-├── configs
-├── data
-│ ├── vimeo_triplet
-│ │ ├── BDx4
-│ │ │ ├── 00001
-│ │ │ │ ├── 0001
-│ │ │ │ │ ├── im1.png
-│ │ │ │ │ ├── im2.png
-│ │ │ │ │ ├── ...
-│ │ │ │ ├── 0002
-│ │ │ │ ├── 0003
-│ │ │ │ ├── ...
-│ │ │ ├── 00002
-│ │ │ ├── ...
-│ │ ├── BIx4
-│ │ ├── GT
-│ │ ├── meta_info_Vimeo90K_test_GT.txt
-│ │ ├── meta_info_Vimeo90K_train_GT.txt
-```
-
-### Prepare the annotation files for Vimeo90K dataset
-
-To prepare the annotation file for training, you need to download the official training list path for Vimeo90K from the official website, and run the following command:
-
-```shell
-python tools/dataset_converters/super-resolution/vimeo90k/preprocess_vimeo90k_dataset.py ./data/Vimeo90K/official_train_list.txt
-```
-
-The annotation file for test is generated similarly.
-
-### Prepare LMDB dataset for Vimeo90K
-
-If you want to use LMDB datasets for faster IO speed, you can make LMDB files by:
-
-```shell
-python tools/dataset_converters/super-resolution/vimeo90k/preprocess_vimeo90k_dataset.py ./data/Vimeo90K/official_train_list.txt --gt-path ./data/Vimeo90K/GT --lq-path ./data/Vimeo90K/LQ --make-lmdb
-```
diff --git a/docs/en/dataset_zoo/2_inpainting_datasets.md b/docs/en/dataset_zoo/2_inpainting_datasets.md
index f5dac5abac..e69de29bb2 100644
--- a/docs/en/dataset_zoo/2_inpainting_datasets.md
+++ b/docs/en/dataset_zoo/2_inpainting_datasets.md
@@ -1,114 +0,0 @@
-# Inpainting Datasets
-
-It is recommended to symlink the dataset root to `$MMEDITING/data`. If your folder structure is different, you may need to change the corresponding paths in config files.
-
-MMEditing supported inpainting datasets:
-
-- [CelebA-HQ](#celeba-hq-dataset) \[ [Homepage](https://github.com/tkarras/progressive_growing_of_gans#preparing-datasets-for-training) \]
-- [Paris Street View](#paris-street-view-dataset) \[ [Homepage](https://github.com/pathak22/context-encoder/issues/24) \]
-- [Places365](#places365-dataset) \[ [Homepage](http://places2.csail.mit.edu/) \]
-
-As we only need images for inpainting task, further preparation is not necessary and the folder structure can be different from the example. You can utilize the information provided by the original dataset like `Place365` (e.g. `meta`). Also, you can easily scan the data set and list all of the images to a specific `txt` file. Here is an example for the `Places365_val.txt` from Places365 and we will only use the image name information in inpainting.
-
-```
-Places365_val_00000001.jpg 165
-Places365_val_00000002.jpg 358
-Places365_val_00000003.jpg 93
-Places365_val_00000004.jpg 164
-Places365_val_00000005.jpg 289
-Places365_val_00000006.jpg 106
-Places365_val_00000007.jpg 81
-Places365_val_00000008.jpg 121
-Places365_val_00000009.jpg 150
-Places365_val_00000010.jpg 302
-Places365_val_00000011.jpg 42
-```
-
-## CelebA-HQ Dataset
-
-
-
-```bibtex
-@article{karras2017progressive,
- title={Progressive growing of gans for improved quality, stability, and variation},
- author={Karras, Tero and Aila, Timo and Laine, Samuli and Lehtinen, Jaakko},
- journal={arXiv preprint arXiv:1710.10196},
- year={2017}
-}
-```
-
-Follow the instructions [here](https://github.com/tkarras/progressive_growing_of_gans##preparing-datasets-for-training) to prepare the dataset.
-
-```text
-mmediting
-├── mmedit
-├── tools
-├── configs
-├── data
-│ ├── CelebA-HQ
-│ │ ├── train_256
-| | ├── test_256
-| | ├── train_celeba_img_list.txt
-| | ├── val_celeba_img_list.txt
-
-```
-
-## Paris Street View Dataset
-
-
-
-```bibtex
-@inproceedings{pathak2016context,
- title={Context encoders: Feature learning by inpainting},
- author={Pathak, Deepak and Krahenbuhl, Philipp and Donahue, Jeff and Darrell, Trevor and Efros, Alexei A},
- booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
- pages={2536--2544},
- year={2016}
-}
-```
-
-Obtain the dataset [here](https://github.com/pathak22/context-encoder/issues/24).
-
-```text
-mmediting
-├── mmedit
-├── tools
-├── configs
-├── data
-│ ├── paris_street_view
-│ │ ├── train
-| | ├── val
-
-```
-
-## Places365 Dataset
-
-
-
-```bibtex
- @article{zhou2017places,
- title={Places: A 10 million Image Database for Scene Recognition},
- author={Zhou, Bolei and Lapedriza, Agata and Khosla, Aditya and Oliva, Aude and Torralba, Antonio},
- journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
- year={2017},
- publisher={IEEE}
- }
-
-```
-
-Prepare the data from [Places365](http://places2.csail.mit.edu/download.html).
-
-```text
-mmediting
-├── mmedit
-├── tools
-├── configs
-├── data
-│ ├── Places
-│ │ ├── data_large
-│ │ ├── val_large
-| | ├── meta
-| | | ├── places365_train_challenge.txt
-| | | ├── places365_val.txt
-
-```
diff --git a/docs/en/dataset_zoo/3_matting_datasets.md b/docs/en/dataset_zoo/3_matting_datasets.md
index cc9c95f924..e69de29bb2 100644
--- a/docs/en/dataset_zoo/3_matting_datasets.md
+++ b/docs/en/dataset_zoo/3_matting_datasets.md
@@ -1,147 +0,0 @@
-# Matting Datasets
-
-It is recommended to symlink the dataset root to `$MMEDITING/data`. If your folder structure is different, you may need to change the corresponding paths in config files.
-
-MMEditing supported matting datasets:
-
-- [Composition-1k](#composition-1k-dataset) \[ [Homepage](https://sites.google.com/view/deepimagematting) \]
-
-## Composition-1k Dataset
-
-### Introduction
-
-
-
-```bibtex
-@inproceedings{xu2017deep,
- title={Deep image matting},
- author={Xu, Ning and Price, Brian and Cohen, Scott and Huang, Thomas},
- booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
- pages={2970--2979},
- year={2017}
-}
-```
-
-The Adobe Composition-1k dataset consists of foreground images and their corresponding alpha images.
-To get the full dataset, one need to composite the foregrounds with selected backgrounds from the COCO dataset and the Pascal VOC dataset.
-
-### Obtain and Extract
-
-Please follow the instructions of [paper authors](https://sites.google.com/view/deepimagematting) to obtain the Composition-1k (comp1k) dataset.
-
-### Composite the full dataset
-
-The Adobe composition-1k dataset contains only `alpha` and `fg` (and `trimap` in test set).
-It is needed to merge `fg` with COCO data (training) or VOC data (test) before training or evaluation.
-Use the following script to perform image composition and generate annotation files for training or testing:
-
-```shell
-## The script is run under the root folder of MMEditing
-python tools/dataset_converters/matting/comp1k/preprocess_comp1k_dataset.py data/adobe_composition-1k data/coco data/VOCdevkit --composite
-```
-
-The generated data is stored under `adobe_composition-1k/Training_set` and `adobe_composition-1k/Test_set` respectively.
-If you only want to composite test data (since compositing training data is time-consuming), you can skip compositing the training set by removing the `--composite` option:
-
-```shell
-## skip compositing training set
-python tools/dataset_converters/matting/comp1k/preprocess_comp1k_dataset.py data/adobe_composition-1k data/coco data/VOCdevkit
-```
-
-If you only want to preprocess test data, i.e. for FBA, you can skip the train set by adding the `--skip-train` option:
-
-```shell
-## skip preprocessing training set
-python tools/data/matting/comp1k/preprocess_comp1k_dataset.py data/adobe_composition-1k data/coco data/VOCdevkit --skip-train
-```
-
-> Currently, `GCA` and `FBA` support online composition of training data. But you can modify the data pipeline of other models to perform online composition instead of loading composited images (we called it `merged` in our data pipeline).
-
-### Check Directory Structure for DIM
-
-The result folder structure should look like:
-
-```text
-mmediting
-├── mmedit
-├── tools
-├── configs
-├── data
-│ ├── adobe_composition-1k
-│ │ ├── Test_set
-│ │ │ ├── Adobe-licensed images
-│ │ │ │ ├── alpha
-│ │ │ │ ├── fg
-│ │ │ │ ├── trimaps
-│ │ │ ├── merged (generated by tools/dataset_converters/matting/comp1k/preprocess_comp1k_dataset.py)
-│ │ │ ├── bg (generated by tools/dataset_converters/matting/comp1k/preprocess_comp1k_dataset.py)
-│ │ ├── Training_set
-│ │ │ ├── Adobe-licensed images
-│ │ │ │ ├── alpha
-│ │ │ │ ├── fg
-│ │ │ ├── Other
-│ │ │ │ ├── alpha
-│ │ │ │ ├── fg
-│ │ │ ├── merged (generated by tools/dataset_converters/matting/comp1k/preprocess_comp1k_dataset.py)
-│ │ │ ├── bg (generated by tools/dataset_converters/matting/comp1k/preprocess_comp1k_dataset.py)
-│ │ ├── test_list.json (generated by tools/dataset_converters/matting/comp1k/preprocess_comp1k_dataset.py)
-│ │ ├── training_list.json (generated by tools/dataset_converters/matting/comp1k/preprocess_comp1k_dataset.py)
-│ ├── coco
-│ │ ├── train2014 (or train2017)
-│ ├── VOCdevkit
-│ │ ├── VOC2012
-```
-
-### Prepare the dataset for FBA
-
-FBA adopts dynamic dataset augmentation proposed in [Learning-base Sampling for Natural Image Matting](https://openaccess.thecvf.com/content_CVPR_2019/papers/Tang_Learning-Based_Sampling_for_Natural_Image_Matting_CVPR_2019_paper.pdf).
-In addition, to reduce artifacts during augmentation, it uses the extended version of foreground as foreground.
-We provide scripts to estimate foregrounds.
-
-Prepare the test set as follows:
-
-```shell
-## skip preprocessing training set, as it composites online during training
-python tools/dataset_converters/matting/comp1k/preprocess_comp1k_dataset.py data/adobe_composition-1k data/coco data/VOCdevkit --skip-train
-```
-
-Extend the foreground of training set as follows:
-
-```shell
-python tools/dataset_converters/matting/comp1k/extend_fg.py data/adobe_composition-1k
-```
-
-### Check Directory Structure for DIM
-
-The final folder structure should look like:
-
-```text
-mmediting
-├── mmedit
-├── tools
-├── configs
-├── data
-│ ├── adobe_composition-1k
-│ │ ├── Test_set
-│ │ │ ├── Adobe-licensed images
-│ │ │ │ ├── alpha
-│ │ │ │ ├── fg
-│ │ │ │ ├── trimaps
-│ │ │ ├── merged (generated by tools/data/matting/comp1k/preprocess_comp1k_dataset.py)
-│ │ │ ├── bg (generated by tools/data/matting/comp1k/preprocess_comp1k_dataset.py)
-│ │ ├── Training_set
-│ │ │ ├── Adobe-licensed images
-│ │ │ │ ├── alpha
-│ │ │ │ ├── fg
-│ │ │ │ ├── fg_extended (generated by tools/data/matting/comp1k/extend_fg.py)
-│ │ │ ├── Other
-│ │ │ │ ├── alpha
-│ │ │ │ ├── fg
-│ │ │ │ ├── fg_extended (generated by tools/data/matting/comp1k/extend_fg.py)
-│ │ ├── test_list.json (generated by tools/data/matting/comp1k/preprocess_comp1k_dataset.py)
-│ │ ├── training_list_fba.json (generated by tools/data/matting/comp1k/extend_fg.py)
-│ ├── coco
-│ │ ├── train2014 (or train2017)
-│ ├── VOCdevkit
-│ │ ├── VOC2012
-```
diff --git a/docs/en/dataset_zoo/4_video_interpolation_datasets.md b/docs/en/dataset_zoo/4_video_interpolation_datasets.md
index 9889939c91..e69de29bb2 100644
--- a/docs/en/dataset_zoo/4_video_interpolation_datasets.md
+++ b/docs/en/dataset_zoo/4_video_interpolation_datasets.md
@@ -1,50 +0,0 @@
-# Video Frame Interpolation Datasets
-
-It is recommended to symlink the dataset root to `$MMEDITING/data`. If your folder structure is different, you may need to change the corresponding paths in config files.
-
-MMEditing supported video frame interpolation datasets:
-
-- [Vimeo90K-triplet](#vimeo90k-triplet-dataset) \[ [Homepage](http://toflow.csail.mit.edu) \]
-
-## Vimeo90K-triplet Dataset
-
-
-
-```bibtex
-@article{xue2019video,
- title={Video Enhancement with Task-Oriented Flow},
- author={Xue, Tianfan and Chen, Baian and Wu, Jiajun and Wei, Donglai and Freeman, William T},
- journal={International Journal of Computer Vision (IJCV)},
- volume={127},
- number={8},
- pages={1106--1125},
- year={2019},
- publisher={Springer}
-}
-```
-
-The training and test datasets can be download from [here](http://toflow.csail.mit.edu/).
-
-The Vimeo90K-triplet dataset has a `clip/sequence/img` folder structure:
-
-```text
-mmediting
-├── mmedit
-├── tools
-├── configs
-├── data
-│ ├── vimeo_triplet
-│ │ ├── tri_testlist.txt
-│ │ ├── tri_trainlist.txt
-│ │ ├── sequences
-│ │ │ ├── 00001
-│ │ │ │ ├── 0001
-│ │ │ │ │ ├── im1.png
-│ │ │ │ │ ├── im2.png
-│ │ │ │ │ └── im3.png
-│ │ │ │ ├── 0002
-│ │ │ │ ├── 0003
-│ │ │ │ ├── ...
-│ │ │ ├── 00002
-│ │ │ ├── ...
-```
diff --git a/docs/en/dataset_zoo/5_unconditional_gans_datasets.md b/docs/en/dataset_zoo/5_unconditional_gans_datasets.md
index 70d5524f4f..e69de29bb2 100644
--- a/docs/en/dataset_zoo/5_unconditional_gans_datasets.md
+++ b/docs/en/dataset_zoo/5_unconditional_gans_datasets.md
@@ -1,91 +0,0 @@
-# Unconditional GANs Datasets
-
-**Data preparation for unconditional model** is simple. What you need to do is downloading the images and put them into a directory. Next, you should set a symlink in the `data` directory. For standard unconditional gans with static architectures, like DCGAN and StyleGAN2, `UnconditionalImageDataset` is designed to train such unconditional models. Here is an example config for FFHQ dataset:
-
-```python
-dataset_type = 'UnconditionalImageDataset'
-
-train_pipeline = [
- dict(type='LoadImageFromFile', key='img'),
- dict(type='Flip', keys=['img'], direction='horizontal'),
- dict(type='PackGenInputs', keys=['img'], meta_keys=['img_path'])
-]
-
-# `batch_size` and `data_root` need to be set.
-train_dataloader = dict(
- batch_size=4,
- num_workers=8,
- persistent_workers=True,
- sampler=dict(type='InfiniteSampler', shuffle=True),
- dataset=dict(
- type=dataset_type,
- data_root=None, # set by user
- pipeline=train_pipeline))
-```
-
-Here, we adopt `InfinitySampler` to avoid frequent dataloader reloading, which will accelerate the training procedure. As shown in the example, `pipeline` provides important data pipeline to process images, including loading from file system, resizing, cropping, transferring to `torch.Tensor` and packing to `GenDataSample`. All of supported data pipelines can be found in `mmedit/datasets/transforms`.
-
-For unconditional GANs with dynamic architectures like PGGAN and StyleGANv1, `GrowScaleImgDataset` is recommended to use for training. Since such dynamic architectures need real images in different scales, directly adopting `UnconditionalImageDataset` will bring heavy I/O cost for loading multiple high-resolution images. Here is an example we use for training PGGAN in CelebA-HQ dataset:
-
-```python
-dataset_type = 'GrowScaleImgDataset'
-
-pipeline = [
- dict(type='LoadImageFromFile', key='img'),
- dict(type='Flip', keys=['img'], direction='horizontal'),
- dict(type='PackGenInputs')
-]
-
-# `samples_per_gpu` and `imgs_root` need to be set.
-train_dataloader = dict(
- num_workers=4,
- batch_size=64,
- dataset=dict(
- type='GrowScaleImgDataset',
- data_roots={
- '1024': './data/ffhq/images',
- '256': './data/ffhq/ffhq_imgs/ffhq_256',
- '64': './data/ffhq/ffhq_imgs/ffhq_64'
- },
- gpu_samples_base=4,
- # note that this should be changed with total gpu number
- gpu_samples_per_scale={
- '4': 64,
- '8': 32,
- '16': 16,
- '32': 8,
- '64': 4,
- '128': 4,
- '256': 4,
- '512': 4,
- '1024': 4
- },
- len_per_stage=300000,
- pipeline=pipeline),
- sampler=dict(type='InfiniteSampler', shuffle=True))
-```
-
-In this dataset, you should provide a dictionary of image paths to the `data_roots`. Thus, you should resize the images in the dataset in advance.
-For the resizing methods in the data pre-processing, we adopt bilinear interpolation methods in all of the experiments studied in MMEditing.
-
-Note that this dataset should be used with `PGGANFetchDataHook`. In this config file, this hook should be added in the customized hooks, as shown below.
-
-```python
-custom_hooks = [
- dict(
- type='GenVisualizationHook',
- interval=5000,
- fixed_input=True,
- # vis ema and orig at the same time
- vis_kwargs_list=dict(
- type='Noise',
- name='fake_img',
- sample_model='ema/orig',
- target_keys=['ema', 'orig'])),
- dict(type='PGGANFetchDataHook')
-]
-```
-
-This fetching data hook helps the dataloader update the status of dataset to change the data source and batch size during training.
-
-Here, we provide several download links of datasets frequently used in unconditional models: [LSUN](http://dl.yf.io/lsun/), [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html), [CelebA-HQ](https://drive.google.com/drive/folders/11Vz0fqHS2rXDb5pprgTjpD7S2BAJhi1P), [FFHQ](https://drive.google.com/drive/folders/1u2xu7bSrWxrbUxk-dT-UvEJq8IjdmNTP).
diff --git a/docs/en/dataset_zoo/6_image_translation_datasets.md b/docs/en/dataset_zoo/6_image_translation_datasets.md
index 85cec47c16..e69de29bb2 100644
--- a/docs/en/dataset_zoo/6_image_translation_datasets.md
+++ b/docs/en/dataset_zoo/6_image_translation_datasets.md
@@ -1,149 +0,0 @@
-# Image Translation Datasets
-
-**Data preparation for translation model** needs a little attention. You should organize the files in the way we told you in `quick_run.md`. Fortunately, for most official datasets like facades and summer2winter_yosemite, they already have the right format. Also, you should set a symlink in the `data` directory. For paired-data trained translation model like Pix2Pix , `PairedImageDataset` is designed to train such translation models. Here is an example config for facades dataset:
-
-```python
-train_dataset_type = 'PairedImageDataset'
-val_dataset_type = 'PairedImageDataset'
-img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
-train_pipeline = [
- dict(
- type='LoadPairedImageFromFile',
- io_backend='disk',
- key='pair',
- domain_a=domain_a,
- domain_b=domain_b,
- flag='color'),
- dict(
- type='Resize',
- keys=[f'img_{domain_a}', f'img_{domain_b}'],
- scale=(286, 286),
- interpolation='bicubic')
-]
-test_pipeline = [
- dict(
- type='LoadPairedImageFromFile',
- io_backend='disk',
- key='image',
- domain_a=domain_a,
- domain_b=domain_b,
- flag='color'),
- dict(
- type='Resize',
- keys=[f'img_{domain_a}', f'img_{domain_b}'],
- scale=(256, 256),
- interpolation='bicubic')
-]
-dataroot = 'data/paired/facades'
-train_dataloader = dict(
- batch_size=1,
- num_workers=4,
- persistent_workers=True,
- sampler=dict(type='InfiniteSampler', shuffle=True),
- dataset=dict(
- type=dataset_type,
- data_root=dataroot, # set by user
- pipeline=train_pipeline))
-
-val_dataloader = dict(
- batch_size=1,
- num_workers=4,
- dataset=dict(
- type=dataset_type,
- data_root=dataroot, # set by user
- pipeline=test_pipeline),
- sampler=dict(type='DefaultSampler', shuffle=False),
- persistent_workers=True)
-
-test_dataloader = dict(
- batch_size=1,
- num_workers=4,
- dataset=dict(
- type=dataset_type,
- data_root=dataroot, # set by user
- pipeline=test_pipeline),
- sampler=dict(type='DefaultSampler', shuffle=False),
- persistent_workers=True)
-```
-
-Here, we adopt `LoadPairedImageFromFile` to load a paired image as the common loader does and crops
-it into two images with the same shape in different domains. As shown in the example, `pipeline` provides important data pipeline to process images, including loading from file system, resizing, cropping, flipping, transferring to `torch.Tensor` and packing to `GenDataSample`. All of supported data pipelines can be found in `mmedit/datasets/transforms`.
-
-For unpaired-data trained translation model like CycleGAN , `UnpairedImageDataset` is designed to train such translation models. Here is an example config for horse2zebra dataset:
-
-```python
-train_dataset_type = 'UnpairedImageDataset'
-val_dataset_type = 'UnpairedImageDataset'
-img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
-domain_a, domain_b = 'horse', 'zebra'
-train_pipeline = [
- dict(
- type='LoadImageFromFile',
- io_backend='disk',
- key=f'img_{domain_a}',
- flag='color'),
- dict(
- type='LoadImageFromFile',
- io_backend='disk',
- key=f'img_{domain_b}',
- flag='color'),
- dict(
- type='TransformBroadcaster',
- mapping={'img': [f'img_{domain_a}', f'img_{domain_b}']},
- auto_remap=True,
- share_random_params=True,
- transforms=[
- dict(type='Resize', scale=(286, 286), interpolation='bicubic'),
- dict(type='Crop', crop_size=(256, 256), random_crop=True),
- ]),
- dict(type='Flip', keys=[f'img_{domain_a}'], direction='horizontal'),
- dict(type='Flip', keys=[f'img_{domain_b}'], direction='horizontal'),
- dict(
- type='PackGenInputs',
- keys=[f'img_{domain_a}', f'img_{domain_b}'],
- meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path'])
-]
-test_pipeline = [
- dict(type='LoadImageFromFile', io_backend='disk', key='img', flag='color'),
- dict(type='Resize', scale=(256, 256), interpolation='bicubic'),
- dict(
- type='PackGenInputs',
- keys=[f'img_{domain_a}', f'img_{domain_b}'],
- meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path'])
-]
-data_root = './data/horse2zebra/'
-# `batch_size` and `data_root` need to be set.
-train_dataloader = dict(
- batch_size=1,
- num_workers=4,
- persistent_workers=True,
- sampler=dict(type='InfiniteSampler', shuffle=True),
- dataset=dict(
- type=dataset_type,
- data_root=data_root, # set by user
- pipeline=train_pipeline))
-
-val_dataloader = dict(
- batch_size=None,
- num_workers=4,
- dataset=dict(
- type=dataset_type,
- data_root=data_root, # set by user
- pipeline=test_pipeline),
- sampler=dict(type='DefaultSampler', shuffle=False),
- persistent_workers=True)
-
-test_dataloader = dict(
- batch_size=None,
- num_workers=4,
- dataset=dict(
- type=dataset_type,
- data_root=data_root, # set by user
- pipeline=test_pipeline),
- sampler=dict(type='DefaultSampler', shuffle=False),
- persistent_workers=True)
-```
-
-`UnpairedImageDataset` will load both images (domain A and B) from different paths and transform them at the same time.
-
-Here, we provide download links of datasets used in [Pix2Pix](http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/) and [CycleGAN](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/).
diff --git a/docs/en/notes/3_changelog.md b/docs/en/notes/3_changelog.md
index be177b9f25..8a618a5788 100644
--- a/docs/en/notes/3_changelog.md
+++ b/docs/en/notes/3_changelog.md
@@ -1,5 +1,53 @@
# Changelog
+## v1.0.0rc2 (02/11/2022)
+
+**Highlights**
+
+We are excited to announce the release of MMEditing 1.0.0rc2. This release supports 43+ models, 170+ configs and 169+ checkpoints in MMGeneration and MMEditing. We highlight the following new features
+
+- patch-based and slider-based image and video comparison viewer.
+- image colorization.
+
+We want to sincerely thank our community for continuously improving MMEditing.
+
+**New Features & Improvements**
+
+- Support qualitative comparison tools. (#1303)
+- Support instance aware colorization. (#1370)
+- Support multi-metrics with different sample-model. (#1171)
+- Improve the implementation
+ - refactoring evaluation metrics. (#1164)
+ - Save gt images in PGGAN's `forward`. (#1332)
+ - Improve type and change default number of `preprocess_div2k_dataset.py`. (#1380)
+ - Support pixel value clip in visualizer. (#1365)
+ - Support SinGAN Dataset and SinGAN demo. (#1363)
+ - Avoid cast int and float in GenDataPreprocessor. (#1385)
+- Improve the documentation
+ - Update a menu switcher. (#1162)
+ - Fix TTSR's README. (#1325)
+ - Revise docs (change `PackGenInputs` and `GenDataSample`). (#1382)
+
+**Bug Fixes**
+
+- Fix PPL bug. (#1172)
+- Fix RDN number of channels. (#1328)
+- Fix types of exceptions in demos. (#1372)
+- Fix realesrgan ema. (#1341)
+- Improve the assertion to ensuer `GenerateFacialHeatmap` as `np.float32`. (#1310)
+- Fix sampling behavior of `unpaired_dataset.py` and urls in cyclegan's README. (#1308)
+- Fix vsr models in pytorch2onnx. (#1300)
+- Fix incorrect settings in configs. (#1167,#1200,#1236,#1293,#1302,#1304,#1319,#1331,#1336,#1349,#1352,#1353,#1358,#1364,#1367,#1384,#1386,#1391,#1392,#1393)
+
+**New Contributors**
+
+- @gaoyang07 made their first contribution in https://github.com/open-mmlab/mmediting/pull/1372
+
+**Contributors**
+
+A total of 7 developers contributed to this release.
+Thanks @LeoXing1996, @Z-Fran, @zengyh1900, @plyfager, @ryanxingql, @ruoningYu, @gaoyang07.
+
## v1.0.0rc1(23/9/2022)
MMEditing 1.0.0rc1 has merged MMGeneration 1.x.
diff --git a/docs/en/user_guides/1_config.md b/docs/en/user_guides/1_config.md
index b7ffbbae1a..9bf3cbbdcd 100644
--- a/docs/en/user_guides/1_config.md
+++ b/docs/en/user_guides/1_config.md
@@ -331,11 +331,11 @@ data_root = './data/ffhq/' # Root path of data
train_pipeline = [ # Training data process pipeline
dict(type='LoadImageFromFile', key='img'), # First pipeline to load images from file path
dict(type='Flip', keys=['img'], direction='horizontal'), # Argumentation pipeline that flip the images
- dict(type='PackGenInputs', keys=['img'], meta_keys=['img_path']) # The last pipeline that formats the annotation data (if have) and decides which keys in the data should be packed into data_samples
+ dict(type='PackEditInputs', keys=['img']) # The last pipeline that formats the annotation data (if have) and decides which keys in the data should be packed into data_samples
]
val_pipeline = [
dict(type='LoadImageFromFile', key='img'), # First pipeline to load images from file path
- dict(type='PackGenInputs', keys=['img'], meta_keys=['img_path']) # The last pipeline that formats the annotation data (if have) and decides which keys in the data should be packed into data_samples
+ dict(type='PackEditInputs', keys=['img']) # The last pipeline that formats the annotation data (if have) and decides which keys in the data should be packed into data_samples
]
train_dataloader = dict( # The config of train dataloader
batch_size=4, # Batch size of a single GPU
diff --git a/docs/en/user_guides/4_train_test.md b/docs/en/user_guides/4_train_test.md
index c48b6e94ba..d6595111ff 100644
--- a/docs/en/user_guides/4_train_test.md
+++ b/docs/en/user_guides/4_train_test.md
@@ -196,7 +196,7 @@ val_dataloader = dict(
pipeline=[
dict(type='LoadImageFromFile', key='img'),
dict(type='Resize', scale=(64, 64)),
- dict(type='PackGenInputs', meta_keys=[])
+ dict(type='PackEditInputs')
]),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True)
diff --git a/docs/zh_cn/conf.py b/docs/zh_cn/conf.py
index 25ea047d1d..9d397fccb2 100644
--- a/docs/zh_cn/conf.py
+++ b/docs/zh_cn/conf.py
@@ -77,6 +77,24 @@
'name': 'GitHub',
'url': 'https://github.com/open-mmlab/mmediting',
},
+ {
+ 'name':
+ '版本',
+ 'children': [
+ {
+ 'name': 'MMEditing 0.x',
+ 'url': 'https://mmediting.readthedocs.io/en/latest/',
+ 'description': 'Main 分支文档'
+ },
+ {
+ 'name': 'MMEditing 1.x',
+ 'url': 'https://mmediting.readthedocs.io/en/1.x/',
+ 'description': '1.x 分支文档'
+ },
+ ],
+ 'active':
+ True,
+ },
],
'menu_lang':
'cn'
diff --git a/mmedit/apis/__init__.py b/mmedit/apis/__init__.py
index 63989da131..75033721b0 100644
--- a/mmedit/apis/__init__.py
+++ b/mmedit/apis/__init__.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .colorization_inference import colorization_inference
from .gan_inference import sample_conditional_model, sample_unconditional_model
from .inference import delete_cfg, init_model, set_random_seed
from .inpainting_inference import inpainting_inference
@@ -10,16 +11,10 @@
from .video_interpolation_inference import video_interpolation_inference
__all__ = [
- 'init_model',
- 'delete_cfg',
- 'set_random_seed',
- 'matting_inference',
- 'inpainting_inference',
- 'restoration_inference',
- 'restoration_video_inference',
- 'restoration_face_inference',
- 'video_interpolation_inference',
- 'sample_conditional_model',
- 'sample_unconditional_model',
- 'sample_img2img_model',
+ 'init_model', 'delete_cfg', 'set_random_seed', 'matting_inference',
+ 'inpainting_inference', 'restoration_inference',
+ 'restoration_video_inference', 'restoration_face_inference',
+ 'video_interpolation_inference', 'sample_conditional_model',
+ 'sample_unconditional_model', 'sample_img2img_model',
+ 'colorization_inference'
]
diff --git a/mmedit/apis/colorization_inference.py b/mmedit/apis/colorization_inference.py
new file mode 100644
index 0000000000..ddef7ef587
--- /dev/null
+++ b/mmedit/apis/colorization_inference.py
@@ -0,0 +1,51 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmengine.dataset import Compose
+from mmengine.dataset.utils import default_collate as collate
+from torch.nn.parallel import scatter
+
+
+def colorization_inference(model, img):
+ """Inference image with the model.
+
+ Args:
+ model (nn.Module): The loaded model.
+ img (str): Image file path.
+
+ Returns:
+ Tensor: The predicted colorization result.
+ """
+ device = next(model.parameters()).device
+
+ # build the data pipeline
+ test_pipeline = Compose(model.cfg.test_pipeline)
+ # prepare data
+ data = dict(img_path=img)
+ _data = test_pipeline(data)
+ data = dict()
+ data['inputs'] = _data['inputs'] / 255.0
+ data = collate([data])
+ data['data_samples'] = [_data['data_samples']]
+ if 'cuda' in str(device):
+ data = scatter(data, [device])[0]
+ if not data['data_samples'][0].empty_box:
+ data['data_samples'][0].cropped_img.data = scatter(
+ data['data_samples'][0].cropped_img.data, [device])[0] / 255.0
+
+ data['data_samples'][0].box_info.data = scatter(
+ data['data_samples'][0].box_info.data, [device])[0]
+
+ data['data_samples'][0].box_info_2x.data = scatter(
+ data['data_samples'][0].box_info_2x.data, [device])[0]
+
+ data['data_samples'][0].box_info_4x.data = scatter(
+ data['data_samples'][0].box_info_4x.data, [device])[0]
+
+ data['data_samples'][0].box_info_8x.data = scatter(
+ data['data_samples'][0].box_info_8x.data, [device])[0]
+
+ # forward the model
+ with torch.no_grad():
+ result = model(mode='tensor', **data)
+
+ return result
diff --git a/mmedit/apis/restoration_video_inference.py b/mmedit/apis/restoration_video_inference.py
index f046507bcb..293a85e5d7 100644
--- a/mmedit/apis/restoration_video_inference.py
+++ b/mmedit/apis/restoration_video_inference.py
@@ -84,7 +84,7 @@ def restoration_video_inference(model,
tmp_pipeline = []
for pipeline in test_pipeline:
if pipeline['type'] not in [
- 'GenerateSegmentIndices', 'LoadImageFromFileList'
+ 'GenerateSegmentIndices', 'LoadImageFromFile'
]:
tmp_pipeline.append(pipeline)
test_pipeline = tmp_pipeline
diff --git a/mmedit/apis/video_interpolation_inference.py b/mmedit/apis/video_interpolation_inference.py
index 83d310054c..3f3e9e620b 100644
--- a/mmedit/apis/video_interpolation_inference.py
+++ b/mmedit/apis/video_interpolation_inference.py
@@ -99,8 +99,7 @@ def video_interpolation_inference(model,
tmp_pipeline = []
for pipeline in test_pipeline:
if pipeline['type'] not in [
- 'GenerateSegmentIndices', 'LoadImageFromFileList',
- 'LoadImageFromFile'
+ 'GenerateSegmentIndices', 'LoadImageFromFile'
]:
tmp_pipeline.append(pipeline)
test_pipeline = tmp_pipeline
diff --git a/mmedit/datasets/__init__.py b/mmedit/datasets/__init__.py
index 0744816ca6..dd8a74e5da 100644
--- a/mmedit/datasets/__init__.py
+++ b/mmedit/datasets/__init__.py
@@ -7,6 +7,7 @@
from .grow_scale_image_dataset import GrowScaleImgDataset
from .imagenet_dataset import ImageNet
from .paired_image_dataset import PairedImageDataset
+from .singan_dataset import SinGANDataset
from .unpaired_image_dataset import UnpairedImageDataset
__all__ = [
@@ -19,4 +20,5 @@
'ImageNet',
'CIFAR10',
'GrowScaleImgDataset',
+ 'SinGANDataset',
]
diff --git a/mmedit/datasets/singan_dataset.py b/mmedit/datasets/singan_dataset.py
new file mode 100644
index 0000000000..e03d4567b8
--- /dev/null
+++ b/mmedit/datasets/singan_dataset.py
@@ -0,0 +1,139 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+
+import mmcv
+import numpy as np
+from mmengine.dataset import BaseDataset
+
+from mmedit.registry import DATASETS
+
+
+def create_real_pyramid(real, min_size, max_size, scale_factor_init):
+ """Create image pyramid.
+
+ This function is modified from the official implementation:
+ https://github.com/tamarott/SinGAN/blob/master/SinGAN/functions.py#L221
+
+ In this implementation, we adopt the rescaling function from MMCV.
+ Args:
+ real (np.array): The real image array.
+ min_size (int): The minimum size for the image pyramid.
+ max_size (int): The maximum size for the image pyramid.
+ scale_factor_init (float): The initial scale factor.
+ """
+
+ num_scales = int(
+ np.ceil(
+ np.log(np.power(min_size / min(real.shape[0], real.shape[1]), 1)) /
+ np.log(scale_factor_init))) + 1
+
+ scale2stop = int(
+ np.ceil(
+ np.log(
+ min([max_size, max([real.shape[0], real.shape[1]])]) /
+ max([real.shape[0], real.shape[1]])) /
+ np.log(scale_factor_init)))
+
+ stop_scale = num_scales - scale2stop
+
+ scale1 = min(max_size / max([real.shape[0], real.shape[1]]), 1)
+ real_max = mmcv.imrescale(real, scale1)
+ scale_factor = np.power(
+ min_size / (min(real_max.shape[0], real_max.shape[1])),
+ 1 / (stop_scale))
+
+ scale2stop = int(
+ np.ceil(
+ np.log(
+ min([max_size, max([real.shape[0], real.shape[1]])]) /
+ max([real.shape[0], real.shape[1]])) /
+ np.log(scale_factor_init)))
+ stop_scale = num_scales - scale2stop
+
+ reals = []
+ for i in range(stop_scale + 1):
+ scale = np.power(scale_factor, stop_scale - i)
+ curr_real = mmcv.imrescale(real, scale)
+ reals.append(curr_real)
+
+ return reals, scale_factor, stop_scale
+
+
+@DATASETS.register_module()
+class SinGANDataset(BaseDataset):
+ """SinGAN Dataset.
+
+ In this dataset, we create an image pyramid and save it in the cache.
+
+ Args:
+ img_path (str): Path to the single image file.
+ min_size (int): Min size of the image pyramid. Here, the number will be
+ set to the ``min(H, W)``.
+ max_size (int): Max size of the image pyramid. Here, the number will be
+ set to the ``max(H, W)``.
+ scale_factor_init (float): Rescale factor. Note that the actual factor
+ we use may be a little bit different from this value.
+ num_samples (int, optional): The number of samples (length) in this
+ dataset. Defaults to -1.
+ """
+
+ def __init__(self,
+ data_root,
+ min_size,
+ max_size,
+ scale_factor_init,
+ pipeline,
+ num_samples=-1):
+ self.min_size = min_size
+ self.max_size = max_size
+ self.scale_factor_init = scale_factor_init
+ self.num_samples = num_samples
+ super().__init__(data_root=data_root, pipeline=pipeline)
+
+ def full_init(self):
+ """Skip the full init process for SinGANDataset."""
+
+ self.load_data_list(self.min_size, self.max_size,
+ self.scale_factor_init)
+
+ def load_data_list(self, min_size, max_size, scale_factor_init):
+ """Load annatations for SinGAN Dataset.
+
+ Args:
+ min_size (int): The minimum size for the image pyramid.
+ max_size (int): The maximum size for the image pyramid.
+ scale_factor_init (float): The initial scale factor.
+ """
+ real = mmcv.imread(self.data_root)
+ self.reals, self.scale_factor, self.stop_scale = create_real_pyramid(
+ real, min_size, max_size, scale_factor_init)
+
+ self.data_dict = {}
+
+ for i, real in enumerate(self.reals):
+ self.data_dict[f'real_scale{i}'] = real
+
+ self.data_dict['input_sample'] = np.zeros_like(
+ self.data_dict['real_scale0']).astype(np.float32)
+
+ def __getitem__(self, index):
+ """Get `:attr:self.data_dict`. For SinGAN, we use single image with
+ different resolution to train the model.
+
+ Args:
+ idx (int): This will be ignored in `:class:SinGANDataset`.
+
+ Returns:
+ dict: Dict contains input image in different resolution.
+ ``self.pipeline``.
+ """
+ return self.pipeline(deepcopy(self.data_dict))
+
+ def __len__(self):
+ """Get the length of filtered dataset and automatically call
+ ``full_init`` if the dataset has not been fully init.
+
+ Returns:
+ int: The length of filtered dataset.
+ """
+ return int(1e6) if self.num_samples < 0 else self.num_samples
diff --git a/mmedit/datasets/transforms/__init__.py b/mmedit/datasets/transforms/__init__.py
index f5eb2a02c6..39543085f1 100644
--- a/mmedit/datasets/transforms/__init__.py
+++ b/mmedit/datasets/transforms/__init__.py
@@ -6,8 +6,9 @@
from .aug_shape import (Flip, NumpyPad, RandomRotation, RandomTransposeHW,
Resize)
from .crop import (CenterCropLongEdge, Crop, CropAroundCenter, CropAroundFg,
- CropAroundUnknown, CropLike, FixedCrop, ModCrop,
- PairedRandomCrop, RandomCropLongEdge, RandomResizedCrop)
+ CropAroundUnknown, CropLike, FixedCrop, InstanceCrop,
+ ModCrop, PairedRandomCrop, RandomCropLongEdge,
+ RandomResizedCrop)
from .fgbg import (CompositeFg, MergeFgAndBg, PerturbBg, RandomJitter,
RandomLoadResizeBg)
from .formatting import PackEditInputs, ToTensor
@@ -45,5 +46,5 @@
'GenerateSoftSeg', 'FormatTrimap', 'TransformTrimap', 'GenerateTrimap',
'GenerateTrimapWithDistTransform', 'CompositeFg', 'RandomLoadResizeBg',
'MergeFgAndBg', 'PerturbBg', 'RandomJitter', 'LoadPairedImageFromFile',
- 'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad'
+ 'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad', 'InstanceCrop'
]
diff --git a/mmedit/datasets/transforms/aug_shape.py b/mmedit/datasets/transforms/aug_shape.py
index 6e62c3a4a9..fb50fe1134 100644
--- a/mmedit/datasets/transforms/aug_shape.py
+++ b/mmedit/datasets/transforms/aug_shape.py
@@ -319,24 +319,31 @@ def _resize(self, img):
Returns:
img (np.ndarray): Resized image.
"""
-
- if self.keep_ratio:
- img, self.scale_factor = mmcv.imrescale(
- img,
- self.scale,
- return_scale=True,
- interpolation=self.interpolation,
- backend=self.backend)
+ if isinstance(img, list):
+ for i, image in enumerate(img):
+ size, img[i] = self._resize(image)
+ return size, img
else:
- img, w_scale, h_scale = mmcv.imresize(
- img,
- self.scale,
- return_scale=True,
- interpolation=self.interpolation,
- backend=self.backend)
- self.scale_factor = np.array((w_scale, h_scale), dtype=np.float32)
-
- return img
+ if self.keep_ratio:
+ img, self.scale_factor = mmcv.imrescale(
+ img,
+ self.scale,
+ return_scale=True,
+ interpolation=self.interpolation,
+ backend=self.backend)
+ else:
+ img, w_scale, h_scale = mmcv.imresize(
+ img,
+ self.scale,
+ return_scale=True,
+ interpolation=self.interpolation,
+ backend=self.backend)
+ self.scale_factor = np.array((w_scale, h_scale),
+ dtype=np.float32)
+
+ if len(img.shape) == 2:
+ img = np.expand_dims(img, axis=2)
+ return img.shape, img
def transform(self, results: Dict) -> Dict:
"""Transform function to resize images.
@@ -358,11 +365,11 @@ def transform(self, results: Dict) -> Dict:
new_w = min(self.max_size - (self.max_size % self.size_factor),
new_w)
self.scale = (new_w, new_h)
+
for key, out_key in zip(self.keys, self.output_keys):
- results[out_key] = self._resize(results[key])
- if len(results[out_key].shape) == 2:
- results[out_key] = np.expand_dims(results[out_key], axis=2)
- results[f'{out_key}_shape'] = results[out_key].shape
+ if key in results:
+ size, results[out_key] = self._resize(results[key])
+ results[f'{out_key}_shape'] = size
results['scale_factor'] = self.scale_factor
results['keep_ratio'] = self.keep_ratio
diff --git a/mmedit/datasets/transforms/crop.py b/mmedit/datasets/transforms/crop.py
index 906955254f..8554b72018 100644
--- a/mmedit/datasets/transforms/crop.py
+++ b/mmedit/datasets/transforms/crop.py
@@ -2,14 +2,19 @@
import math
import random
+import cv2 as cv
import mmcv
import numpy as np
+import torch
from mmcv.transforms import BaseTransform
+from mmdet.apis import inference_detector, init_detector
+from mmengine.hub import get_config
+from mmengine.registry import DefaultScope
from mmengine.utils import is_list_of, is_tuple_of
from torch.nn.modules.utils import _pair
from mmedit.registry import TRANSFORMS
-from mmedit.utils import random_choose_unknown
+from mmedit.utils import get_box_info, random_choose_unknown
@TRANSFORMS.register_module()
@@ -916,3 +921,108 @@ def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(keys={self.keys})')
return repr_str
+
+
+@TRANSFORMS.register_module()
+class InstanceCrop(BaseTransform):
+ """Use maskrcnn to detect instances on image.
+
+ Mask R-CNN is used to detect the instance on the image
+ pred_bbox is used to segment the instance on the image
+
+ Args:
+ config_file (str): config file name relative to detectron2's "configs/"
+ key (str): Unused
+ box_num_upbound (int):The upper limit on the number of instances
+ in the figure
+ """
+
+ def __init__(self,
+ config_file,
+ key='img',
+ box_num_upbound=-1,
+ finesize=256):
+
+ cfg = get_config(config_file, pretrained=True)
+ with DefaultScope.overwrite_default_scope('mmdet'):
+ self.predictor = init_detector(cfg, cfg.model_path)
+
+ self.key = key
+ self.box_num_upbound = box_num_upbound
+ self.final_size = finesize
+
+ def transform(self, results: dict) -> dict:
+ """The transform function of InstanceCrop.
+
+ Args:
+ results (dict): A dict containing the necessary information and
+ data for Conversion
+
+ Returns:
+ results (dict): A dict containing the processed data
+ and information.
+ """
+ # get consistent box prediction based on L channel
+
+ full_img = results['img']
+ full_img_size = results['ori_img_shape'][:-1][::-1]
+ pred_bbox, pred_scores = self.predict_bbox(full_img)
+
+ if self.box_num_upbound > 0 and pred_bbox.shape[
+ 0] > self.box_num_upbound:
+ index_mask = np.argsort(pred_scores, axis=0)
+ index_mask = index_mask[pred_scores.shape[0] -
+ self.box_num_upbound:pred_scores.shape[0]]
+ pred_bbox = pred_bbox[index_mask]
+
+ # get cropped images and box info
+ cropped_img_list = []
+ index_list = range(len(pred_bbox))
+ box_info, box_info_2x, box_info_4x, box_info_8x = np.zeros(
+ (4, len(index_list), 6))
+ for i in index_list:
+ startx, starty, endx, endy = pred_bbox[i]
+ cropped_img = full_img[starty:endy, startx:endx, :]
+ cropped_img_list.append(cropped_img)
+ box_info[i] = np.array(
+ get_box_info(pred_bbox[i], full_img_size, self.final_size))
+ box_info_2x[i] = np.array(
+ get_box_info(pred_bbox[i], full_img_size,
+ self.final_size // 2))
+ box_info_4x[i] = np.array(
+ get_box_info(pred_bbox[i], full_img_size,
+ self.final_size // 4))
+ box_info_8x[i] = np.array(
+ get_box_info(pred_bbox[i], full_img_size,
+ self.final_size // 8))
+
+ # update results
+ if len(pred_bbox) > 0:
+ results['cropped_img'] = cropped_img_list
+ results['box_info'] = torch.from_numpy(box_info).type(torch.long)
+ results['box_info_2x'] = torch.from_numpy(box_info_2x).type(
+ torch.long)
+ results['box_info_4x'] = torch.from_numpy(box_info_4x).type(
+ torch.long)
+ results['box_info_8x'] = torch.from_numpy(box_info_8x).type(
+ torch.long)
+ results['empty_box'] = False
+ else:
+ results['empty_box'] = True
+ return results
+
+ def predict_bbox(self, image):
+ lab_image = cv.cvtColor(image, cv.COLOR_BGR2LAB)
+ l_channel, _, _ = cv.split(lab_image)
+ l_stack = np.stack([l_channel, l_channel, l_channel], axis=2)
+
+ with DefaultScope.overwrite_default_scope('mmdet'):
+ with torch.no_grad():
+ results = inference_detector(self.predictor, l_stack)
+
+ bboxes = results.pred_instances.bboxes.cpu().numpy().astype(np.int32)
+ scores = results.pred_instances.scores.cpu().numpy()
+ index_mask = [i for i, x in enumerate(scores) if x >= 0.7]
+ scores = np.array(scores[index_mask])
+ bboxes = np.array(bboxes[index_mask])
+ return bboxes, scores
diff --git a/mmedit/datasets/transforms/formatting.py b/mmedit/datasets/transforms/formatting.py
index f271c31bca..5df741884d 100644
--- a/mmedit/datasets/transforms/formatting.py
+++ b/mmedit/datasets/transforms/formatting.py
@@ -76,6 +76,27 @@ def images_to_tensor(value):
return tensor
+def can_convert_to_image(value):
+ """Judge whether the input value can be converted to image tensor via
+ :func:`images_to_tensor` function.
+
+ Args:
+ value (any): The input value.
+
+ Returns:
+ bool: If true, the input value can convert to image with
+ :func:`images_to_tensor`, and vice versa.
+ """
+ if isinstance(value, (List, Tuple)):
+ return all([can_convert_to_image(v) for v in value])
+ elif isinstance(value, np.ndarray):
+ return True
+ elif isinstance(value, torch.Tensor):
+ return True
+ else:
+ return False
+
+
@TRANSFORMS.register_module()
class PackEditInputs(BaseTransform):
"""Pack the inputs data for SR, VFI, matting and inpainting.
@@ -83,11 +104,17 @@ class PackEditInputs(BaseTransform):
Keys for images include ``img``, ``gt``, ``ref``, ``mask``, ``gt_heatmap``,
``trimap``, ``gt_alpha``, ``gt_fg``, ``gt_bg``. All of them will be
packed into data field of EditDataSample.
+ pack_all (bool): Whether pack all variables in `results` to `inputs` dict.
+ This is useful when keys of the input dict is not fixed.
+ Please be careful when using this function, because we do not
+ Defaults to False.
Others will be packed into metainfo field of EditDataSample.
"""
- def __init__(self, keys: Tuple[List[str], str, None] = None):
+ def __init__(self,
+ keys: Tuple[List[str], str, None] = None,
+ pack_all: bool = False):
if keys is not None:
if isinstance(keys, list):
self.keys = keys
@@ -95,6 +122,7 @@ def __init__(self, keys: Tuple[List[str], str, None] = None):
self.keys = [keys]
else:
self.keys = None
+ self.pack_all = pack_all
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
@@ -113,14 +141,14 @@ def transform(self, results: dict) -> dict:
packed_results = dict()
data_sample = EditDataSample()
- if self.keys is not None:
+ pack_keys = [k for k in results.keys()] if self.pack_all else self.keys
+ if pack_keys is not None:
packed_results['inputs'] = dict()
- for key in self.keys:
- img = results.pop(key)
- if len(img.shape) < 3:
- img = np.expand_dims(img, -1)
- img = np.ascontiguousarray(img.transpose(2, 0, 1))
- packed_results['inputs'][key] = to_tensor(img)
+ for key in pack_keys:
+ val = results[key]
+ if can_convert_to_image(val):
+ packed_results['inputs'][key] = images_to_tensor(val)
+ results.pop(key)
elif 'img' in results:
img = results.pop('img')
@@ -201,6 +229,21 @@ def transform(self, results: dict) -> dict:
gt_bg_tensor = images_to_tensor(gt_bg)
data_sample.gt_bg = PixelData(data=gt_bg_tensor)
+ if 'rgb_img' in results:
+ gt_rgb = results.pop('rgb_img')
+ gt_rgb_tensor = images_to_tensor(gt_rgb)
+ data_sample.gt_rgb = PixelData(data=gt_rgb_tensor)
+
+ if 'gray_img' in results:
+ gray = results.pop('gray_img')
+ gray_tensor = images_to_tensor(gray)
+ data_sample.gray = PixelData(data=gray_tensor)
+
+ if 'cropped_img' in results:
+ cropped_img = results.pop('cropped_img')
+ cropped_img = images_to_tensor(cropped_img)
+ data_sample.cropped_img = PixelData(data=cropped_img)
+
metainfo = dict()
for key in results:
metainfo[key] = results[key]
diff --git a/mmedit/datasets/transforms/generate_assistant.py b/mmedit/datasets/transforms/generate_assistant.py
index 34a96bed54..202f95fcce 100644
--- a/mmedit/datasets/transforms/generate_assistant.py
+++ b/mmedit/datasets/transforms/generate_assistant.py
@@ -190,7 +190,7 @@ def transform(self, results):
else:
heatmap = self.generate_heatmap_from_img(img)
- results[f'{self.image_key}_heatmap'] = heatmap
+ results[f'{self.image_key}_heatmap'] = heatmap.astype(np.float32)
return results
diff --git a/mmedit/datasets/transforms/generate_frame_indices.py b/mmedit/datasets/transforms/generate_frame_indices.py
index b665b509d9..705f325c58 100644
--- a/mmedit/datasets/transforms/generate_frame_indices.py
+++ b/mmedit/datasets/transforms/generate_frame_indices.py
@@ -260,6 +260,8 @@ def transform(self, results):
self.sequence_length = results['sequence_length']
num_input_frames = results.get('num_input_frames',
self.sequence_length)
+ if num_input_frames is None:
+ num_input_frames = self.sequence_length
# randomly select a frame as start
if self.sequence_length - num_input_frames * interval < 0:
diff --git a/mmedit/datasets/unpaired_image_dataset.py b/mmedit/datasets/unpaired_image_dataset.py
index f18ccc0180..204c4a29e8 100644
--- a/mmedit/datasets/unpaired_image_dataset.py
+++ b/mmedit/datasets/unpaired_image_dataset.py
@@ -104,7 +104,7 @@ def get_data_info(self, idx) -> dict:
dict: The idx-th annotation of the dataset.
"""
img_a_path = self.data_infos_a[idx % self.len_a]['path']
- if self.test_mode:
+ if not self.test_mode:
idx_b = np.random.randint(0, self.len_b)
img_b_path = self.data_infos_b[idx_b]['path']
else:
diff --git a/mmedit/evaluation/evaluator.py b/mmedit/evaluation/evaluator.py
index a7b297db06..f51d945f5b 100644
--- a/mmedit/evaluation/evaluator.py
+++ b/mmedit/evaluation/evaluator.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import hashlib
from collections import defaultdict
from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union
@@ -8,6 +9,7 @@
from mmedit.registry import EVALUATORS
from mmedit.structures import EditDataSample
+from .metrics.base_gen_metric import GenMetric
@EVALUATORS.register_module()
@@ -70,6 +72,19 @@ def prepare_metrics(self, module: BaseModel, dataloader: DataLoader):
metric.prepare(module, dataloader)
self.is_ready = True
+ @staticmethod
+ def _cal_metric_hash(metric: GenMetric):
+ """Calculate a unique hash value based on the `SAMPLER_MODE` and
+ `sample_model`."""
+ sampler_mode = metric.SAMPLER_MODE
+ sample_model = metric.sample_model
+ metric_dict = {
+ 'SAMPLER_MODE': sampler_mode,
+ 'sample_model': sample_model
+ }
+ md5 = hashlib.md5(repr(metric_dict).encode('utf-8')).hexdigest()
+ return md5
+
def prepare_samplers(self, module: BaseModel, dataloader: DataLoader
) -> List[Tuple[List[BaseMetric], Iterator]]:
"""Prepare for the sampler for metrics whose sampling mode are
@@ -91,11 +106,11 @@ def prepare_samplers(self, module: BaseModel, dataloader: DataLoader
List[Tuple[List[BaseMetric], Iterator]]: A list of "metrics-shared
sampler" pair.
"""
-
- # grouping metrics based on `SAMPLER_MODE`.
+ # grouping metrics based on `SAMPLER_MODE` and `sample_mode`
metric_mode_dict = defaultdict(list)
for metric in self.metrics:
- metric_mode_dict[metric.SAMPLER_MODE].append(metric)
+ metric_md5 = self._cal_metric_hash(metric)
+ metric_mode_dict[metric_md5].append(metric)
metrics_sampler_list = []
for metrics in metric_mode_dict.values():
diff --git a/mmedit/evaluation/metrics/__init__.py b/mmedit/evaluation/metrics/__init__.py
index d89a216e07..e6a0aa3e01 100644
--- a/mmedit/evaluation/metrics/__init__.py
+++ b/mmedit/evaluation/metrics/__init__.py
@@ -1,40 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .connectivity_error import ConnectivityError
from .equivariance import Equivariance
from .fid import FrechetInceptionDistance, TransFID
+from .gradient_error import GradientError
from .inception_score import InceptionScore, TransIS
-from .matting import SAD, ConnectivityError, GradientError, MattingMSE
+from .mae import MAE
+from .matting_mse import MattingMSE
from .ms_ssim import MultiScaleStructureSimilarity
+from .mse import MSE
from .niqe import NIQE, niqe
-from .pixel_metrics import MAE, MSE, PSNR, SNR, psnr, snr
from .ppl import PerceptualPathLength
from .precision_and_recall import PrecisionAndRecall
+from .psnr import PSNR, psnr
+from .sad import SAD
+from .snr import SNR, snr
from .ssim import SSIM, ssim
from .swd import SlicedWassersteinDistance
__all__ = [
- 'ConnectivityError',
- 'GradientError',
'MAE',
- 'MattingMSE',
'MSE',
- 'NIQE',
- 'niqe',
'PSNR',
'psnr',
- 'SAD',
'SNR',
'snr',
'SSIM',
'ssim',
- 'Equivariance',
+ 'MultiScaleStructureSimilarity',
'FrechetInceptionDistance',
+ 'TransFID',
'InceptionScore',
- 'MultiScaleStructureSimilarity',
+ 'TransIS',
+ 'SAD',
+ 'MattingMSE',
+ 'ConnectivityError',
+ 'GradientError',
'PerceptualPathLength',
- 'MultiScaleStructureSimilarity',
'PrecisionAndRecall',
'SlicedWassersteinDistance',
- 'TransFID',
- 'TransIS',
+ 'NIQE',
+ 'niqe',
+ 'Equivariance',
]
diff --git a/mmedit/evaluation/metrics/connectivity_error.py b/mmedit/evaluation/metrics/connectivity_error.py
new file mode 100644
index 0000000000..507359db90
--- /dev/null
+++ b/mmedit/evaluation/metrics/connectivity_error.py
@@ -0,0 +1,117 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""Evaluation metrics used in Image Matting."""
+
+from typing import List, Sequence
+
+import cv2
+import numpy as np
+from mmengine.evaluator import BaseMetric
+
+from mmedit.registry import METRICS
+from .metrics_utils import _fetch_data_and_check, average
+
+
+@METRICS.register_module()
+class ConnectivityError(BaseMetric):
+ """Connectivity error for evaluating alpha matte prediction.
+
+ .. note::
+
+ Current implementation assume image / alpha / trimap array in numpy
+ format and with pixel value ranging from 0 to 255.
+
+ .. note::
+
+ pred_alpha should be masked by trimap before passing
+ into this metric
+
+ Args:
+ step (float): Step of threshold when computing intersection between
+ `alpha` and `pred_alpha`. Default to 0.1 .
+ norm_const (int): Divide the result to reduce its magnitude.
+ Default to 1000.
+
+ Default prefix: ''
+
+ Metrics:
+ - ConnectivityError (float): Connectivity Error
+ """
+
+ def __init__(
+ self,
+ step=0.1,
+ norm_constant=1000,
+ **kwargs,
+ ) -> None:
+ self.step = step
+ self.norm_constant = norm_constant
+ super().__init__(**kwargs)
+
+ def process(self, data_batch: Sequence[dict],
+ data_samples: Sequence[dict]) -> None:
+ """Process one batch of data samples and predictions. The processed
+ results should be stored in ``self.results``, which will be used to
+ compute the metrics when all batches have been processed.
+
+ Args:
+ data_batch (Sequence[dict]): A batch of data from the dataloader.
+ predictions (Sequence[dict]): A batch of outputs from
+ the model.
+ """
+
+ for data_sample in data_samples:
+ pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample)
+
+ thresh_steps = np.arange(0, 1 + self.step, self.step)
+ round_down_map = -np.ones_like(gt_alpha)
+ for i in range(1, len(thresh_steps)):
+ gt_alpha_thresh = gt_alpha >= thresh_steps[i]
+ pred_alpha_thresh = pred_alpha >= thresh_steps[i]
+ intersection = gt_alpha_thresh & pred_alpha_thresh
+ intersection = intersection.astype(np.uint8)
+
+ # connected components
+ _, output, stats, _ = cv2.connectedComponentsWithStats(
+ intersection, connectivity=4)
+ # start from 1 in dim 0 to exclude background
+ size = stats[1:, -1]
+
+ # largest connected component of the intersection
+ omega = np.zeros_like(gt_alpha)
+ if len(size) != 0:
+ max_id = np.argmax(size)
+ # plus one to include background
+ omega[output == max_id + 1] = 1
+
+ mask = (round_down_map == -1) & (omega == 0)
+ round_down_map[mask] = thresh_steps[i - 1]
+ round_down_map[round_down_map == -1] = 1
+
+ gt_alpha_diff = gt_alpha - round_down_map
+ pred_alpha_diff = pred_alpha - round_down_map
+ # only calculate difference larger than or equal to 0.15
+ gt_alpha_phi = 1 - gt_alpha_diff * (gt_alpha_diff >= 0.15)
+ pred_alpha_phi = 1 - pred_alpha_diff * (pred_alpha_diff >= 0.15)
+
+ connectivity_error = np.sum(
+ np.abs(gt_alpha_phi - pred_alpha_phi) * (trimap == 128))
+
+ # divide by 1000 to reduce the magnitude of the result
+ connectivity_error /= self.norm_constant
+
+ self.results.append({'conn_err': connectivity_error})
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (dict): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+
+ conn_err = average(results, 'conn_err')
+
+ return {'ConnectivityError': conn_err}
diff --git a/mmedit/evaluation/metrics/gradient_error.py b/mmedit/evaluation/metrics/gradient_error.py
new file mode 100644
index 0000000000..de5a15dccc
--- /dev/null
+++ b/mmedit/evaluation/metrics/gradient_error.py
@@ -0,0 +1,95 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Sequence
+
+import cv2
+import numpy as np
+from mmengine.evaluator import BaseMetric
+
+from mmedit.registry import METRICS
+from ..functional import gauss_gradient
+from .metrics_utils import _fetch_data_and_check, average
+
+
+@METRICS.register_module()
+class GradientError(BaseMetric):
+ """Gradient error for evaluating alpha matte prediction.
+
+ .. note::
+
+ Current implementation assume image / alpha / trimap array in numpy
+ format and with pixel value ranging from 0 to 255.
+
+ .. note::
+
+ pred_alpha should be masked by trimap before passing
+ into this metric
+
+ Args:
+ sigma (float): Standard deviation of the gaussian kernel.
+ Defaults to 1.4 .
+ norm_const (int): Divide the result to reduce its magnitude.
+ Defaults to 1000 .
+
+ Default prefix: ''
+
+ Metrics:
+ - GradientError (float): Gradient Error
+ """
+
+ def __init__(
+ self,
+ sigma=1.4,
+ norm_constant=1000,
+ **kwargs,
+ ) -> None:
+ self.sigma = sigma
+ self.norm_constant = norm_constant
+ super().__init__(**kwargs)
+
+ def process(self, data_batch: Sequence[dict],
+ data_samples: Sequence[dict]) -> None:
+ """Process one batch of data samples and predictions. The processed
+ results should be stored in ``self.results``, which will be used to
+ compute the metrics when all batches have been processed.
+
+ Args:
+ data_batch (Sequence[dict]): A batch of data from the dataloader.
+ predictions (Sequence[dict]): A batch of outputs from
+ the model.
+ """
+
+ for data_sample in data_samples:
+ pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample)
+
+ gt_alpha_normed = np.zeros_like(gt_alpha)
+ pred_alpha_normed = np.zeros_like(pred_alpha)
+
+ cv2.normalize(gt_alpha, gt_alpha_normed, 1.0, 0.0, cv2.NORM_MINMAX)
+ cv2.normalize(pred_alpha, pred_alpha_normed, 1.0, 0.0,
+ cv2.NORM_MINMAX)
+
+ gt_alpha_grad = gauss_gradient(gt_alpha_normed, self.sigma)
+ pred_alpha_grad = gauss_gradient(pred_alpha_normed, self.sigma)
+ # this is the sum over n samples
+ grad_loss = ((gt_alpha_grad - pred_alpha_grad)**2 *
+ (trimap == 128)).sum()
+
+ # divide by 1000 to reduce the magnitude of the result
+ grad_loss /= self.norm_constant
+
+ self.results.append({'grad_err': grad_loss})
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (dict): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+
+ grad_err = average(results, 'grad_err')
+
+ return {'GradientError': grad_err}
diff --git a/mmedit/evaluation/metrics/mae.py b/mmedit/evaluation/metrics/mae.py
new file mode 100644
index 0000000000..0acd972a63
--- /dev/null
+++ b/mmedit/evaluation/metrics/mae.py
@@ -0,0 +1,60 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""Evaluation metrics based on pixels."""
+
+import numpy as np
+
+from mmedit.registry import METRICS
+from .base_sample_wise_metric import BaseSampleWiseMetric
+
+
+@METRICS.register_module()
+class MAE(BaseSampleWiseMetric):
+ """Mean Absolute Error metric for image.
+
+ mean(abs(a-b))
+
+ Args:
+
+ gt_key (str): Key of ground-truth. Default: 'gt_img'
+ pred_key (str): Key of prediction. Default: 'pred_img'
+ mask_key (str, optional): Key of mask, if mask_key is None, calculate
+ all regions. Default: None
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Default: None
+
+ Metrics:
+ - MAE (float): Mean of Absolute Error
+ """
+
+ metric = 'MAE'
+
+ def process_image(self, gt, pred, mask):
+ """Process an image.
+
+ Args:
+ gt (Tensor | np.ndarray): GT image.
+ pred (Tensor | np.ndarray): Pred image.
+ mask (Tensor | np.ndarray): Mask of evaluation.
+ Returns:
+ result (np.ndarray): MAE result.
+ """
+
+ gt = gt / 255.
+ pred = pred / 255.
+
+ diff = gt - pred
+ diff = abs(diff)
+
+ if self.mask_key is not None:
+ diff *= mask # broadcast for channel dimension
+ scale = np.prod(diff.shape) / np.prod(mask.shape)
+ result = diff.sum() / (mask.sum() * scale + 1e-12)
+ else:
+ result = diff.mean()
+
+ return result
diff --git a/mmedit/evaluation/metrics/matting.py b/mmedit/evaluation/metrics/matting.py
deleted file mode 100644
index ba4c6e30f9..0000000000
--- a/mmedit/evaluation/metrics/matting.py
+++ /dev/null
@@ -1,396 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-"""Evaluation metrics used in Image Matting."""
-
-from typing import List, Sequence
-
-import cv2
-import numpy as np
-from mmengine.evaluator import BaseMetric
-
-from mmedit.registry import METRICS
-from ..functional import gauss_gradient
-from .metrics_utils import average
-
-
-def _assert_ndim(input, name, ndim, shape_hint):
- if input.ndim != ndim:
- raise ValueError(
- f'{name} should be of shape {shape_hint}, but got {input.shape}.')
-
-
-def _assert_masked(pred_alpha, trimap):
- if (pred_alpha[trimap == 0] != 0).any() or (pred_alpha[trimap == 255] !=
- 255).any():
- raise ValueError(
- 'pred_alpha should be masked by trimap before evaluation')
-
-
-def _fetch_data_and_check(data_samples):
- """Fetch and check data from one item of data_batch and predictions.
-
- Args:
- data_batch (dict): One item of data_batch.
- predictions (dict): One item of predictions.
-
- Returns:
- pred_alpha (Tensor): Pred_alpha data of predictions.
- ori_alpha (Tensor): Ori_alpha data of data_batch.
- ori_trimap (Tensor): Ori_trimap data of data_batch.
- """
- ori_trimap = data_samples['ori_trimap'][:, :, 0]
- ori_alpha = data_samples['ori_alpha'][:, :, 0]
- pred_alpha = data_samples['output']['pred_alpha']['data'] # 2D tensor
- pred_alpha = pred_alpha.cpu().numpy()
-
- _assert_ndim(ori_trimap, 'trimap', 2, 'HxW')
- _assert_ndim(ori_alpha, 'gt_alpha', 2, 'HxW')
- _assert_ndim(pred_alpha, 'pred_alpha', 2, 'HxW')
- _assert_masked(pred_alpha, ori_trimap)
-
- # dtype uint8 -> float64
- pred_alpha = pred_alpha / 255.0
- ori_alpha = ori_alpha / 255.0
- # test shows that using float32 vs float64 differs final results at 1e-4
- # speed are comparable, so we choose float64 for accuracy
-
- return pred_alpha, ori_alpha, ori_trimap
-
-
-@METRICS.register_module()
-class SAD(BaseMetric):
- """Sum of Absolute Differences metric for image matting.
-
- This metric compute per-pixel absolute difference and sum across all
- pixels.
- i.e. sum(abs(a-b)) / norm_const
-
- .. note::
-
- Current implementation assume image / alpha / trimap array in numpy
- format and with pixel value ranging from 0 to 255.
-
- .. note::
-
- pred_alpha should be masked by trimap before passing
- into this metric
-
- Default prefix: ''
-
- Args:
- norm_const (int): Divide the result to reduce its magnitude.
- Default to 1000.
-
- Metrics:
- - SAD (float): Sum of Absolute Differences
- """
-
- default_prefix = ''
-
- def __init__(
- self,
- norm_const=1000,
- **kwargs,
- ) -> None:
- self.norm_const = norm_const
- super().__init__(**kwargs)
-
- def process(self, data_batch: Sequence[dict],
- data_samples: Sequence[dict]) -> None:
- """Process one batch of data and predictions.
-
- Args:
- data_batch (Sequence[Tuple[Any, dict]]): A batch of data
- from the dataloader.
- predictions (Sequence[dict]): A batch of outputs from
- the model.
- """
- for data_sample in data_samples:
- pred_alpha, gt_alpha, _ = _fetch_data_and_check(data_sample)
-
- # divide by 1000 to reduce the magnitude of the result
- sad_sum = np.abs(pred_alpha - gt_alpha).sum() / self.norm_const
-
- result = {'sad': sad_sum}
-
- self.results.append(result)
-
- def compute_metrics(self, results: List):
- """Compute the metrics from processed results.
-
- Args:
- results (dict): The processed results of each batch.
-
- Returns:
- Dict: The computed metrics. The keys are the names of the metrics,
- and the values are corresponding results.
- """
-
- sad = average(results, 'sad')
-
- return {'SAD': sad}
-
-
-@METRICS.register_module()
-class MattingMSE(BaseMetric):
- """Mean Squared Error metric for image matting.
-
- This metric compute per-pixel squared error average across all
- pixels.
- i.e. mean((a-b)^2) / norm_const
-
- .. note::
-
- Current implementation assume image / alpha / trimap array in numpy
- format and with pixel value ranging from 0 to 255.
-
- .. note::
-
- pred_alpha should be masked by trimap before passing
- into this metric
-
- Default prefix: ''
-
- Args:
- norm_const (int): Divide the result to reduce its magnitude.
- Default to 1000.
-
- Metrics:
- - MattingMSE (float): Mean of Squared Error
- """
-
- default_prefix = ''
-
- def __init__(
- self,
- norm_const=1000,
- **kwargs,
- ) -> None:
- self.norm_const = norm_const
- super().__init__(**kwargs)
-
- def process(self, data_batch: Sequence[dict],
- data_samples: Sequence[dict]) -> None:
- """Process one batch of data and predictions.
-
- Args:
- data_batch (Sequence[dict]): A batch of data
- from the dataloader.
- data_samples (Sequence[dict]): A batch of outputs from
- the model.
- """
- for data_sample in data_samples:
- pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample)
-
- weight_sum = (trimap == 128).sum()
- if weight_sum != 0:
- mse_result = ((pred_alpha - gt_alpha)**2).sum() / weight_sum
- else:
- mse_result = 0
-
- self.results.append({'mse': mse_result})
-
- def compute_metrics(self, results: List):
- """Compute the metrics from processed results.
-
- Args:
- results (dict): The processed results of each batch.
-
- Returns:
- Dict: The computed metrics. The keys are the names of the metrics,
- and the values are corresponding results.
- """
-
- mse = average(results, 'mse')
-
- return {'MattingMSE': mse}
-
-
-@METRICS.register_module()
-class GradientError(BaseMetric):
- """Gradient error for evaluating alpha matte prediction.
-
- .. note::
-
- Current implementation assume image / alpha / trimap array in numpy
- format and with pixel value ranging from 0 to 255.
-
- .. note::
-
- pred_alpha should be masked by trimap before passing
- into this metric
-
- Args:
- sigma (float): Standard deviation of the gaussian kernel.
- Defaults to 1.4 .
- norm_const (int): Divide the result to reduce its magnitude.
- Defaults to 1000 .
-
- Default prefix: ''
-
- Metrics:
- - GradientError (float): Gradient Error
- """
-
- def __init__(
- self,
- sigma=1.4,
- norm_constant=1000,
- **kwargs,
- ) -> None:
- self.sigma = sigma
- self.norm_constant = norm_constant
- super().__init__(**kwargs)
-
- def process(self, data_batch: Sequence[dict],
- data_samples: Sequence[dict]) -> None:
- """Process one batch of data samples and predictions. The processed
- results should be stored in ``self.results``, which will be used to
- compute the metrics when all batches have been processed.
-
- Args:
- data_batch (Sequence[dict]): A batch of data from the dataloader.
- predictions (Sequence[dict]): A batch of outputs from
- the model.
- """
-
- for data_sample in data_samples:
- pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample)
-
- gt_alpha_normed = np.zeros_like(gt_alpha)
- pred_alpha_normed = np.zeros_like(pred_alpha)
-
- cv2.normalize(gt_alpha, gt_alpha_normed, 1.0, 0.0, cv2.NORM_MINMAX)
- cv2.normalize(pred_alpha, pred_alpha_normed, 1.0, 0.0,
- cv2.NORM_MINMAX)
-
- gt_alpha_grad = gauss_gradient(gt_alpha_normed, self.sigma)
- pred_alpha_grad = gauss_gradient(pred_alpha_normed, self.sigma)
- # this is the sum over n samples
- grad_loss = ((gt_alpha_grad - pred_alpha_grad)**2 *
- (trimap == 128)).sum()
-
- # divide by 1000 to reduce the magnitude of the result
- grad_loss /= self.norm_constant
-
- self.results.append({'grad_err': grad_loss})
-
- def compute_metrics(self, results: List):
- """Compute the metrics from processed results.
-
- Args:
- results (dict): The processed results of each batch.
-
- Returns:
- Dict: The computed metrics. The keys are the names of the metrics,
- and the values are corresponding results.
- """
-
- grad_err = average(results, 'grad_err')
-
- return {'GradientError': grad_err}
-
-
-@METRICS.register_module()
-class ConnectivityError(BaseMetric):
- """Connectivity error for evaluating alpha matte prediction.
-
- .. note::
-
- Current implementation assume image / alpha / trimap array in numpy
- format and with pixel value ranging from 0 to 255.
-
- .. note::
-
- pred_alpha should be masked by trimap before passing
- into this metric
-
- Args:
- step (float): Step of threshold when computing intersection between
- `alpha` and `pred_alpha`. Default to 0.1 .
- norm_const (int): Divide the result to reduce its magnitude.
- Default to 1000.
-
- Default prefix: ''
-
- Metrics:
- - ConnectivityError (float): Connectivity Error
- """
-
- def __init__(
- self,
- step=0.1,
- norm_constant=1000,
- **kwargs,
- ) -> None:
- self.step = step
- self.norm_constant = norm_constant
- super().__init__(**kwargs)
-
- def process(self, data_batch: Sequence[dict],
- data_samples: Sequence[dict]) -> None:
- """Process one batch of data samples and predictions. The processed
- results should be stored in ``self.results``, which will be used to
- compute the metrics when all batches have been processed.
-
- Args:
- data_batch (Sequence[dict]): A batch of data from the dataloader.
- predictions (Sequence[dict]): A batch of outputs from
- the model.
- """
-
- for data_sample in data_samples:
- pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample)
-
- thresh_steps = np.arange(0, 1 + self.step, self.step)
- round_down_map = -np.ones_like(gt_alpha)
- for i in range(1, len(thresh_steps)):
- gt_alpha_thresh = gt_alpha >= thresh_steps[i]
- pred_alpha_thresh = pred_alpha >= thresh_steps[i]
- intersection = gt_alpha_thresh & pred_alpha_thresh
- intersection = intersection.astype(np.uint8)
-
- # connected components
- _, output, stats, _ = cv2.connectedComponentsWithStats(
- intersection, connectivity=4)
- # start from 1 in dim 0 to exclude background
- size = stats[1:, -1]
-
- # largest connected component of the intersection
- omega = np.zeros_like(gt_alpha)
- if len(size) != 0:
- max_id = np.argmax(size)
- # plus one to include background
- omega[output == max_id + 1] = 1
-
- mask = (round_down_map == -1) & (omega == 0)
- round_down_map[mask] = thresh_steps[i - 1]
- round_down_map[round_down_map == -1] = 1
-
- gt_alpha_diff = gt_alpha - round_down_map
- pred_alpha_diff = pred_alpha - round_down_map
- # only calculate difference larger than or equal to 0.15
- gt_alpha_phi = 1 - gt_alpha_diff * (gt_alpha_diff >= 0.15)
- pred_alpha_phi = 1 - pred_alpha_diff * (pred_alpha_diff >= 0.15)
-
- connectivity_error = np.sum(
- np.abs(gt_alpha_phi - pred_alpha_phi) * (trimap == 128))
-
- # divide by 1000 to reduce the magnitude of the result
- connectivity_error /= self.norm_constant
-
- self.results.append({'conn_err': connectivity_error})
-
- def compute_metrics(self, results: List):
- """Compute the metrics from processed results.
-
- Args:
- results (dict): The processed results of each batch.
-
- Returns:
- Dict: The computed metrics. The keys are the names of the metrics,
- and the values are corresponding results.
- """
-
- conn_err = average(results, 'conn_err')
-
- return {'ConnectivityError': conn_err}
diff --git a/mmedit/evaluation/metrics/matting_mse.py b/mmedit/evaluation/metrics/matting_mse.py
new file mode 100644
index 0000000000..d734c01bb2
--- /dev/null
+++ b/mmedit/evaluation/metrics/matting_mse.py
@@ -0,0 +1,82 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Sequence
+
+from mmengine.evaluator import BaseMetric
+
+from mmedit.registry import METRICS
+from .metrics_utils import _fetch_data_and_check, average
+
+
+@METRICS.register_module()
+class MattingMSE(BaseMetric):
+ """Mean Squared Error metric for image matting.
+
+ This metric compute per-pixel squared error average across all
+ pixels.
+ i.e. mean((a-b)^2) / norm_const
+
+ .. note::
+
+ Current implementation assume image / alpha / trimap array in numpy
+ format and with pixel value ranging from 0 to 255.
+
+ .. note::
+
+ pred_alpha should be masked by trimap before passing
+ into this metric
+
+ Default prefix: ''
+
+ Args:
+ norm_const (int): Divide the result to reduce its magnitude.
+ Default to 1000.
+
+ Metrics:
+ - MattingMSE (float): Mean of Squared Error
+ """
+
+ default_prefix = ''
+
+ def __init__(
+ self,
+ norm_const=1000,
+ **kwargs,
+ ) -> None:
+ self.norm_const = norm_const
+ super().__init__(**kwargs)
+
+ def process(self, data_batch: Sequence[dict],
+ data_samples: Sequence[dict]) -> None:
+ """Process one batch of data and predictions.
+
+ Args:
+ data_batch (Sequence[dict]): A batch of data
+ from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from
+ the model.
+ """
+ for data_sample in data_samples:
+ pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample)
+
+ weight_sum = (trimap == 128).sum()
+ if weight_sum != 0:
+ mse_result = ((pred_alpha - gt_alpha)**2).sum() / weight_sum
+ else:
+ mse_result = 0
+
+ self.results.append({'mse': mse_result})
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (dict): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+
+ mse = average(results, 'mse')
+
+ return {'MattingMSE': mse}
diff --git a/mmedit/evaluation/metrics/metrics_utils.py b/mmedit/evaluation/metrics/metrics_utils.py
index 7519f54111..4ed5602a64 100644
--- a/mmedit/evaluation/metrics/metrics_utils.py
+++ b/mmedit/evaluation/metrics/metrics_utils.py
@@ -6,6 +6,50 @@
from mmedit.utils import reorder_image
+def _assert_ndim(input, name, ndim, shape_hint):
+ if input.ndim != ndim:
+ raise ValueError(
+ f'{name} should be of shape {shape_hint}, but got {input.shape}.')
+
+
+def _assert_masked(pred_alpha, trimap):
+ if (pred_alpha[trimap == 0] != 0).any() or (pred_alpha[trimap == 255] !=
+ 255).any():
+ raise ValueError(
+ 'pred_alpha should be masked by trimap before evaluation')
+
+
+def _fetch_data_and_check(data_samples):
+ """Fetch and check data from one item of data_batch and predictions.
+
+ Args:
+ data_batch (dict): One item of data_batch.
+ predictions (dict): One item of predictions.
+
+ Returns:
+ pred_alpha (Tensor): Pred_alpha data of predictions.
+ ori_alpha (Tensor): Ori_alpha data of data_batch.
+ ori_trimap (Tensor): Ori_trimap data of data_batch.
+ """
+ ori_trimap = data_samples['ori_trimap'][:, :, 0]
+ ori_alpha = data_samples['ori_alpha'][:, :, 0]
+ pred_alpha = data_samples['output']['pred_alpha']['data'] # 2D tensor
+ pred_alpha = pred_alpha.cpu().numpy()
+
+ _assert_ndim(ori_trimap, 'trimap', 2, 'HxW')
+ _assert_ndim(ori_alpha, 'gt_alpha', 2, 'HxW')
+ _assert_ndim(pred_alpha, 'pred_alpha', 2, 'HxW')
+ _assert_masked(pred_alpha, ori_trimap)
+
+ # dtype uint8 -> float64
+ pred_alpha = pred_alpha / 255.0
+ ori_alpha = ori_alpha / 255.0
+ # test shows that using float32 vs float64 differs final results at 1e-4
+ # speed are comparable, so we choose float64 for accuracy
+
+ return pred_alpha, ori_alpha, ori_trimap
+
+
def average(results, key):
"""Average of key in results(list[dict]).
diff --git a/mmedit/evaluation/metrics/mse.py b/mmedit/evaluation/metrics/mse.py
new file mode 100644
index 0000000000..863fbb97cf
--- /dev/null
+++ b/mmedit/evaluation/metrics/mse.py
@@ -0,0 +1,57 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""Evaluation metrics based on pixels."""
+
+from mmedit.registry import METRICS
+from .base_sample_wise_metric import BaseSampleWiseMetric
+
+
+@METRICS.register_module()
+class MSE(BaseSampleWiseMetric):
+ """Mean Squared Error metric for image.
+
+ mean((a-b)^2)
+
+ Args:
+
+ gt_key (str): Key of ground-truth. Default: 'gt_img'
+ pred_key (str): Key of prediction. Default: 'pred_img'
+ mask_key (str, optional): Key of mask, if mask_key is None, calculate
+ all regions. Default: None
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Default: None
+
+ Metrics:
+ - MSE (float): Mean of Squared Error
+ """
+
+ metric = 'MSE'
+
+ def process_image(self, gt, pred, mask):
+ """Process an image.
+
+ Args:
+ gt (Torch | np.ndarray): GT image.
+ pred (Torch | np.ndarray): Pred image.
+ mask (Torch | np.ndarray): Mask of evaluation.
+ Returns:
+ result (np.ndarray): MSE result.
+ """
+
+ gt = gt / 255.
+ pred = pred / 255.
+
+ diff = gt - pred
+ diff *= diff
+
+ if self.mask_key is not None:
+ diff *= mask
+ result = diff.sum() / mask.sum()
+ else:
+ result = diff.mean()
+
+ return result
diff --git a/mmedit/evaluation/metrics/pixel_metrics.py b/mmedit/evaluation/metrics/pixel_metrics.py
deleted file mode 100644
index 19479bc4a6..0000000000
--- a/mmedit/evaluation/metrics/pixel_metrics.py
+++ /dev/null
@@ -1,360 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-"""Evaluation metrics based on pixels."""
-
-from typing import Optional
-
-import numpy as np
-
-from mmedit.registry import METRICS
-from .base_sample_wise_metric import BaseSampleWiseMetric
-from .metrics_utils import img_transform
-
-
-@METRICS.register_module()
-class MAE(BaseSampleWiseMetric):
- """Mean Absolute Error metric for image.
-
- mean(abs(a-b))
-
- Args:
-
- gt_key (str): Key of ground-truth. Default: 'gt_img'
- pred_key (str): Key of prediction. Default: 'pred_img'
- mask_key (str, optional): Key of mask, if mask_key is None, calculate
- all regions. Default: None
- collect_device (str): Device name used for collecting results from
- different ranks during distributed training. Must be 'cpu' or
- 'gpu'. Defaults to 'cpu'.
- prefix (str, optional): The prefix that will be added in the metric
- names to disambiguate homonymous metrics of different evaluators.
- If prefix is not provided in the argument, self.default_prefix
- will be used instead. Default: None
-
- Metrics:
- - MAE (float): Mean of Absolute Error
- """
-
- metric = 'MAE'
-
- def process_image(self, gt, pred, mask):
- """Process an image.
-
- Args:
- gt (Tensor | np.ndarray): GT image.
- pred (Tensor | np.ndarray): Pred image.
- mask (Tensor | np.ndarray): Mask of evaluation.
- Returns:
- result (np.ndarray): MAE result.
- """
-
- gt = gt / 255.
- pred = pred / 255.
-
- diff = gt - pred
- diff = abs(diff)
-
- if self.mask_key is not None:
- diff *= mask # broadcast for channel dimension
- scale = np.prod(diff.shape) / np.prod(mask.shape)
- result = diff.sum() / (mask.sum() * scale + 1e-12)
- else:
- result = diff.mean()
-
- return result
-
-
-@METRICS.register_module()
-class MSE(BaseSampleWiseMetric):
- """Mean Squared Error metric for image.
-
- mean((a-b)^2)
-
- Args:
-
- gt_key (str): Key of ground-truth. Default: 'gt_img'
- pred_key (str): Key of prediction. Default: 'pred_img'
- mask_key (str, optional): Key of mask, if mask_key is None, calculate
- all regions. Default: None
- collect_device (str): Device name used for collecting results from
- different ranks during distributed training. Must be 'cpu' or
- 'gpu'. Defaults to 'cpu'.
- prefix (str, optional): The prefix that will be added in the metric
- names to disambiguate homonymous metrics of different evaluators.
- If prefix is not provided in the argument, self.default_prefix
- will be used instead. Default: None
-
- Metrics:
- - MSE (float): Mean of Squared Error
- """
-
- metric = 'MSE'
-
- def process_image(self, gt, pred, mask):
- """Process an image.
-
- Args:
- gt (Torch | np.ndarray): GT image.
- pred (Torch | np.ndarray): Pred image.
- mask (Torch | np.ndarray): Mask of evaluation.
- Returns:
- result (np.ndarray): MSE result.
- """
-
- gt = gt / 255.
- pred = pred / 255.
-
- diff = gt - pred
- diff *= diff
-
- if self.mask_key is not None:
- diff *= mask
- result = diff.sum() / mask.sum()
- else:
- result = diff.mean()
-
- return result
-
-
-@METRICS.register_module()
-class PSNR(BaseSampleWiseMetric):
- """Peak Signal-to-Noise Ratio.
-
- Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
-
- Args:
-
- gt_key (str): Key of ground-truth. Default: 'gt_img'
- pred_key (str): Key of prediction. Default: 'pred_img'
- collect_device (str): Device name used for collecting results from
- different ranks during distributed training. Must be 'cpu' or
- 'gpu'. Defaults to 'cpu'.
- prefix (str, optional): The prefix that will be added in the metric
- names to disambiguate homonymous metrics of different evaluators.
- If prefix is not provided in the argument, self.default_prefix
- will be used instead. Default: None
- crop_border (int): Cropped pixels in each edges of an image. These
- pixels are not involved in the PSNR calculation. Default: 0.
- input_order (str): Whether the input order is 'HWC' or 'CHW'.
- Default: 'CHW'.
- convert_to (str): Whether to convert the images to other color models.
- If None, the images are not altered. When computing for 'Y',
- the images are assumed to be in BGR order. Options are 'Y' and
- None. Default: None.
-
- Metrics:
- - PSNR (float): Peak Signal-to-Noise Ratio
- """
-
- metric = 'PSNR'
-
- def __init__(self,
- gt_key: str = 'gt_img',
- pred_key: str = 'pred_img',
- collect_device: str = 'cpu',
- prefix: Optional[str] = None,
- crop_border=0,
- input_order='CHW',
- convert_to=None) -> None:
- super().__init__(
- gt_key=gt_key,
- pred_key=pred_key,
- mask_key=None,
- collect_device=collect_device,
- prefix=prefix)
-
- self.crop_border = crop_border
- self.input_order = input_order
- self.convert_to = convert_to
-
- def process_image(self, gt, pred, mask):
- """Process an image.
-
- Args:
- gt (Torch | np.ndarray): GT image.
- pred (Torch | np.ndarray): Pred image.
- mask (Torch | np.ndarray): Mask of evaluation.
- Returns:
- np.ndarray: PSNR result.
- """
-
- return psnr(
- img1=gt,
- img2=pred,
- crop_border=self.crop_border,
- input_order=self.input_order,
- convert_to=self.convert_to,
- channel_order=self.channel_order)
-
-
-@METRICS.register_module()
-class SNR(BaseSampleWiseMetric):
- """Signal-to-Noise Ratio.
-
- Ref: https://en.wikipedia.org/wiki/Signal-to-noise_ratio
-
- Args:
-
- gt_key (str): Key of ground-truth. Default: 'gt_img'
- pred_key (str): Key of prediction. Default: 'pred_img'
- collect_device (str): Device name used for collecting results from
- different ranks during distributed training. Must be 'cpu' or
- 'gpu'. Defaults to 'cpu'.
- prefix (str, optional): The prefix that will be added in the metric
- names to disambiguate homonymous metrics of different evaluators.
- If prefix is not provided in the argument, self.default_prefix
- will be used instead. Default: None
- crop_border (int): Cropped pixels in each edges of an image. These
- pixels are not involved in the SNR calculation. Default: 0.
- input_order (str): Whether the input order is 'HWC' or 'CHW'.
- Default: 'CHW'.
- convert_to (str): Whether to convert the images to other color models.
- If None, the images are not altered. When computing for 'Y',
- the images are assumed to be in BGR order. Options are 'Y' and
- None. Default: None.
-
- Metrics:
- - SNR (float): Signal-to-Noise Ratio
- """
-
- metric = 'SNR'
-
- def __init__(self,
- gt_key: str = 'gt_img',
- pred_key: str = 'pred_img',
- collect_device: str = 'cpu',
- prefix: Optional[str] = None,
- crop_border=0,
- input_order='CHW',
- convert_to=None) -> None:
- super().__init__(
- gt_key=gt_key,
- pred_key=pred_key,
- mask_key=None,
- collect_device=collect_device,
- prefix=prefix)
-
- self.crop_border = crop_border
- self.input_order = input_order
- self.convert_to = convert_to
-
- def process_image(self, gt, pred, mask):
- """Process an image.
-
- Args:
- gt (Torch | np.ndarray): GT image.
- pred (Torch | np.ndarray): Pred image.
- mask (Torch | np.ndarray): Mask of evaluation.
- Returns:
- np.ndarray: SNR result.
- """
-
- return snr(
- gt=gt,
- pred=pred,
- crop_border=self.crop_border,
- input_order=self.input_order,
- convert_to=self.convert_to,
- channel_order=self.channel_order)
-
-
-def psnr(img1,
- img2,
- crop_border=0,
- input_order='HWC',
- convert_to=None,
- channel_order='rgb'):
- """Calculate PSNR (Peak Signal-to-Noise Ratio).
-
- Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
-
- Args:
- img1 (ndarray): Images with range [0, 255].
- img2 (ndarray): Images with range [0, 255].
- crop_border (int): Cropped pixels in each edges of an image. These
- pixels are not involved in the PSNR calculation. Default: 0.
- input_order (str): Whether the input order is 'HWC' or 'CHW'.
- Default: 'HWC'.
- convert_to (str): Whether to convert the images to other color models.
- If None, the images are not altered. When computing for 'Y',
- the images are assumed to be in BGR order. Options are 'Y' and
- None. Default: None.
- channel_order (str): The channel order of image. Default: 'rgb'.
-
- Returns:
- result (float): PSNR result.
- """
-
- assert img1.shape == img2.shape, (
- f'Image shapes are different: {img1.shape}, {img2.shape}.')
-
- img1 = img_transform(
- img1,
- crop_border=crop_border,
- input_order=input_order,
- convert_to=convert_to,
- channel_order=channel_order)
- img2 = img_transform(
- img2,
- crop_border=crop_border,
- input_order=input_order,
- convert_to=convert_to,
- channel_order=channel_order)
-
- mse_value = ((img1 - img2)**2).mean()
- if mse_value == 0:
- result = float('inf')
- else:
- result = 20. * np.log10(255. / np.sqrt(mse_value))
-
- return result
-
-
-def snr(gt,
- pred,
- crop_border=0,
- input_order='HWC',
- convert_to=None,
- channel_order='rgb'):
- """Calculate PSNR (Peak Signal-to-Noise Ratio).
-
- Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
-
- Args:
- gt (ndarray): Images with range [0, 255].
- pred (ndarray): Images with range [0, 255].
- crop_border (int): Cropped pixels in each edges of an image. These
- pixels are not involved in the PSNR calculation. Default: 0.
- input_order (str): Whether the input order is 'HWC' or 'CHW'.
- Default: 'HWC'.
- convert_to (str): Whether to convert the images to other color models.
- If None, the images are not altered. When computing for 'Y',
- the images are assumed to be in BGR order. Options are 'Y' and
- None. Default: None.
- channel_order (str): The channel order of image. Default: 'rgb'.
-
- Returns:
- float: SNR result.
- """
-
- assert gt.shape == pred.shape, (
- f'Image shapes are different: {gt.shape}, {pred.shape}.')
-
- gt = img_transform(
- gt,
- crop_border=crop_border,
- input_order=input_order,
- convert_to=convert_to,
- channel_order=channel_order)
- pred = img_transform(
- pred,
- crop_border=crop_border,
- input_order=input_order,
- convert_to=convert_to,
- channel_order=channel_order)
-
- signal = ((gt)**2).mean()
- noise = ((gt - pred)**2).mean()
-
- result = 10. * np.log10(signal / noise)
-
- return result
diff --git a/mmedit/evaluation/metrics/ppl.py b/mmedit/evaluation/metrics/ppl.py
index ba6500de3a..7e22ef1ec3 100644
--- a/mmedit/evaluation/metrics/ppl.py
+++ b/mmedit/evaluation/metrics/ppl.py
@@ -240,6 +240,7 @@ def __next__(self):
if self.idx >= len(self.batch_sizes):
raise StopIteration
batch = self.batch_sizes[self.idx]
+ injected_noise = self.generator.make_injected_noise()
inputs = torch.randn([batch * 2, self.latent_dim],
device=self.device)
if self.sampling == 'full':
@@ -270,6 +271,7 @@ def __next__(self):
inputs=dict(
noise=latent_e,
sample_kwargs=dict(
+ injected_noise=injected_noise,
input_is_latent=(self.space == 'W'))))
ppl_sampler = PPLSampler(
diff --git a/mmedit/evaluation/metrics/psnr.py b/mmedit/evaluation/metrics/psnr.py
new file mode 100644
index 0000000000..9aec992b1a
--- /dev/null
+++ b/mmedit/evaluation/metrics/psnr.py
@@ -0,0 +1,131 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+import numpy as np
+
+from mmedit.registry import METRICS
+from .base_sample_wise_metric import BaseSampleWiseMetric
+from .metrics_utils import img_transform
+
+
+@METRICS.register_module()
+class PSNR(BaseSampleWiseMetric):
+ """Peak Signal-to-Noise Ratio.
+
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+ Args:
+
+ gt_key (str): Key of ground-truth. Default: 'gt_img'
+ pred_key (str): Key of prediction. Default: 'pred_img'
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Default: None
+ crop_border (int): Cropped pixels in each edges of an image. These
+ pixels are not involved in the PSNR calculation. Default: 0.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'CHW'.
+ convert_to (str): Whether to convert the images to other color models.
+ If None, the images are not altered. When computing for 'Y',
+ the images are assumed to be in BGR order. Options are 'Y' and
+ None. Default: None.
+
+ Metrics:
+ - PSNR (float): Peak Signal-to-Noise Ratio
+ """
+
+ metric = 'PSNR'
+
+ def __init__(self,
+ gt_key: str = 'gt_img',
+ pred_key: str = 'pred_img',
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None,
+ crop_border=0,
+ input_order='CHW',
+ convert_to=None) -> None:
+ super().__init__(
+ gt_key=gt_key,
+ pred_key=pred_key,
+ mask_key=None,
+ collect_device=collect_device,
+ prefix=prefix)
+
+ self.crop_border = crop_border
+ self.input_order = input_order
+ self.convert_to = convert_to
+
+ def process_image(self, gt, pred, mask):
+ """Process an image.
+
+ Args:
+ gt (Torch | np.ndarray): GT image.
+ pred (Torch | np.ndarray): Pred image.
+ mask (Torch | np.ndarray): Mask of evaluation.
+ Returns:
+ np.ndarray: PSNR result.
+ """
+
+ return psnr(
+ img1=gt,
+ img2=pred,
+ crop_border=self.crop_border,
+ input_order=self.input_order,
+ convert_to=self.convert_to,
+ channel_order=self.channel_order)
+
+
+def psnr(img1,
+ img2,
+ crop_border=0,
+ input_order='HWC',
+ convert_to=None,
+ channel_order='rgb'):
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
+
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edges of an image. These
+ pixels are not involved in the PSNR calculation. Default: 0.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ convert_to (str): Whether to convert the images to other color models.
+ If None, the images are not altered. When computing for 'Y',
+ the images are assumed to be in BGR order. Options are 'Y' and
+ None. Default: None.
+ channel_order (str): The channel order of image. Default: 'rgb'.
+
+ Returns:
+ result (float): PSNR result.
+ """
+
+ assert img1.shape == img2.shape, (
+ f'Image shapes are different: {img1.shape}, {img2.shape}.')
+
+ img1 = img_transform(
+ img1,
+ crop_border=crop_border,
+ input_order=input_order,
+ convert_to=convert_to,
+ channel_order=channel_order)
+ img2 = img_transform(
+ img2,
+ crop_border=crop_border,
+ input_order=input_order,
+ convert_to=convert_to,
+ channel_order=channel_order)
+
+ mse_value = ((img1 - img2)**2).mean()
+ if mse_value == 0:
+ result = float('inf')
+ else:
+ result = 20. * np.log10(255. / np.sqrt(mse_value))
+
+ return result
diff --git a/mmedit/evaluation/metrics/sad.py b/mmedit/evaluation/metrics/sad.py
new file mode 100644
index 0000000000..05abb6153f
--- /dev/null
+++ b/mmedit/evaluation/metrics/sad.py
@@ -0,0 +1,82 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Sequence
+
+import numpy as np
+from mmengine.evaluator import BaseMetric
+
+from mmedit.registry import METRICS
+from .metrics_utils import _fetch_data_and_check, average
+
+
+@METRICS.register_module()
+class SAD(BaseMetric):
+ """Sum of Absolute Differences metric for image matting.
+
+ This metric compute per-pixel absolute difference and sum across all
+ pixels.
+ i.e. sum(abs(a-b)) / norm_const
+
+ .. note::
+
+ Current implementation assume image / alpha / trimap array in numpy
+ format and with pixel value ranging from 0 to 255.
+
+ .. note::
+
+ pred_alpha should be masked by trimap before passing
+ into this metric
+
+ Default prefix: ''
+
+ Args:
+ norm_const (int): Divide the result to reduce its magnitude.
+ Default to 1000.
+
+ Metrics:
+ - SAD (float): Sum of Absolute Differences
+ """
+
+ default_prefix = ''
+
+ def __init__(
+ self,
+ norm_const=1000,
+ **kwargs,
+ ) -> None:
+ self.norm_const = norm_const
+ super().__init__(**kwargs)
+
+ def process(self, data_batch: Sequence[dict],
+ data_samples: Sequence[dict]) -> None:
+ """Process one batch of data and predictions.
+
+ Args:
+ data_batch (Sequence[Tuple[Any, dict]]): A batch of data
+ from the dataloader.
+ predictions (Sequence[dict]): A batch of outputs from
+ the model.
+ """
+ for data_sample in data_samples:
+ pred_alpha, gt_alpha, _ = _fetch_data_and_check(data_sample)
+
+ # divide by 1000 to reduce the magnitude of the result
+ sad_sum = np.abs(pred_alpha - gt_alpha).sum() / self.norm_const
+
+ result = {'sad': sad_sum}
+
+ self.results.append(result)
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (dict): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+
+ sad = average(results, 'sad')
+
+ return {'SAD': sad}
diff --git a/mmedit/evaluation/metrics/snr.py b/mmedit/evaluation/metrics/snr.py
new file mode 100644
index 0000000000..e94d35c4f6
--- /dev/null
+++ b/mmedit/evaluation/metrics/snr.py
@@ -0,0 +1,130 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+import numpy as np
+
+from mmedit.registry import METRICS
+from .base_sample_wise_metric import BaseSampleWiseMetric
+from .metrics_utils import img_transform
+
+
+@METRICS.register_module()
+class SNR(BaseSampleWiseMetric):
+ """Signal-to-Noise Ratio.
+
+ Ref: https://en.wikipedia.org/wiki/Signal-to-noise_ratio
+
+ Args:
+
+ gt_key (str): Key of ground-truth. Default: 'gt_img'
+ pred_key (str): Key of prediction. Default: 'pred_img'
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Default: None
+ crop_border (int): Cropped pixels in each edges of an image. These
+ pixels are not involved in the SNR calculation. Default: 0.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'CHW'.
+ convert_to (str): Whether to convert the images to other color models.
+ If None, the images are not altered. When computing for 'Y',
+ the images are assumed to be in BGR order. Options are 'Y' and
+ None. Default: None.
+
+ Metrics:
+ - SNR (float): Signal-to-Noise Ratio
+ """
+
+ metric = 'SNR'
+
+ def __init__(self,
+ gt_key: str = 'gt_img',
+ pred_key: str = 'pred_img',
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None,
+ crop_border=0,
+ input_order='CHW',
+ convert_to=None) -> None:
+ super().__init__(
+ gt_key=gt_key,
+ pred_key=pred_key,
+ mask_key=None,
+ collect_device=collect_device,
+ prefix=prefix)
+
+ self.crop_border = crop_border
+ self.input_order = input_order
+ self.convert_to = convert_to
+
+ def process_image(self, gt, pred, mask):
+ """Process an image.
+
+ Args:
+ gt (Torch | np.ndarray): GT image.
+ pred (Torch | np.ndarray): Pred image.
+ mask (Torch | np.ndarray): Mask of evaluation.
+ Returns:
+ np.ndarray: SNR result.
+ """
+
+ return snr(
+ gt=gt,
+ pred=pred,
+ crop_border=self.crop_border,
+ input_order=self.input_order,
+ convert_to=self.convert_to,
+ channel_order=self.channel_order)
+
+
+def snr(gt,
+ pred,
+ crop_border=0,
+ input_order='HWC',
+ convert_to=None,
+ channel_order='rgb'):
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
+
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+ Args:
+ gt (ndarray): Images with range [0, 255].
+ pred (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edges of an image. These
+ pixels are not involved in the PSNR calculation. Default: 0.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ convert_to (str): Whether to convert the images to other color models.
+ If None, the images are not altered. When computing for 'Y',
+ the images are assumed to be in BGR order. Options are 'Y' and
+ None. Default: None.
+ channel_order (str): The channel order of image. Default: 'rgb'.
+
+ Returns:
+ float: SNR result.
+ """
+
+ assert gt.shape == pred.shape, (
+ f'Image shapes are different: {gt.shape}, {pred.shape}.')
+
+ gt = img_transform(
+ gt,
+ crop_border=crop_border,
+ input_order=input_order,
+ convert_to=convert_to,
+ channel_order=channel_order)
+ pred = img_transform(
+ pred,
+ crop_border=crop_border,
+ input_order=input_order,
+ convert_to=convert_to,
+ channel_order=channel_order)
+
+ signal = ((gt)**2).mean()
+ noise = ((gt - pred)**2).mean()
+
+ result = 10. * np.log10(signal / noise)
+
+ return result
diff --git a/mmedit/models/base_models/__init__.py b/mmedit/models/base_models/__init__.py
index ff1107fb67..0ec81d6d5a 100644
--- a/mmedit/models/base_models/__init__.py
+++ b/mmedit/models/base_models/__init__.py
@@ -10,7 +10,14 @@
from .two_stage import TwoStageInpaintor
__all__ = [
- 'BaseEditModel', 'BaseGAN', 'BaseConditionalGAN', 'BaseMattor',
- 'BasicInterpolator', 'BaseTranslationModel', 'OneStageInpaintor',
- 'TwoStageInpaintor', 'ExponentialMovingAverage', 'RampUpEMA'
+ 'BaseEditModel',
+ 'BaseGAN',
+ 'BaseConditionalGAN',
+ 'BaseMattor',
+ 'BasicInterpolator',
+ 'BaseTranslationModel',
+ 'OneStageInpaintor',
+ 'TwoStageInpaintor',
+ 'ExponentialMovingAverage',
+ 'RampUpEMA',
]
diff --git a/mmedit/models/base_models/base_translation_model.py b/mmedit/models/base_models/base_translation_model.py
index 870cf3a418..7450816477 100644
--- a/mmedit/models/base_models/base_translation_model.py
+++ b/mmedit/models/base_models/base_translation_model.py
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta
from copy import deepcopy
-from typing import List
+from typing import List, Optional
import torch.nn as nn
from mmengine.model import BaseModel, is_model_wrapper
@@ -52,7 +52,8 @@ def __init__(self,
data_preprocessor,
discriminator_steps: int = 1,
disc_init_steps: int = 0,
- real_img_key: str = 'real_img'):
+ real_img_key: str = 'real_img',
+ loss_config: Optional[dict] = None):
super().__init__(data_preprocessor)
self._default_domain = default_domain
self._reachable_domains = reachable_domains
@@ -80,6 +81,7 @@ def __init__(self,
else:
self.discriminators = None
+ self.loss_config = dict() if loss_config is None else loss_config
self.init_weights()
def init_weights(self, pretrained=None):
diff --git a/mmedit/models/data_preprocessors/gen_preprocessor.py b/mmedit/models/data_preprocessors/gen_preprocessor.py
index 25c5ac8d94..ff8206a70d 100644
--- a/mmedit/models/data_preprocessors/gen_preprocessor.py
+++ b/mmedit/models/data_preprocessors/gen_preprocessor.py
@@ -89,7 +89,7 @@ def cast_data(self, data: CastData) -> CastData:
Returns:
CollatedResult: Inputs and data sample at target device.
"""
- if isinstance(data, str):
+ if isinstance(data, (str, int, float)):
return data
return super().cast_data(data)
diff --git a/mmedit/models/editors/__init__.py b/mmedit/models/editors/__init__.py
index a5dd2252b2..2ecb5668e6 100644
--- a/mmedit/models/editors/__init__.py
+++ b/mmedit/models/editors/__init__.py
@@ -28,6 +28,7 @@
from .indexnet import (DepthwiseIndexBlock, HolisticIndexBlock,
IndexedUpsample, IndexNet, IndexNetDecoder,
IndexNetEncoder)
+from .inst_colorization import InstColorization
from .liif import LIIF, MLPRefiner
from .lsgan import LSGAN
from .mspie import MSPIEStyleGAN2, PESinGAN
@@ -45,7 +46,7 @@
from .srgan import SRGAN, ModifiedVGG, MSRResNet
from .stylegan1 import StyleGAN1
from .stylegan2 import StyleGAN2
-from .stylegan3 import StyleGAN3
+from .stylegan3 import StyleGAN3, StyleGAN3Generator
from .tdan import TDAN, TDANNet
from .tof import TOFlowVFINet, TOFlowVSRNet, ToFResBlock
from .ttsr import LTE, TTSR, SearchTransformer, TTSRDiscriminator, TTSRNet
@@ -73,5 +74,5 @@
'FBADecoder', 'WGANGP', 'CycleGAN', 'SAGAN', 'LSGAN', 'GGAN', 'Pix2Pix',
'StyleGAN1', 'StyleGAN2', 'StyleGAN3', 'BigGAN', 'DCGAN',
'ProgressiveGrowingGAN', 'SinGAN', 'IDLossModel', 'PESinGAN',
- 'MSPIEStyleGAN2'
+ 'MSPIEStyleGAN2', 'StyleGAN3Generator', 'InstColorization'
]
diff --git a/mmedit/models/editors/basicvsr_plusplus_net/basicvsr_plusplus_net.py b/mmedit/models/editors/basicvsr_plusplus_net/basicvsr_plusplus_net.py
index 45742b31e5..1bf41e5207 100644
--- a/mmedit/models/editors/basicvsr_plusplus_net/basicvsr_plusplus_net.py
+++ b/mmedit/models/editors/basicvsr_plusplus_net/basicvsr_plusplus_net.py
@@ -98,9 +98,6 @@ def __init__(self,
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
- # check if the sequence is augmented by flipping
- self.is_mirror_extended = False
-
def check_if_mirror_extended(self, lqs):
"""Check whether the input is a mirror-extended sequence.
@@ -112,6 +109,8 @@ def check_if_mirror_extended(self, lqs):
shape (n, t, c, h, w).
"""
+ # check if the sequence is augmented by flipping
+ self.is_mirror_extended = False
if lqs.size(1) % 2 == 0:
lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1)
if torch.norm(lqs_1 - lqs_2.flip(1)) == 0:
diff --git a/mmedit/models/editors/inst_colorization/__init__.py b/mmedit/models/editors/inst_colorization/__init__.py
new file mode 100644
index 0000000000..434ebe14d0
--- /dev/null
+++ b/mmedit/models/editors/inst_colorization/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .colorization_net import ColorizationNet
+from .fusion_net import FusionNet
+from .inst_colorization import InstColorization
+
+__all__ = [
+ 'InstColorization',
+ 'ColorizationNet',
+ 'FusionNet',
+]
diff --git a/mmedit/models/editors/inst_colorization/color_utils.py b/mmedit/models/editors/inst_colorization/color_utils.py
new file mode 100644
index 0000000000..6ecc57b72f
--- /dev/null
+++ b/mmedit/models/editors/inst_colorization/color_utils.py
@@ -0,0 +1,307 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+
+def xyz2rgb(xyz):
+ """Conversion images from xyz to rgb.
+
+ Args:
+ xyz (tensor): The images to be conversion
+
+ Returns:
+ out (tensor): The converted image
+ """
+ r = 3.24048134 * xyz[:, 0, :, :] - 1.53715152 * xyz[:, 1, :, :] \
+ - 0.49853633 * xyz[:, 2, :, :]
+ g = -0.96925495 * xyz[:, 0, :, :] + 1.87599 * xyz[:, 1, :, :] \
+ + .04155593 * xyz[:, 2, :, :]
+ b = .05564664 * xyz[:, 0, :, :] - .20404134 * xyz[:, 1, :, :] \
+ + 1.05731107 * xyz[:, 2, :, :]
+
+ # sometimes reaches a small negative number, which causes NaNs
+ rgb = torch.cat((r[:, None, :, :], g[:, None, :, :], b[:, None, :, :]),
+ dim=1)
+ rgb = torch.max(rgb, torch.zeros_like(rgb))
+
+ mask = (rgb > .0031308).type(torch.FloatTensor)
+ if rgb.is_cuda:
+ mask = mask.cuda()
+
+ rgb = (1.055 * (rgb**(1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask)
+ return rgb
+
+
+def lab2xyz(lab):
+ """Conversion images from lab to xyz.
+
+ Args:
+ lab (tensor): The images to be conversion
+
+ Returns:
+ out (tensor): The converted image
+ """
+ y_int = (lab[:, 0, :, :] + 16.) / 116.
+ x_int = (lab[:, 1, :, :] / 500.) + y_int
+ z_int = y_int - (lab[:, 2, :, :] / 200.)
+ if (z_int.is_cuda):
+ z_int = torch.max(torch.Tensor((0, )).cuda(), z_int)
+ else:
+ z_int = torch.max(torch.Tensor((0, )), z_int)
+
+ out = torch.cat(
+ (x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]),
+ dim=1)
+ mask = (out > .2068966).type(torch.FloatTensor)
+ if (out.is_cuda):
+ mask = mask.cuda()
+
+ out = (out**3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask)
+
+ sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None]
+ sc = sc.to(out.device)
+
+ out = out * sc
+ return out
+
+
+def lab2rgb(lab_rs, color_data_opt):
+ """Conversion images from lab to rgb.
+
+ Args:
+ lab_rs (tensor): The images to be conversion
+ color_data_opt (dict): Config for image colorspace transformation.
+ Include: l_norm, ab_norm, l_cent
+
+ Returns:
+ out (tensor): The converted image
+ """
+ L = lab_rs[:,
+ [0], :, :] * color_data_opt['l_norm'] + color_data_opt['l_cent']
+ AB = lab_rs[:, 1:, :, :] * color_data_opt['ab_norm']
+ lab = torch.cat((L, AB), dim=1)
+ out = xyz2rgb(lab2xyz(lab))
+ return out
+
+
+def encode_ab_ind(data_ab, color_data_opt):
+ """Encode ab value into an index.
+
+ Args:
+ data_ab: Nx2xHxW from [-1,1]
+ color_data_opt: Config for image colorspace transformation.
+ ab_max, ab_quant, ab_norm, ab_quant
+ Returns:
+ Nx1xHxW from [0,Q)
+ """
+ A = 2 * color_data_opt['ab_max'] / color_data_opt['ab_quant'] + 1
+ data_ab_rs = torch.round(
+ (data_ab * color_data_opt['ab_norm'] + color_data_opt['ab_max']) /
+ color_data_opt['ab_quant']) # normalized bin number
+ data_q = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :]
+ return data_q
+
+
+def rgb2xyz(rgb):
+ """Conversion images from rgb to xyz
+ rgb from [0,1]
+ xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423],
+ [0.212671, 0.715160, 0.072169],
+ [0.019334, 0.119193, 0.950227]])
+ Args:
+ rgb (Tensor): image in rgb colorspace
+
+ Returns:
+ xyz (Tensor): image in xyz colorspace
+
+ """
+ mask = (rgb > .04045).type(torch.FloatTensor)
+ if (rgb.is_cuda):
+ mask = mask.cuda()
+
+ rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask)
+
+ x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] \
+ + .180423 * rgb[:, 2, :, :]
+ y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] \
+ + .072169 * rgb[:, 2, :, :]
+ z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] \
+ + .950227 * rgb[:, 2, :, :]
+ out = torch.cat((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]),
+ dim=1)
+
+ return out
+
+
+def xyz2lab(xyz):
+ """Conversion images from xyz to lab
+ xyz from [0,1]
+ factors: 0.95047, 1., 1.08883
+
+ Args:
+ xyz (Tensor): image in xyz colorspace
+
+ Returns:
+ out (Tensor): Image in lab colorspace
+ """
+ sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None]
+ if (xyz.is_cuda):
+ sc = sc.cuda()
+
+ xyz_scale = xyz / sc
+
+ mask = (xyz_scale > .008856).type(torch.FloatTensor)
+ if (xyz_scale.is_cuda):
+ mask = mask.cuda()
+
+ xyz_int = xyz_scale**(1 / 3.) * mask + (7.787 * xyz_scale +
+ 16. / 116.) * (1 - mask)
+
+ L = 116. * xyz_int[:, 1, :, :] - 16.
+ a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :])
+ b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :])
+ out = torch.cat((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]),
+ dim=1)
+
+ return out
+
+
+def rgb2lab(rgb, color_opt):
+ """Conversion images from rgb to lab.
+
+ Args:
+ data_raw (tensor): The images to be conversion
+ color_opt (dict): Config for image colorspace transformation.
+ Include: ab_thresh, ab_norm, sample_PS, mask_cent
+
+ Returns:
+ out (tensor): The converted image
+ """
+ lab = xyz2lab(rgb2xyz(rgb))
+ l_rs = (lab[:, [0], :, :] - color_opt['l_cent']) / color_opt['l_norm']
+ ab_rs = lab[:, 1:, :, :] / color_opt['ab_norm']
+ out = torch.cat((l_rs, ab_rs), dim=1)
+ return out
+
+
+def get_colorization_data(data_raw, color_opt, num_points=None):
+ """Conversion images from rgb to lab.
+
+ Args:
+ data_raw (tensor): The images to be conversion
+ color_opt (dict): Config for image colorspace transformation.
+ Include: ab_thresh, ab_norm, sample_PS, mask_cent
+
+ Returns:
+ results (dict): Output in add_color_patches_rand_gt
+ """
+ data = {}
+ data_lab = rgb2lab(data_raw, color_opt)
+ data['A'] = data_lab[:, [
+ 0,
+ ], :, :]
+ data['B'] = data_lab[:, 1:, :, :]
+
+ # mask out grayscale images
+ if color_opt['ab_thresh'] > 0:
+ thresh = 1. * color_opt['ab_thresh'] / color_opt['ab_norm']
+ mask = torch.sum(
+ torch.abs(
+ torch.max(torch.max(data['B'], dim=3)[0], dim=2)[0] -
+ torch.min(torch.min(data['B'], dim=3)[0], dim=2)[0]),
+ dim=1) >= thresh
+ data['A'] = data['A'][mask, :, :, :]
+ data['B'] = data['B'][mask, :, :, :]
+ if torch.sum(mask) == 0:
+ return None
+
+ return add_color_patches_rand_gt(
+ data, color_opt, p=color_opt['p'], num_points=num_points)
+
+
+def add_color_patches_rand_gt(data,
+ color_opt,
+ p=.125,
+ num_points=None,
+ use_avg=True,
+ samp='normal'):
+ """Add random color points sampled from ground truth based on: Number of
+ points.
+
+ - if num_points is 0, then sample from geometric distribution,
+ drawn from probability p
+ - if num_points > 0, then sample that number of points
+ Location of points
+ - if samp is 'normal', draw from N(0.5, 0.25) of image
+ - otherwise, draw from U[0, 1] of image
+
+ Args:
+ data (tensor): The images to be conversion
+ color_opt (dict): Config for image colorspace transformation
+ Include: ab_thresh, ab_norm, sample_PS, mask_cent
+ p (float): Sampling geometric distribution, 1.0 means no hints
+ num_points (int): Certain number of points
+ use_avg (bool): Whether to use the mean when add color point
+ Default: True.
+ samp (str): Geometric distribution or uniform distribution when
+ sample location. Default: normal.
+
+ Returns:
+ results (dict): Result dict from :obj:``mmcv.BaseDataset``.
+ """
+ N, C, H, W = data['B'].shape
+
+ data['hint_B'] = torch.zeros_like(data['B'])
+ data['mask_B'] = torch.zeros_like(data['A'])
+
+ for nn in range(N):
+ pp = 0
+ cont_cond = True
+ while cont_cond:
+ # draw from geometric
+ if num_points is None:
+ cont_cond = np.random.rand() < (1 - p)
+ else:
+ # add certain number of points
+ cont_cond = pp < num_points
+ # skip out of loop if condition not met
+ if not cont_cond:
+ continue
+
+ # patch size
+ P = np.random.choice(color_opt['sample_PS'])
+ # sample location: geometric distribution
+ if samp == 'normal':
+ h = int(
+ np.clip(
+ np.random.normal((H - P + 1) / 2., (H - P + 1) / 4.),
+ 0, H - P))
+ w = int(
+ np.clip(
+ np.random.normal((W - P + 1) / 2., (W - P + 1) / 4.),
+ 0, W - P))
+ else: # uniform distribution
+ h = np.random.randint(H - P + 1)
+ w = np.random.randint(W - P + 1)
+
+ # add color point
+ if use_avg:
+ data['hint_B'][nn, :, h:h + P, w:w + P] = torch.mean(
+ torch.mean(
+ data['B'][nn, :, h:h + P, w:w + P],
+ dim=2,
+ keepdim=True),
+ dim=1,
+ keepdim=True).view(1, C, 1, 1)
+ else:
+ data['hint_B'][nn, :, h:h + P, w:w + P] = \
+ data['B'][nn, :, h:h + P, w:w + P]
+
+ data['mask_B'][nn, :, h:h + P, w:w + P] = 1
+
+ # increment counter
+ pp += 1
+
+ data['mask_B'] -= color_opt['mask_cent']
+
+ return data
diff --git a/mmedit/models/editors/inst_colorization/colorization_net.py b/mmedit/models/editors/inst_colorization/colorization_net.py
new file mode 100644
index 0000000000..6d62209e07
--- /dev/null
+++ b/mmedit/models/editors/inst_colorization/colorization_net.py
@@ -0,0 +1,313 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import torch
+import torch.nn as nn
+from mmengine.model import BaseModule
+
+from mmedit.registry import MODULES
+from .weight_layer import get_norm_layer
+
+
+@MODULES.register_module()
+class ColorizationNet(BaseModule):
+ """Real-Time User-Guided Image Colorization with Learned Deep Priors. The
+ backbone used for.
+
+ https://arxiv.org/abs/1705.02999
+
+ Codes adapted from 'https://github.com/ericsujw/InstColorization.git'
+ 'InstColorization/blob/master/models/networks.py#L108'
+
+ Args:
+ input_nc (int): input image channels
+ output_nc (int): output image channels
+ norm_type (str): instance normalization or batch normalization
+ use_tanh (bool): Whether to use nn.Tanh() Default: True.
+ classification (bool): backprop trunk using classification,
+ otherwise use regression. Default: True
+ """
+
+ def __init__(self,
+ input_nc,
+ output_nc,
+ norm_type,
+ use_tanh=True,
+ classification=True):
+ super().__init__()
+ self.input_nc = input_nc
+ self.output_nc = output_nc
+ self.classification = classification
+
+ norm_layer = get_norm_layer(norm_type)
+
+ use_bias = True
+
+ # Conv1
+ self.model1 = nn.Sequential(
+ nn.Conv2d(
+ input_nc,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(64),
+ )
+
+ # Conv2
+ self.model2 = nn.Sequential(
+ nn.Conv2d(
+ 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(128),
+ )
+
+ # Conv3
+ self.model3 = nn.Sequential(
+ nn.Conv2d(
+ 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(256),
+ )
+
+ # Conv4
+ self.model4 = nn.Sequential(
+ nn.Conv2d(
+ 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(512),
+ )
+
+ # Conv5
+ self.model5 = nn.Sequential(
+ nn.Conv2d(
+ 512,
+ 512,
+ kernel_size=3,
+ dilation=2,
+ stride=1,
+ padding=2,
+ bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512,
+ 512,
+ kernel_size=3,
+ dilation=2,
+ stride=1,
+ padding=2,
+ bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512,
+ 512,
+ kernel_size=3,
+ dilation=2,
+ stride=1,
+ padding=2,
+ bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(512),
+ )
+
+ # Conv6
+ self.model6 = nn.Sequential(
+ nn.Conv2d(
+ 512,
+ 512,
+ kernel_size=3,
+ dilation=2,
+ stride=1,
+ padding=2,
+ bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512,
+ 512,
+ kernel_size=3,
+ dilation=2,
+ stride=1,
+ padding=2,
+ bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512,
+ 512,
+ kernel_size=3,
+ dilation=2,
+ stride=1,
+ padding=2,
+ bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(512),
+ )
+
+ # Conv7
+ self.model7 = nn.Sequential(
+ nn.Conv2d(
+ 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(512),
+ )
+
+ # Conv8
+ self.model8up = nn.ConvTranspose2d(
+ 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias)
+
+ self.model3short8 = nn.Conv2d(
+ 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias)
+
+ self.model8 = nn.Sequential(
+ nn.ReLU(True),
+ nn.Conv2d(
+ 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(256),
+ )
+
+ # Conv9
+ self.model9up = nn.ConvTranspose2d(
+ 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias)
+
+ self.model2short9 = nn.Conv2d(
+ 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias)
+ self.model9 = nn.Sequential(
+ nn.ReLU(True),
+ nn.Conv2d(
+ 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(128),
+ )
+
+ # Conv10
+ self.model10up = nn.ConvTranspose2d(
+ 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias)
+
+ self.model1short10 = nn.Conv2d(
+ 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias)
+
+ self.model10 = nn.Sequential(
+ nn.ReLU(True),
+ nn.Conv2d(
+ 128,
+ 128,
+ kernel_size=3,
+ dilation=1,
+ stride=1,
+ padding=1,
+ bias=use_bias),
+ nn.LeakyReLU(negative_slope=.2),
+ )
+
+ # classification output
+ self.model_class = nn.Conv2d(
+ 256,
+ 529,
+ kernel_size=1,
+ padding=0,
+ dilation=1,
+ stride=1,
+ bias=use_bias)
+
+ # regression output
+ model_out = [
+ nn.Conv2d(
+ 128,
+ 2,
+ kernel_size=1,
+ padding=0,
+ dilation=1,
+ stride=1,
+ bias=use_bias),
+ ]
+ if (use_tanh):
+ model_out += [nn.Tanh()]
+ self.model_out = nn.Sequential(*model_out)
+
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='nearest')
+ self.softmax = nn.Softmax(dim=1)
+
+ def forward(self, input_A, input_B, mask_B):
+ """Forward function.
+
+ Args:
+ input_A (tensor): Channel of the image in lab color space
+ input_B (tensor): Color patch
+ mask_B (tensor): Color patch mask
+
+ Returns:
+ out_class (tensor): Classification output
+ out_reg (tensor): Regression output
+ feature_map (dict): The full-image feature
+ """
+ conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1))
+ conv2_2 = self.model2(conv1_2[:, :, ::2, ::2])
+ conv3_3 = self.model3(conv2_2[:, :, ::2, ::2])
+ conv4_3 = self.model4(conv3_3[:, :, ::2, ::2])
+ conv5_3 = self.model5(conv4_3)
+ conv6_3 = self.model6(conv5_3)
+ conv7_3 = self.model7(conv6_3)
+ conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
+ conv8_3 = self.model8(conv8_up)
+
+ if (self.classification):
+ out_class = self.model_class(conv8_3)
+ conv9_up = self.model9up(conv8_3.detach()) + self.model2short9(
+ conv2_2.detach())
+ conv9_3 = self.model9(conv9_up)
+ conv10_up = self.model10up(conv9_3) + self.model1short10(
+ conv1_2.detach())
+ else:
+ out_class = self.model_class(conv8_3.detach())
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
+ conv9_3 = self.model9(conv9_up)
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
+
+ conv10_2 = self.model10(conv10_up)
+ out_reg = self.model_out(conv10_2)
+
+ feature_map = {}
+ feature_map['conv1_2'] = conv1_2
+ feature_map['conv2_2'] = conv2_2
+ feature_map['conv3_3'] = conv3_3
+ feature_map['conv4_3'] = conv4_3
+ feature_map['conv5_3'] = conv5_3
+ feature_map['conv6_3'] = conv6_3
+ feature_map['conv7_3'] = conv7_3
+ feature_map['conv8_up'] = conv8_up
+ feature_map['conv8_3'] = conv8_3
+ feature_map['conv9_up'] = conv9_up
+ feature_map['conv9_3'] = conv9_3
+ feature_map['conv10_up'] = conv10_up
+ feature_map['conv10_2'] = conv10_2
+ feature_map['out_reg'] = out_reg
+
+ return (out_class, out_reg, feature_map)
diff --git a/mmedit/models/editors/inst_colorization/fusion_net.py b/mmedit/models/editors/inst_colorization/fusion_net.py
new file mode 100644
index 0000000000..10c5732680
--- /dev/null
+++ b/mmedit/models/editors/inst_colorization/fusion_net.py
@@ -0,0 +1,353 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmengine.model import BaseModule
+
+from mmedit.registry import MODULES
+from .weight_layer import WeightLayer, get_norm_layer
+
+
+@MODULES.register_module()
+class FusionNet(BaseModule):
+ """Instance-aware Image Colorization.
+
+ https://arxiv.org/abs/2005.10825
+
+ Codes adapted from 'https://github.com/ericsujw/InstColorization.git'
+ 'InstColorization/blob/master/models/networks.py#L314'
+ FusionNet: the full image model with weight layer for fusion.
+
+ Args:
+ input_nc (int): input image channels
+ output_nc (int): output image channels
+ norm_type (str): instance normalization or batch normalization
+ use_tanh (bool): Whether to use nn.Tanh() Default: True.
+ classification (bool): backprop trunk using classification,
+ otherwise use regression. Default: True
+ """
+
+ def __init__(self,
+ input_nc,
+ output_nc,
+ norm_type,
+ use_tanh=True,
+ classification=True):
+ super().__init__()
+ self.input_nc = input_nc
+ self.output_nc = output_nc
+ self.classification = classification
+
+ norm_layer = get_norm_layer(norm_type)
+ use_bias = True
+
+ # Conv1
+ self.model1 = nn.Sequential(
+ nn.Conv2d(
+ input_nc,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(64),
+ )
+
+ self.weight_layer = WeightLayer(64)
+
+ # Conv2
+ self.model2 = nn.Sequential(
+ nn.Conv2d(
+ 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(128),
+ )
+
+ self.weight_layer2 = WeightLayer(128)
+
+ # Conv3
+ self.model3 = nn.Sequential(
+ nn.Conv2d(
+ 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(256),
+ )
+
+ self.weight_layer3 = WeightLayer(256)
+
+ # Conv4
+ self.model4 = nn.Sequential(
+ nn.Conv2d(
+ 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(512),
+ )
+
+ self.weight_layer4 = WeightLayer(512)
+
+ # Conv5
+ self.model5 = nn.Sequential(
+ nn.Conv2d(
+ 512,
+ 512,
+ kernel_size=3,
+ dilation=2,
+ stride=1,
+ padding=2,
+ bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512,
+ 512,
+ kernel_size=3,
+ dilation=2,
+ stride=1,
+ padding=2,
+ bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512,
+ 512,
+ kernel_size=3,
+ dilation=2,
+ stride=1,
+ padding=2,
+ bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(512),
+ )
+
+ self.weight_layer5 = WeightLayer(512)
+
+ # Conv6
+ self.model6 = nn.Sequential(
+ nn.Conv2d(
+ 512,
+ 512,
+ kernel_size=3,
+ dilation=2,
+ stride=1,
+ padding=2,
+ bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512,
+ 512,
+ kernel_size=3,
+ dilation=2,
+ stride=1,
+ padding=2,
+ bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512,
+ 512,
+ kernel_size=3,
+ dilation=2,
+ stride=1,
+ padding=2,
+ bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(512),
+ )
+
+ self.weight_layer6 = WeightLayer(512)
+
+ # Conv7
+ self.model7 = nn.Sequential(
+ nn.Conv2d(
+ 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(512),
+ )
+
+ self.weight_layer7 = WeightLayer(512)
+
+ # Conv8
+ self.model8up = nn.ConvTranspose2d(
+ 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias)
+
+ self.model3short8 = nn.Conv2d(
+ 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias)
+
+ self.weight_layer8_1 = WeightLayer(256)
+
+ self.model8 = nn.Sequential(
+ nn.ReLU(True),
+ nn.Conv2d(
+ 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ nn.Conv2d(
+ 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(256),
+ )
+
+ self.weight_layer8_2 = WeightLayer(256)
+
+ # Conv9
+ self.model9up = nn.ConvTranspose2d(
+ 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias)
+
+ self.model2short9 = nn.Conv2d(
+ 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias)
+
+ self.weight_layer9_1 = WeightLayer(128)
+
+ self.model9 = nn.Sequential(
+ nn.ReLU(True),
+ nn.Conv2d(
+ 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ nn.ReLU(True),
+ norm_layer(128),
+ )
+
+ self.weight_layer9_2 = WeightLayer(128)
+
+ # Conv10
+ self.model10up = nn.ConvTranspose2d(
+ 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias)
+
+ self.model1short10 = nn.Conv2d(
+ 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias)
+
+ self.weight_layer10_1 = WeightLayer(128)
+
+ self.model10 = nn.Sequential(
+ nn.ReLU(True),
+ nn.Conv2d(
+ 128,
+ 128,
+ kernel_size=3,
+ dilation=1,
+ stride=1,
+ padding=1,
+ bias=use_bias),
+ nn.LeakyReLU(negative_slope=.2),
+ )
+
+ self.weight_layer10_2 = WeightLayer(128)
+
+ # classification output
+ self.model_class = nn.Conv2d(
+ 256,
+ 529,
+ kernel_size=1,
+ padding=0,
+ dilation=1,
+ stride=1,
+ bias=use_bias)
+
+ # regression output
+ model_out = [
+ nn.Conv2d(
+ 128,
+ 2,
+ kernel_size=1,
+ padding=0,
+ dilation=1,
+ stride=1,
+ bias=use_bias),
+ ]
+ if (use_tanh):
+ model_out += [nn.Tanh()]
+ self.model_out = nn.Sequential(*model_out)
+
+ self.weight_layerout = WeightLayer(2)
+
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='nearest')
+ self.softmax = nn.Softmax(dim=1)
+
+ def forward(self, input_A, input_B, mask_B, instance_feature,
+ box_info_list):
+ """Forward function.
+
+ Args:
+ input_A (tensor): Channel of the image in lab color space
+ input_B (tensor): Color patch
+ mask_B (tensor): Color patch mask
+ instance_feature (dict): A bunch of instance features
+ box_info_list (list): Bounding box information corresponding
+ to the instance
+
+ Returns:
+ out_reg (tensor): Regression output
+ """
+ conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1))
+ conv1_2 = self.weight_layer(instance_feature['conv1_2'], conv1_2,
+ box_info_list[0])
+
+ conv2_2 = self.model2(conv1_2[:, :, ::2, ::2])
+ conv2_2 = self.weight_layer2(instance_feature['conv2_2'], conv2_2,
+ box_info_list[1])
+
+ conv3_3 = self.model3(conv2_2[:, :, ::2, ::2])
+ conv3_3 = self.weight_layer3(instance_feature['conv3_3'], conv3_3,
+ box_info_list[2])
+
+ conv4_3 = self.model4(conv3_3[:, :, ::2, ::2])
+ conv4_3 = self.weight_layer4(instance_feature['conv4_3'], conv4_3,
+ box_info_list[3])
+
+ conv5_3 = self.model5(conv4_3)
+ conv5_3 = self.weight_layer5(instance_feature['conv5_3'], conv5_3,
+ box_info_list[3])
+
+ conv6_3 = self.model6(conv5_3)
+ conv6_3 = self.weight_layer6(instance_feature['conv6_3'], conv6_3,
+ box_info_list[3])
+
+ conv7_3 = self.model7(conv6_3)
+ conv7_3 = self.weight_layer7(instance_feature['conv7_3'], conv7_3,
+ box_info_list[3])
+
+ conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
+ conv8_up = self.weight_layer8_1(instance_feature['conv8_up'], conv8_up,
+ box_info_list[2])
+
+ conv8_3 = self.model8(conv8_up)
+ conv8_3 = self.weight_layer8_2(instance_feature['conv8_3'], conv8_3,
+ box_info_list[2])
+
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
+ conv9_up = self.weight_layer9_1(instance_feature['conv9_up'], conv9_up,
+ box_info_list[1])
+
+ conv9_3 = self.model9(conv9_up)
+ conv9_3 = self.weight_layer9_2(instance_feature['conv9_3'], conv9_3,
+ box_info_list[1])
+
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
+ conv10_up = self.weight_layer10_1(instance_feature['conv10_up'],
+ conv10_up, box_info_list[0])
+
+ conv10_2 = self.model10(conv10_up)
+ conv10_2 = self.weight_layer10_2(instance_feature['conv10_2'],
+ conv10_2, box_info_list[0])
+
+ out_reg = self.model_out(conv10_2)
+ return out_reg
diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py
new file mode 100644
index 0000000000..4c63aac225
--- /dev/null
+++ b/mmedit/models/editors/inst_colorization/inst_colorization.py
@@ -0,0 +1,238 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List, Optional, Union
+
+import torch
+from mmengine.config import Config
+from mmengine.model import BaseModel
+from mmengine.optim import OptimWrapperDict
+
+from mmedit.registry import MODULES
+from mmedit.structures import EditDataSample, PixelData
+from .color_utils import get_colorization_data, lab2rgb
+
+
+@MODULES.register_module()
+class InstColorization(BaseModel):
+ """Colorization InstColorization method.
+
+ This Colorization is implemented according to the paper:
+ Instance-aware Image Colorization, CVPR 2020
+
+ Adapted from 'https://github.com/ericsujw/InstColorization.git'
+ 'InstColorization/models/train_model'
+ Copyright (c) 2020, Su, under MIT License.
+
+ Args:
+ data_preprocessor (dict, optional): The pre-process config of
+ :class:`BaseDataPreprocessor`.
+ image_model (dict): Config for single image model
+ instance_model (dict): Config for instance model
+ fusion_model (dict): Config for fusion model
+ color_data_opt (dict): Option for colorspace conversion
+ which_direction (str): AtoB or BtoA
+ loss (dict): Config for loss.
+ init_cfg (str): Initialization config dict. Default: None.
+ train_cfg (dict): Config for training. Default: None.
+ test_cfg (dict): Config for testing. Default: None.
+ """
+
+ def __init__(self,
+ data_preprocessor: Union[dict, Config],
+ image_model,
+ instance_model,
+ fusion_model,
+ color_data_opt,
+ which_direction='AtoB',
+ loss=None,
+ init_cfg=None,
+ train_cfg=None,
+ test_cfg=None):
+
+ super().__init__(
+ init_cfg=init_cfg, data_preprocessor=data_preprocessor)
+
+ # colorization networks
+ # image_model: used to colorize a single image
+ self.image_model = MODULES.build(image_model)
+
+ # instance model: used to colorize cropped instance
+ self.instance_model = MODULES.build(instance_model)
+
+ # fusion model: input a single image with related instance features
+ self.fusion_model = MODULES.build(fusion_model)
+
+ self.color_data_opt = color_data_opt
+ self.which_direction = which_direction
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ def forward(self,
+ inputs: torch.Tensor,
+ data_samples: Optional[List[EditDataSample]] = None,
+ mode: str = 'tensor',
+ **kwargs):
+ """Returns losses or predictions of training, validation, testing, and
+ simple inference process.
+
+ ``forward`` method of BaseModel is an abstract method, its subclasses
+ must implement this method.
+
+ Accepts ``inputs`` and ``data_samples`` processed by
+ :attr:`data_preprocessor`, and returns results according to mode
+ arguments.
+
+ During non-distributed training, validation, and testing process,
+ ``forward`` will be called by ``BaseModel.train_step``,
+ ``BaseModel.val_step`` and ``BaseModel.val_step`` directly.
+
+ During distributed data parallel training process,
+ ``MMSeparateDistributedDataParallel.train_step`` will first call
+ ``DistributedDataParallel.forward`` to enable automatic
+ gradient synchronization, and then call ``forward`` to get training
+ loss.
+
+ Args:
+ inputs (torch.Tensor): batch input tensor collated by
+ :attr:`data_preprocessor`.
+ data_samples (List[BaseDataElement], optional):
+ data samples collated by :attr:`data_preprocessor`.
+ mode (str): mode should be one of ``loss``, ``predict`` and
+ ``tensor``. Default: 'tensor'.
+
+ - ``loss``: Called by ``train_step`` and return loss ``dict``
+ used for logging
+ - ``predict``: Called by ``val_step`` and ``test_step``
+ and return list of ``BaseDataElement`` results used for
+ computing metric.
+ - ``tensor``: Called by custom use to get ``Tensor`` type
+ results.
+
+ Returns:
+ ForwardResults:
+
+ - If ``mode == loss``, return a ``dict`` of loss tensor used
+ for backward and logging.
+ - If ``mode == predict``, return a ``list`` of
+ :obj:`BaseDataElement` for computing metric
+ and getting inference result.
+ - If ``mode == tensor``, return a tensor or ``tuple`` of tensor
+ or ``dict`` or tensor for custom use.
+ """
+
+ if mode == 'tensor':
+ return self.forward_tensor(inputs, data_samples, **kwargs)
+
+ elif mode == 'predict':
+ predictions = self.forward_inference(inputs, data_samples,
+ **kwargs)
+ predictions = self.convert_to_datasample(data_samples, predictions)
+ return predictions
+
+ elif mode == 'loss':
+ return self.forward_train(inputs, data_samples, **kwargs)
+
+ def convert_to_datasample(self, inputs, data_samples):
+ for data_sample, output in zip(inputs, data_samples):
+ data_sample.output = output
+ return inputs
+
+ def forward_train(self, inputs, data_samples=None, **kwargs):
+ """Forward function for training."""
+ raise NotImplementedError(
+ 'Instance Colorization has not supported training.')
+
+ def train_step(self, data: List[dict],
+ optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]:
+ """Train step function.
+
+ Args:
+ data (List[dict]): Batch of data as input.
+ optim_wrapper (dict[torch.optim.Optimizer]): Dict with optimizers
+ for generator and discriminator (if have).
+ Returns:
+ dict: Dict with loss, information for logger, the number of
+ samples and results for visualization.
+ """
+ raise NotImplementedError(
+ 'Instance Colorization has not supported training.')
+
+ def forward_inference(self, inputs, data_samples=None, **kwargs):
+ """Forward inference. Returns predictions of validation, testing.
+
+ Args:
+ inputs (torch.Tensor): batch input tensor collated by
+ :attr:`data_preprocessor`.
+ data_samples (List[BaseDataElement], optional):
+ data samples collated by :attr:`data_preprocessor`.
+
+ Returns:
+ List[EditDataSample]: predictions.
+ """
+ feats = self.forward_tensor(inputs, data_samples, **kwargs)
+ predictions = []
+ for idx in range(feats.shape[0]):
+ batch_tensor = feats[idx] * 127.5 + 127.5
+ pred_img = PixelData(data=batch_tensor.to('cpu'))
+ predictions.append(
+ EditDataSample(
+ pred_img=pred_img, metainfo=data_samples[idx].metainfo))
+
+ return predictions
+
+ def forward_tensor(self, inputs, data_samples):
+ """Forward function in tensor mode.
+
+ Args:
+ inputs (torch.Tensor): Input tensor.
+ data_sample (dict): Dict contains data sample.
+
+ Returns:
+ dict: Dict contains output results.
+ """
+
+ # prepare data
+
+ assert len(data_samples) == 1, \
+ 'fusion model supports only one image due to different numbers '\
+ 'of instances of different images'
+
+ full_img_data = get_colorization_data(inputs, self.color_data_opt)
+ AtoB = self.which_direction == 'AtoB'
+
+ # preprocess input for a single image
+ full_real_A = full_img_data['A' if AtoB else 'B']
+ full_hint_B = full_img_data['hint_B']
+ full_mask_B = full_img_data['mask_B']
+
+ if not data_samples[0].empty_box:
+ # preprocess instance input
+ cropped_img = data_samples[0].cropped_img.data
+ box_info_list = [
+ data_samples[0].box_info, data_samples[0].box_info_2x,
+ data_samples[0].box_info_4x, data_samples[0].box_info_8x
+ ]
+ cropped_data = get_colorization_data(cropped_img,
+ self.color_data_opt)
+
+ real_A = cropped_data['A' if AtoB else 'B']
+ hint_B = cropped_data['hint_B']
+ mask_B = cropped_data['mask_B']
+
+ # network forward
+ _, output, feature_map = self.instance_model(
+ real_A, hint_B, mask_B)
+ output = self.fusion_model(full_real_A, full_hint_B, full_mask_B,
+ feature_map, box_info_list)
+
+ else:
+ _, output, _ = self.image_model(full_real_A, full_hint_B,
+ full_mask_B)
+
+ output = [
+ full_real_A.type(torch.cuda.FloatTensor),
+ output.type(torch.cuda.FloatTensor)
+ ]
+ output = torch.cat(output, dim=1)
+ output = torch.clamp(lab2rgb(output, self.color_data_opt), 0.0, 1.0)
+ return output
diff --git a/mmedit/models/editors/inst_colorization/weight_layer.py b/mmedit/models/editors/inst_colorization/weight_layer.py
new file mode 100644
index 0000000000..c2f05b34f0
--- /dev/null
+++ b/mmedit/models/editors/inst_colorization/weight_layer.py
@@ -0,0 +1,132 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+
+import torch
+from mmengine.model import BaseModule
+from torch import nn
+
+from mmedit.registry import MODULES
+
+
+def get_norm_layer(norm_type='instance'):
+ """Gets the normalization layer.
+
+ Args:
+ norm_type (str): Type of the normalization layer.
+
+ Returns:
+ norm_layer (BatchNorm2d or InstanceNorm2d or None):
+ normalization layer. Default: instance
+ """
+ if norm_type == 'batch':
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
+ elif norm_type == 'instance':
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
+ elif norm_type == 'none':
+ norm_layer = None
+ else:
+ raise NotImplementedError('normalization layer [%s] is not found' %
+ norm_type)
+ return norm_layer
+
+
+@MODULES.register_module()
+class WeightLayer(BaseModule):
+ """Weight layer of the fusion_net. A small neural network with three
+ convolutional layers to predict full-image weight map and perinstance
+ weight map.
+
+ Args:
+ input_ch (int): Number of channels in the input image.
+ inner_ch (int): Number of channels produced by the convolution.
+ Default: True
+ """
+
+ def __init__(self, input_ch, inner_ch=16):
+ super().__init__()
+ self.simple_instance_conv = nn.Sequential(
+ nn.Conv2d(input_ch, inner_ch, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(inner_ch, inner_ch, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(inner_ch, 1, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ )
+
+ self.simple_bg_conv = nn.Sequential(
+ nn.Conv2d(input_ch, inner_ch, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(inner_ch, inner_ch, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(inner_ch, 1, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ )
+
+ self.normalize = nn.Softmax(1)
+
+ def resize_and_pad(self, feauture_maps, info_array):
+ """Resize the instance feature as well as the weight map to match the
+ size of full-image and do zero padding on both of them.
+
+ Args:
+ feauture_maps (tensor): Feature map
+ info_array (tensor): The bounding box
+
+ Returns:
+ feauture_maps (tensor): Feature maps after resize and padding
+ """
+ feauture_maps = torch.nn.functional.interpolate(
+ feauture_maps,
+ size=(info_array[5], info_array[4]),
+ mode='bilinear')
+ feauture_maps = torch.nn.functional.pad(feauture_maps,
+ (info_array[0], info_array[1],
+ info_array[2], info_array[3]),
+ 'constant', 0)
+ return feauture_maps
+
+ def forward(self, instance_feature, bg_feature, box_info):
+ """Forward function.
+
+ Args:
+ instance_feature (tensor): Instance feature obtained from the
+ colorization_net
+ bg_feature (tensor): full-image feature
+ box_info (tensor): The bounding box corresponding to the instance
+
+ Returns:
+ out (tensor): Fused feature
+ """
+ mask_list = []
+ featur_map_list = []
+ mask_sum_for_pred = torch.zeros_like(bg_feature)[:1, :1]
+ for i in range(instance_feature.shape[0]):
+ tmp_crop = torch.unsqueeze(instance_feature[i], 0)
+ conv_tmp_crop = self.simple_instance_conv(tmp_crop)
+ pred_mask = self.resize_and_pad(conv_tmp_crop, box_info[i])
+
+ tmp_crop = self.resize_and_pad(tmp_crop, box_info[i])
+
+ mask = torch.zeros_like(bg_feature)[:1, :1]
+ mask[0, 0, box_info[i][2]:box_info[i][2] + box_info[i][5],
+ box_info[i][0]:box_info[i][0] + box_info[i][4]] = 1.0
+ device = mask.device
+ mask = mask.type(torch.FloatTensor).to(device)
+
+ mask_sum_for_pred = torch.clamp(mask_sum_for_pred + mask, 0.0, 1.0)
+
+ mask_list.append(pred_mask)
+ featur_map_list.append(tmp_crop)
+
+ pred_bg_mask = self.simple_bg_conv(bg_feature)
+ mask_list.append(pred_bg_mask + (1 - mask_sum_for_pred) * 100000.0)
+ mask_list = self.normalize(torch.cat(mask_list, 1))
+
+ mask_list_maskout = mask_list.clone()
+
+ featur_map_list.append(bg_feature)
+ featur_map_list = torch.cat(featur_map_list, 0)
+ mask_list_maskout = mask_list_maskout.permute(1, 0, 2, 3).contiguous()
+ out = featur_map_list * mask_list_maskout
+ out = torch.sum(out, 0, keepdim=True)
+ return out
diff --git a/mmedit/models/editors/mspie/pe_singan_generator.py b/mmedit/models/editors/mspie/pe_singan_generator.py
index f40c74e6f0..d4cd33bdf2 100644
--- a/mmedit/models/editors/mspie/pe_singan_generator.py
+++ b/mmedit/models/editors/mspie/pe_singan_generator.py
@@ -167,9 +167,12 @@ def forward(self,
noise_list = []
if input_sample is None:
+ h, w = fixed_noises[0].shape[-2:]
+ if self.noise_with_pad:
+ h -= 2 * self.pad_head
+ w -= 2 * self.pad_head
input_sample = torch.zeros(
- (num_batches, 3, fixed_noises[0].shape[-2],
- fixed_noises[0].shape[-1])).to(fixed_noises[0])
+ (num_batches, 3, h, w)).to(fixed_noises[0])
g_res = input_sample
diff --git a/mmedit/models/editors/pggan/pggan.py b/mmedit/models/editors/pggan/pggan.py
index e56339c961..1e1ec7e08b 100644
--- a/mmedit/models/editors/pggan/pggan.py
+++ b/mmedit/models/editors/pggan/pggan.py
@@ -175,6 +175,8 @@ def forward(self,
gen_sample = EditDataSample()
if data_samples:
gen_sample.update(data_samples[idx])
+ if isinstance(inputs, dict) and 'img' in inputs:
+ gen_sample.gt_img = PixelData(data=inputs['img'][idx])
if isinstance(outputs, dict):
gen_sample.ema = EditDataSample(
fake_img=PixelData(data=outputs['ema'][idx]),
diff --git a/mmedit/models/editors/pix2pix/pix2pix.py b/mmedit/models/editors/pix2pix/pix2pix.py
index c251ca4d64..1efc1380d4 100644
--- a/mmedit/models/editors/pix2pix/pix2pix.py
+++ b/mmedit/models/editors/pix2pix/pix2pix.py
@@ -20,6 +20,7 @@ class Pix2Pix(BaseTranslationModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ self.pixel_loss_weight = self.loss_config.get('pixel_loss_weight', 100)
def forward_test(self, img, target_domain, **kwargs):
"""Forward function for testing.
@@ -92,6 +93,12 @@ def _get_gen_loss(self, outputs):
losses['loss_gan_g'] = F.binary_cross_entropy_with_logits(
fake_pred, 1. * torch.ones_like(fake_pred))
+ # L1 loss for generator
+ losses['loss_pixel'] = self.pixel_loss_weight * F.l1_loss(
+ outputs[f'real_{target_domain}'],
+ outputs[f'fake_{target_domain}'],
+ reduce='mean')
+
loss_g, log_vars_g = self.parse_losses(losses)
return loss_g, log_vars_g
@@ -141,11 +148,9 @@ def train_step(self, data, optim_wrapper=None):
# discriminator
set_requires_grad(self.discriminators, True)
# optimize
- disc_optimizer_wrapper.zero_grad()
loss_d, log_vars_d = self._get_disc_loss(outputs)
+ disc_optimizer_wrapper.update_params(loss_d)
log_vars.update(log_vars_d)
- disc_optimizer_wrapper.backward(loss_d)
- disc_optimizer_wrapper.step()
# generator, no updates to discriminator parameters.
gen_optimizer_wrapper = optim_wrapper['generators']
@@ -154,11 +159,9 @@ def train_step(self, data, optim_wrapper=None):
set_requires_grad(self.discriminators, False)
# optimize
with gen_optimizer_wrapper.optim_context(self.generators):
- gen_optimizer_wrapper.zero_grad()
loss_g, log_vars_g = self._get_gen_loss(outputs)
+ gen_optimizer_wrapper.update_params(loss_g)
log_vars.update(log_vars_g)
- gen_optimizer_wrapper.backward(loss_g)
- gen_optimizer_wrapper.step()
return log_vars
diff --git a/mmedit/models/editors/rdn/rdn_net.py b/mmedit/models/editors/rdn/rdn_net.py
index 29fa19ce70..24eeb912f8 100644
--- a/mmedit/models/editors/rdn/rdn_net.py
+++ b/mmedit/models/editors/rdn/rdn_net.py
@@ -15,6 +15,11 @@ class RDNNet(BaseModule):
'RDN-pytorch/blob/master/models.py'
Copyright (c) 2021, JaeYun Yeo, under MIT License.
+ Most of the implementation follows the implementation in:
+ 'https://github.com/sanghyun-son/EDSR-PyTorch.git'
+ 'EDSR-PyTorch/blob/master/src/model/rdn.py'
+ Copyright (c) 2017, sanghyun-son, under MIT license.
+
Args:
in_channels (int): Channel number of inputs.
out_channels (int): Channel number of outputs.
@@ -51,16 +56,15 @@ def __init__(self,
mid_channels, mid_channels, kernel_size=3, padding=3 // 2)
# residual dense blocks
- self.rdbs = nn.ModuleList(
- [RDB(self.mid_channels, self.channel_growth, self.num_layers)])
- for _ in range(self.num_blocks - 1):
+ self.rdbs = nn.ModuleList()
+ for _ in range(self.num_blocks):
self.rdbs.append(
- RDB(self.channel_growth, self.channel_growth, self.num_layers))
+ RDB(self.mid_channels, self.channel_growth, self.num_layers))
# global feature fusion
self.gff = nn.Sequential(
nn.Conv2d(
- self.channel_growth * self.num_blocks,
+ self.mid_channels * self.num_blocks,
self.mid_channels,
kernel_size=1),
nn.Conv2d(
@@ -165,7 +169,7 @@ def __init__(self, in_channels, channel_growth, num_layers):
# local feature fusion
self.lff = nn.Conv2d(
in_channels + channel_growth * num_layers,
- channel_growth,
+ in_channels,
kernel_size=1)
def forward(self, x):
diff --git a/mmedit/models/editors/real_basicvsr/real_basicvsr.py b/mmedit/models/editors/real_basicvsr/real_basicvsr.py
index bcff67569a..8bb6a140e5 100644
--- a/mmedit/models/editors/real_basicvsr/real_basicvsr.py
+++ b/mmedit/models/editors/real_basicvsr/real_basicvsr.py
@@ -58,6 +58,7 @@ def __init__(self,
is_use_sharpened_gt_in_pixel=False,
is_use_sharpened_gt_in_percep=False,
is_use_sharpened_gt_in_gan=False,
+ is_use_ema=False,
train_cfg=None,
test_cfg=None,
init_cfg=None,
@@ -72,6 +73,7 @@ def __init__(self,
is_use_sharpened_gt_in_pixel=is_use_sharpened_gt_in_pixel,
is_use_sharpened_gt_in_percep=is_use_sharpened_gt_in_percep,
is_use_sharpened_gt_in_gan=is_use_sharpened_gt_in_gan,
+ is_use_ema=is_use_ema,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg,
diff --git a/mmedit/models/editors/real_esrgan/real_esrgan.py b/mmedit/models/editors/real_esrgan/real_esrgan.py
index 987e8a3cf0..b16b31477a 100644
--- a/mmedit/models/editors/real_esrgan/real_esrgan.py
+++ b/mmedit/models/editors/real_esrgan/real_esrgan.py
@@ -33,6 +33,8 @@ class RealESRGAN(SRGAN):
is_use_sharpened_gt_in_gan (bool, optional): Whether to use the
image sharpened by unsharp masking as the GT for adversarial loss.
Default: False.
+ is_use_ema (bool, optional): When to apply exponential moving average
+ on the network weights. Default: True.
train_cfg (dict): Config for training. Default: None.
You may change the training of gan by setting:
`disc_steps`: how many discriminator updates after one generate
@@ -56,6 +58,7 @@ def __init__(self,
is_use_sharpened_gt_in_pixel=False,
is_use_sharpened_gt_in_percep=False,
is_use_sharpened_gt_in_gan=False,
+ is_use_ema=True,
train_cfg=None,
test_cfg=None,
init_cfg=None,
@@ -75,12 +78,34 @@ def __init__(self,
self.is_use_sharpened_gt_in_pixel = is_use_sharpened_gt_in_pixel
self.is_use_sharpened_gt_in_percep = is_use_sharpened_gt_in_percep
self.is_use_sharpened_gt_in_gan = is_use_sharpened_gt_in_gan
+ self.is_use_ema = is_use_ema
if train_cfg is not None: # used for initializing from ema model
self.start_iter = train_cfg.get('start_iter', -1)
else:
self.start_iter = -1
+ def forward_tensor(self, inputs, data_samples=None, training=False):
+ """Forward tensor. Returns result of simple forward.
+
+ Args:
+ inputs (torch.Tensor): batch input tensor collated by
+ :attr:`data_preprocessor`.
+ data_samples (List[BaseDataElement], optional):
+ data samples collated by :attr:`data_preprocessor`.
+ training (bool): Whether is training. Default: False.
+
+ Returns:
+ Tensor: result of simple forward.
+ """
+
+ if training or not self.is_use_ema:
+ feats = self.generator(inputs)
+ else:
+ feats = self.generator_ema(inputs)
+
+ return feats
+
def g_step(self, batch_outputs, batch_gt_data):
"""G step of GAN: Calculate losses of generator.
diff --git a/mmedit/models/editors/singan/singan.py b/mmedit/models/editors/singan/singan.py
index 391f55c460..53e39c0834 100644
--- a/mmedit/models/editors/singan/singan.py
+++ b/mmedit/models/editors/singan/singan.py
@@ -16,7 +16,7 @@
from mmedit.models.utils import get_module_device
from mmedit.registry import MODELS
from mmedit.structures import EditDataSample, PixelData
-from mmedit.utils import SampleList
+from mmedit.utils import ForwardInputs, SampleList
from ...base_models import BaseGAN
from ...utils import set_requires_grad
@@ -171,31 +171,33 @@ def construct_fixed_noises(self):
self.fixed_noises.append(noise)
def forward(self,
- batch_inputs: dict,
+ inputs: ForwardInputs,
data_samples: Optional[list] = None,
mode=None) -> List[EditDataSample]:
- """Forward function for SinGAN. For SinGAN, `batch_inputs` should be a
- dict contains 'num_batches', 'mode' and other input arguments for the
+ """Forward function for SinGAN. For SinGAN, `inputs` should be a dict
+ contains 'num_batches', 'mode' and other input arguments for the
generator.
Args:
- batch_inputs (dict): Dict containing the necessary information
+ inputs (dict): Dict containing the necessary information
(e.g., noise, num_batches, mode) to generate image.
data_samples (Optional[list]): Data samples collated by
:attr:`data_preprocessor`. Defaults to None.
mode (Optional[str]): `mode` is not used in
:class:`BaseConditionalGAN`. Defaults to None.
"""
- sample_model = self._get_valid_model(batch_inputs)
# handle batch_inputs
- assert isinstance(batch_inputs, dict), (
- 'SinGAN only support dict type batch_inputs in forward function.')
- gen_kwargs = deepcopy(batch_inputs)
+ assert isinstance(inputs, dict), (
+ 'SinGAN only support dict type inputs in forward function.')
+ gen_kwargs = deepcopy(inputs)
num_batches = gen_kwargs.pop('num_batches', 1)
assert num_batches == 1, (
'SinGAN only support \'num_batches\' as 1, but receive '
f'{num_batches}.')
+ sample_model = self._get_valid_model(inputs)
+ gen_kwargs.pop('sample_model', None) # remove sample_model
+
mode = gen_kwargs.pop('mode', mode)
mode = 'rand' if mode is None else mode
curr_scale = gen_kwargs.pop('curr_scale', self.curr_stage)
@@ -235,14 +237,27 @@ def forward(self,
gen_sample = EditDataSample()
if data_samples:
gen_sample.update(data_samples[idx])
- if isinstance(outputs, dict):
- gen_sample.ema = EditDataSample(
- fake_img=PixelData(data=outputs['ema'][idx]),
- sample_model='ema')
- gen_sample.orig = EditDataSample(
- fake_img=PixelData(data=outputs['orig'][idx]),
- sample_model='orig')
- gen_sample.sample_model = 'ema/orig'
+ if sample_model == 'ema/orig':
+ for model_ in ['ema', 'orig']:
+ model_sample_ = EditDataSample()
+ output_ = outputs[model_]
+ if isinstance(output_, dict):
+ fake_img = PixelData(data=output_['fake_img'][idx])
+ prev_res_list = [
+ r[idx] for r in outputs[model_]['prev_res_list']
+ ]
+ model_sample_.prev_res_list = prev_res_list
+ else:
+ fake_img = PixelData(data=output_[idx])
+ model_sample_.fake_img = fake_img
+ model_sample_.sample_model = sample_model
+ gen_sample.set_field(model_sample_, model_)
+ elif isinstance(outputs, dict):
+ gen_sample.fake_img = PixelData(data=outputs['fake_img'][idx])
+ gen_sample.prev_res_list = [
+ r[idx] for r in outputs['prev_res_list']
+ ]
+ gen_sample.sample_model = sample_model
else:
gen_sample.fake_img = PixelData(data=outputs[idx])
gen_sample.sample_model = sample_model
diff --git a/mmedit/models/editors/stylegan1/stylegan1.py b/mmedit/models/editors/stylegan1/stylegan1.py
index c37be47b32..c203ad6103 100644
--- a/mmedit/models/editors/stylegan1/stylegan1.py
+++ b/mmedit/models/editors/stylegan1/stylegan1.py
@@ -16,6 +16,7 @@
@MODELS.register_module('StyleGANV1')
+@MODELS.register_module('StyleGANv1')
@MODELS.register_module()
class StyleGAN1(ProgressiveGrowingGAN):
"""Implementation of `A Style-Based Generator Architecture for Generative
diff --git a/mmedit/models/losses/__init__.py b/mmedit/models/losses/__init__.py
index 4027c72013..df66126d39 100644
--- a/mmedit/models/losses/__init__.py
+++ b/mmedit/models/losses/__init__.py
@@ -18,13 +18,35 @@
from .pixelwise_loss import CharbonnierLoss, L1Loss, MaskedTVLoss, MSELoss
__all__ = [
- 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'L1CompositionLoss',
- 'MSECompositionLoss', 'CharbonnierCompLoss', 'GANLoss', 'GaussianBlur',
- 'GradientPenaltyLoss', 'PerceptualLoss', 'PerceptualVGG', 'reduce_loss',
- 'mask_reduce_loss', 'DiscShiftLoss', 'MaskedTVLoss', 'GradientLoss',
- 'TransferalPerceptualLoss', 'LightCNNFeatureLoss', 'gradient_penalty_loss',
- 'r1_gradient_penalty_loss', 'gen_path_regularizer', 'FaceIdLoss',
- 'CLIPLoss', 'CLIPLossComps', 'DiscShiftLossComps', 'FaceIdLossComps',
- 'GANLossComps', 'GeneratorPathRegularizerComps',
- 'GradientPenaltyLossComps', 'R1GradientPenaltyComps', 'disc_shift_loss'
+ 'L1Loss',
+ 'MSELoss',
+ 'CharbonnierLoss',
+ 'L1CompositionLoss',
+ 'MSECompositionLoss',
+ 'CharbonnierCompLoss',
+ 'GANLoss',
+ 'GaussianBlur',
+ 'GradientPenaltyLoss',
+ 'PerceptualLoss',
+ 'PerceptualVGG',
+ 'reduce_loss',
+ 'mask_reduce_loss',
+ 'DiscShiftLoss',
+ 'MaskedTVLoss',
+ 'GradientLoss',
+ 'TransferalPerceptualLoss',
+ 'LightCNNFeatureLoss',
+ 'gradient_penalty_loss',
+ 'r1_gradient_penalty_loss',
+ 'gen_path_regularizer',
+ 'FaceIdLoss',
+ 'CLIPLoss',
+ 'CLIPLossComps',
+ 'DiscShiftLossComps',
+ 'FaceIdLossComps',
+ 'GANLossComps',
+ 'GeneratorPathRegularizerComps',
+ 'GradientPenaltyLossComps',
+ 'R1GradientPenaltyComps',
+ 'disc_shift_loss',
]
diff --git a/mmedit/models/utils/__init__.py b/mmedit/models/utils/__init__.py
index b579869d60..6d98db512e 100644
--- a/mmedit/models/utils/__init__.py
+++ b/mmedit/models/utils/__init__.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+
from .bbox_utils import extract_around_bbox, extract_bbox_patch
from .flow_warp import flow_warp
from .model_utils import (default_init_weights, generation_init_weights,
@@ -8,17 +9,9 @@
from .tensor_utils import get_unknown_tensor
__all__ = [
- 'default_init_weights',
- 'make_layer',
- 'flow_warp',
- 'generation_init_weights',
- 'set_requires_grad',
- 'extract_bbox_patch',
- 'extract_around_bbox',
- 'get_unknown_tensor',
- 'noise_sample_fn',
- 'label_sample_fn',
- 'get_valid_num_batches',
- 'get_valid_noise_size',
- 'get_module_device',
+ 'default_init_weights', 'make_layer', 'flow_warp',
+ 'generation_init_weights', 'set_requires_grad', 'extract_bbox_patch',
+ 'extract_around_bbox', 'get_unknown_tensor', 'noise_sample_fn',
+ 'label_sample_fn', 'get_valid_num_batches', 'get_valid_noise_size',
+ 'get_module_device'
]
diff --git a/mmedit/utils/__init__.py b/mmedit/utils/__init__.py
index 3fdaa607f2..1533e91fc7 100644
--- a/mmedit/utils/__init__.py
+++ b/mmedit/utils/__init__.py
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cli import modify_args
-from .img_utils import reorder_image, tensor2img, to_numpy
+from .img_utils import get_box_info, reorder_image, tensor2img, to_numpy
from .io_utils import MMEDIT_CACHE_DIR, download_from_url
# TODO replace with engine's API
from .logger import print_colored_log
@@ -17,5 +17,5 @@
'download_from_url', 'get_sampler', 'tensor2img', 'random_choose_unknown',
'add_gaussian_noise', 'adjust_gamma', 'make_coord', 'bbox2mask',
'brush_stroke_mask', 'get_irregular_mask', 'random_bbox', 'reorder_image',
- 'to_numpy'
+ 'to_numpy', 'get_box_info'
]
diff --git a/mmedit/utils/img_utils.py b/mmedit/utils/img_utils.py
index bf420910b5..ff5b2c1f03 100644
--- a/mmedit/utils/img_utils.py
+++ b/mmedit/utils/img_utils.py
@@ -125,3 +125,40 @@ def to_numpy(img, dtype=np.float64):
img = img.astype(dtype)
return img
+
+
+def get_box_info(pred_bbox, original_shape, final_size):
+ """
+
+ Args:
+ pred_bbox: The bounding box for the instance
+ original_shape: Original image shape
+ final_size: Size of the final output
+
+ Returns:
+ List: [L_pad, R_pad, T_pad, B_pad, rh, rw]
+ """
+ assert len(pred_bbox) == 4
+ resize_startx = int(pred_bbox[0] / original_shape[0] * final_size)
+ resize_starty = int(pred_bbox[1] / original_shape[1] * final_size)
+ resize_endx = int(pred_bbox[2] / original_shape[0] * final_size)
+ resize_endy = int(pred_bbox[3] / original_shape[1] * final_size)
+ rh = resize_endx - resize_startx
+ rw = resize_endy - resize_starty
+ if rh < 1:
+ if final_size - resize_endx > 1:
+ resize_endx += 1
+ else:
+ resize_startx -= 1
+ rh = 1
+ if rw < 1:
+ if final_size - resize_endy > 1:
+ resize_endy += 1
+ else:
+ resize_starty -= 1
+ rw = 1
+ L_pad = resize_startx
+ R_pad = final_size - resize_endx
+ T_pad = resize_starty
+ B_pad = final_size - resize_endy
+ return [L_pad, R_pad, T_pad, B_pad, rh, rw]
diff --git a/mmedit/utils/io_utils.py b/mmedit/utils/io_utils.py
index de965c1908..bd6128f707 100644
--- a/mmedit/utils/io_utils.py
+++ b/mmedit/utils/io_utils.py
@@ -3,7 +3,7 @@
import os
import click
-import mmcv
+import mmengine
import requests
import torch.distributed as dist
from mmengine.dist import get_dist_info
@@ -67,7 +67,7 @@ def download_from_url(url,
if rank == 0:
# mkdir
_dir = os.path.dirname(dest_path)
- mmcv.mkdir_or_exist(_dir)
+ mmengine.mkdir_or_exist(_dir)
if hash_prefix is not None:
sha256 = hashlib.sha256()
diff --git a/mmedit/version.py b/mmedit/version.py
index b367889550..da34a5e738 100644
--- a/mmedit/version.py
+++ b/mmedit/version.py
@@ -1,6 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
-__version__ = '1.0.0rc1'
+__version__ = '1.0.0rc2'
def parse_version_info(version_str):
diff --git a/mmedit/visualization/gen_visualizer.py b/mmedit/visualization/gen_visualizer.py
index 13d976e2b3..0ffdc11cde 100644
--- a/mmedit/visualization/gen_visualizer.py
+++ b/mmedit/visualization/gen_visualizer.py
@@ -54,7 +54,7 @@ def _post_process_image(image: Tensor,
mean: mean_std_type = None,
std: mean_std_type = None) -> Tensor:
"""Post process images. First convert image to `rgb` order. And then
- de-norm image to fid `mean` and `std` if `mean` and `std` is passed.
+ de-norm image to `mean` and `std` if they are passed.
Args:
image (Tensor): Image to pose process.
@@ -73,7 +73,7 @@ def _post_process_image(image: Tensor,
image = image[:, [2, 1, 0], ...]
if mean is not None and std is not None:
image = image * std + mean
- return image
+ return image.clamp(0, 255)
@staticmethod
def _get_n_row_and_padding(samples: Tuple[dict, Tensor],
diff --git a/model-index.yml b/model-index.yml
index 6373a6dc4d..b0ce511cac 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -20,6 +20,7 @@ Import:
- configs/global_local/metafile.yml
- configs/iconvsr/metafile.yml
- configs/indexnet/metafile.yml
+- configs/inst_colorization/metafile.yml
- configs/liif/metafile.yml
- configs/lsgan/metafile.yml
- configs/partial_conv/metafile.yml
diff --git a/requirements/optional.txt b/requirements/optional.txt
new file mode 100644
index 0000000000..95688066cc
--- /dev/null
+++ b/requirements/optional.txt
@@ -0,0 +1 @@
+mmdet
diff --git a/tests/test_apis/test_colorization_inference.py b/tests/test_apis/test_colorization_inference.py
new file mode 100644
index 0000000000..3e574633bb
--- /dev/null
+++ b/tests/test_apis/test_colorization_inference.py
@@ -0,0 +1,52 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import platform
+import unittest
+
+import pytest
+import torch
+from mmengine import Config
+from mmengine.runner import load_checkpoint
+
+from mmedit.apis import colorization_inference
+from mmedit.registry import MODELS
+from mmedit.utils import register_all_modules, tensor2img
+
+
+@pytest.mark.skipif(
+ 'win' in platform.system().lower() and 'cu' in torch.__version__,
+ reason='skip on windows-cuda due to limited RAM.')
+def test_colorization_inference():
+ register_all_modules()
+
+ if not torch.cuda.is_available():
+ # RoI pooling only support in GPU
+ return unittest.skip('test requires GPU and torch+cuda')
+
+ if torch.cuda.is_available():
+ device = torch.device('cuda', 0)
+ else:
+ device = torch.device('cpu')
+
+ config = osp.join(
+ osp.dirname(__file__),
+ '../..',
+ 'configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py' # noqa
+ )
+ checkpoint = None
+
+ cfg = Config.fromfile(config)
+ model = MODELS.build(cfg.model)
+
+ if checkpoint is not None:
+ checkpoint = load_checkpoint(model, checkpoint)
+
+ model.cfg = cfg
+ model.to(device)
+ model.eval()
+
+ img_path = osp.join(
+ osp.dirname(__file__), '..', 'data/image/img_root/horse/horse.jpeg')
+
+ result = colorization_inference(model, img_path)
+ assert tensor2img(result)[..., ::-1].shape == (256, 256, 3)
diff --git a/tests/test_datasets/test_singan_dataset.py b/tests/test_datasets/test_singan_dataset.py
new file mode 100644
index 0000000000..a872d4242f
--- /dev/null
+++ b/tests/test_datasets/test_singan_dataset.py
@@ -0,0 +1,31 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+
+from mmedit.datasets import SinGANDataset
+from mmedit.utils import register_all_modules
+
+register_all_modules()
+
+
+class TestSinGANDataset(object):
+
+ @classmethod
+ def setup_class(cls):
+ cls.imgs_root = osp.join(
+ osp.dirname(osp.dirname(__file__)), 'data/image/gt/baboon.png')
+ cls.min_size = 25
+ cls.max_size = 250
+ cls.scale_factor_init = 0.75
+ cls.pipeline = [dict(type='PackEditInputs', pack_all=True)]
+
+ def test_singan_dataset(self):
+ dataset = SinGANDataset(
+ self.imgs_root,
+ min_size=self.min_size,
+ max_size=self.max_size,
+ scale_factor_init=self.scale_factor_init,
+ pipeline=self.pipeline)
+ assert len(dataset) == 1000000
+
+ data_dict = dataset[0]['inputs']
+ assert all([f'real_scale{i}' in data_dict for i in range(10)])
diff --git a/tests/test_datasets/test_transforms/test_crop.py b/tests/test_datasets/test_transforms/test_crop.py
index b755e264ff..703e7657c6 100644
--- a/tests/test_datasets/test_transforms/test_crop.py
+++ b/tests/test_datasets/test_transforms/test_crop.py
@@ -1,10 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
+import os.path as osp
+import unittest
+import cv2
import numpy as np
import pytest
+import torch
-from mmedit.datasets.transforms import (Crop, CropLike, FixedCrop, ModCrop,
+from mmedit.datasets.transforms import (Crop, CropLike, FixedCrop,
+ InstanceCrop, ModCrop,
PairedRandomCrop, RandomResizedCrop)
@@ -350,3 +355,31 @@ def test_crop_like():
assert results['gt'].shape == (512, 512)
sum_diff = np.sum(abs(results['gt'][:480, :512] - img[:480, :512, 0]))
assert sum_diff < 1e-6
+
+
+def test_instance_crop():
+
+ if not torch.cuda.is_available():
+ # RoI pooling only support in GPU
+ return unittest.skip('test requires GPU and torch+cuda')
+
+ croper = InstanceCrop(
+ key='img',
+ finesize=256,
+ box_num_upbound=2,
+ config_file='mmdet::mask_rcnn/'
+ 'mask-rcnn_x101-32x8d_fpn_ms-poly-3x_coco.py') # noqa
+
+ img_path = osp.join(
+ osp.dirname(__file__), '..', '..',
+ 'data/image/img_root/horse/horse.jpeg')
+ img = cv2.imread(img_path)
+ data = dict(img=img, ori_img_shape=img.shape, img_channel_order='rgb')
+
+ results = croper(data)
+
+ assert 'empty_box' in results
+ if results['empty_box']:
+ cropped_img = results['cropped_img']
+ assert len(cropped_img) == 0
+ assert len(cropped_img) <= 2
diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py
index 55b633e2e8..c9ac37ddf1 100644
--- a/tests/test_datasets/test_transforms/test_formatting.py
+++ b/tests/test_datasets/test_transforms/test_formatting.py
@@ -4,7 +4,8 @@
from mmcv.transforms import to_tensor
from mmedit.datasets.transforms import PackEditInputs, ToTensor
-from mmedit.datasets.transforms.formatting import images_to_tensor
+from mmedit.datasets.transforms.formatting import (can_convert_to_image,
+ images_to_tensor)
from mmedit.structures.edit_data_sample import EditDataSample
@@ -117,6 +118,18 @@ def test_pack_edit_inputs():
assert data_sample.metainfo['img_shape'] == (64, 64)
assert data_sample.metainfo['a'] == 'b'
+ # test pack_all
+ pack_edit_inputs = PackEditInputs(pack_all=True)
+ results = ori_results.copy()
+ packed_results = pack_edit_inputs(results)
+ print(packed_results['inputs'].keys())
+
+ target_keys = [
+ 'img', 'gt', 'img_lq', 'ref', 'ref_lq', 'mask', 'gt_heatmap',
+ 'gt_unsharp', 'merged', 'trimap', 'alpha', 'fg', 'bg'
+ ]
+ assert all([k in target_keys for k in packed_results['inputs']])
+
def test_to_tensor():
@@ -135,3 +148,14 @@ def test_to_tensor():
assert set(keys).issubset(results.keys())
for _, v in results.items():
assert isinstance(v, torch.Tensor)
+
+
+def test_can_convert_to_image():
+ values = [
+ np.random.rand(64, 64, 3),
+ [np.random.rand(64, 61, 3),
+ np.random.rand(64, 61, 3)], (64, 64), 'b'
+ ]
+ targets = [True, True, False, False]
+ for val, tar in zip(values, targets):
+ assert can_convert_to_image(val) == tar
diff --git a/tests/test_datasets/test_transforms/test_generate_frame_indices.py b/tests/test_datasets/test_transforms/test_generate_frame_indices.py
index 4106c89ac4..0cc4a6dc48 100644
--- a/tests/test_datasets/test_transforms/test_generate_frame_indices.py
+++ b/tests/test_datasets/test_transforms/test_generate_frame_indices.py
@@ -221,3 +221,15 @@ def test_frame_index_generation_for_recurrent(self):
frame_index_generator = GenerateSegmentIndices(interval_list=[10])
with pytest.raises(ValueError):
frame_index_generator(copy.deepcopy(results))
+
+ # num_input_frames is None
+ results = dict(
+ img_path='fake_img_root',
+ gt_path='fake_gt_root',
+ key='000',
+ num_input_frames=None,
+ sequence_length=100)
+
+ frame_index_generator = GenerateSegmentIndices(interval_list=[1])
+ rlt = frame_index_generator(copy.deepcopy(results))
+ assert len(rlt['img_path']) == 100
diff --git a/tests/test_evaluation/test_evaluator.py b/tests/test_evaluation/test_evaluator.py
index ea9290efef..50d790f7c9 100644
--- a/tests/test_evaluation/test_evaluator.py
+++ b/tests/test_evaluation/test_evaluator.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
from unittest import TestCase
from unittest.mock import MagicMock, patch
@@ -77,6 +78,26 @@ def test_prepare_samplers(self):
self.assertEqual(metric_sampler_list[0][1].max_length, 11)
self.assertEqual(len(metric_sampler_list[0][1]), 6)
+ # test prepare metrics with different `sample_model`
+ cfg = deepcopy(self.metrics)
+ cfg.append(
+ dict(
+ type='FrechetInceptionDistance',
+ fake_nums=12,
+ inception_style='pytorch',
+ sample_model='ema'))
+ evaluator = GenEvaluator(cfg)
+
+ # mock metrics
+ model = MagicMock()
+ model.data_preprocessor.device = 'cpu'
+
+ dataloader = MagicMock()
+ dataloader.batch_size = 2
+
+ metric_sampler_list = evaluator.prepare_samplers(model, dataloader)
+ self.assertEqual(len(metric_sampler_list), 3)
+
@patch(is_loading_str, loading_mock)
@patch(fid_loading_str, loading_mock)
def test_process(self):
diff --git a/tests/test_evaluation/test_metrics/test_connectivity_error.py b/tests/test_evaluation/test_metrics/test_connectivity_error.py
new file mode 100644
index 0000000000..29ce9be702
--- /dev/null
+++ b/tests/test_evaluation/test_metrics/test_connectivity_error.py
@@ -0,0 +1,108 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from pathlib import Path
+
+import numpy as np
+import pytest
+import torch
+
+from mmedit.datasets.transforms import LoadImageFromFile
+from mmedit.evaluation.metrics import ConnectivityError
+
+
+class TestMattingMetrics:
+
+ @classmethod
+ def setup_class(cls):
+ # Make sure these values are immutable across different test cases.
+
+ # This test depends on the interface of loading
+ # if loading is changed, data should be change accordingly.
+ test_path = Path(__file__).parent.parent.parent
+ alpha_path = (
+ test_path / 'data' / 'matting_dataset' / 'alpha' / 'GT05.jpg')
+
+ results = dict(alpha_path=alpha_path)
+ config = dict(key='alpha')
+ image_loader = LoadImageFromFile(**config)
+ results = image_loader(results)
+ assert results['alpha'].ndim == 3
+
+ gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255
+ trimap = np.zeros((32, 32), dtype=np.uint8)
+ trimap[:16, :16] = 128
+ trimap[16:, 16:] = 255
+ # non-masked pred_alpha
+ pred_alpha = torch.zeros((32, 32), dtype=torch.uint8)
+ # masked pred_alpha
+ masked_pred_alpha = pred_alpha.clone()
+ masked_pred_alpha[trimap == 0] = 0
+ masked_pred_alpha[trimap == 255] = 255
+
+ gt_alpha = gt_alpha[..., None]
+ trimap = trimap[..., None]
+ # pred_alpha = pred_alpha.unsqueeze(0)
+ # masked_pred_alpha = masked_pred_alpha.unsqueeze(0)
+
+ cls.data_batch = [{
+ 'inputs': [],
+ 'data_samples': {
+ 'ori_trimap': trimap,
+ 'ori_alpha': gt_alpha,
+ },
+ }]
+
+ cls.data_samples = [d_['data_samples'] for d_ in cls.data_batch]
+
+ cls.bad_preds1_ = [{'pred_alpha': dict(data=pred_alpha)}]
+ # pred_alpha should be masked by trimap before evaluation
+ cls.bad_preds1 = copy.deepcopy(cls.data_samples)
+ for d, p in zip(cls.bad_preds1, cls.bad_preds1_):
+ d['output'] = p
+
+ cls.bad_preds2_ = [{'pred_alpha': dict(data=pred_alpha[0])}]
+ # pred_alpha should be 3 dimensional
+ cls.bad_preds2 = copy.deepcopy(cls.data_samples)
+ for d, p in zip(cls.bad_preds2, cls.bad_preds2_):
+ d['output'] = p
+
+ cls.good_preds_ = [{'pred_alpha': dict(data=masked_pred_alpha)}]
+ cls.good_preds = copy.deepcopy((cls.data_samples))
+ for d, p in zip(cls.good_preds, cls.good_preds_):
+ d['output'] = p
+
+ def test_connectivity_error(self):
+ """Test connectivity error for evaluating predicted alpha matte."""
+
+ data_batch, bad_pred1, bad_pred2, good_pred = (
+ self.data_batch,
+ self.bad_preds1,
+ self.bad_preds2,
+ self.good_preds,
+ )
+
+ conn_err = ConnectivityError()
+
+ with pytest.raises(ValueError):
+ conn_err.process(data_batch, bad_pred1)
+
+ with pytest.raises(ValueError):
+ conn_err.process(data_batch, bad_pred2)
+
+ # process 2 batches
+ conn_err.process(data_batch, good_pred)
+ conn_err.process(data_batch, good_pred)
+
+ assert conn_err.results == [
+ {
+ 'conn_err': 0.256,
+ },
+ {
+ 'conn_err': 0.256,
+ },
+ ]
+
+ res = conn_err.compute_metrics(conn_err.results)
+
+ assert list(res.keys()) == ['ConnectivityError']
+ assert np.allclose(res['ConnectivityError'], 0.256)
diff --git a/tests/test_evaluation/test_metrics/test_matting.py b/tests/test_evaluation/test_metrics/test_gradient_error.py
similarity index 54%
rename from tests/test_evaluation/test_metrics/test_matting.py
rename to tests/test_evaluation/test_metrics/test_gradient_error.py
index 6777912a3b..f6720585c5 100644
--- a/tests/test_evaluation/test_metrics/test_matting.py
+++ b/tests/test_evaluation/test_metrics/test_gradient_error.py
@@ -7,8 +7,7 @@
import torch
from mmedit.datasets.transforms import LoadImageFromFile
-from mmedit.evaluation.metrics import (SAD, ConnectivityError, GradientError,
- MattingMSE)
+from mmedit.evaluation.metrics import GradientError
class TestMattingMetrics:
@@ -72,78 +71,6 @@ def setup_class(cls):
for d, p in zip(cls.good_preds, cls.good_preds_):
d['output'] = p
- def test_sad(self):
- """Test SAD for evaluating predicted alpha matte."""
-
- data_batch, bad_pred1, bad_pred2, good_pred = (
- self.data_batch,
- self.bad_preds1,
- self.bad_preds2,
- self.good_preds,
- )
-
- sad = SAD()
-
- with pytest.raises(ValueError):
- sad.process(data_batch, bad_pred1)
-
- with pytest.raises(ValueError):
- sad.process(data_batch, bad_pred2)
-
- # process 2 batches
- sad.process(data_batch, good_pred)
- sad.process(data_batch, good_pred)
-
- assert sad.results == [
- {
- 'sad': 0.768,
- },
- {
- 'sad': 0.768,
- },
- ]
-
- res = sad.compute_metrics(sad.results)
-
- assert list(res.keys()) == ['SAD']
- np.testing.assert_almost_equal(res['SAD'], 0.768)
-
- def test_mse(self):
- """Test MattingMSE for evaluating predicted alpha matte."""
-
- data_batch, bad_pred1, bad_pred2, good_pred = (
- self.data_batch,
- self.bad_preds1,
- self.bad_preds2,
- self.good_preds,
- )
-
- mse = MattingMSE()
-
- with pytest.raises(ValueError):
- mse.process(data_batch, bad_pred1)
-
- with pytest.raises(ValueError):
- mse.process(data_batch, bad_pred2)
-
- # process 2 batches
- mse.process(data_batch, good_pred)
- mse.process(data_batch, good_pred)
-
- assert mse.results == [
- {
- 'mse': 3.0,
- },
- {
- 'mse': 3.0,
- },
- ]
-
- res = mse.compute_metrics(mse.results)
-
- assert list(res.keys()) == ['MattingMSE']
- np.testing.assert_almost_equal(res['MattingMSE'], 3.0)
-
def test_gradient_error(self):
"""Test gradient error for evaluating predicted alpha matte."""
@@ -176,39 +103,3 @@ def test_gradient_error(self):
assert list(res.keys()) == ['GradientError']
np.testing.assert_almost_equal(el['grad_err'], 0.0028887)
# assert np.allclose(res['GradientError'], 0.0028887)
-
- def test_connectivity_error(self):
- """Test connectivity error for evaluating predicted alpha matte."""
-
- data_batch, bad_pred1, bad_pred2, good_pred = (
- self.data_batch,
- self.bad_preds1,
- self.bad_preds2,
- self.good_preds,
- )
-
- conn_err = ConnectivityError()
-
- with pytest.raises(ValueError):
- conn_err.process(data_batch, bad_pred1)
-
- with pytest.raises(ValueError):
- conn_err.process(data_batch, bad_pred2)
-
- # process 2 batches
- conn_err.process(data_batch, good_pred)
- conn_err.process(data_batch, good_pred)
-
- assert conn_err.results == [
- {
- 'conn_err': 0.256,
- },
- {
- 'conn_err': 0.256,
- },
- ]
-
- res = conn_err.compute_metrics(conn_err.results)
-
- assert list(res.keys()) == ['ConnectivityError']
- assert np.allclose(res['ConnectivityError'], 0.256)
diff --git a/tests/test_evaluation/test_metrics/test_mae.py b/tests/test_evaluation/test_metrics/test_mae.py
new file mode 100644
index 0000000000..ef03773d1b
--- /dev/null
+++ b/tests/test_evaluation/test_metrics/test_mae.py
@@ -0,0 +1,51 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+
+import numpy as np
+import torch
+
+from mmedit.evaluation.metrics import MAE
+
+
+class TestPixelMetrics:
+
+ @classmethod
+ def setup_class(cls):
+
+ mask = np.ones((32, 32, 3)) * 2
+ mask[:16] *= 0
+ gt = np.ones((32, 32, 3)) * 2
+ data_sample = dict(gt_img=gt, mask=mask, gt_channel_order='bgr')
+ cls.data_batch = [dict(data_samples=data_sample)]
+ cls.predictions = [dict(pred_img=np.ones((32, 32, 3)))]
+
+ cls.data_batch.append(
+ dict(
+ data_samples=dict(
+ gt_img=torch.from_numpy(gt),
+ mask=torch.from_numpy(mask),
+ img_channel_order='bgr')))
+ cls.predictions.append({
+ k: torch.from_numpy(deepcopy(v))
+ for (k, v) in cls.predictions[0].items()
+ })
+
+ for d, p in zip(cls.data_batch, cls.predictions):
+ d['output'] = p
+ cls.predictions = cls.data_batch
+
+ def test_mae(self):
+
+ # Single MAE
+ mae = MAE()
+ mae.process(self.data_batch, self.predictions)
+ result = mae.compute_metrics(mae.results)
+ assert 'MAE' in result
+ np.testing.assert_almost_equal(result['MAE'], 0.003921568627)
+
+ # Masked MAE
+ mae = MAE(mask_key='mask', prefix='MAE')
+ mae.process(self.data_batch, self.predictions)
+ result = mae.compute_metrics(mae.results)
+ assert 'MAE' in result
+ np.testing.assert_almost_equal(result['MAE'], 0.003921568627)
diff --git a/tests/test_evaluation/test_metrics/test_matting_mse.py b/tests/test_evaluation/test_metrics/test_matting_mse.py
new file mode 100644
index 0000000000..532f248928
--- /dev/null
+++ b/tests/test_evaluation/test_metrics/test_matting_mse.py
@@ -0,0 +1,108 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from pathlib import Path
+
+import numpy as np
+import pytest
+import torch
+
+from mmedit.datasets.transforms import LoadImageFromFile
+from mmedit.evaluation.metrics import MattingMSE
+
+
+class TestMattingMetrics:
+
+ @classmethod
+ def setup_class(cls):
+ # Make sure these values are immutable across different test cases.
+
+ # This test depends on the interface of loading
+ # if loading is changed, data should be change accordingly.
+ test_path = Path(__file__).parent.parent.parent
+ alpha_path = (
+ test_path / 'data' / 'matting_dataset' / 'alpha' / 'GT05.jpg')
+
+ results = dict(alpha_path=alpha_path)
+ config = dict(key='alpha')
+ image_loader = LoadImageFromFile(**config)
+ results = image_loader(results)
+ assert results['alpha'].ndim == 3
+
+ gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255
+ trimap = np.zeros((32, 32), dtype=np.uint8)
+ trimap[:16, :16] = 128
+ trimap[16:, 16:] = 255
+ # non-masked pred_alpha
+ pred_alpha = torch.zeros((32, 32), dtype=torch.uint8)
+ # masked pred_alpha
+ masked_pred_alpha = pred_alpha.clone()
+ masked_pred_alpha[trimap == 0] = 0
+ masked_pred_alpha[trimap == 255] = 255
+
+ gt_alpha = gt_alpha[..., None]
+ trimap = trimap[..., None]
+ # pred_alpha = pred_alpha.unsqueeze(0)
+ # masked_pred_alpha = masked_pred_alpha.unsqueeze(0)
+
+ cls.data_batch = [{
+ 'inputs': [],
+ 'data_samples': {
+ 'ori_trimap': trimap,
+ 'ori_alpha': gt_alpha,
+ },
+ }]
+
+ cls.data_samples = [d_['data_samples'] for d_ in cls.data_batch]
+
+ cls.bad_preds1_ = [{'pred_alpha': dict(data=pred_alpha)}]
+ # pred_alpha should be masked by trimap before evaluation
+ cls.bad_preds1 = copy.deepcopy(cls.data_samples)
+ for d, p in zip(cls.bad_preds1, cls.bad_preds1_):
+ d['output'] = p
+
+ cls.bad_preds2_ = [{'pred_alpha': dict(data=pred_alpha[0])}]
+ # pred_alpha should be 3 dimensional
+ cls.bad_preds2 = copy.deepcopy(cls.data_samples)
+ for d, p in zip(cls.bad_preds2, cls.bad_preds2_):
+ d['output'] = p
+
+ cls.good_preds_ = [{'pred_alpha': dict(data=masked_pred_alpha)}]
+ cls.good_preds = copy.deepcopy((cls.data_samples))
+ for d, p in zip(cls.good_preds, cls.good_preds_):
+ d['output'] = p
+
+ def test_mse(self):
+ """Test MattingMSE for evaluating predicted alpha matte."""
+
+ data_batch, bad_pred1, bad_pred2, good_pred = (
+ self.data_batch,
+ self.bad_preds1,
+ self.bad_preds2,
+ self.good_preds,
+ )
+
+ mse = MattingMSE()
+
+ with pytest.raises(ValueError):
+ mse.process(data_batch, bad_pred1)
+
+ with pytest.raises(ValueError):
+ mse.process(data_batch, bad_pred2)
+
+ # process 2 batches
+ mse.process(data_batch, good_pred)
+ mse.process(data_batch, good_pred)
+
+ assert mse.results == [
+ {
+ 'mse': 3.0,
+ },
+ {
+ 'mse': 3.0,
+ },
+ ]
+
+ res = mse.compute_metrics(mse.results)
+
+ assert list(res.keys()) == ['MattingMSE']
+ np.testing.assert_almost_equal(res['MattingMSE'], 3.0)
diff --git a/tests/test_evaluation/test_metrics/test_metrics_utils.py b/tests/test_evaluation/test_metrics/test_metrics_utils.py
index 2cf7417815..290682a3b0 100644
--- a/tests/test_evaluation/test_metrics/test_metrics_utils.py
+++ b/tests/test_evaluation/test_metrics/test_metrics_utils.py
@@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
+import pytest
from mmedit.evaluation.metrics import metrics_utils
+from mmedit.evaluation.metrics.metrics_utils import reorder_image
def test_average():
@@ -29,3 +31,21 @@ def test_obtain_data():
data_sample = {'data_samples': {key: img}}
result = metrics_utils.obtain_data(data_sample, key)
assert not (result - img).any()
+
+
+def test_reorder_image():
+ img_hw = np.ones((32, 32))
+ img_hwc = np.ones((32, 32, 3))
+ img_chw = np.ones((3, 32, 32))
+
+ with pytest.raises(ValueError):
+ reorder_image(img_hw, 'HH')
+
+ output = reorder_image(img_hw)
+ assert output.shape == (32, 32, 1)
+
+ output = reorder_image(img_hwc)
+ assert output.shape == (32, 32, 3)
+
+ output = reorder_image(img_chw, input_order='CHW')
+ assert output.shape == (32, 32, 3)
diff --git a/tests/test_evaluation/test_metrics/test_mse.py b/tests/test_evaluation/test_metrics/test_mse.py
new file mode 100644
index 0000000000..5b0fdfeeab
--- /dev/null
+++ b/tests/test_evaluation/test_metrics/test_mse.py
@@ -0,0 +1,51 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+
+import numpy as np
+import torch
+
+from mmedit.evaluation.metrics import MSE
+
+
+class TestPixelMetrics:
+
+ @classmethod
+ def setup_class(cls):
+
+ mask = np.ones((32, 32, 3)) * 2
+ mask[:16] *= 0
+ gt = np.ones((32, 32, 3)) * 2
+ data_sample = dict(gt_img=gt, mask=mask, gt_channel_order='bgr')
+ cls.data_batch = [dict(data_samples=data_sample)]
+ cls.predictions = [dict(pred_img=np.ones((32, 32, 3)))]
+
+ cls.data_batch.append(
+ dict(
+ data_samples=dict(
+ gt_img=torch.from_numpy(gt),
+ mask=torch.from_numpy(mask),
+ img_channel_order='bgr')))
+ cls.predictions.append({
+ k: torch.from_numpy(deepcopy(v))
+ for (k, v) in cls.predictions[0].items()
+ })
+
+ for d, p in zip(cls.data_batch, cls.predictions):
+ d['output'] = p
+ cls.predictions = cls.data_batch
+
+ def test_mse(self):
+
+ # Single MSE
+ mae = MSE()
+ mae.process(self.data_batch, self.predictions)
+ result = mae.compute_metrics(mae.results)
+ assert 'MSE' in result
+ np.testing.assert_almost_equal(result['MSE'], 0.000015378700496)
+
+ # Masked MSE
+ mae = MSE(mask_key='mask', prefix='MSE')
+ mae.process(self.data_batch, self.predictions)
+ result = mae.compute_metrics(mae.results)
+ assert 'MSE' in result
+ np.testing.assert_almost_equal(result['MSE'], 0.000015378700496)
diff --git a/tests/test_evaluation/test_metrics/test_pixel_metrics.py b/tests/test_evaluation/test_metrics/test_pixel_metrics.py
deleted file mode 100644
index 32476fb24c..0000000000
--- a/tests/test_evaluation/test_metrics/test_pixel_metrics.py
+++ /dev/null
@@ -1,199 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from copy import deepcopy
-
-import numpy as np
-import pytest
-import torch
-
-from mmedit.evaluation.metrics import MAE, MSE, PSNR, SNR, psnr
-from mmedit.evaluation.metrics.metrics_utils import reorder_image
-
-
-class TestPixelMetrics:
-
- @classmethod
- def setup_class(cls):
-
- mask = np.ones((32, 32, 3)) * 2
- mask[:16] *= 0
- gt = np.ones((32, 32, 3)) * 2
- data_sample = dict(gt_img=gt, mask=mask, gt_channel_order='bgr')
- cls.data_batch = [dict(data_samples=data_sample)]
- cls.predictions = [dict(pred_img=np.ones((32, 32, 3)))]
-
- cls.data_batch.append(
- dict(
- data_samples=dict(
- gt_img=torch.from_numpy(gt),
- mask=torch.from_numpy(mask),
- img_channel_order='bgr')))
- cls.predictions.append({
- k: torch.from_numpy(deepcopy(v))
- for (k, v) in cls.predictions[0].items()
- })
-
- for d, p in zip(cls.data_batch, cls.predictions):
- d['output'] = p
- cls.predictions = cls.data_batch
-
- def test_mae(self):
-
- # Single MAE
- mae = MAE()
- mae.process(self.data_batch, self.predictions)
- result = mae.compute_metrics(mae.results)
- assert 'MAE' in result
- np.testing.assert_almost_equal(result['MAE'], 0.003921568627)
-
- # Masked MAE
- mae = MAE(mask_key='mask', prefix='MAE')
- mae.process(self.data_batch, self.predictions)
- result = mae.compute_metrics(mae.results)
- assert 'MAE' in result
- np.testing.assert_almost_equal(result['MAE'], 0.003921568627)
-
- def test_mse(self):
-
- # Single MSE
- mae = MSE()
- mae.process(self.data_batch, self.predictions)
- result = mae.compute_metrics(mae.results)
- assert 'MSE' in result
- np.testing.assert_almost_equal(result['MSE'], 0.000015378700496)
-
- # Masked MSE
- mae = MSE(mask_key='mask', prefix='MSE')
- mae.process(self.data_batch, self.predictions)
- result = mae.compute_metrics(mae.results)
- assert 'MSE' in result
- np.testing.assert_almost_equal(result['MSE'], 0.000015378700496)
-
- def test_psnr(self):
-
- psnr_ = PSNR()
- psnr_.process(self.data_batch, self.predictions)
- result = psnr_.compute_metrics(psnr_.results)
- assert 'PSNR' in result
- np.testing.assert_almost_equal(result['PSNR'], 48.1308036)
-
- def test_snr(self):
-
- snr_ = SNR()
- snr_.process(self.data_batch, self.predictions)
- result = snr_.compute_metrics(snr_.results)
- assert 'SNR' in result
- np.testing.assert_almost_equal(result['SNR'], 6.0206001996994)
-
-
-def test_reorder_image():
- img_hw = np.ones((32, 32))
- img_hwc = np.ones((32, 32, 3))
- img_chw = np.ones((3, 32, 32))
-
- with pytest.raises(ValueError):
- reorder_image(img_hw, 'HH')
-
- output = reorder_image(img_hw)
- assert output.shape == (32, 32, 1)
-
- output = reorder_image(img_hwc)
- assert output.shape == (32, 32, 3)
-
- output = reorder_image(img_chw, input_order='CHW')
- assert output.shape == (32, 32, 3)
-
-
-def test_psnr():
- img_hw_1 = np.ones((32, 32))
- img_hwc_1 = np.ones((32, 32, 3))
- img_chw_1 = np.ones((3, 32, 32))
- img_hw_2 = np.ones((32, 32)) * 2
- img_hwc_2 = np.ones((32, 32, 3)) * 2
- img_chw_2 = np.ones((3, 32, 32)) * 2
-
- with pytest.raises(ValueError):
- psnr(img_hw_1, img_hw_2, crop_border=0, input_order='HH')
-
- with pytest.raises(ValueError):
- psnr(img_hw_1, img_hw_2, crop_border=0, convert_to='ABC')
-
- psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0)
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
- psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, input_order='HWC')
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
- psnr_result = psnr(img_chw_1, img_chw_2, crop_border=0, input_order='CHW')
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
-
- psnr_result = psnr(img_hw_1, img_hw_2, crop_border=2)
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
- psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=3, input_order='HWC')
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
- psnr_result = psnr(img_chw_1, img_chw_2, crop_border=4, input_order='CHW')
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
-
- psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to=None)
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
- psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to='Y')
- np.testing.assert_almost_equal(psnr_result, 49.4527218)
-
- # test float inf
- psnr_result = psnr(img_hw_1, img_hw_1, crop_border=0)
- assert psnr_result == float('inf')
-
- # test uint8
- img_hw_1 = np.zeros((32, 32), dtype=np.uint8)
- img_hw_2 = np.ones((32, 32), dtype=np.uint8) * 255
- psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0)
- assert psnr_result == 0
-
-
-def test_snr():
- img_hw_1 = np.ones((32, 32))
- img_hwc_1 = np.ones((32, 32, 3))
- img_chw_1 = np.ones((3, 32, 32))
- img_hw_2 = np.ones((32, 32)) * 2
- img_hwc_2 = np.ones((32, 32, 3)) * 2
- img_chw_2 = np.ones((3, 32, 32)) * 2
-
- with pytest.raises(ValueError):
- psnr(img_hw_1, img_hw_2, crop_border=0, input_order='HH')
-
- with pytest.raises(ValueError):
- psnr(img_hw_1, img_hw_2, crop_border=0, convert_to='ABC')
-
- psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0)
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
- psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, input_order='HWC')
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
- psnr_result = psnr(img_chw_1, img_chw_2, crop_border=0, input_order='CHW')
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
-
- psnr_result = psnr(img_hw_1, img_hw_2, crop_border=2)
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
- psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=3, input_order='HWC')
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
- psnr_result = psnr(img_chw_1, img_chw_2, crop_border=4, input_order='CHW')
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
-
- psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to=None)
- np.testing.assert_almost_equal(psnr_result, 48.1308036)
- psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to='Y')
- np.testing.assert_almost_equal(psnr_result, 49.4527218)
-
- # test float inf
- psnr_result = psnr(img_hw_1, img_hw_1, crop_border=0)
- assert psnr_result == float('inf')
-
- # test uint8
- img_hw_1 = np.zeros((32, 32), dtype=np.uint8)
- img_hw_2 = np.ones((32, 32), dtype=np.uint8) * 255
- psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0)
- assert psnr_result == 0
-
-
-t = TestPixelMetrics()
-t.setup_class()
-t.test_mae()
-t.test_mse()
-t.test_psnr()
-t.test_snr()
diff --git a/tests/test_evaluation/test_metrics/test_psnr.py b/tests/test_evaluation/test_metrics/test_psnr.py
new file mode 100644
index 0000000000..b220c651aa
--- /dev/null
+++ b/tests/test_evaluation/test_metrics/test_psnr.py
@@ -0,0 +1,88 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+
+import numpy as np
+import pytest
+import torch
+
+from mmedit.evaluation.metrics import PSNR, psnr
+
+
+class TestPixelMetrics:
+
+ @classmethod
+ def setup_class(cls):
+
+ mask = np.ones((32, 32, 3)) * 2
+ mask[:16] *= 0
+ gt = np.ones((32, 32, 3)) * 2
+ data_sample = dict(gt_img=gt, mask=mask, gt_channel_order='bgr')
+ cls.data_batch = [dict(data_samples=data_sample)]
+ cls.predictions = [dict(pred_img=np.ones((32, 32, 3)))]
+
+ cls.data_batch.append(
+ dict(
+ data_samples=dict(
+ gt_img=torch.from_numpy(gt),
+ mask=torch.from_numpy(mask),
+ img_channel_order='bgr')))
+ cls.predictions.append({
+ k: torch.from_numpy(deepcopy(v))
+ for (k, v) in cls.predictions[0].items()
+ })
+
+ for d, p in zip(cls.data_batch, cls.predictions):
+ d['output'] = p
+ cls.predictions = cls.data_batch
+
+ def test_psnr(self):
+
+ psnr_ = PSNR()
+ psnr_.process(self.data_batch, self.predictions)
+ result = psnr_.compute_metrics(psnr_.results)
+ assert 'PSNR' in result
+ np.testing.assert_almost_equal(result['PSNR'], 48.1308036)
+
+
+def test_psnr():
+ img_hw_1 = np.ones((32, 32))
+ img_hwc_1 = np.ones((32, 32, 3))
+ img_chw_1 = np.ones((3, 32, 32))
+ img_hw_2 = np.ones((32, 32)) * 2
+ img_hwc_2 = np.ones((32, 32, 3)) * 2
+ img_chw_2 = np.ones((3, 32, 32)) * 2
+
+ with pytest.raises(ValueError):
+ psnr(img_hw_1, img_hw_2, crop_border=0, input_order='HH')
+
+ with pytest.raises(ValueError):
+ psnr(img_hw_1, img_hw_2, crop_border=0, convert_to='ABC')
+
+ psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0)
+ np.testing.assert_almost_equal(psnr_result, 48.1308036)
+ psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, input_order='HWC')
+ np.testing.assert_almost_equal(psnr_result, 48.1308036)
+ psnr_result = psnr(img_chw_1, img_chw_2, crop_border=0, input_order='CHW')
+ np.testing.assert_almost_equal(psnr_result, 48.1308036)
+
+ psnr_result = psnr(img_hw_1, img_hw_2, crop_border=2)
+ np.testing.assert_almost_equal(psnr_result, 48.1308036)
+ psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=3, input_order='HWC')
+ np.testing.assert_almost_equal(psnr_result, 48.1308036)
+ psnr_result = psnr(img_chw_1, img_chw_2, crop_border=4, input_order='CHW')
+ np.testing.assert_almost_equal(psnr_result, 48.1308036)
+
+ psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to=None)
+ np.testing.assert_almost_equal(psnr_result, 48.1308036)
+ psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to='Y')
+ np.testing.assert_almost_equal(psnr_result, 49.4527218)
+
+ # test float inf
+ psnr_result = psnr(img_hw_1, img_hw_1, crop_border=0)
+ assert psnr_result == float('inf')
+
+ # test uint8
+ img_hw_1 = np.zeros((32, 32), dtype=np.uint8)
+ img_hw_2 = np.ones((32, 32), dtype=np.uint8) * 255
+ psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0)
+ assert psnr_result == 0
diff --git a/tests/test_evaluation/test_metrics/test_sad.py b/tests/test_evaluation/test_metrics/test_sad.py
new file mode 100644
index 0000000000..f05dd3d5b8
--- /dev/null
+++ b/tests/test_evaluation/test_metrics/test_sad.py
@@ -0,0 +1,108 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from pathlib import Path
+
+import numpy as np
+import pytest
+import torch
+
+from mmedit.datasets.transforms import LoadImageFromFile
+from mmedit.evaluation.metrics import SAD
+
+
+class TestMattingMetrics:
+
+ @classmethod
+ def setup_class(cls):
+ # Make sure these values are immutable across different test cases.
+
+ # This test depends on the interface of loading
+ # if loading is changed, data should be change accordingly.
+ test_path = Path(__file__).parent.parent.parent
+ alpha_path = (
+ test_path / 'data' / 'matting_dataset' / 'alpha' / 'GT05.jpg')
+
+ results = dict(alpha_path=alpha_path)
+ config = dict(key='alpha')
+ image_loader = LoadImageFromFile(**config)
+ results = image_loader(results)
+ assert results['alpha'].ndim == 3
+
+ gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255
+ trimap = np.zeros((32, 32), dtype=np.uint8)
+ trimap[:16, :16] = 128
+ trimap[16:, 16:] = 255
+ # non-masked pred_alpha
+ pred_alpha = torch.zeros((32, 32), dtype=torch.uint8)
+ # masked pred_alpha
+ masked_pred_alpha = pred_alpha.clone()
+ masked_pred_alpha[trimap == 0] = 0
+ masked_pred_alpha[trimap == 255] = 255
+
+ gt_alpha = gt_alpha[..., None]
+ trimap = trimap[..., None]
+ # pred_alpha = pred_alpha.unsqueeze(0)
+ # masked_pred_alpha = masked_pred_alpha.unsqueeze(0)
+
+ cls.data_batch = [{
+ 'inputs': [],
+ 'data_samples': {
+ 'ori_trimap': trimap,
+ 'ori_alpha': gt_alpha,
+ },
+ }]
+
+ cls.data_samples = [d_['data_samples'] for d_ in cls.data_batch]
+
+ cls.bad_preds1_ = [{'pred_alpha': dict(data=pred_alpha)}]
+ # pred_alpha should be masked by trimap before evaluation
+ cls.bad_preds1 = copy.deepcopy(cls.data_samples)
+ for d, p in zip(cls.bad_preds1, cls.bad_preds1_):
+ d['output'] = p
+
+ cls.bad_preds2_ = [{'pred_alpha': dict(data=pred_alpha[0])}]
+ # pred_alpha should be 3 dimensional
+ cls.bad_preds2 = copy.deepcopy(cls.data_samples)
+ for d, p in zip(cls.bad_preds2, cls.bad_preds2_):
+ d['output'] = p
+
+ cls.good_preds_ = [{'pred_alpha': dict(data=masked_pred_alpha)}]
+ cls.good_preds = copy.deepcopy((cls.data_samples))
+ for d, p in zip(cls.good_preds, cls.good_preds_):
+ d['output'] = p
+
+ def test_sad(self):
+ """Test SAD for evaluating predicted alpha matte."""
+
+ data_batch, bad_pred1, bad_pred2, good_pred = (
+ self.data_batch,
+ self.bad_preds1,
+ self.bad_preds2,
+ self.good_preds,
+ )
+
+ sad = SAD()
+
+ with pytest.raises(ValueError):
+ sad.process(data_batch, bad_pred1)
+
+ with pytest.raises(ValueError):
+ sad.process(data_batch, bad_pred2)
+
+ # process 2 batches
+ sad.process(data_batch, good_pred)
+ sad.process(data_batch, good_pred)
+
+ assert sad.results == [
+ {
+ 'sad': 0.768,
+ },
+ {
+ 'sad': 0.768,
+ },
+ ]
+
+ res = sad.compute_metrics(sad.results)
+
+ assert list(res.keys()) == ['SAD']
+ np.testing.assert_almost_equal(res['SAD'], 0.768)
diff --git a/tests/test_evaluation/test_metrics/test_snr.py b/tests/test_evaluation/test_metrics/test_snr.py
new file mode 100644
index 0000000000..4d87eda969
--- /dev/null
+++ b/tests/test_evaluation/test_metrics/test_snr.py
@@ -0,0 +1,90 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+
+import numpy as np
+import pytest
+import torch
+
+from mmedit.evaluation.metrics import SNR, snr
+
+
+class TestPixelMetrics:
+
+ @classmethod
+ def setup_class(cls):
+
+ mask = np.ones((32, 32, 3)) * 2
+ mask[:16] *= 0
+ gt = np.ones((32, 32, 3)) * 2
+ data_sample = dict(gt_img=gt, mask=mask, gt_channel_order='bgr')
+ cls.data_batch = [dict(data_samples=data_sample)]
+ cls.predictions = [dict(pred_img=np.ones((32, 32, 3)))]
+
+ cls.data_batch.append(
+ dict(
+ data_samples=dict(
+ gt_img=torch.from_numpy(gt),
+ mask=torch.from_numpy(mask),
+ img_channel_order='bgr')))
+ cls.predictions.append({
+ k: torch.from_numpy(deepcopy(v))
+ for (k, v) in cls.predictions[0].items()
+ })
+
+ for d, p in zip(cls.data_batch, cls.predictions):
+ d['output'] = p
+ cls.predictions = cls.data_batch
+
+ def test_snr(self):
+
+ snr_ = SNR()
+ snr_.process(self.data_batch, self.predictions)
+ result = snr_.compute_metrics(snr_.results)
+ assert 'SNR' in result
+ np.testing.assert_almost_equal(result['SNR'], 6.0206001996994)
+
+
+def test_snr():
+ img_hw_1 = np.ones((32, 32)) * 2
+ img_hwc_1 = np.ones((32, 32, 3)) * 2
+ img_chw_1 = np.ones((3, 32, 32)) * 2
+ img_hw_2 = np.ones((32, 32))
+ img_hwc_2 = np.ones((32, 32, 3))
+ img_chw_2 = np.ones((3, 32, 32))
+
+ with pytest.raises(ValueError):
+ snr(img_hw_1, img_hw_2, crop_border=0, input_order='HH')
+
+ with pytest.raises(ValueError):
+ snr(img_hw_1, img_hw_2, crop_border=0, convert_to='ABC')
+
+ snr_result = snr(img_hw_1, img_hw_2, crop_border=0)
+ np.testing.assert_almost_equal(snr_result, 6.020600199699402)
+ snr_result = snr(img_hwc_1, img_hwc_2, crop_border=0, input_order='HWC')
+ np.testing.assert_almost_equal(snr_result, 6.020600199699402)
+ print(snr_result)
+ snr_result = snr(img_chw_1, img_chw_2, crop_border=0, input_order='CHW')
+ np.testing.assert_almost_equal(snr_result, 6.020600199699402)
+ print(snr_result)
+
+ snr_result = snr(img_hw_1, img_hw_2, crop_border=2)
+ np.testing.assert_almost_equal(snr_result, 6.020600199699402)
+ snr_result = snr(img_hwc_1, img_hwc_2, crop_border=3, input_order='HWC')
+ np.testing.assert_almost_equal(snr_result, 6.020600199699402)
+ snr_result = snr(img_chw_1, img_chw_2, crop_border=4, input_order='CHW')
+ np.testing.assert_almost_equal(snr_result, 6.020600199699402)
+
+ snr_result = snr(img_hwc_1, img_hwc_2, crop_border=0, convert_to=None)
+ np.testing.assert_almost_equal(snr_result, 6.020600199699402)
+ snr_result = snr(img_hwc_1, img_hwc_2, crop_border=0, convert_to='Y')
+ np.testing.assert_almost_equal(snr_result, 26.290040016174316)
+
+ # test float inf
+ snr_result = snr(img_hw_1, img_hw_1, crop_border=0)
+ assert snr_result == float('inf')
+
+ # test uint8
+ img_hw_1 = np.ones((32, 32), dtype=np.uint8)
+ img_hw_2 = np.zeros((32, 32), dtype=np.uint8)
+ snr_result = snr(img_hw_1, img_hw_2, crop_border=0)
+ assert snr_result == 0
diff --git a/tests/test_models/test_editors/test_inst_colorization/test_color_utils.py b/tests/test_models/test_editors/test_inst_colorization/test_color_utils.py
new file mode 100644
index 0000000000..ff113d0917
--- /dev/null
+++ b/tests/test_models/test_editors/test_inst_colorization/test_color_utils.py
@@ -0,0 +1,153 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import torch
+
+from mmedit.models.editors.inst_colorization import color_utils
+
+
+class TestColorUtils:
+ color_data_opt = dict(
+ ab_thresh=0,
+ p=1.0,
+ sample_PS=[
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ ],
+ ab_norm=110,
+ ab_max=110.,
+ ab_quant=10.,
+ l_norm=100.,
+ l_cent=50.,
+ mask_cent=0.5)
+
+ def test_xyz2lab(self):
+ xyz = torch.rand(1, 3, 8, 8)
+ lab = color_utils.xyz2lab(xyz)
+
+ sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None]
+ xyz_scale = xyz / sc
+ mask = (xyz_scale > .008856).type(torch.FloatTensor)
+
+ xyz_int = xyz_scale**(1 / 3.) * mask + (7.787 * xyz_scale +
+ 16. / 116.) * (1 - mask)
+ L = 116. * xyz_int[:, 1, :, :] - 16.
+ a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :])
+ b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :])
+
+ assert lab.shape == (1, 3, 8, 8)
+ assert lab.equal(
+ torch.cat((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]),
+ dim=1))
+
+ def test_rgb2xyz(self):
+ rgb = torch.rand(1, 3, 8, 8)
+ xyz = color_utils.rgb2xyz(rgb)
+
+ mask = (rgb > .04045).type(torch.FloatTensor)
+ rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask)
+
+ x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] \
+ + .180423 * rgb[:, 2, :, :]
+ y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] \
+ + .072169 * rgb[:, 2, :, :]
+ z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] \
+ + .950227 * rgb[:, 2, :, :]
+
+ assert xyz.shape == (1, 3, 8, 8)
+ assert xyz.equal(
+ torch.cat((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]),
+ dim=1))
+
+ def test_rgb2lab(self):
+ rgb = torch.rand(1, 3, 8, 8)
+ lab = color_utils.rgb2lab(rgb, self.color_data_opt)
+ _lab = color_utils.xyz2lab(color_utils.rgb2xyz(rgb))
+
+ l_rs = (_lab[:, [0], :, :] -
+ self.color_data_opt['l_cent']) / self.color_data_opt['l_norm']
+ ab_rs = _lab[:, 1:, :, :] / self.color_data_opt['ab_norm']
+
+ assert lab.shape == (1, 3, 8, 8)
+ assert lab.equal(torch.cat((l_rs, ab_rs), dim=1))
+
+ def test_lab2rgb(self):
+ lab = torch.rand(1, 3, 8, 8)
+ rgb = color_utils.lab2rgb(lab, self.color_data_opt)
+
+ L = lab[:, [0], :, :] * self.color_data_opt[
+ 'l_norm'] + self.color_data_opt['l_cent']
+ AB = lab[:, 1:, :, :] * self.color_data_opt['ab_norm']
+
+ lab = torch.cat((L, AB), dim=1)
+
+ assert rgb.shape == (1, 3, 8, 8)
+ assert rgb.equal(color_utils.xyz2rgb(color_utils.lab2xyz(lab)))
+
+ def test_lab2xyz(self):
+ lab = torch.rand(1, 3, 8, 8)
+ xyz = color_utils.lab2xyz(lab)
+ y_int = (lab[:, 0, :, :] + 16.) / 116.
+ x_int = (lab[:, 1, :, :] / 500.) + y_int
+ z_int = y_int - (lab[:, 2, :, :] / 200.)
+ z_int = torch.max(torch.Tensor((0, )), z_int)
+
+ out = torch.cat(
+ (x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]),
+ dim=1)
+ mask = (out > .2068966).type(torch.FloatTensor)
+ sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None]
+ out = (out**3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask)
+ target = sc * out
+ assert xyz.shape == (1, 3, 8, 8)
+ assert xyz.equal(target)
+
+ def test_xyz2rgb(self):
+ xyz = torch.rand(1, 3, 8, 8)
+
+ rgb = color_utils.xyz2rgb(xyz)
+
+ r = 3.24048134 * xyz[:, 0, :, :] - 1.53715152 * xyz[:, 1, :, :] \
+ - 0.49853633 * xyz[:, 2, :, :]
+ g = -0.96925495 * xyz[:, 0, :, :] + 1.87599 * xyz[:, 1, :, :] \
+ + .04155593 * xyz[:, 2, :, :]
+ b = .05564664 * xyz[:, 0, :, :] - .20404134 * xyz[:, 1, :, :] \
+ + 1.05731107 * xyz[:, 2, :, :]
+
+ _rgb = torch.cat(
+ (r[:, None, :, :], g[:, None, :, :], b[:, None, :, :]), dim=1)
+ _rgb = torch.max(_rgb, torch.zeros_like(_rgb))
+
+ mask = (_rgb > .0031308).type(torch.FloatTensor)
+
+ assert rgb.shape == (1, 3, 8, 8) and mask.shape == (1, 3, 8, 8)
+ assert rgb.equal((1.055 * (_rgb**(1. / 2.4)) - 0.055) * mask +
+ 12.92 * _rgb * (1 - mask))
+
+ def test_get_colorization_data(self):
+ data_raw = torch.rand(1, 3, 8, 8)
+
+ res = color_utils.get_colorization_data(data_raw, self.color_data_opt)
+
+ assert isinstance(res, dict)
+ assert 'A' in res.keys() and 'B' in res.keys() \
+ and 'hint_B' in res.keys() and 'mask_B' in res.keys()
+ assert res['A'].shape == res['mask_B'].shape == (1, 1, 8, 8)
+ assert res['hint_B'].shape == res['B'].shape == (1, 2, 8, 8)
+
+ def test_encode_ab_ind(self):
+ data_ab = torch.rand(1, 2, 8, 8)
+ data_q = color_utils.encode_ab_ind(data_ab, self.color_data_opt)
+ A = 2 * 110. / 10. + 1
+
+ data_ab_rs = torch.round((data_ab * 110 + 110.) / 10.)
+
+ assert data_q.shape == (1, 1, 8, 8)
+ assert data_q.equal(data_ab_rs[:, [0], :, :] * A +
+ data_ab_rs[:, [1], :, :])
diff --git a/tests/test_models/test_editors/test_inst_colorization/test_colorization_net.py b/tests/test_models/test_editors/test_inst_colorization/test_colorization_net.py
new file mode 100644
index 0000000000..c6d4454cab
--- /dev/null
+++ b/tests/test_models/test_editors/test_inst_colorization/test_colorization_net.py
@@ -0,0 +1,42 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmedit.registry import MODULES
+
+
+def test_colorization_net():
+
+ model_cfg = dict(
+ type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch')
+
+ # build model
+ model = MODULES.build(model_cfg)
+
+ # test attributes
+ assert model.__class__.__name__ == 'ColorizationNet'
+
+ # prepare data
+ input_A = torch.rand(1, 1, 256, 256)
+ input_B = torch.rand(1, 2, 256, 256)
+ mask_B = torch.rand(1, 1, 256, 256)
+
+ target_shape = (1, 2, 256, 256)
+
+ # test on cpu
+ (out_class, out_reg, feature_map) = model(input_A, input_B, mask_B)
+ assert isinstance(feature_map, dict)
+ assert feature_map['conv1_2'].shape == (1, 64, 256, 256) \
+ and feature_map['out_reg'].shape == target_shape
+
+ # test on gpu
+ if torch.cuda.is_available():
+ model = model.cuda()
+ input_A = input_A.cuda()
+ input_B = input_B.cuda()
+ mask_B = mask_B.cuda()
+ (out_class, out_reg, feature_map) = \
+ model(input_A, input_B, mask_B)
+
+ assert isinstance(feature_map, dict)
+ for item in feature_map.keys():
+ assert torch.is_tensor(feature_map[item])
diff --git a/tests/test_models/test_editors/test_inst_colorization/test_fusion_net.py b/tests/test_models/test_editors/test_inst_colorization/test_fusion_net.py
new file mode 100644
index 0000000000..929c9c8eb1
--- /dev/null
+++ b/tests/test_models/test_editors/test_inst_colorization/test_fusion_net.py
@@ -0,0 +1,81 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import torch
+
+from mmedit.registry import MODULES
+
+
+def test_fusion_net():
+
+ model_cfg = dict(
+ type='FusionNet', input_nc=4, output_nc=2, norm_type='batch')
+
+ # build model
+ model = MODULES.build(model_cfg)
+
+ # test attributes
+ assert model.__class__.__name__ == 'FusionNet'
+
+ # prepare data
+ input_A = torch.rand(1, 1, 256, 256)
+ input_B = torch.rand(1, 2, 256, 256)
+ mask_B = torch.rand(1, 1, 256, 256)
+
+ instance_feature = dict(
+ conv1_2=torch.rand(1, 64, 256, 256),
+ conv2_2=torch.rand(1, 128, 128, 128),
+ conv3_3=torch.rand(1, 256, 64, 64),
+ conv4_3=torch.rand(1, 512, 32, 32),
+ conv5_3=torch.rand(1, 512, 32, 32),
+ conv6_3=torch.rand(1, 512, 32, 32),
+ conv7_3=torch.rand(1, 512, 32, 32),
+ conv8_up=torch.rand(1, 256, 64, 64),
+ conv8_3=torch.rand(1, 256, 64, 64),
+ conv9_up=torch.rand(1, 128, 128, 128),
+ conv9_3=torch.rand(1, 128, 128, 128),
+ conv10_up=torch.rand(1, 128, 256, 256),
+ conv10_2=torch.rand(1, 128, 256, 256),
+ )
+
+ target_shape = (1, 2, 256, 256)
+
+ box_info_box = [
+ torch.tensor([[175, 29, 96, 54, 52, 106], [14, 191, 84, 61, 51, 111],
+ [117, 64, 115, 46, 75, 95], [41, 165, 121, 47, 50, 88],
+ [46, 136, 94, 45, 74, 117], [79, 124, 62, 115, 53, 79],
+ [156, 64, 77, 138, 36, 41], [200, 48, 114, 131, 8, 11],
+ [115, 78, 92, 81, 63, 83]]),
+ torch.tensor([[87, 15, 48, 27, 26, 53], [7, 96, 42, 31, 25, 55],
+ [58, 32, 57, 23, 38, 48], [20, 83, 60, 24, 25, 44],
+ [23, 68, 47, 23, 37, 58], [39, 62, 31, 58, 27, 39],
+ [78, 32, 38, 69, 18, 21], [100, 24, 57, 66, 4, 5],
+ [57, 39, 46, 41, 32, 41]]),
+ torch.tensor([[43, 8, 24, 14, 13, 26], [3, 48, 21, 16, 13, 27],
+ [29, 16, 28, 12, 19, 24], [10, 42, 30, 12, 12, 22],
+ [11, 34, 23, 12, 19, 29], [19, 31, 15, 29, 14, 20],
+ [39, 16, 19, 35, 9, 10], [50, 12, 28, 33, 2, 3],
+ [28, 20, 23, 21, 16, 20]]),
+ torch.tensor([[21, 4, 12, 7, 7, 13], [1, 24, 10, 8, 7, 14],
+ [14, 8, 14, 6, 10, 12], [5, 21, 15, 6, 6, 11],
+ [5, 17, 11, 6, 10, 15], [9, 16, 7, 15, 7, 10],
+ [19, 8, 9, 18, 5, 5], [25, 6, 14, 17, 1, 1],
+ [14, 10, 11, 11, 8, 10]])
+ ]
+
+ # test on cpu
+ out = model(input_A, input_B, mask_B, instance_feature, box_info_box)
+ assert torch.is_tensor(out)
+ assert out.shape == target_shape
+
+ # test on gpu
+ if torch.cuda.is_available():
+ model = model.cuda()
+ input_A = input_A.cuda()
+ input_B = input_B.cuda()
+ mask_B = mask_B.cuda()
+ for item in instance_feature.keys():
+ instance_feature[item] = instance_feature[item].cuda()
+ box_info_box = [i.cuda() for i in box_info_box]
+ output = model(input_A, input_B, mask_B, instance_feature,
+ box_info_box)
+ assert torch.is_tensor(output)
diff --git a/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py b/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py
new file mode 100644
index 0000000000..5d769b2134
--- /dev/null
+++ b/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py
@@ -0,0 +1,122 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import platform
+import unittest
+
+import pytest
+import torch
+
+from mmedit.registry import BACKBONES
+from mmedit.structures import EditDataSample, PixelData
+from mmedit.utils import register_all_modules
+
+
+@pytest.mark.skipif(
+ 'win' in platform.system().lower() and 'cu' in torch.__version__,
+ reason='skip on windows-cuda due to limited RAM.')
+class TestInstColorization:
+
+ def test_inst_colorization(self):
+ if not torch.cuda.is_available():
+ # RoI pooling only support in GPU
+ return unittest.skip('test requires GPU and torch+cuda')
+
+ register_all_modules()
+ model_cfg = dict(
+ type='InstColorization',
+ data_preprocessor=dict(
+ type='EditDataPreprocessor',
+ mean=[127.5],
+ std=[127.5],
+ ),
+ image_model=dict(
+ type='ColorizationNet',
+ input_nc=4,
+ output_nc=2,
+ norm_type='batch'),
+ instance_model=dict(
+ type='ColorizationNet',
+ input_nc=4,
+ output_nc=2,
+ norm_type='batch'),
+ fusion_model=dict(
+ type='FusionNet', input_nc=4, output_nc=2, norm_type='batch'),
+ color_data_opt=dict(
+ ab_thresh=0,
+ p=1.0,
+ sample_PS=[
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ ],
+ ab_norm=110,
+ ab_max=110.,
+ ab_quant=10.,
+ l_norm=100.,
+ l_cent=50.,
+ mask_cent=0.5),
+ which_direction='AtoB',
+ loss=dict(type='HuberLoss', delta=.01))
+
+ model = BACKBONES.build(model_cfg)
+
+ # test attributes
+ assert model.__class__.__name__ == 'InstColorization'
+
+ # prepare data
+ inputs = torch.rand(1, 3, 256, 256)
+ target_shape = (1, 3, 256, 256)
+
+ data_sample = EditDataSample(gt_img=PixelData(data=inputs))
+ metainfo = dict(
+ cropped_img=PixelData(data=torch.rand(9, 3, 256, 256)),
+ box_info=torch.tensor([[175, 29, 96, 54, 52, 106],
+ [14, 191, 84, 61, 51, 111],
+ [117, 64, 115, 46, 75, 95],
+ [41, 165, 121, 47, 50, 88],
+ [46, 136, 94, 45, 74, 117],
+ [79, 124, 62, 115, 53, 79],
+ [156, 64, 77, 138, 36, 41],
+ [200, 48, 114, 131, 8, 11],
+ [115, 78, 92, 81, 63, 83]]),
+ box_info_2x=torch.tensor([[87, 15, 48, 27, 26, 53],
+ [7, 96, 42, 31, 25, 55],
+ [58, 32, 57, 23, 38, 48],
+ [20, 83, 60, 24, 25, 44],
+ [23, 68, 47, 23, 37, 58],
+ [39, 62, 31, 58, 27, 39],
+ [78, 32, 38, 69, 18, 21],
+ [100, 24, 57, 66, 4, 5],
+ [57, 39, 46, 41, 32, 41]]),
+ box_info_4x=torch.tensor([[43, 8, 24, 14, 13, 26],
+ [3, 48, 21, 16, 13, 27],
+ [29, 16, 28, 12, 19, 24],
+ [10, 42, 30, 12, 12, 22],
+ [11, 34, 23, 12, 19, 29],
+ [19, 31, 15, 29, 14, 20],
+ [39, 16, 19, 35, 9, 10],
+ [50, 12, 28, 33, 2, 3],
+ [28, 20, 23, 21, 16, 20]]),
+ box_info_8x=torch.tensor([[21, 4, 12, 7, 7, 13],
+ [1, 24, 10, 8, 7, 14],
+ [14, 8, 14, 6, 10, 12],
+ [5, 21, 15, 6, 6, 11],
+ [5, 17, 11, 6, 10, 15],
+ [9, 16, 7, 15, 7, 10],
+ [19, 8, 9, 18, 5, 5],
+ [25, 6, 14, 17, 1, 1],
+ [14, 10, 11, 11, 8, 10]]),
+ empty_box=False)
+ data_sample.set_metainfo(metainfo=metainfo)
+
+ data = dict(inputs=inputs, data_samples=[data_sample])
+
+ res = model(mode='tensor', **data)
+
+ assert torch.is_tensor(res)
+ assert res.shape == target_shape
diff --git a/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py b/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py
new file mode 100644
index 0000000000..8293eb116e
--- /dev/null
+++ b/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import platform
+
+import pytest
+import torch
+
+from mmedit.models.editors.inst_colorization.weight_layer import WeightLayer
+
+
+@pytest.mark.skipif(
+ 'win' in platform.system().lower() and 'cu' in torch.__version__,
+ reason='skip on windows-cuda due to limited RAM.')
+def test_weight_layer():
+
+ weight_layer = WeightLayer(64)
+
+ instance_feature_conv1_2 = torch.rand(1, 64, 256, 256)
+ conv1_2 = torch.rand(1, 64, 256, 256)
+ box_info = torch.tensor([[175, 29, 96, 54, 52, 106],
+ [14, 191, 84, 61, 51, 111],
+ [117, 64, 115, 46, 75, 95],
+ [41, 165, 121, 47, 50, 88],
+ [46, 136, 94, 45, 74, 117],
+ [79, 124, 62, 115, 53, 79],
+ [156, 64, 77, 138, 36, 41],
+ [200, 48, 114, 131, 8, 11],
+ [115, 78, 92, 81, 63, 83]])
+ conv1_2 = weight_layer(instance_feature_conv1_2, conv1_2, box_info)
+
+ assert conv1_2.shape == instance_feature_conv1_2.shape
diff --git a/tests/test_models/test_editors/test_liif/test_liif_net.py b/tests/test_models/test_editors/test_liif/test_liif_net.py
index c1642a06f6..ab8f409adb 100644
--- a/tests/test_models/test_editors/test_liif/test_liif_net.py
+++ b/tests/test_models/test_editors/test_liif/test_liif_net.py
@@ -1,9 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import platform
+
+import pytest
import torch
from mmedit.registry import BACKBONES
+@pytest.mark.skipif(
+ 'win' in platform.system().lower() and 'cu' in torch.__version__,
+ reason='skip on windows-cuda due to limited RAM.')
def test_liif_edsr_net():
model_cfg = dict(
diff --git a/tests/test_models/test_editors/test_mspie/test_pe_singan_generator.py b/tests/test_models/test_editors/test_mspie/test_pe_singan_generator.py
index 6e14b7b5af..144baaf8fe 100644
--- a/tests/test_models/test_editors/test_mspie/test_pe_singan_generator.py
+++ b/tests/test_models/test_editors/test_mspie/test_pe_singan_generator.py
@@ -83,3 +83,13 @@ def test_singan_gen_pe(self):
res = gen(self.input_sample, self.fixed_noises, self.noise_weights,
'rand', 2)
assert res.shape == (1, 3, 12, 12)
+
+ gen = SinGANMSGeneratorPE(
+ interp_pad=True, noise_with_pad=True, **self.default_args)
+ res = gen(None, self.fixed_noises, self.noise_weights, 'rand', 2)
+ assert res.shape == (1, 3, 6, 6)
+
+ gen = SinGANMSGeneratorPE(
+ interp_pad=True, noise_with_pad=False, **self.default_args)
+ res = gen(None, self.fixed_noises, self.noise_weights, 'rand', 2)
+ assert res.shape == (1, 3, 12, 12)
diff --git a/tests/test_models/test_editors/test_pggan/test_pggan.py b/tests/test_models/test_editors/test_pggan/test_pggan.py
index 19a9af8bcb..ccec7eec7c 100644
--- a/tests/test_models/test_editors/test_pggan/test_pggan.py
+++ b/tests/test_models/test_editors/test_pggan/test_pggan.py
@@ -72,6 +72,10 @@ def test_pggan_cpu(self):
assert np.isclose(pggan._actual_nkimgs[-1], 0.012, atol=1e-8)
# test forward
+ outputs = pggan.forward(dict(img=torch.randn(3, 3, 16, 16)))
+ assert len(outputs) == 3
+ assert all(['gt_img' in out for out in outputs])
+
outputs = pggan.forward(dict(num_batches=2))
assert len(outputs) == 2
assert all([out.fake_img.data.shape == (3, 16, 16) for out in outputs])
diff --git a/tests/test_models/test_editors/test_rdn/test_rdn_net.py b/tests/test_models/test_editors/test_rdn/test_rdn_net.py
index ba84ed32ca..9e3b939be8 100644
--- a/tests/test_models/test_editors/test_rdn/test_rdn_net.py
+++ b/tests/test_models/test_editors/test_rdn/test_rdn_net.py
@@ -19,6 +19,7 @@ def test_rdn():
in_channels=3,
out_channels=3,
mid_channels=64,
+ channel_growth=32,
num_blocks=16,
upscale_factor=scale)
diff --git a/tests/test_models/test_editors/test_real_esrgan/test_real_esrgan.py b/tests/test_models/test_editors/test_real_esrgan/test_real_esrgan.py
index 3747e74fcd..7fa70cfa26 100644
--- a/tests/test_models/test_editors/test_real_esrgan/test_real_esrgan.py
+++ b/tests/test_models/test_editors/test_real_esrgan/test_real_esrgan.py
@@ -51,6 +51,7 @@ def test_real_esrgan(init_weights):
is_use_sharpened_gt_in_pixel=False,
is_use_sharpened_gt_in_percep=False,
is_use_sharpened_gt_in_gan=False,
+ is_use_ema=False,
train_cfg=None,
test_cfg=None,
data_preprocessor=EditDataPreprocessor())
@@ -88,6 +89,12 @@ def test_real_esrgan(init_weights):
output = model.val_step(data)
assert output[0].output.pred_img.data.shape == (3, 128, 128)
+ # val_ema
+ model.generator_ema = model.generator
+ model.is_use_ema = True
+ output = model.val_step(data)
+ assert output[0].output.pred_img.data.shape == (3, 128, 128)
+
# feat
output = model(torch.rand(1, 3, 32, 32), mode='tensor')
assert output.shape == (1, 3, 128, 128)
diff --git a/tests/test_models/test_editors/test_singan/test_singan.py b/tests/test_models/test_editors/test_singan/test_singan.py
index a007d4f9e0..4b0bf73030 100644
--- a/tests/test_models/test_editors/test_singan/test_singan.py
+++ b/tests/test_models/test_editors/test_singan/test_singan.py
@@ -75,3 +75,53 @@ def test_singan_cpu(self):
elif i in [4, 5]:
assert singan.curr_stage == 2
assert img.shape[-2:] == (32, 32)
+
+ outputs = singan.forward(
+ dict(num_batches=1, get_prev_res=True), None)
+ assert all([hasattr(out, 'prev_res_list') for out in outputs])
+
+ # test forward singan with ema
+ singan = SinGAN(
+ self.generator,
+ self.disc,
+ num_scales=3,
+ data_preprocessor=self.data_preprocessor,
+ noise_weight_init=self.noise_weight_init,
+ iters_per_scale=self.iters_per_scale,
+ lr_scheduler_args=self.lr_scheduler_args,
+ ema_confg=dict(type='ExponentialMovingAverage'))
+ optim_wrapper_dict_builder = SinGANOptimWrapperConstructor(
+ self.optim_wrapper_cfg)
+ optim_wrapper_dict = optim_wrapper_dict_builder(singan)
+
+ for i in range(6):
+ singan.train_step(self.data_batch, optim_wrapper_dict)
+ message_hub.update_info('iter', message_hub.get_info('iter') + 1)
+
+ outputs = singan.forward(
+ dict(num_batches=1, sample_model='ema/orig'), None)
+
+ img = torch.stack([out.orig.fake_img.data for out in outputs],
+ dim=0)
+ img_ema = torch.stack([out.ema.fake_img.data for out in outputs],
+ dim=0)
+ if i in [0, 1]:
+ assert singan.curr_stage == 0
+ assert img.shape[-2:] == (25, 25)
+ assert img_ema.shape[-2:] == (25, 25)
+ elif i in [2, 3]:
+ assert singan.curr_stage == 1
+ assert img.shape[-2:] == (30, 30)
+ assert img_ema.shape[-2:] == (30, 30)
+ elif i in [4, 5]:
+ assert singan.curr_stage == 2
+ assert img.shape[-2:] == (32, 32)
+ assert img_ema.shape[-2:] == (32, 32)
+
+ outputs = singan.forward(
+ dict(
+ num_batches=1, sample_model='ema/orig', get_prev_res=True),
+ None)
+
+ assert all([hasattr(out.orig, 'prev_res_list') for out in outputs])
+ assert all([hasattr(out.ema, 'prev_res_list') for out in outputs])
diff --git a/tests/test_models/test_editors/test_stylegan3/test_stylegan3_utils.py b/tests/test_models/test_editors/test_stylegan3/test_stylegan3_utils.py
index 4e906daaf2..b995f0baae 100644
--- a/tests/test_models/test_editors/test_stylegan3/test_stylegan3_utils.py
+++ b/tests/test_models/test_editors/test_stylegan3/test_stylegan3_utils.py
@@ -16,6 +16,13 @@ def test_integer_transformation():
print(z.shape)
print(m.shape)
+ # cover more lines
+ t = torch.zeros(2)
+ z, m = apply_integer_translation(x, t[0], t[1])
+
+ t = torch.ones(2) * 2
+ z, m = apply_integer_translation(x, t[0], t[1])
+
def test_fractional_translation():
x = torch.randn(1, 3, 16, 16)
@@ -24,6 +31,13 @@ def test_fractional_translation():
print(z.shape)
print(m.shape)
+ # cover more lines
+ t = torch.zeros(2)
+ z, m = apply_fractional_translation(x, t[0], t[1])
+
+ t = torch.ones(2) * 2
+ z, m = apply_fractional_translation(x, t[0], t[1])
+
@pytest.mark.skipif(
digit_version(TORCH_VERSION) < digit_version('1.8.0'),
diff --git a/tests/test_utils/test_io_utils.py b/tests/test_utils/test_io_utils.py
index 1fc84dc4b7..ebd8ab18b4 100644
--- a/tests/test_utils/test_io_utils.py
+++ b/tests/test_utils/test_io_utils.py
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from mmedit.utils import download_from_url
+from mmedit.utils.io_utils import download_from_url
def test_download_from_url():
diff --git a/tools/dataset_converters/image_translation/README.md b/tools/dataset_converters/image_translation/README.md
index 85cec47c16..51e5312cdd 100644
--- a/tools/dataset_converters/image_translation/README.md
+++ b/tools/dataset_converters/image_translation/README.md
@@ -67,7 +67,7 @@ test_dataloader = dict(
```
Here, we adopt `LoadPairedImageFromFile` to load a paired image as the common loader does and crops
-it into two images with the same shape in different domains. As shown in the example, `pipeline` provides important data pipeline to process images, including loading from file system, resizing, cropping, flipping, transferring to `torch.Tensor` and packing to `GenDataSample`. All of supported data pipelines can be found in `mmedit/datasets/transforms`.
+it into two images with the same shape in different domains. As shown in the example, `pipeline` provides important data pipeline to process images, including loading from file system, resizing, cropping, flipping, transferring to `torch.Tensor` and packing to `EditDataSample`. All of supported data pipelines can be found in `mmedit/datasets/transforms`.
For unpaired-data trained translation model like CycleGAN , `UnpairedImageDataset` is designed to train such translation models. Here is an example config for horse2zebra dataset:
@@ -99,17 +99,15 @@ train_pipeline = [
dict(type='Flip', keys=[f'img_{domain_a}'], direction='horizontal'),
dict(type='Flip', keys=[f'img_{domain_b}'], direction='horizontal'),
dict(
- type='PackGenInputs',
- keys=[f'img_{domain_a}', f'img_{domain_b}'],
- meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path'])
+ type='PackEditInputs',
+ keys=[f'img_{domain_a}', f'img_{domain_b}'])
]
test_pipeline = [
dict(type='LoadImageFromFile', io_backend='disk', key='img', flag='color'),
dict(type='Resize', scale=(256, 256), interpolation='bicubic'),
dict(
- type='PackGenInputs',
- keys=[f'img_{domain_a}', f'img_{domain_b}'],
- meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path'])
+ type='PackEditInputs',
+ keys=[f'img_{domain_a}', f'img_{domain_b}'])
]
data_root = './data/horse2zebra/'
# `batch_size` and `data_root` need to be set.
diff --git a/tools/dataset_converters/super-resolution/div2k/preprocess_div2k_dataset.py b/tools/dataset_converters/super-resolution/div2k/preprocess_div2k_dataset.py
index 0ecfd4c641..ad849d08e9 100644
--- a/tools/dataset_converters/super-resolution/div2k/preprocess_div2k_dataset.py
+++ b/tools/dataset_converters/super-resolution/div2k/preprocess_div2k_dataset.py
@@ -9,6 +9,7 @@
import cv2
import lmdb
import mmcv
+import mmengine
import numpy as np
@@ -88,10 +89,10 @@ def extract_subimages(opt):
print(f'Folder {save_folder} already exists. Exit.')
sys.exit(1)
- img_list = list(mmcv.scandir(input_folder))
+ img_list = list(mmengine.scandir(input_folder))
img_list = [osp.join(input_folder, v) for v in img_list]
- prog_bar = mmcv.ProgressBar(len(img_list))
+ prog_bar = mmengine.ProgressBar(len(img_list))
pool = Pool(opt['n_thread'])
for path in img_list:
pool.apply_async(
@@ -197,7 +198,7 @@ def prepare_keys_div2k(folder_path):
"""
print('Reading image path list ...')
img_path_list = sorted(
- list(mmcv.scandir(folder_path, suffix='png', recursive=False)))
+ list(mmengine.scandir(folder_path, suffix='png', recursive=False)))
keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)]
return img_path_list, keys
@@ -381,7 +382,8 @@ def parse_args():
parser.add_argument(
'--n-thread',
nargs='?',
- default=20,
+ default=8,
+ type=int,
help='thread number when using multiprocessing')
parser.add_argument(
'--make-lmdb',
diff --git a/tools/dataset_converters/unconditional_gans/README.md b/tools/dataset_converters/unconditional_gans/README.md
index 70d5524f4f..a15b6204de 100644
--- a/tools/dataset_converters/unconditional_gans/README.md
+++ b/tools/dataset_converters/unconditional_gans/README.md
@@ -8,7 +8,7 @@ dataset_type = 'UnconditionalImageDataset'
train_pipeline = [
dict(type='LoadImageFromFile', key='img'),
dict(type='Flip', keys=['img'], direction='horizontal'),
- dict(type='PackGenInputs', keys=['img'], meta_keys=['img_path'])
+ dict(type='PackEditInputs', keys=['img'], meta_keys=['img_path'])
]
# `batch_size` and `data_root` need to be set.
@@ -23,7 +23,7 @@ train_dataloader = dict(
pipeline=train_pipeline))
```
-Here, we adopt `InfinitySampler` to avoid frequent dataloader reloading, which will accelerate the training procedure. As shown in the example, `pipeline` provides important data pipeline to process images, including loading from file system, resizing, cropping, transferring to `torch.Tensor` and packing to `GenDataSample`. All of supported data pipelines can be found in `mmedit/datasets/transforms`.
+Here, we adopt `InfinitySampler` to avoid frequent dataloader reloading, which will accelerate the training procedure. As shown in the example, `pipeline` provides important data pipeline to process images, including loading from file system, resizing, cropping, transferring to `torch.Tensor` and packing to `EditDataSample`. All of supported data pipelines can be found in `mmedit/datasets/transforms`.
For unconditional GANs with dynamic architectures like PGGAN and StyleGANv1, `GrowScaleImgDataset` is recommended to use for training. Since such dynamic architectures need real images in different scales, directly adopting `UnconditionalImageDataset` will bring heavy I/O cost for loading multiple high-resolution images. Here is an example we use for training PGGAN in CelebA-HQ dataset:
@@ -33,7 +33,7 @@ dataset_type = 'GrowScaleImgDataset'
pipeline = [
dict(type='LoadImageFromFile', key='img'),
dict(type='Flip', keys=['img'], direction='horizontal'),
- dict(type='PackGenInputs')
+ dict(type='PackEditInputs')
]
# `samples_per_gpu` and `imgs_root` need to be set.
diff --git a/tools/gui/README.md b/tools/gui/README.md
new file mode 100644
index 0000000000..6de87446a7
--- /dev/null
+++ b/tools/gui/README.md
@@ -0,0 +1,105 @@
+# MMEditing Viewer
+
+- [Introduction](#introduction)
+- [Major features](#major-features)
+- [Prerequisites](#prerequisites)
+- [Getting Started](#getting-started)
+- [Examples](#examples)
+- [Contributing](#contributing)
+
+## Introduction
+
+**MMEditing Viewer** is a qualitative comparison tools to facilitate your research.
+
+## Major features
+
+- **Patch-based comparison**
+ - Crop a patch on multiple images to compare
+ - Batch comparison
+ - Flexible settings on number of columns and size of images.
+ - Save your comparison result
+- **Before/After slider comparison**
+ - Support both videos and images comparison
+ - Record and save comparison results as a video clip
+
+## Prerequisites
+
+MMEditing Viewer works on Linux, Windows and macOS. It requires:
+
+- Python >= 3.6
+- PyQt5
+- opencv-python (headless version)
+
+## Getting Started
+
+**Step 0.**
+Install PyQt5.
+
+```shell
+pip install PyQt5
+```
+
+**Step 1.**
+Install and check opencv-python version.
+If your meet following errors:
+
+```
+QObject::moveToThread: Current thread is not the object's thread.
+Available platform plugins are: xcb... .
+```
+
+Please install opencv-python-headless version.
+
+```shell
+pip install opencv-python-headless
+```
+
+**Step 2.**
+Install MMEditing.
+
+```shell
+git clone -b 1.x https://github.com/open-mmlab/mmediting.git
+```
+
+If you want to follow the newest features, you can clone `dev-1.x` branch.
+
+```shell
+git clone -b dev-1.x https://github.com/open-mmlab/mmediting.git
+```
+
+**Step 3.**
+Run
+
+```shell
+python tools/gui/gui.py
+```
+
+## Examples
+
+**1. Patch-based comparison: batch images**
+
+https://user-images.githubusercontent.com/49083766/199232588-7a07a3d9-725d-48be-89bf-1ffb45bd5d74.mp4
+
+**2. Patch-based comparison: single image**
+
+https://user-images.githubusercontent.com/49083766/199232606-f8539191-4bda-4b2c-975a-59020927abae.mp4
+
+**3. Before/After slider comparison: images**
+
+https://user-images.githubusercontent.com/49083766/199232615-2c56dcf1-0b42-41a5-884c-16a8f28a2647.mp4
+
+**4. Before/After slider comparison: input video frames**
+
+https://user-images.githubusercontent.com/49083766/199232617-e03a06dc-727b-43bb-8110-049d0fff28ba.mp4
+
+**5. Before/After slider comparison: input Mp4 video**
+
+https://user-images.githubusercontent.com/49083766/199232651-87d8064e-cbaf-4d30-b90b-94ee0af7d497.mp4
+
+**6. Before/After slider comparison: record**
+
+https://user-images.githubusercontent.com/49083766/199232634-eca70d28-8437-400a-8ab9-d2fe396b6ea9.mp4
+
+## Contributing
+
+We appreciate all contributions to improve MMEditing Viewer. You can create your issue to report bugs or request new features. Welcome to give us suggestions or contribute your codes.
diff --git a/tools/gui/component.py b/tools/gui/component.py
new file mode 100644
index 0000000000..54f1c490e1
--- /dev/null
+++ b/tools/gui/component.py
@@ -0,0 +1,453 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import time
+
+import cv2
+import numpy as np
+from PyQt5 import QtCore, QtGui, QtWidgets
+from utils import layout2widget
+
+
+class QLabelClick(QtWidgets.QLabel):
+ clicked = QtCore.pyqtSignal(str)
+
+ def __init__(self):
+ super().__init__()
+
+ def mousePressEvent(self, event):
+ self.clicked.emit(self.text())
+
+
+class QLabelSlider(QtWidgets.QLabel):
+
+ def __init__(self, parent, scale, label_1, label_2, title):
+ super().__init__()
+ self.parent = parent
+ self.hSlider = -1
+ self.vSlider = -1
+ self.oldSlider = -1
+ self.scale = scale
+ self.label_1 = label_1
+ self.label_2 = label_2
+ self.title = title
+ self.images = self.parent.images
+ self.auto_mode = 0
+
+ def mousePressEvent(self, ev: QtGui.QMouseEvent) -> None:
+ if ev.button() == QtCore.Qt.LeftButton:
+ if self.hSlider > -1:
+ self.oldSlider = self.hSlider
+ elif self.vSlider > -1:
+ self.oldSlider = self.vSlider
+ return super().mousePressEvent(ev)
+
+ def mouseReleaseEvent(self, ev: QtGui.QMouseEvent) -> None:
+ if ev.button() == QtCore.Qt.LeftButton:
+ self.oldSlider = -1
+ self.update()
+ return super().mouseReleaseEvent(ev)
+
+ def mouseMoveEvent(self, ev: QtGui.QMouseEvent) -> None:
+ if self.oldSlider > -1:
+ if self.hSlider > -1:
+ self.hSlider = ev.pos().x()
+ elif self.vSlider > -1:
+ self.vSlider = ev.pos().y()
+ self.update()
+ return super().mouseMoveEvent(ev)
+
+ def paintEvent(self, ev: QtGui.QPaintEvent) -> None:
+ qp = QtGui.QPainter()
+ qp.begin(self)
+ qp.drawImage(0, 0, QtGui.QImage(self.getImage()))
+ pen = QtGui.QPen(QtCore.Qt.green, 3)
+ qp.setPen(pen)
+ qp.drawLine(self.hSlider, 0, self.hSlider, self.height())
+ length = 9
+ qp.drawText(self.hSlider - 10 - len(self.label_1) * length, 20,
+ self.label_1)
+ qp.drawText(self.hSlider + 10, 20, self.label_2)
+ qp.drawText(10, self.height() - 10, self.title)
+ qp.end()
+
+ def set_scale(self, scale):
+ self.hSlider = int(self.hSlider * scale / self.scale)
+ self.scale = scale
+ self.update()
+
+ def setImage(self, images):
+ self.images = images
+ self.update()
+
+ def set_auoMode(self, mode):
+ if mode != 3 or self.auto_mode < 3:
+ self.auto_mode = mode
+
+ def auto_slider(self):
+ try:
+ if self.auto_mode == 1:
+ self.hSlider += 1
+ if self.hSlider > self.w:
+ self.hSlider = 0
+ elif self.auto_mode == 2:
+ self.hSlider -= 1
+ if self.hSlider < 0:
+ self.hSlider = self.w
+ elif self.auto_mode == 3:
+ self.hSlider += 1
+ if self.hSlider >= self.w:
+ self.auto_mode = 4
+ elif self.auto_mode == 4:
+ self.hSlider -= 1
+ if self.hSlider <= 0:
+ self.auto_mode = 3
+ self.update()
+
+ except Exception:
+ print(Exception)
+ pass
+
+ def getImage(self):
+ img1, img2 = self.images
+ if img1 is None or img2 is None:
+ return
+ h1, w1, c = img1.shape
+ h2, w2, c = img2.shape
+ if w2 > w1:
+ img1 = cv2.resize(img1, (w2, h2))
+ h, w = h2, w2
+ else:
+ img2 = cv2.resize(img2, (w1, h1))
+ h, w = h1, w1
+ self.h = int(h * self.scale)
+ self.w = int(w * self.scale)
+ self.setFixedHeight(self.h)
+ self.setFixedWidth(self.w)
+ if self.hSlider < 0:
+ self.hSlider = int(self.w / 2.0)
+
+ v = int(self.hSlider / self.scale)
+ img11 = img1[:, 0:v].copy()
+ img22 = img2[:, v:].copy()
+ img = np.hstack((img11, img22))
+ # img = cv2.line(img, (v, 0), (v, h2), (0, 222, 0), 4)
+ rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_dis = QtGui.QImage(rgb_img, w, h, w * c,
+ QtGui.QImage.Format_RGB888)
+ img = QtGui.QPixmap.fromImage(img_dis).scaled(self.width(),
+ self.height())
+ return img
+
+
+class QLabelPaint(QtWidgets.QLabel):
+
+ def __init__(self, parent, beginPoint=None, endPoint=None):
+ super().__init__()
+ self.beginPoint = beginPoint
+ self.endPoint = endPoint
+ self.parent = parent
+ self.statusBar = self.parent.statusBar
+ if self.beginPoint and self.endPoint:
+ self.isShow = True
+ else:
+ self.isShow = False
+
+ def mousePressEvent(self, ev: QtGui.QMouseEvent) -> None:
+ if ev.button() == QtCore.Qt.LeftButton:
+ self.endPoint = None
+ self.beginPoint = ev.pos()
+ self.isShow = False
+ return super().mousePressEvent(ev)
+
+ def mouseReleaseEvent(self, ev: QtGui.QMouseEvent) -> None:
+ if ev.button() == QtCore.Qt.LeftButton:
+ self.endPoint = ev.pos()
+ self.update()
+ self.isShow = True
+ self.parent.set_rect(self.beginPoint, self.endPoint)
+ return super().mouseReleaseEvent(ev)
+
+ def mouseMoveEvent(self, ev: QtGui.QMouseEvent) -> None:
+ self.endPoint = ev.pos()
+ self.update()
+ self.statusBar.showMessage(
+ f'Start: {self.beginPoint.x()},{self.beginPoint.y()}; \
+ End: {self.endPoint.x()},{self.endPoint.y()}')
+ return super().mouseMoveEvent(ev)
+
+ def paintEvent(self, ev: QtGui.QPaintEvent) -> None:
+ super().paintEvent(ev)
+ if self.beginPoint and self.endPoint:
+ qp = QtGui.QPainter()
+ qp.begin(self)
+ pen = QtGui.QPen(QtCore.Qt.red, 2)
+ qp.setPen(pen)
+ w = abs(self.beginPoint.x() - self.endPoint.x())
+ h = abs(self.beginPoint.y() - self.endPoint.y())
+ qp.drawRect(self.beginPoint.x(), self.beginPoint.y(), w, h)
+ qp.end()
+ if self.isShow and isinstance(self.parent, ConcatImageWidget):
+ self.parent.show_images(self.beginPoint.x(),
+ self.beginPoint.y(), w, h)
+ self.isShow = False
+
+
+class ConcatImageWidget(QtWidgets.QWidget):
+
+ def __init__(self, parent, mode, col_num=4):
+ super(ConcatImageWidget, self).__init__(parent)
+ self.parent = parent
+ self.statusBar = self.parent.statusBar
+ self.hlayout = QtWidgets.QHBoxLayout()
+ self.setLayout(self.hlayout)
+ self.mode = mode
+ self.scale = 1
+ self.col_num = col_num
+ self.file_path = None
+ self.labels = None
+ self.gt = None
+ self.img_h = 0
+ self.rect = None
+
+ def show_images(self, x=0, y=0, w=0, h=0):
+ self.rect = [x, y, w, h]
+ vlayout = QtWidgets.QVBoxLayout()
+ vlayout.setContentsMargins(0, 0, 0, 0)
+ hlayout_img = QtWidgets.QHBoxLayout()
+ hlayout_img.setContentsMargins(0, 0, 0, 0)
+ hlayout_text = QtWidgets.QHBoxLayout()
+ hlayout_text.setContentsMargins(0, 0, 0, 0)
+
+ for i, (path, text) in enumerate(zip(self.file_path, self.labels)):
+ img = QtGui.QPixmap(path).scaled(self.gt_w, self.gt_h)
+ if self.mode == 0:
+ img = img.copy(QtCore.QRect(x, y, w, h))
+ if self.img_h > 0:
+ img_w = int(float(self.img_h) / img.height() * img.width())
+ img = img.scaled(img_w, self.img_h)
+
+ label = QtWidgets.QLabel()
+ label.setFixedWidth(img.width())
+ label.setFixedHeight(img.height())
+ label.setMargin(0)
+ label.setPixmap(img)
+ hlayout_img.addWidget(label)
+
+ label_text = QtWidgets.QLabel()
+ label_text.setMargin(0)
+ label_text.setAlignment(QtCore.Qt.AlignCenter)
+ label_text.setText(text)
+ label_text.adjustSize()
+ hlayout_text.addWidget(label_text)
+
+ if (i + 1) % self.col_num == 0:
+ vlayout.addWidget(layout2widget(hlayout_img))
+ vlayout.addWidget(layout2widget(hlayout_text))
+ hlayout_img = QtWidgets.QHBoxLayout()
+ hlayout_img.setContentsMargins(0, 0, 0, 0)
+ hlayout_text = QtWidgets.QHBoxLayout()
+ hlayout_text.setContentsMargins(0, 0, 0, 0)
+
+ if len(hlayout_img) > 0:
+ for i in range(0, self.col_num - len(hlayout_img)):
+ label = QtWidgets.QLabel()
+ label.setMargin(0)
+ label.setFixedWidth(img.width())
+ label.setFixedHeight(img.height())
+ hlayout_img.addWidget(label)
+ label = QtWidgets.QLabel()
+ label.setMargin(0)
+ hlayout_text.addWidget(label)
+ vlayout.addWidget(layout2widget(hlayout_img))
+ vlayout.addWidget(layout2widget(hlayout_text))
+
+ total_w = self.gt_w + (img.width() + 2) * self.col_num
+ self.setFixedWidth(total_w)
+ if self.hlayout.count() > 1:
+ item = self.hlayout.itemAt(1)
+ self.hlayout.removeItem(item)
+ if item.widget():
+ item.widget().deleteLater()
+ self.hlayout.addWidget(layout2widget(vlayout))
+ else:
+ self.hlayout.addWidget(layout2widget(vlayout))
+
+ def set_images(self, file_path, labels, gt=None, scale=1, rect=None):
+ self.file_path = file_path
+ self.labels = labels
+ self.gt = gt
+ self.scale = scale
+
+ for i in reversed(range(self.hlayout.count())):
+ self.hlayout.itemAt(i).widget().deleteLater()
+ self.hlayout.setContentsMargins(0, 0, 0, 0)
+
+ img = QtGui.QPixmap(self.gt)
+ self.gt_w, self.gt_h = img.width() * scale, img.height() * scale
+ img = img.scaled(self.gt_w, self.gt_h)
+ row = (len(self.file_path) + self.col_num - 1) // self.col_num
+ self.img_h = int(float(self.gt_h - (row - 1) * 29) / row)
+
+ beginPoint = None
+ endPoint = None
+ if rect:
+ beginPoint = QtCore.QPoint(rect[0], rect[1])
+ endPoint = QtCore.QPoint(rect[2], rect[3])
+ label = QLabelPaint(self, beginPoint, endPoint)
+ label.setMargin(0)
+ label.setAlignment(QtCore.Qt.AlignTop)
+ label.setPixmap(img)
+ self.hlayout.addWidget(label)
+ if rect:
+ self.show_images(rect[0], rect[1], rect[2] - rect[0],
+ rect[3] - rect[1])
+ else:
+ self.show_images(0, 0, self.gt_w, self.gt_h)
+
+ def set_rect(self, beginPoint, endPoint):
+ self.parent.rect = [
+ beginPoint.x(),
+ beginPoint.y(),
+ endPoint.x(),
+ endPoint.y()
+ ]
+ self.parent.old_scale = self.scale
+
+
+class VideoPlayer(QtCore.QThread):
+ sigout = QtCore.pyqtSignal(np.ndarray)
+ sigend = QtCore.pyqtSignal(bool)
+
+ def __init__(self, parent):
+ super(VideoPlayer, self).__init__(parent)
+ self.parent = parent
+
+ def set(self, path, fps=None):
+ self.path = path
+ self.video, self.fps, self.actual_frames = self.setVideo(path, fps)
+ self.time = self.actual_frames / self.fps
+ self.total_frames = self.actual_frames
+ self.num = 0
+ self.working = True
+ self.isPause = False
+ self.mutex = QtCore.QMutex()
+ self.cond = QtCore.QWaitCondition()
+
+ def setVideo(self, path, fps):
+ if os.path.isfile(path):
+ v = cv2.VideoCapture(path)
+ total_frames = v.get(cv2.CAP_PROP_FRAME_COUNT)
+ if fps is None:
+ fps = v.get(cv2.CAP_PROP_FPS)
+ else:
+ files = sorted(os.listdir(path))
+ v = ['/'.join([path, f]) for f in files]
+ total_frames = len(v)
+ if fps is None:
+ fps = 25
+ return v, fps, total_frames
+
+ def pause(self):
+ self.isPause = True
+
+ def resume(self):
+ self.isPause = False
+ self.cond.wakeAll()
+
+ def __del__(self):
+ self.working = False
+ self.wait()
+
+ def run(self):
+ while self.working:
+ self.mutex.lock()
+ if self.isPause:
+ self.cond.wait(self.mutex)
+ if isinstance(self.video, list):
+ img = cv2.imread(self.video[self.num])
+ self.num += 1
+ self.sigout.emit(img)
+ if self.num >= self.total_frames:
+ self.sigend.emit(True)
+ self.num = 0
+ time.sleep(1 / self.fps)
+ self.mutex.unlock()
+
+
+class VideoSlider(QtCore.QThread):
+ sigout = QtCore.pyqtSignal(list)
+ sigend = QtCore.pyqtSignal(bool)
+
+ def __init__(self, parent):
+ super(VideoSlider, self).__init__(parent)
+ self.parent = parent
+ self.mutex = QtCore.QMutex()
+ self.cond = QtCore.QWaitCondition()
+
+ def set(self, path1, path2, fps1=None, fps2=None):
+ self.path1 = path1
+ self.path2 = path2
+ self.v1, self.fps1, self.total_frames1 = self.setVideo(path1, fps1)
+ self.v2, self.fps2, self.total_frames2 = self.setVideo(path2, fps2)
+ if self.fps1 != self.fps2:
+ return False
+ self.fps = self.fps1
+ self.num = 0
+ self.working = True
+ self.isPause = False
+ return True
+
+ def setVideo(self, path, fps):
+ if os.path.isfile(path):
+ v = cv2.VideoCapture(path)
+ total_frames = v.get(cv2.CAP_PROP_FRAME_COUNT)
+ fps = v.get(cv2.CAP_PROP_FPS)
+ else:
+ files = sorted(os.listdir(path))
+ v = ['/'.join([path, f]) for f in files]
+ total_frames = len(v)
+ if fps is None:
+ fps = 25
+ return v, fps, total_frames
+
+ def pause(self):
+ self.isPause = True
+
+ def resume(self):
+ self.isPause = False
+ self.cond.wakeAll()
+
+ def __del__(self):
+ self.working = False
+ self.wait()
+
+ def run(self):
+ while self.working:
+ self.mutex.lock()
+ if self.isPause:
+ self.cond.wait(self.mutex)
+ if isinstance(self.v1, list):
+ num = self.num if self.num < self.total_frames1 \
+ else self.total_frames1 - 1
+ img1 = cv2.imread(self.v1[num])
+ elif isinstance(self.v1, cv2.VideoCapture):
+ r, img1 = self.v1.read()
+ if not r:
+ self.v1, self.fps1, self.total_frames1 = self.setVideo(
+ self.path1, self.fps1)
+ if isinstance(self.v2, list):
+ num = self.num if self.num < self.total_frames2 \
+ else self.total_frames2 - 1
+ img2 = cv2.imread(self.v2[num])
+ elif isinstance(self.v1, cv2.VideoCapture):
+ r, img2 = self.v1.read()
+ if not r:
+ self.v2, self.fps2, self.total_frames2 = self.setVideo(
+ self.path2, self.fps2)
+ self.num += 1
+ self.sigout.emit([img1, img2])
+ if self.num >= self.total_frames1 and \
+ self.num >= self.total_frames2:
+ self.num = 0
+ time.sleep(1 / self.fps)
+ self.mutex.unlock()
diff --git a/tools/gui/gui.py b/tools/gui/gui.py
new file mode 100644
index 0000000000..e95e8cd827
--- /dev/null
+++ b/tools/gui/gui.py
@@ -0,0 +1,81 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import sys
+
+from component import QLabelClick
+from page_sr import SRPage
+from PyQt5 import QtWidgets
+
+
+class Homepage(QtWidgets.QWidget):
+
+ def __init__(self, main_window):
+ super().__init__()
+ self.main_window = main_window
+ layout = QtWidgets.QGridLayout()
+ self.setLayout(layout)
+
+ t1 = QLabelClick()
+ t2 = QLabelClick()
+ t3 = QLabelClick()
+ t4 = QLabelClick()
+ t1.setText('general')
+ t2.setText('sr')
+ t3.setText('inpainting')
+ t4.setText('matting')
+ layout.addWidget(t1, 0, 0)
+ layout.addWidget(t2, 0, 1)
+ layout.addWidget(t3, 1, 0)
+ layout.addWidget(t4, 1, 1)
+
+ t1.clicked.connect(self.main_window.change_window)
+ t2.clicked.connect(self.main_window.change_window)
+ t3.clicked.connect(self.main_window.change_window)
+ t4.clicked.connect(self.main_window.change_window)
+
+
+class MainWindow(QtWidgets.QMainWindow):
+
+ def __init__(self):
+ super(MainWindow, self).__init__()
+
+ self.setWindowTitle('MMEditing Viewer')
+ # # MenuBar
+ # menubar_Aaa = self.menuBar().addMenu('Aaa')
+ # menubar_Bbb = self.menuBar().addMenu('Bbb')
+ # menubar_Ccc = self.menuBar().addMenu('Ccc')
+ # menubar_Aaa.addAction('New')
+ # save = QtWidgets.QAction('Save', self)
+ # save.setShortcut('Ctrl+S')
+ # menubar_Aaa.addAction(save)
+ # menubar_Bbb.addAction('New')
+ # menubar_Ccc.addAction('New')
+
+ # # ToolBar
+ # self.toolBar = QtWidgets.QToolBar('ToolBar')
+ # open = QtWidgets.QAction(QtGui.QIcon(), 'Open', self)
+ # save = QtWidgets.QAction(QtGui.QIcon(), 'Save', self)
+ # self.toolBar.addAction(open)
+ # self.toolBar.addAction(save)
+ # self.addToolBar(QtCore.Qt.ToolBarArea.LeftToolBarArea, self.toolBar)
+
+ # StatusBar
+ self.statusBar = QtWidgets.QStatusBar()
+ self.setStatusBar(self.statusBar)
+ self.statusBar.showMessage('')
+
+ self.homepage = Homepage(self)
+ self.sr = SRPage(self)
+ self.setCentralWidget(self.sr)
+
+ def change_window(self, wname):
+ if wname == 'sr':
+ self.setCentralWidget(self.sr)
+ elif wname == 'general':
+ self.setCentralWidget(self.general)
+
+
+if __name__ == '__main__':
+ app = QtWidgets.QApplication(sys.argv)
+ myWin = MainWindow()
+ myWin.showMaximized()
+ sys.exit(app.exec_())
diff --git a/tools/gui/page_general.py b/tools/gui/page_general.py
new file mode 100644
index 0000000000..002f7909dd
--- /dev/null
+++ b/tools/gui/page_general.py
@@ -0,0 +1,102 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+from PyQt5 import QtCore, QtGui, QtWidgets
+from utils import layout2widget
+
+
+class SliderTab(QtWidgets.QWidget):
+
+ def __init__(self):
+ super().__init__()
+ self.images = [None, None]
+
+ self.cb_1 = QtWidgets.QComboBox()
+ self.cb_2 = QtWidgets.QComboBox()
+ self.cb_1.currentIndexChanged.connect(self.change_image)
+ self.cb_2.currentIndexChanged.connect(self.change_image)
+ self.btn_add = QtWidgets.QPushButton()
+ self.btn_add.setText('Add')
+ self.btn_add.clicked.connect(self.add)
+ left_grid = QtWidgets.QGridLayout()
+ left_grid.addWidget(self.cb_1, 0, 0)
+ left_grid.addWidget(self.cb_2, 1, 0)
+ left_grid.addWidget(self.btn_add, 2, 0)
+
+ self.imageArea = QtWidgets.QLabel()
+ self.imageArea.setFrameShape(QtWidgets.QFrame.Box)
+ self.imageArea.setLineWidth(2)
+ self.imageArea.setAlignment(QtCore.Qt.AlignBottom)
+ self.imageArea.setStyleSheet(
+ 'border-width: 0px; border-style: solid; border-color: rgb(100, 100, 100);background-color: rgb(255, 255, 255)' # noqa
+ )
+ self.slider = QtWidgets.QSlider(QtCore.Qt.Horizontal)
+ self.slider.setMaximum(800)
+ self.slider.valueChanged.connect(self.show_image)
+ right_grid = QtWidgets.QGridLayout()
+ right_grid.addWidget(self.imageArea, 0, 0)
+ right_grid.addWidget(self.slider, 1, 0)
+
+ # Splitter
+ hsplitter = QtWidgets.QSplitter(QtCore.Qt.Horizontal)
+ hsplitter.addWidget(layout2widget(left_grid))
+ hsplitter.addWidget(layout2widget(right_grid))
+ hlayout = QtWidgets.QHBoxLayout()
+ hlayout.addWidget(hsplitter)
+ self.setLayout(hlayout)
+
+ def add(self):
+ path, _ = QtWidgets.QFileDialog.getOpenFileName(
+ self, 'Select gt file', '', 'Images (*.jpg *.png *.mp4 *.avi)')
+ if self.cb_1.count() > 0:
+ if self.cb_1.findText(path) > -1:
+ self.cb_1.removeItem(self.cb_1.findText(path))
+ if self.cb_2.findText(path) > -1:
+ self.cb_2.removeItem(self.cb_2.findText(path))
+ self.cb_1.addItem(path)
+ self.cb_2.addItem(path)
+ self.cb_2.setCurrentIndex(self.cb_2.count() - 1)
+ else:
+ self.cb_1.addItem(path)
+ self.cb_2.addItem(path)
+ self.cb_1.setCurrentIndex(0)
+ self.cb_2.setCurrentIndex(0)
+
+ def change_image(self):
+ self.images[0] = cv2.imread(self.cb_1.currentText())
+ self.images[1] = cv2.imread(self.cb_2.currentText())
+ if self.images[0] is None or self.images[1] is None:
+ return
+ self.show_image()
+
+ def show_image(self):
+ img1, img2 = self.images
+ h2, w2, c2 = img2.shape
+ img2 = cv2.resize(img2, (800, int(800 / w2 * h2)))
+ h2, w2, c2 = img2.shape
+ img1 = cv2.resize(img1, (w2, h2))
+ v = self.slider.value()
+ img11 = img1[:, 0:v].copy()
+ img22 = img2[:, v:].copy()
+ img = np.hstack((img11, img22))
+ img = cv2.line(img, (v, 0), (v, h2), (0, 222, 0), 4)
+ rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_dis = QtGui.QImage(rgb_img, w2, h2, w2 * c2,
+ QtGui.QImage.Format_RGB888)
+ jpg = QtGui.QPixmap.fromImage(img_dis).scaled(
+ self.imageArea.width(), int(self.imageArea.width() / w2 * h2))
+ self.imageArea.setPixmap(jpg)
+
+
+class GeneralPage(QtWidgets.QWidget):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.tab_slider = SliderTab()
+ self.tabs = QtWidgets.QTabWidget()
+ self.tabs.addTab(self.tab_slider, 'before/after slider')
+
+ layout = QtWidgets.QVBoxLayout()
+ self.setLayout(layout)
+ layout.addWidget(self.tabs)
diff --git a/tools/gui/page_sr.py b/tools/gui/page_sr.py
new file mode 100644
index 0000000000..6699c386a6
--- /dev/null
+++ b/tools/gui/page_sr.py
@@ -0,0 +1,854 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import datetime
+import os
+
+import cv2
+from component import ConcatImageWidget, QLabelSlider, VideoSlider
+from PyQt5 import QtCore, QtWidgets
+from utils import layout2widget
+
+
+class PatchTab(QtWidgets.QWidget):
+
+ def __init__(self, parent):
+ super().__init__()
+
+ self.parent = parent
+ self.statusBar = self.parent.statusBar
+ self.file_paths = []
+ self.labels = []
+ self.rect = None
+ self.images = None
+ self.isShow = False
+
+ # Left Widget
+ self.btn_add_file = QtWidgets.QPushButton()
+ self.btn_add_file.setText('Add')
+ self.btn_add_file.clicked.connect(self.add_file)
+ self.btn_open_file = QtWidgets.QPushButton()
+ self.btn_open_file.setText('Open new')
+ self.btn_open_file.clicked.connect(self.open_file)
+
+ self.input_label = QtWidgets.QLineEdit()
+ self.input_file = QtWidgets.QLineEdit()
+
+ self.tb_files = QtWidgets.QTableWidget()
+ self.tb_files.setColumnCount(3)
+ self.tb_files.setHorizontalHeaderLabels(
+ ['File/Folder', 'Label', 'Other'])
+ self.tb_files.horizontalHeader().setSectionResizeMode(
+ QtWidgets.QHeaderView.Stretch)
+ self.tb_files.setSelectionBehavior(
+ QtWidgets.QAbstractItemView.SelectRows)
+ self.tb_files.setSelectionMode(
+ QtWidgets.QAbstractItemView.SingleSelection)
+ self.tb_files.setContextMenuPolicy(QtCore.Qt.CustomContextMenu)
+ self.tb_files.customContextMenuRequested.connect(self.tableMenu)
+
+ left_grid = QtWidgets.QGridLayout()
+ left_grid.addWidget(self.tb_files, 0, 0, 5, 10)
+ left_grid.addWidget(self.input_file, 6, 0, 1, 4)
+ left_grid.addWidget(self.input_label, 6, 4, 1, 2)
+ left_grid.addWidget(self.btn_open_file, 6, 6, 1, 2)
+ left_grid.addWidget(self.btn_add_file, 6, 8, 1, 2)
+
+ # Right Widget
+ styleSheet = '''
+ QGroupBox {
+ background-color: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
+ stop: 0 #E0E0E0, stop: 1 #FFFFFF);
+ border: 1px solid #999999;
+ border-radius: 5px;
+ margin-top: 2ex; /*leave space at the top for the title */
+ padding: 2ex 10ex;
+ font-size: 20px;
+ color: black;
+ }
+ QGroupBox::title {
+ subcontrol-origin: margin;
+ subcontrol-position: top left; /* position at the top center */
+ padding: 0 3px;
+ left: 30px;
+ font-size: 8px;
+ color: black;
+ }
+ ''' # noqa
+ # select mode
+ self.modeRect = QtWidgets.QGroupBox('Select Mode')
+ self.modeRect.setFlat(True)
+ # self.modeRect.setStyleSheet(styleSheet)
+ btn_mode1 = QtWidgets.QRadioButton('Input whole image', self.modeRect)
+ btn_mode2 = QtWidgets.QRadioButton('Input crop image', self.modeRect)
+ self.btnGroup_mode = QtWidgets.QButtonGroup()
+ self.btnGroup_mode.addButton(btn_mode1, 0)
+ self.btnGroup_mode.addButton(btn_mode2, 1)
+ self.btnGroup_mode.button(0).setChecked(True)
+ self.btnGroup_mode.idToggled.connect(self.reset)
+ hlayout = QtWidgets.QHBoxLayout()
+ hlayout.addWidget(btn_mode1)
+ hlayout.addWidget(btn_mode2)
+ self.modeRect.setLayout(hlayout)
+
+ # select dirType
+ self.dirTypeRect = QtWidgets.QGroupBox('Select File or Directory')
+ self.dirTypeRect.setFlat(True)
+ # self.dirTypeRect.setStyleSheet(styleSheet)
+ btn_dirType1 = QtWidgets.QRadioButton('Directory', self.dirTypeRect)
+ btn_dirType2 = QtWidgets.QRadioButton('File', self.dirTypeRect)
+ self.btnGroup_dirType = QtWidgets.QButtonGroup()
+ self.btnGroup_dirType.addButton(btn_dirType1, 0)
+ self.btnGroup_dirType.addButton(btn_dirType2, 1)
+ self.btnGroup_dirType.button(0).setChecked(True)
+ self.btnGroup_dirType.idToggled.connect(self.reset)
+ hlayout = QtWidgets.QHBoxLayout()
+ hlayout.addWidget(btn_dirType1)
+ hlayout.addWidget(btn_dirType2)
+ self.dirTypeRect.setLayout(hlayout)
+
+ # select gt
+ self.cb_gt = QtWidgets.QComboBox()
+ self.cb_gt.currentTextChanged.connect(self.change_gt)
+
+ # set column
+ self.spin_cols = QtWidgets.QSpinBox()
+ self.spin_cols.setMinimum(1)
+ self.spin_cols.setMaximum(10)
+ self.spin_cols.setValue(4)
+ self.spin_cols.valueChanged.connect(self.set_column)
+
+ # set scale
+ self.scale = 100
+ self.txt_scale = QtWidgets.QLabel('100 %')
+ self.slider_scale = QtWidgets.QSlider(QtCore.Qt.Horizontal)
+ self.slider_scale.setMinimum(1)
+ self.slider_scale.setMaximum(200)
+ self.slider_scale.setValue(100)
+ self.slider_scale.valueChanged.connect(self.set_scale)
+
+ # operation
+ self.btn_run = QtWidgets.QPushButton()
+ self.btn_run.setText('Run')
+ self.btn_run.clicked.connect(self.run)
+ self.btn_reset = QtWidgets.QPushButton()
+ self.btn_reset.setText('Reset')
+ self.btn_reset.clicked.connect(self.reset)
+ self.btn_save = QtWidgets.QPushButton()
+ self.btn_save.setText('Save')
+ self.btn_save.clicked.connect(self.save)
+
+ right_grid = QtWidgets.QGridLayout()
+ right_grid.addWidget(self.modeRect, 0, 0, 1, 20)
+ right_grid.addWidget(self.dirTypeRect, 1, 0, 1, 20)
+ right_grid.addWidget(self.cb_gt, 2, 0, 1, 15)
+ right_grid.addWidget(QtWidgets.QLabel('Reference'), 2, 16, 1, 3)
+ right_grid.addWidget(self.spin_cols, 3, 0, 1, 15)
+ right_grid.addWidget(QtWidgets.QLabel('Set columns'), 3, 16, 1, 3)
+ right_grid.addWidget(self.txt_scale, 4, 0, 1, 1)
+ right_grid.addWidget(self.slider_scale, 4, 1, 1, 14)
+ right_grid.addWidget(QtWidgets.QLabel('Set scale'), 4, 16, 1, 3)
+ right_grid.addWidget(self.btn_run, 5, 0, 1, 20)
+ right_grid.addWidget(self.btn_reset, 6, 0, 1, 20)
+ right_grid.addWidget(self.btn_save, 7, 0, 1, 20)
+
+ # Bottom Widget
+ self.image_scroll = QtWidgets.QScrollArea()
+ self.image_scroll.installEventFilter(self)
+
+ # Splitter
+ hsplitter = QtWidgets.QSplitter(QtCore.Qt.Horizontal)
+ vsplitter = QtWidgets.QSplitter(QtCore.Qt.Vertical)
+ hsplitter.addWidget(layout2widget(left_grid))
+ hsplitter.addWidget(layout2widget(right_grid))
+ hsplitter.setStretchFactor(0, 1)
+ hsplitter.setStretchFactor(1, 2)
+ vsplitter.addWidget(hsplitter)
+ vsplitter.addWidget(self.image_scroll)
+ vsplitter.setStretchFactor(0, 1)
+ vsplitter.setStretchFactor(1, 5)
+ hlayout = QtWidgets.QHBoxLayout()
+ hlayout.addWidget(vsplitter)
+ self.setLayout(hlayout)
+
+ def open_file(self):
+ """Open a file or directory from dialog."""
+ if self.btnGroup_dirType.checkedId() == 0:
+ path = QtWidgets.QFileDialog.getExistingDirectory()
+ if len(path) <= 0:
+ return
+ label = path.split('/')[-1]
+ self.input_file.setText(path)
+ self.input_label.setText(label)
+ else:
+ path, _ = QtWidgets.QFileDialog.getOpenFileName(
+ self, 'Select file', '', 'Images (*.jpg *.png)')
+ if len(path) <= 0:
+ return
+ label = path.split('/')[-1].split('.')[0]
+ self.input_file.setText(path)
+ self.input_label.setText(label)
+
+ def add_file(self):
+ """Add opened file or directory to table."""
+ if os.path.exists(self.input_file.text()):
+ row = self.tb_files.rowCount()
+ self.tb_files.setRowCount(row + 1)
+ self.tb_files.setItem(
+ row, 0, QtWidgets.QTableWidgetItem(self.input_file.text()))
+ self.tb_files.setItem(
+ row, 1, QtWidgets.QTableWidgetItem(self.input_label.text()))
+ # if self.cb_gt.count() == 0:
+ # self.set_gt(self.input_file.text())
+ else:
+ QtWidgets.QMessageBox.about(self, 'Message',
+ 'Please input available file/folder')
+
+ def tableMenu(self, pos):
+ """Set mouse right button menu of table."""
+ menu = QtWidgets.QMenu()
+ menu_up = menu.addAction('move up')
+ menu_down = menu.addAction('move down')
+ menu_delete = menu.addAction('delete')
+ menu_gt = menu.addAction('set reference')
+ menu_pos = self.tb_files.mapToGlobal(pos)
+ row = None
+ for i in self.tb_files.selectionModel().selection().indexes():
+ row = i.row()
+ if row is None or row >= self.tb_files.rowCount():
+ return
+ action = menu.exec(menu_pos)
+ if action == menu_up:
+ self.tb_swap(row, -1)
+ elif action == menu_down:
+ self.tb_swap(row, 1)
+ elif action == menu_delete:
+ self.tb_files.removeRow(row)
+ elif action == menu_gt:
+ for r in range(self.tb_files.rowCount()):
+ self.tb_files.setItem(r, 2, QtWidgets.QTableWidgetItem(''))
+ self.tb_files.setItem(row, 2,
+ QtWidgets.QTableWidgetItem('Reference'))
+ self.set_gt(self.tb_files.item(row, 0).text())
+ if self.isShow:
+ self.run()
+
+ def tb_swap(self, row, move):
+ """Move items of table."""
+ if move == -1:
+ if 0 == row:
+ return
+ target = row - 1
+ elif move == 1:
+ if self.tb_files.rowCount() - 1 == row:
+ return
+ target = row + 1
+ for col in range(self.tb_files.columnCount()):
+ tmp = self.tb_files.takeItem(row, col)
+ self.tb_files.setItem(row, col,
+ self.tb_files.takeItem(target, col))
+ self.tb_files.setItem(target, col, tmp)
+
+ def set_gt(self, path):
+ """Set GT ComboBox items."""
+ self.cb_gt.clear()
+ if os.path.isfile(path):
+ self.cb_gt.addItem(path)
+ else:
+ files = sorted(os.listdir(path))
+ for f in files:
+ self.cb_gt.addItem(path + '/' + f)
+
+ def change_gt(self):
+ if self.isShow:
+ self.run()
+
+ def set_scale(self):
+ """Set scale."""
+ scale = self.slider_scale.value()
+ self.txt_scale.setText(f'{scale} %')
+ self.scale = scale
+ if self.isShow:
+ rect = None
+ if self.rect:
+ rect = [
+ int(r * scale / 100.0 / self.old_scale) for r in self.rect
+ ]
+ self.run(rect)
+
+ def set_column(self):
+ """Set column."""
+ if self.isShow:
+ self.run()
+
+ def run(self, rect=None):
+ """Generate patch compare result."""
+ if self.cb_gt.currentText() == '':
+ QtWidgets.QMessageBox.about(self, 'Message', 'Please set gt!')
+ return
+
+ rows = self.tb_files.rowCount()
+ if rows <= 0:
+ QtWidgets.QMessageBox.about(self, 'Message',
+ 'Please add a file at least!')
+ return
+
+ files = []
+ self.labels = []
+ for r in range(rows):
+ if self.tb_files.item(r, 2).text() != 'Reference':
+ files.append(self.tb_files.item(r, 0).text())
+ self.labels.append(self.tb_files.item(r, 1).text())
+
+ file_name = os.path.basename(self.cb_gt.currentText())
+ self.file_paths = []
+ for i, file in enumerate(files):
+ if self.btnGroup_dirType.checkedId() == 0:
+ self.file_paths.append(file + '/' + file_name)
+ else:
+ self.file_paths.append(file)
+
+ mode = self.btnGroup_mode.checkedId()
+ self.images = ConcatImageWidget(self, mode, self.spin_cols.value())
+ self.images.set_images(self.file_paths, self.labels,
+ self.cb_gt.currentText(), self.scale / 100.0,
+ rect)
+ self.image_scroll.setWidget(self.images)
+ self.isShow = True
+
+ def reset(self):
+ """Init window."""
+ self.file_paths = []
+ self.labels = []
+ self.isShow = False
+
+ self.input_label.clear()
+ self.input_file.clear()
+ self.tb_files.setRowCount(0)
+ self.cb_gt.clear()
+ self.spin_cols.setValue(4)
+ self.slider_scale.setValue(100)
+ self.images = ConcatImageWidget(self, 0, self.spin_cols.value())
+ self.image_scroll.setWidget(self.images)
+
+ def save(self):
+ """Save patch compare result."""
+ path, _ = QtWidgets.QFileDialog.getSaveFileName(self, 'save')
+ if len(path) <= 0:
+ return
+ if self.images:
+ self.images.grab().save(path)
+ QtWidgets.QMessageBox.about(self, 'Message', 'Success!')
+ else:
+ QtWidgets.QMessageBox.about(self, 'Message', 'Nothing to save.')
+
+ def wheelEvent(self, ev) -> None:
+ key = QtWidgets.QApplication.keyboardModifiers()
+ if key == QtCore.Qt.ControlModifier:
+ scale = ev.angleDelta().y() / 120
+ self.scale += scale
+ if self.scale < 1:
+ self.scale = 1
+ self.txt_scale.setText(f'{self.scale} %')
+ self.slider_scale.setValue(self.scale)
+ if self.scale > 200 and self.isShow:
+ rect = None
+ if self.rect:
+ rect = [
+ int(r * self.scale / 100.0 / self.old_scale)
+ for r in self.rect
+ ]
+ self.run(rect)
+ return super().wheelEvent(ev)
+
+ def keyPressEvent(self, ev) -> None:
+ if ev.key() == QtCore.Qt.Key_Left:
+ if self.cb_gt.currentIndex() > 0:
+ self.cb_gt.setCurrentIndex(self.cb_gt.currentIndex() - 1)
+ else:
+ self.cb_gt.setCurrentIndex(self.cb_gt.count() - 1)
+ elif ev.key() == QtCore.Qt.Key_Right:
+ if self.cb_gt.currentIndex() < self.cb_gt.count() - 1:
+ self.cb_gt.setCurrentIndex(self.cb_gt.currentIndex() + 1)
+ else:
+ self.cb_gt.setCurrentIndex(0)
+ return super().keyPressEvent(ev)
+
+ def eventFilter(self, object, event) -> bool:
+ if object == self.image_scroll:
+ if event.type() == QtCore.QEvent.KeyPress:
+ self.keyPressEvent(event)
+ return False
+ return super().eventFilter(object, event)
+
+
+class SliderTab(QtWidgets.QWidget):
+
+ def __init__(self, parent):
+ super().__init__()
+ self.parent = parent
+ self.images = [None, None]
+ self.imageArea = None
+
+ # Type setting
+ self.typeRect = QtWidgets.QGroupBox('Type')
+ self.typeRect.setFlat(True)
+ btn_type1 = QtWidgets.QRadioButton('Video', self.typeRect)
+ btn_type2 = QtWidgets.QRadioButton('Image', self.typeRect)
+ self.btnGroup_type = QtWidgets.QButtonGroup()
+ self.btnGroup_type.addButton(btn_type1, 0)
+ self.btnGroup_type.addButton(btn_type2, 1)
+ self.btnGroup_type.button(0).setChecked(True)
+ self.btnGroup_type.idToggled.connect(self.reset)
+ hlayout = QtWidgets.QHBoxLayout()
+ hlayout.addWidget(btn_type1)
+ hlayout.addWidget(btn_type2)
+ self.typeRect.setLayout(hlayout)
+
+ # Mode setting
+ self.modeRect = QtWidgets.QGroupBox('Mode')
+ self.modeRect.setFlat(True)
+ btn_mode1 = QtWidgets.QRadioButton('Match', self.modeRect)
+ btn_mode2 = QtWidgets.QRadioButton('Single', self.modeRect)
+ self.btnGroup_mode = QtWidgets.QButtonGroup()
+ self.btnGroup_mode.addButton(btn_mode1, 0)
+ self.btnGroup_mode.addButton(btn_mode2, 1)
+ self.btnGroup_mode.button(0).setChecked(True)
+ self.btnGroup_mode.idToggled.connect(self.reset)
+ hlayout = QtWidgets.QHBoxLayout()
+ hlayout.addWidget(btn_mode1)
+ hlayout.addWidget(btn_mode2)
+ self.modeRect.setLayout(hlayout)
+
+ # Settings
+ self.cb_1 = QtWidgets.QComboBox()
+ self.cb_2 = QtWidgets.QComboBox()
+ self.cb_1.currentIndexChanged.connect(self.change_image_1)
+ self.cb_2.currentIndexChanged.connect(self.change_image_2)
+ self.input_label_1 = QtWidgets.QLineEdit()
+ self.input_label_2 = QtWidgets.QLineEdit()
+ self.input_title = QtWidgets.QLineEdit()
+ self.input_label_1.textChanged.connect(self.set_label)
+ self.input_label_2.textChanged.connect(self.set_label)
+ self.input_title.textChanged.connect(self.set_label)
+
+ # Set scale
+ self.scale = 100
+ self.txt_scale = QtWidgets.QLabel('100 %')
+ self.slider_scale = QtWidgets.QSlider(QtCore.Qt.Horizontal)
+ self.slider_scale.setMinimum(1)
+ self.slider_scale.setMaximum(200)
+ self.slider_scale.setValue(100)
+ self.slider_scale.valueChanged.connect(self.set_scale)
+
+ # Auto slider setting
+ self.autoRect = QtWidgets.QGroupBox('Auto Slider')
+ self.autoRect.setFlat(True)
+ btn_autoMode1 = QtWidgets.QRadioButton('Right', self.autoRect)
+ btn_autoMode2 = QtWidgets.QRadioButton('Left', self.autoRect)
+ btn_autoMode3 = QtWidgets.QRadioButton('Alternate', self.autoRect)
+ self.btnGroup_auto = QtWidgets.QButtonGroup()
+ self.btnGroup_auto.addButton(btn_autoMode1, 0)
+ self.btnGroup_auto.addButton(btn_autoMode2, 1)
+ self.btnGroup_auto.addButton(btn_autoMode3, 2)
+ self.btnGroup_auto.button(2).setChecked(True)
+ self.btnGroup_auto.idToggled.connect(self.set_autoSlider)
+ self.slider_auto = QtWidgets.QSlider(QtCore.Qt.Horizontal)
+ self.slider_auto.setMinimum(0)
+ self.slider_auto.setMaximum(200)
+ self.slider_auto.setValue(0)
+ self.slider_auto.valueChanged.connect(self.set_autoSlider)
+ glayout = QtWidgets.QGridLayout()
+ glayout.addWidget(QtWidgets.QLabel('Direction:'), 0, 0, 1, 1)
+ glayout.addWidget(btn_autoMode3, 0, 2, 1, 1)
+ glayout.addWidget(btn_autoMode1, 0, 3, 1, 1)
+ glayout.addWidget(btn_autoMode2, 0, 4, 1, 1)
+ glayout.addWidget(QtWidgets.QLabel('Speed:'), 1, 0, 1, 1)
+ glayout.addWidget(self.slider_auto, 1, 1, 1, 4)
+ self.autoRect.setLayout(glayout)
+
+ # Add file
+ self.btn_add_1 = QtWidgets.QPushButton()
+ self.btn_add_2 = QtWidgets.QPushButton()
+ self.btn_add_1.setText('Add a video')
+ self.btn_add_2.setText('Add frames')
+ self.btn_add_1.clicked.connect(self.add_1)
+ self.btn_add_2.clicked.connect(self.add_2)
+
+ # Buttons
+ self.btn_pause = QtWidgets.QPushButton()
+ self.btn_pause.setText('Pause (Space)')
+ self.btn_pause.clicked.connect(self.pause)
+ self.btn_pause.setEnabled(False)
+
+ self.btn_reset = QtWidgets.QPushButton()
+ self.btn_reset.setText('Reset')
+ self.btn_reset.clicked.connect(self.reset)
+
+ self.btn_save = QtWidgets.QPushButton()
+ self.btn_save.setText('Save')
+ self.btn_save.clicked.connect(self.save)
+
+ self.btn_record = QtWidgets.QPushButton()
+ self.btn_record.setText('Record (Enter)')
+ self.btn_record.clicked.connect(self.record)
+
+ self.txt_prompt = QtWidgets.QLabel()
+ self.txt_prompt.setStyleSheet('color: red; font-size: 28px')
+
+ left_grid = QtWidgets.QGridLayout()
+ left_grid.addWidget(self.typeRect, 0, 0, 1, 10)
+ left_grid.addWidget(self.modeRect, 1, 0, 1, 10)
+ left_grid.addWidget(QtWidgets.QLabel('Left'), 2, 0, 1, 1)
+ left_grid.addWidget(self.cb_1, 2, 1, 1, 9)
+ left_grid.addWidget(QtWidgets.QLabel('Right'), 3, 0, 1, 1)
+ left_grid.addWidget(self.cb_2, 3, 1, 1, 9)
+ left_grid.addWidget(QtWidgets.QLabel('Set label 1'), 4, 0, 1, 1)
+ left_grid.addWidget(self.input_label_1, 4, 1, 1, 9)
+ left_grid.addWidget(QtWidgets.QLabel('Set label 2'), 5, 0, 1, 1)
+ left_grid.addWidget(self.input_label_2, 5, 1, 1, 9)
+ left_grid.addWidget(QtWidgets.QLabel('Set title'), 6, 0, 1, 1)
+ left_grid.addWidget(self.input_title, 6, 1, 1, 9)
+ left_grid.addWidget(QtWidgets.QLabel('Set scale'), 7, 0, 1, 1)
+ left_grid.addWidget(self.slider_scale, 7, 1, 1, 8)
+ left_grid.addWidget(self.txt_scale, 7, 9, 1, 1)
+ left_grid.addWidget(self.autoRect, 8, 0, 1, 10)
+ left_grid.addWidget(self.btn_add_1, 9, 0, 1, 5)
+ left_grid.addWidget(self.btn_add_2, 9, 5, 1, 5)
+ left_grid.addWidget(self.btn_pause, 10, 0, 1, 10)
+ left_grid.addWidget(self.btn_reset, 11, 0, 1, 10)
+ left_grid.addWidget(self.btn_save, 12, 0, 1, 10)
+ left_grid.addWidget(self.btn_record, 13, 0, 1, 10)
+ left_grid.addWidget(QtWidgets.QLabel(), 14, 0, 10, 10)
+ left_grid.addWidget(self.txt_prompt, 25, 0, 20, 10)
+
+ # Image area
+ self.image_scroll = QtWidgets.QScrollArea()
+ self.image_scroll.setAlignment(QtCore.Qt.AlignCenter)
+ self.image_scroll.installEventFilter(self)
+ right_grid = QtWidgets.QGridLayout()
+ right_grid.addWidget(self.image_scroll, 0, 0)
+
+ # Splitter
+ hsplitter = QtWidgets.QSplitter(QtCore.Qt.Horizontal)
+ hsplitter.addWidget(layout2widget(left_grid))
+ hsplitter.addWidget(layout2widget(right_grid))
+ hsplitter.setStretchFactor(0, 1)
+ hsplitter.setStretchFactor(1, 5)
+ hlayout = QtWidgets.QHBoxLayout()
+ hlayout.addWidget(hsplitter)
+ self.setLayout(hlayout)
+
+ # Timer
+ self.timer_slider = QtCore.QTimer(self)
+ self.timer_slider.timeout.connect(self.auto_slider)
+ self.timer_record = QtCore.QTimer(self)
+ self.timer_record.timeout.connect(self.recording)
+
+ # Player
+ self.player = VideoSlider(self)
+ self.player.sigout.connect(self.setImg)
+ # self.player1_end = False
+ # self.player2_end = False
+ # self.player1 = VideoPlayer(self)
+ # self.player2 = VideoPlayer(self)
+ # self.player1.sigout.connect(self.setImg1)
+ # self.player2.sigout.connect(self.setImg2)
+ # self.player1.sigend.connect(self.set_player1)
+ # self.player2.sigend.connect(self.set_player2)
+ self.record_num = 0
+ self.show_image()
+
+ def set_autoSlider(self):
+ if self.imageArea:
+ self.timer_slider.stop()
+ if self.slider_auto.value() > 0:
+ self.imageArea.set_auoMode(self.btnGroup_auto.checkedId() + 1)
+ self.timer_slider.start(1000 / (self.slider_auto.value()))
+ else:
+ self.imageArea.set_auoMode(0)
+
+ def auto_slider(self):
+ self.imageArea.auto_slider()
+
+ def add_image(self, cb, btn):
+ if self.btnGroup_mode.checkedId() == 0:
+ path = QtWidgets.QFileDialog.getExistingDirectory(self)
+ if len(path) <= 0:
+ return
+ cb.clear()
+ files = sorted(os.listdir(path))
+ for f in files:
+ cb.addItem(path + '/' + f)
+ else:
+ if btn == 'add_1':
+ path, _ = QtWidgets.QFileDialog.getOpenFileName(
+ self, 'Select gt file', '', 'Images (*.jpg *.png)')
+ if len(path) <= 0:
+ return
+ files = [path]
+ elif btn == 'add_2':
+ paths = QtWidgets.QFileDialog.getExistingDirectory(self)
+ if len(paths) <= 0:
+ return
+ files = sorted(os.listdir(paths))
+ files = [paths + '/' + f for f in files]
+
+ for path in files:
+ if self.cb_1.count() > 0:
+ if self.cb_1.findText(path) > -1:
+ self.cb_1.removeItem(self.cb_1.findText(path))
+ if self.cb_2.findText(path) > -1:
+ self.cb_2.removeItem(self.cb_2.findText(path))
+ self.cb_1.addItem(path)
+ self.cb_2.addItem(path)
+ self.cb_2.setCurrentIndex(self.cb_2.count() - 1)
+ else:
+ self.cb_1.addItem(path)
+ self.cb_2.addItem(path)
+ self.cb_1.setCurrentIndex(0)
+ self.cb_2.setCurrentIndex(0)
+
+ def add_1(self):
+ if self.btnGroup_type.checkedId() == 1:
+ self.add_image(self.cb_1, 'add_1')
+ else:
+ path, _ = QtWidgets.QFileDialog.getOpenFileName(
+ self, 'Select gt file', '', 'Images (*.mp4 *.avi)')
+ if len(path) <= 0:
+ return
+ if self.cb_1.count() > 0:
+ if self.cb_1.findText(path) > -1:
+ self.cb_1.removeItem(self.cb_1.findText(path))
+ if self.cb_2.findText(path) > -1:
+ self.cb_2.removeItem(self.cb_2.findText(path))
+ self.cb_1.addItem(path)
+ self.cb_2.addItem(path)
+ self.cb_2.setCurrentIndex(self.cb_2.count() - 1)
+ else:
+ self.cb_1.addItem(path)
+ self.cb_2.addItem(path)
+ self.cb_1.setCurrentIndex(0)
+ self.cb_2.setCurrentIndex(0)
+
+ def add_2(self):
+ if self.btnGroup_type.checkedId() == 1:
+ self.add_image(self.cb_2, 'add_2')
+ else:
+ path = QtWidgets.QFileDialog.getExistingDirectory(self)
+ if len(path) <= 0:
+ return
+ if self.cb_1.count() > 0:
+ if self.cb_1.findText(path) > -1:
+ self.cb_1.removeItem(self.cb_1.findText(path))
+ if self.cb_2.findText(path) > -1:
+ self.cb_2.removeItem(self.cb_2.findText(path))
+ self.cb_1.addItem(path)
+ self.cb_2.addItem(path)
+ self.cb_2.setCurrentIndex(self.cb_2.count() - 1)
+ else:
+ self.cb_1.addItem(path)
+ self.cb_2.addItem(path)
+ self.cb_1.setCurrentIndex(0)
+ self.cb_2.setCurrentIndex(0)
+
+ def set_label(self):
+ if self.imageArea is not None:
+ self.imageArea.label_1 = self.input_label_1.text()
+ self.imageArea.label_2 = self.input_label_2.text()
+ self.imageArea.title = self.input_title.text()
+ self.imageArea.update()
+
+ def set_scale(self):
+ """Set scale."""
+ self.scale = self.slider_scale.value()
+ self.txt_scale.setText(f'{self.scale} %')
+ if self.imageArea is not None:
+ self.imageArea.set_scale(self.scale / 100.0)
+
+ def change_image_1(self):
+ if self.btnGroup_type.checkedId() == 0:
+ self.change_video()
+ else:
+ if self.btnGroup_mode.checkedId() == 0:
+ self.cb_2.setCurrentIndex(self.cb_1.currentIndex())
+ self.images[0] = cv2.imread(self.cb_1.currentText())
+ self.images[1] = cv2.imread(self.cb_2.currentText())
+ self.show_image()
+
+ def change_image_2(self):
+ if self.btnGroup_type.checkedId() == 0:
+ self.change_video()
+ else:
+ if self.btnGroup_mode.checkedId() == 0:
+ self.cb_1.setCurrentIndex(self.cb_2.currentIndex())
+ self.images[0] = cv2.imread(self.cb_1.currentText())
+ self.images[1] = cv2.imread(self.cb_2.currentText())
+ self.show_image()
+
+ def change_video(self):
+ if self.cb_1.currentText() != '' and self.cb_2.currentText() != '':
+ self.player.set(self.cb_1.currentText(), self.cb_2.currentText())
+ self.player.start()
+ self.btn_pause.setEnabled(True)
+
+ # def setImg1(self, img):
+ # self.images[0] = img
+ # self.imageArea.setImage(self.images)
+ # def setImg2(self, img):
+ # self.images[1] = img
+ # self.imageArea.setImage(self.images)
+ def setImg(self, images):
+ self.images = images
+ self.imageArea.setImage(self.images)
+
+ def show_image(self):
+ self.imageArea = QLabelSlider(self,
+ self.slider_scale.value() / 100.0 + 1e-7,
+ self.input_label_1.text(),
+ self.input_label_2.text(),
+ self.input_title.text())
+ self.imageArea.setFrameShape(QtWidgets.QFrame.Box)
+ self.imageArea.setLineWidth(2)
+ self.imageArea.setAlignment(QtCore.Qt.AlignBottom)
+ self.imageArea.setStyleSheet(
+ 'border-width: 0px; border-style: solid; border-color: rgb(100, 100, 100);background-color: rgb(255, 255, 255)' # noqa
+ )
+ self.image_scroll.setWidget(self.imageArea)
+
+ def pause(self):
+ if self.btn_pause.text() == 'Pause (Space)':
+ self.player.pause()
+ self.timer_slider.stop()
+ self.btn_pause.setText('Play (Space)')
+ else:
+ self.player.resume()
+ self.set_autoSlider()
+ self.btn_pause.setText('Pause (Space)')
+
+ def reset(self):
+ self.images = [None, None]
+ self.cb_1.clear()
+ self.cb_2.clear()
+ self.show_image()
+ if self.btnGroup_type.checkedId() == 1:
+ if self.btnGroup_mode.checkedId() == 0:
+ self.btn_add_1.setText('Set image 1')
+ self.btn_add_2.setText('Set image 2')
+ else:
+ self.btn_add_1.setText('Add file')
+ self.btn_add_2.setText('Add directory')
+ else:
+ self.btn_add_1.setText('Add a video')
+ self.btn_add_2.setText('Add frames')
+
+ self.btn_pause.setText('Pause (Space)')
+ self.btn_pause.setEnabled(False)
+ self.btn_record.setText('Record (Enter)')
+
+ self.timer_slider.stop()
+ self.timer_record.stop()
+
+ def save(self):
+ """Save slider compare result."""
+ path, _ = QtWidgets.QFileDialog.getSaveFileName(self, 'save')
+ if len(path) <= 0:
+ return
+ if self.imageArea:
+ self.imageArea.grab().save(path)
+ QtWidgets.QMessageBox.about(self, 'Message', 'Success!')
+ else:
+ QtWidgets.QMessageBox.about(self, 'Message', 'Nothing to save.')
+
+ def record(self):
+ if self.btn_record.text() == 'Record (Enter)':
+ self.record_num = 0
+ self.timer_record.start(1000 / 25)
+ self.txt_prompt.setText('Recording...')
+ self.btn_record.setText('End (Enter)')
+ elif self.btn_record.text() == 'End (Enter)':
+ paths = sorted(os.listdir('.tmp/'))
+ paths = ['.tmp/' + p for p in paths]
+ if len(paths) <= 0:
+ return
+ img = cv2.imread(paths[0])
+ h, w, _ = img.shape
+ cur = datetime.datetime.now()
+ fname = f'{cur.year}{cur.month}{cur.day}{cur.hour}{cur.minute}{cur.second}.mp4' # noqa
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ self.recorder = cv2.VideoWriter(fname, fourcc, 25, (w, h))
+ self.txt_prompt.setText('Saving...')
+ for i in range(self.record_num):
+ img = cv2.imread(paths[i])
+ img = cv2.resize(img, (w, h))
+ self.recorder.write(img)
+ os.remove(paths[i])
+ self.recorder.release()
+ self.timer_record.stop()
+ self.txt_prompt.setText('')
+ self.btn_record.setText('Record (Enter)')
+ QtWidgets.QMessageBox.about(self, 'Message',
+ f'Save {fname} success!')
+
+ def recording(self):
+ if not os.path.isdir('.tmp/'):
+ os.makedirs('.tmp/')
+ fname = '.tmp/' + str(self.record_num).zfill(8) + '.png'
+ self.record_num += 1
+ self.imageArea.grab().save(fname)
+
+ def wheelEvent(self, ev) -> None:
+ key = QtWidgets.QApplication.keyboardModifiers()
+ if key == QtCore.Qt.ControlModifier:
+ scale = ev.angleDelta().y() / 120
+ self.scale += scale
+ if self.scale < 1:
+ self.scale = 1
+ self.txt_scale.setText(f'{self.scale} %')
+ if self.scale > 200 and self.imageArea is not None:
+ self.imageArea.set_scale(self.scale / 100.0)
+ else:
+ self.slider_scale.setValue(self.scale)
+ return super().wheelEvent(ev)
+
+ def keyPressEvent(self, ev) -> None:
+ if ev.key() == QtCore.Qt.Key_Left:
+ if self.cb_1.currentIndex() > 0:
+ self.cb_1.setCurrentIndex(self.cb_1.currentIndex() - 1)
+ else:
+ self.cb_1.setCurrentIndex(self.cb_1.count() - 1)
+ elif ev.key() == QtCore.Qt.Key_Right:
+ if self.cb_1.currentIndex() < self.cb_1.count() - 1:
+ self.cb_1.setCurrentIndex(self.cb_1.currentIndex() + 1)
+ else:
+ self.cb_1.setCurrentIndex(0)
+ elif ev.key(
+ ) == QtCore.Qt.Key_Enter or ev.key() + 1 == QtCore.Qt.Key_Enter:
+ self.record()
+ elif ev.key() == QtCore.Qt.Key_Space:
+ self.pause()
+
+ def eventFilter(self, object, event) -> bool:
+ if object == self.image_scroll or object == self:
+ if event.type() == QtCore.QEvent.KeyPress:
+ self.keyPressEvent(event)
+ return True
+ return super().eventFilter(object, event)
+
+
+class SRPage(QtWidgets.QWidget):
+
+ def __init__(self, parent) -> None:
+ super().__init__()
+
+ self.parent = parent
+ self.statusBar = self.parent.statusBar
+
+ self.tab_patch = PatchTab(self)
+ self.tab_slider = SliderTab(self)
+ self.tabs = QtWidgets.QTabWidget()
+ self.tabs.addTab(self.tab_patch, 'patch compare')
+ self.tabs.addTab(self.tab_slider, 'before/after slider')
+
+ layout = QtWidgets.QVBoxLayout()
+ self.setLayout(layout)
+ layout.addWidget(self.tabs)
+ # self.tabs.currentChanged.connect(self.tabsCurrentChanged)
diff --git a/tools/gui/utils.py b/tools/gui/utils.py
new file mode 100644
index 0000000000..a6a9fd9cd1
--- /dev/null
+++ b/tools/gui/utils.py
@@ -0,0 +1,22 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+from PyQt5 import QtGui, QtWidgets
+
+
+def layout2widget(layout):
+ wg = QtWidgets.QWidget()
+ wg.setLayout(layout)
+ return wg
+
+
+def qimage2array(img):
+ w = img.width()
+ h = img.height()
+ img = img.convertToFormat(QtGui.QImage.Format.Format_RGBA8888)
+ img = img.bits()
+ img.setsize(w * h * 4)
+ img = np.frombuffer(img, np.uint8)
+ img = np.reshape(img, (h, w, 4))
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
+ return img
diff --git a/tools/model_converters/pytorch2onnx.py b/tools/model_converters/pytorch2onnx.py
index 09919c2b6f..f0dabb00ad 100644
--- a/tools/model_converters/pytorch2onnx.py
+++ b/tools/model_converters/pytorch2onnx.py
@@ -1,18 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
-import glob
import os.path as osp
-import re
import warnings
-from functools import reduce
import cv2
-import mmcv
import numpy as np
import onnx
import onnxruntime as rt
import torch
from mmcv.onnx import register_extra_symbolics
+from mmengine import Config
from mmengine.dataset import Compose
from mmengine.runner import load_checkpoint
@@ -63,7 +60,7 @@ def pytorch2onnx(model,
img = input['inputs'].unsqueeze(0)
data = torch.cat((img, masks), dim=1)
elif model_type == 'video_restorer':
- data = input['inputs'].unsqueeze(0)
+ data = input['inputs'].unsqueeze(0).float()
data = data.to(device)
# pytorch has some bug in pytorch1.3, we have to fix it
@@ -158,6 +155,8 @@ def parse_args():
'--mask-path',
default=None,
help='path to input mask file, used in inpainting model')
+ parser.add_argument('--num-frames', type=int, default=None)
+ parser.add_argument('--sequence-length', type=int, default=None)
parser.add_argument('--device', type=int, default=0, help='CUDA device id')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
@@ -188,7 +187,7 @@ def parse_args():
else:
device = torch.device('cuda', args.device)
- config = mmcv.Config.fromfile(args.config)
+ config = Config.fromfile(args.config)
delete_cfg(config, key='init_cfg')
# ONNX does not support spectral norm
@@ -217,6 +216,8 @@ def parse_args():
keys_to_remove = ['alpha', 'ori_alpha']
elif model_type == 'image_restorer':
keys_to_remove = ['gt', 'gt_path']
+ elif model_type == 'video_restorer':
+ keys_to_remove = ['gt', 'gt_path']
else:
keys_to_remove = []
for key in keys_to_remove:
@@ -244,17 +245,15 @@ def parse_args():
f'"GenerateSegmentIndices", but got '
f'"{test_pipeline[0]["type"]}".')
# prepare data
- sequence_length = len(glob.glob(osp.join(args.img_path, '*')))
- img_dir_split = re.split(r'[\\/]', args.img_path)
- if img_dir_split[0] == '':
- img_dir_split[0] = '/'
- key = img_dir_split[-1]
- lq_folder = reduce(osp.join, img_dir_split[:-1])
+ # sequence_length = len(glob.glob(osp.join(args.img_path, '*')))
+ lq_folder = osp.dirname(args.img_path)
+ key = osp.basename(args.img_path)
data = dict(
img_path=lq_folder,
gt_path='',
key=key,
- sequence_length=sequence_length)
+ num_frames=args.num_frames,
+ sequence_length=args.sequence_length)
# build the data pipeline
test_pipeline = Compose(test_pipeline)
diff --git a/tools/test.py b/tools/test.py
index 58dcdb9d8c..f353544daa 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -5,6 +5,7 @@
import mmengine
from mmengine.config import Config, DictAction
+from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmedit.utils import print_colored_log, register_all_modules
@@ -73,7 +74,7 @@ def main():
if args.out:
- class SaveMetricHook(mmengine.Hook):
+ class SaveMetricHook(Hook):
def after_test_epoch(self, _, metrics=None):
if metrics is not None: