diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..ed4203e --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +.git + +# input data, saved log, checkpoints +data/ +input/ +saved/ +datasets/ + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9eb0f0f --- /dev/null +++ b/.gitignore @@ -0,0 +1,46 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +*.deb + +# Jupyter Notebook +.ipynb_checkpoints + +# input data, saved log, checkpoints, notebooks,... +data/ +input/ +saved/ +datasets/ +paper_data/ +checkpoints/ +jupyter/ +experiments_train.sh +experiments_test.sh + +# editor, os cache directory +.vscode/ +.idea/ +__MACOSX/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..92d9233 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Dinesh Daultani + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..cc1c446 --- /dev/null +++ b/README.md @@ -0,0 +1,172 @@ +# Consolidating separate degradations model via weights fusion and distillation + +Authors: [Dinesh Daultani](https://dineshdaultani.github.io/), [Hugo Larochelle](https://mila.quebec/en/person/hugo-larochelle/) + +[[paper](https://openaccess.thecvf.com/content/WACV2024W/VAQ/papers/Daultani_Consolidating_Separate_Degradations_Model_via_Weights_Fusion_and_Distillation_WACVW_2024_paper.pdf)] + +This repository contains the source code associated with the paper "Consolidating separate degradations model via weights fusion and distillation", which was presented at WACV 2024 workshop. + +## Abstract +Real-world images prevalently contain different varieties of degradation, such as motion blur and luminance noise. Computer vision recognition models trained on clean images perform poorly on degraded images. Previously, several works have explored how to perform image classification of degraded images while training a single model for each degradation. Nevertheless, it becomes challenging to host several degradation models for each degradation on limited hardware applications and to estimate degradation parameters correctly at the run-time. This work proposes a method for effectively combining several models trained separately on different degradations into a single model to classify images with different types of degradations. Our proposed method is four-fold: (1) train a base model on clean images, (2) fine-tune the base model individually for all given image degradations, (3) perform a fusion of weights given the fine-tuned models for individual degradations, (4) perform fine-tuning on given task using distillation and cross-entropy loss. Our proposed method can outperform previous state-of-the-art methods of pretraining in out-of-distribution generalization based on degradations such as JPEG compression, salt-and-pepper noise, Gaussian blur, and additive white Gaussian noise by 2.5% on CIFAR-100 dataset and by 1.3% on CIFAR-10 dataset. Moreover, our proposed method can handle degradation used for training without any explicit information about degradation at the inference time. + +## Brief Introduction +Computer vision has been widely used in real-world applications nowadays. Considerable research has focused on the assumption that the images do not contain abnormalities and only ideal images. Real-world images frequently have different perturbations, like motion blur, noise (caused by low-light conditions), and compression, appearing in various digital versions of images/videos. Specifically, some computer vision domains face challenges with image degradation, leading to diminishing model performance or reliability. Hence, degradation in vision is often unavoidable, so it is crucial to handle degraded images properly. At the same time, often, the models trained are specific for a particular degradation. However, our study focuses on an essential aspect of this limitation: combining several models trained separately on individual degradations. To the best of our knowledge, this study is the first to investigate the method of combining separately trained degradation models into a single model for the classification of images with distinct types of degradation. + +Our proposed method is split into four steps as follows: +1. Train a base model $\mu_{clean}$ on clean images. +2. Fine-tune the base model $\mu_{clean}$ individually for each degradation $deg$, i.e., $\mu_{deg}$. +3. Perform fusion of weights given the fine-tuned models $\mu_{deg}$ for individual degradations as $\sigma$. +4. Perform fine-tuning on all degradation images using distillation and cross-entropy loss as $\sigma_{tuned}$. + +

+ + + Figure 1. Fine-tuning of Student $\sigma_{tuned}$ using fusion model initialized with weights $\sigma$ and knowledge transfer from Teachers $\mu_{deg_{1}}$, $\mu_{deg_{2}}$, ..., $\mu_{deg_{N}}$ where $N$ represents total individual degradations used for consolidation in the student network. +

+ + +## Quantitative evaluation results + +Performance evaluation for comparison approaches, applied to different datasets on ResNet56 backbones. These datasets undergo assessment under four distinct degradations, i.e., JPEG compression, Gaussian blur, additive white Gaussian noise, and salt-and-pepper noise, denoted as JPEG, Blur, AWGN, and SAPN, respectively. The "Avg" column contains the average for the above four degradations. Moreover, results in **bold** and underline represent the best performance for combined degradation and separate/combined models, respectively. + +### CIFAR-100 dataset +| Approach | JPEG | Blur | AWGN | SAPN | Avg | +|:-------------------------:|:----:|:----:|:----:|:----:|:----:| +| Base separate (Oracle) | 63.4 | 58.2 | 65.4 | 74.1 | 65.3 | +| Ensemble | 59.3 | 25.3 | 43.6 | 58.7 | 46.7 | +| Scratch | 59.1 | 54.3 | 61.0 | 66.0 | 60.1 | +| Vanilla fine-tuning | 62.2 | 56.0 | 63.8 | 70.2 | 63.1 | +| ModelSoups | 1.1 | 1.0 | 1.2 | 1.0 | 1.1 | +| Fusing | 62.3 | 56.4 | 64.1 | 70.5 | 63.3 | +| FusionDistill (Ours) | **64.5** | **58.6** | **66.4** | **73.7** | **65.8** | + +### Tiny Imagenet dataset +| Approach | JPEG | Blur | AWGN | SAPN | Avg | +|:-------------------------:|:----:|:----:|:----:|:----:|:----:| +| Base separate (Oracle) | 56.3 | 50.3 | 57.7 | 60.6 | 56.2 | +| Ensemble | 53.1 | 17.4 | 45.4 | 49.9 | 35.5 | +| Scratch | 54.0 | 48.2 | 55.3 | 58.2 | 53.9 | +| Vanilla fine-tuning | 55.2 | 48.1 | 56.6 | 59.7 | 54.9 | +| ModelSoups | 0.5 | 0.5 | 0.5 | 0.5 | 0.5 | +| Fusing | 54.2 | 47.1 | 55.2 | 58.8 | 53.8 | +| FusionDistill (Ours) | **56.7** | **48.5** | **58.2** | **61.9** | **56.3** | + +For more details of our proposed method and relevant details please refer to our [paper](https://openaccess.thecvf.com/content/WACV2024W/VAQ/papers/Daultani_Consolidating_Separate_Degradations_Model_via_Weights_Fusion_and_Distillation_WACVW_2024_paper.pdf). + +## Requirements +To install all the pre-requisite libraries install the docker container from `docker/Dockerfile`. Following python libraries are main pre-requisites: +* pytorch==1.12.0 +* imagedegrade + +To properly reproduce the results from the paper, please use the provided docker `Dockerfile` and `docker/requirements.txt`. Sample docker build and run commands are as follows: +* `docker build -f docker/Dockerfile -t pytorch_1.12 .` +* `docker run -v /:/data --shm-size 50G -p 8008:8008 -it --gpus '"device=0,1"' pytorch_1.12 /bin/bash` + +## Training / Evaluation +### Proposed method +Proposed method training is done in four steps: +1. Step-1: Train a base model on clean images. + - Training: + ``` + python train.py -c configs/ind//ResNet56_clean.yaml + ``` + By default the clean image models reside in `saved/jpeg/IndTrainer` directory. + +2. Step-2: Fine-tune the base model trained on clean images for each degradation individually. + - Training: + ``` + python train.py -c configs/sl//ResNet56-56.yaml --dt + ``` + +3. Step-3: Perform fusion of weights given the fine-tuned models for individual degradations + - Fusion (ModelSoups): + ``` + python utils/model_soups.py --dataset + ``` + + We need to specify the dataset and it will by default read the ResNet config file for model soups in the specific dataset folder. + +4. Step-4: Perform fine-tuning on all degradation images using distillation and cross-entropy loss + - Training: + ``` + python train_all_deg.py -c configs/deg_all//ResNet56-56_consistency.yaml --dt combined_deg + ``` + - Evaluation: + ``` + python test.py -r saved/combined_deg/SLDA_Trainer/ResNet56-56_CIFAR10/train//model_best.pth --dt + ``` + +Please refer to the notes at the end for some common details to run the experiments. + +### Running existing method experiments (Baselines) +#### Base separate (Oracle) + Oracle method is nothing but the Step-2 models trained on degradation individually. Hence, we need to run evaluation after Step-2: + - Evaluation: + ``` + python test.py -r saved//SLTrainer/ResNet56-56_/train//model_best.pth + ``` + +#### Ensemble + This method uses combination of all individually trained degradation models. Hence, we need to only run the evaluation out-of-the-box using separately prepared ensemble script: + - Evaluation: + ``` + python test_ensemble.py -c configs/deg_all//ResNet56_ensemble.yaml + ``` + +#### Scratch + - Training: + ``` + python train_all_deg.py -c configs/deg_all//ResNet56_base_scratch.yaml --dt combined_deg + ``` + - Evaluation: + ``` + python test.py -r saved/combined_deg/IndDATrainer/ResNet56__deg_scratch/train//model_best.pth --dt + ``` + +#### Vanilla Finetuning + - Training: + ``` + python train_all_deg.py -c configs/deg_all//ResNet56_base_vanilla.yaml --dt combined_deg + ``` + - Evaluation: + ``` + python test.py -r saved/combined_deg/IndDATrainer/ResNet56__deg_vanilla/train//model_best.pth --dt + ``` + +#### ModelSoups + ModelSoups method doesn't involve any training and running evaluation is after the Step-3 of the proposed method, i.e., the weight fusion process. + - Evaluation: + ``` + python test.py -r saved/combined_deg/SLTrainer/ResNet56-56__soups/train//model_best.pth --dt + ``` + +#### Fusing + - Training: + ``` + python train_all_deg.py -c configs/deg_all//ResNet56_fused.yaml --dt combined_deg + ``` + - Evaluation: + ``` + python test.py -r saved/combined_deg/IndDATrainer/ResNet56__deg_fused/train//model_best.pth --dt + ``` + +### Notes for running experiments: +- Replace with either "CIFAR10", "CIFAR100" or "TinyImagenet" to run the experiments for specific dataset for training or evaluation scripts. +- Replace with degradations such as `jpeg`, `blur`, `noise`, `saltpepper` when training for individual degradation, i.e., Step-2. Consequently, replace with `combined_deg` when training for all degradations, i.e., Step-4. +- If you train the model from scratch you need to change `pretrained_path` paths in config yaml files. +- Replace with appropriate run id after running the training. +- Evaluation script is run for specific degradation. + +## Citation +If you find our work or this repository helpful, please consider citing our work: + +```bibtex +@InProceedings{Daultani_2024_WACV, + author = {Daultani, Dinesh and Larochelle, Hugo}, + title = {Consolidating Separate Degradations Model via Weights Fusion and Distillation}, + booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV) Workshops}, + month = {January}, + year = {2024}, + pages = {440-449} +} +``` \ No newline at end of file diff --git a/base/__init__.py b/base/__init__.py new file mode 100644 index 0000000..19c2224 --- /dev/null +++ b/base/__init__.py @@ -0,0 +1,3 @@ +from .base_data_loader import * +from .base_model import * +from .base_trainer import * diff --git a/base/base_data_loader.py b/base/base_data_loader.py new file mode 100644 index 0000000..2c62a3e --- /dev/null +++ b/base/base_data_loader.py @@ -0,0 +1,72 @@ +import numpy as np +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate +from torch.utils.data.sampler import SubsetRandomSampler +from utils.util import seed_worker +import torch + +class BaseDataLoader(DataLoader): + """ + Base class for all data loaders + """ + def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, + collate_fn=default_collate, pin_memory = True, persistent_workers = True, seed = 123): + self.validation_split = validation_split + self.shuffle = shuffle + + self.batch_idx = 0 + self.n_samples = len(dataset) + + self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) + + # Based on torch recommendation on data loader randomness + # https://pytorch.org/docs/stable/notes/randomness.html#dataloader + g = torch.Generator() + g.manual_seed(seed) + + self.init_kwargs = { + 'dataset': dataset, + 'batch_size': batch_size, + 'shuffle': self.shuffle, + 'collate_fn': collate_fn, + 'num_workers': num_workers, + 'pin_memory': pin_memory, + 'persistent_workers': persistent_workers, + 'worker_init_fn': seed_worker, + 'generator': g, + } + super().__init__(sampler=self.sampler, **self.init_kwargs) + + def _split_sampler(self, split): + if split == 0.0: + return None, None + + idx_full = np.arange(self.n_samples) + + np.random.seed(0) + np.random.shuffle(idx_full) + + if isinstance(split, int): + assert split > 0 + assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." + len_valid = split + else: + len_valid = int(self.n_samples * split) + + valid_idx = idx_full[0:len_valid] + train_idx = np.delete(idx_full, np.arange(0, len_valid)) + + train_sampler = SubsetRandomSampler(train_idx) + valid_sampler = SubsetRandomSampler(valid_idx) + + # turn off shuffle option which is mutually exclusive with sampler + self.shuffle = False + self.n_samples = len(train_idx) + + return train_sampler, valid_sampler + + def split_validation(self): + if self.valid_sampler is None: + return None + else: + return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) diff --git a/base/base_model.py b/base/base_model.py new file mode 100644 index 0000000..84c3c66 --- /dev/null +++ b/base/base_model.py @@ -0,0 +1,71 @@ +import torch.nn as nn +import numpy as np +from abc import abstractmethod + + +class BaseModel(nn.Module): + """ + Base class for all models + """ + def __init__(self, deg_flag = 'clean'): + """ + Constructor + Args: + num_of_features (int) : The number of features extracted by the feature extractor. + init_weights (bool) : True if you initialize weights. + deg_flag (int) : clean when training with clean images + deg when training with clean images + """ + super(BaseModel, self).__init__() + self.deg_flag = deg_flag + print('Creating network object with deg_flag: ', self.deg_flag) + + def define_input(self, *inputs): + """ + Forward pass logic + if deg_flag is True, run forward pass on degraded images + else run forward pass on clean images + :return: Model output + """ + clean_img, deg_img = inputs + if self.deg_flag == 'clean': + return clean_img + elif self.deg_flag == 'deg': + return deg_img + else: + raise NotImplementedError + + + def _init_weight_layers(self, layers): + """ + Initialize each layer depends on the layer type + """ + for layer in layers.modules(): + if isinstance(layer, nn.Conv2d): + nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') + if layer.bias is not None: + nn.init.constant_(layer.bias, 0) + elif isinstance(layer, nn.BatchNorm2d): + nn.init.constant_(layer.weight, 1) + nn.init.constant_(layer.bias, 0) + elif isinstance(layer, nn.Linear): + nn.init.normal_(layer.weight, 0, 0.01) + nn.init.constant_(layer.bias, 0) + + def __str__(self): + """ + Model prints with number of trainable parameters + """ + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + params = sum([np.prod(p.size()) for p in model_parameters]) + return super().__str__() + '\nTrainable parameters: {}'.format(params) + + @abstractmethod + def forward(): + """ + forward function for the model + """ + msg = "forward functionhas not been implemeted." + raise NotImplementedError(msg) + + \ No newline at end of file diff --git a/base/base_trainer.py b/base/base_trainer.py new file mode 100644 index 0000000..008514e --- /dev/null +++ b/base/base_trainer.py @@ -0,0 +1,177 @@ +import torch +from abc import abstractmethod +from numpy import inf +from logger import TensorboardWriter +from utils import prepare_device +from utils import inf_loop + +class BaseTrainer: + """ + Base class for all trainers + """ + def __init__(self, metric_ftns, config, train_data_loader, valid_data_loader, + len_epoch): + self.config = config + self.logger = config.get_logger('trainer', config['trainer']['args']['verbosity']) + model_args = config['model']['args'] if 'model' in config else \ + config['student_model']['args'] + self.deg_flag = model_args['deg_flag'] + # prepare for (multi-device) GPU training + self.device, self.device_ids = prepare_device(config['n_gpu']) + self.metric_ftns = metric_ftns + self.train_data_loader = train_data_loader + if len_epoch is None: + # epoch-based training + self.len_epoch = len(self.train_data_loader) + else: + # iteration-based training + self.train_data_loader = inf_loop(train_data_loader) + self.len_epoch = len_epoch + self.valid_data_loader = valid_data_loader + self.do_validation = self.valid_data_loader is not None + cfg_trainer = config['trainer']['args'] + self.epochs = cfg_trainer['epochs'] + self.save_period = cfg_trainer['save_period'] + self.monitor = cfg_trainer.get('monitor', 'off') + + # configuration to monitor model performance and save best + if self.monitor == 'off': + self.mnt_mode = 'off' + self.mnt_best = 0 + else: + self.mnt_mode, self.mnt_metric = self.monitor.split() + assert self.mnt_mode in ['min', 'max'] + + self.mnt_best = inf if self.mnt_mode == 'min' else -inf + self.early_stop = cfg_trainer.get('early_stop', inf) + if self.early_stop <= 0: + self.early_stop = inf + + self.start_epoch = 1 + self.checkpoint_dir = config.save_dir + + # setup visualization writer instance + self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) + + if config.resume is not None: + self._resume_checkpoint(config) + + @abstractmethod + def _train_epoch(self, epoch): + """ + Training logic for an epoch + + :param epoch: Current epoch number + """ + raise NotImplementedError + + @abstractmethod + def _build_model(self, config): + """ + Build model from the configuration file + + :param config: config file + """ + raise NotImplementedError + + @abstractmethod + def _load_loss(self, config): + """ + Load loss from the configuration file + + :param config: config file + """ + raise NotImplementedError + + @abstractmethod + def _load_optimizer(self, model, config): + """ + Load optimizer from the configuration file + + :param config: config file + """ + raise NotImplementedError + + @abstractmethod + def _load_scheduler(self, optimizer, config): + """ + Load scheduler from the configuration file + + :param config: config file + """ + raise NotImplementedError + + @abstractmethod + def _save_checkpoint(self, epoch, save_best): + """ + Saving checkpoints + + :param epoch: current epoch number + :param log: logging information of the epoch + :param save_best: if True, rename the saved checkpoint to 'model_best.pth' + """ + raise NotImplementedError + + @abstractmethod + def _resume_checkpoint(self, resume_path): + """ + Resume from saved checkpoints + + :param resume_path: Checkpoint path to be resumed + """ + raise NotImplementedError + + def train(self): + """ + Full training logic + """ + not_improved_count = 0 + for epoch in range(self.start_epoch, self.epochs + 1): + result = self._train_epoch(epoch) + + # save logged informations into log dict + log = {'epoch': epoch} + log.update(result) + + # print logged informations to the screen + for key, value in log.items(): + self.logger.info(' {:15s}: {}'.format(str(key), value)) + + # evaluate model performance according to configured metric, save best checkpoint as model_best + best = False + if self.mnt_mode != 'off': + try: + # check whether model performance improved or not, according to specified metric(mnt_metric) + improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ + (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) + except KeyError: + self.logger.warning("Warning: Metric '{}' is not found. " + "Model performance monitoring is disabled.".format(self.mnt_metric)) + self.mnt_mode = 'off' + improved = False + + if improved: + self.mnt_best = log[self.mnt_metric] + not_improved_count = 0 + best = True + else: + not_improved_count += 1 + + if not_improved_count > self.early_stop: + self.logger.info("Validation performance didn\'t improve for {} epochs. " + "Training stops.".format(self.early_stop)) + break + + if epoch % self.save_period == 0: + self._save_checkpoint(epoch, save_best=best) + + def _progress(self, batch_idx): + base = '[{}/{} ({:.0f}%)]' + if hasattr(self.train_data_loader, 'n_samples'): + current = batch_idx * self.train_data_loader.batch_size + total = self.train_data_loader.n_samples + else: + current = batch_idx + total = self.len_epoch + return base.format(current, total, 100.0 * current / total) + diff --git a/configs/deg_all/CIFAR10/ResNet56-56_fusiondistill.yaml b/configs/deg_all/CIFAR10/ResNet56-56_fusiondistill.yaml new file mode 100644 index 0000000..6a99189 --- /dev/null +++ b/configs/deg_all/CIFAR10/ResNet56-56_fusiondistill.yaml @@ -0,0 +1,56 @@ +name: ResNet56-56_CIFAR10 +n_gpu: 1 +teacher_model: + type: ResNet56 + args: + deg_flag: clean + num_class: 10 + pretrained_path_jpeg: saved/jpeg/SLTrainer/ResNet56-56_CIFAR10/train//model_best.pth + pretrained_path_blur: saved/blur/SLTrainer/ResNet56-56_CIFAR10/train//model_best.pth + pretrained_path_noise: saved/noise/SLTrainer/ResNet56-56_CIFAR10/train//model_best.pth + pretrained_path_saltpepper: saved/saltpepper/SLTrainer/ResNet56-56_CIFAR10/train//model_best.pth +student_model: + type: ResNet56 + args: + deg_flag: deg + num_class: 10 + pretrained_path: ./saved/combined_deg/SLTrainer/ResNet56-56_CIFAR10_soups/train//model_best.pth +data_loader: + type: DegCIFAR10DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: combined_deg + cutout_method: Cutout + cutout_length: 16 + cutout_apply_clean: false + cutout_apply_deg: true +optimizer: + type: RAdam + args: + lr: 0.001 + weight_decay: 0.0001 +loss: + - supervised_loss: CE + - inheritance_loss: COS +loss_weights: [0.1, [0.0, 0.0, 1.0]] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [30, 70, 90] + gamma: 0.2 +trainer: + type: SLDA_Trainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/CIFAR10/ResNet56_base_scratch.yaml b/configs/deg_all/CIFAR10/ResNet56_base_scratch.yaml new file mode 100644 index 0000000..5e6a7d3 --- /dev/null +++ b/configs/deg_all/CIFAR10/ResNet56_base_scratch.yaml @@ -0,0 +1,46 @@ +name: ResNet56_CIFAR10_deg_scratch +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 10 +data_loader: + type: DegCIFAR10DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: jpeg + cutout_apply_clean: true + cutout_apply_deg: true + cutout_method: Cutout + cutout_length: 16 +optimizer: + type: SGD + args: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0001 +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [60, 140, 180] + gamma: 0.2 +trainer: + type: IndDATrainer + args: + epochs: 200 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 200 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/CIFAR10/ResNet56_base_vanilla.yaml b/configs/deg_all/CIFAR10/ResNet56_base_vanilla.yaml new file mode 100644 index 0000000..170c798 --- /dev/null +++ b/configs/deg_all/CIFAR10/ResNet56_base_vanilla.yaml @@ -0,0 +1,46 @@ +name: ResNet56_CIFAR10_deg_vanilla +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 10 + pretrained_path: ./saved/jpeg/IndTrainer/ResNet56_CIFAR10_clean/train//model_best.pth +data_loader: + type: DegCIFAR10DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: combined_deg + cutout_apply_clean: true + cutout_apply_deg: true + cutout_method: Cutout + cutout_length: 16 +optimizer: + type: RAdam + args: + lr: 0.001 + weight_decay: 0.0001 +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [60, 140, 180] + gamma: 0.2 +trainer: + type: IndDATrainer + args: + epochs: 200 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 200 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/CIFAR10/ResNet56_ensemble.yaml b/configs/deg_all/CIFAR10/ResNet56_ensemble.yaml new file mode 100644 index 0000000..77176df --- /dev/null +++ b/configs/deg_all/CIFAR10/ResNet56_ensemble.yaml @@ -0,0 +1,35 @@ +name: ResNet56_CIFAR10_deg_ensemble +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 10 + pretrained_path_jpeg: saved/jpeg/SLTrainer/ResNet56-56_CIFAR10/train//model_best.pth + pretrained_path_blur: saved/blur/SLTrainer/ResNet56-56_CIFAR10/train//model_best.pth + pretrained_path_noise: saved/noise/SLTrainer/ResNet56-56_CIFAR10/train//model_best.pth + pretrained_path_saltpepper: saved/saltpepper/SLTrainer/ResNet56-56_CIFAR10/train//model_best.pth +data_loader: + type: DegCIFAR10DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: jpeg +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy_classes +trainer: + type: IndDATrainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/CIFAR10/ResNet56_fused.yaml b/configs/deg_all/CIFAR10/ResNet56_fused.yaml new file mode 100644 index 0000000..6601af5 --- /dev/null +++ b/configs/deg_all/CIFAR10/ResNet56_fused.yaml @@ -0,0 +1,46 @@ +name: ResNet56_CIFAR10_deg_fused +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 10 + pretrained_path: ./saved/combined_deg/SLTrainer/ResNet56-56_CIFAR10_soups/train//model_best.pth +data_loader: + type: DegCIFAR10DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: combined_deg + cutout_apply_clean: true + cutout_apply_deg: true + cutout_method: Cutout + cutout_length: 16 +optimizer: + type: RAdam + args: + lr: 0.001 + weight_decay: 0.0001 +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [30, 70, 90] + gamma: 0.2 +trainer: + type: IndDATrainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/CIFAR10/ResNet56_soups.yaml b/configs/deg_all/CIFAR10/ResNet56_soups.yaml new file mode 100644 index 0000000..379a427 --- /dev/null +++ b/configs/deg_all/CIFAR10/ResNet56_soups.yaml @@ -0,0 +1,35 @@ +name: ResNet56_CIFAR10_deg +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 10 + pretrained_path_jpeg: saved/jpeg/SLTrainer/ResNet56-56_CIFAR10/train//model_best.pth + pretrained_path_blur: saved/blur/SLTrainer/ResNet56-56_CIFAR10/train//model_best.pth + pretrained_path_noise: saved/noise/SLTrainer/ResNet56-56_CIFAR10/train//model_best.pth + pretrained_path_saltpepper: saved/saltpepper/SLTrainer/ResNet56-56_CIFAR10/train//model_best.pth +data_loader: + type: DegCIFAR10DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: jpeg +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +trainer: + type: SLTrainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/CIFAR100/ResNet56-56_fusiondistill.yaml b/configs/deg_all/CIFAR100/ResNet56-56_fusiondistill.yaml new file mode 100644 index 0000000..419c9bb --- /dev/null +++ b/configs/deg_all/CIFAR100/ResNet56-56_fusiondistill.yaml @@ -0,0 +1,56 @@ +name: ResNet56-56_CIFAR100 +n_gpu: 1 +teacher_model: + type: ResNet56 + args: + deg_flag: clean + num_class: 100 + pretrained_path_jpeg: saved/jpeg/SLTrainer/ResNet56-56_CIFAR100/train//model_best.pth + pretrained_path_blur: saved/blur/SLTrainer/ResNet56-56_CIFAR100/train//model_best.pth + pretrained_path_noise: saved/noise/SLTrainer/ResNet56-56_CIFAR100/train//model_best.pth + pretrained_path_saltpepper: saved/saltpepper/SLTrainer/ResNet56-56_CIFAR100/train//model_best.pth +student_model: + type: ResNet56 + args: + deg_flag: deg + num_class: 100 + pretrained_path: ./saved/combined_deg/SLTrainer/ResNet56-56_CIFAR100_soups/train//model_best.pth # Model_soups: ILIAC - seed 1 +data_loader: + type: DegCIFAR100DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: combined_deg + cutout_method: Cutout + cutout_length: 16 + cutout_apply_clean: false + cutout_apply_deg: true +optimizer: + type: RAdam + args: + lr: 0.001 + weight_decay: 0.0001 +loss: + - supervised_loss: CE + - inheritance_loss: COS +loss_weights: [0.1, [0.0, 0.0, 1.0]] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [30, 70, 90] + gamma: 0.2 +trainer: + type: SLDA_Trainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/CIFAR100/ResNet56_base_scratch.yaml b/configs/deg_all/CIFAR100/ResNet56_base_scratch.yaml new file mode 100644 index 0000000..d65a16b --- /dev/null +++ b/configs/deg_all/CIFAR100/ResNet56_base_scratch.yaml @@ -0,0 +1,46 @@ +name: ResNet56_CIFAR100_deg_scratch +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 100 +data_loader: + type: DegCIFAR100DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: jpeg + cutout_apply_clean: true + cutout_apply_deg: true + cutout_method: Cutout + cutout_length: 16 +optimizer: + type: SGD + args: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0001 +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [60, 140, 180] + gamma: 0.2 +trainer: + type: IndDATrainer + args: + epochs: 200 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 200 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/CIFAR100/ResNet56_base_vanilla.yaml b/configs/deg_all/CIFAR100/ResNet56_base_vanilla.yaml new file mode 100644 index 0000000..eab8866 --- /dev/null +++ b/configs/deg_all/CIFAR100/ResNet56_base_vanilla.yaml @@ -0,0 +1,46 @@ +name: ResNet56_CIFAR100_deg_vanilla +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 100 + pretrained_path: ./saved/jpeg/IndTrainer/ResNet56_CIFAR100_clean/train//model_best.pth +data_loader: + type: DegCIFAR100DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: combined_deg + cutout_apply_clean: true + cutout_apply_deg: true + cutout_method: Cutout + cutout_length: 16 +optimizer: + type: RAdam + args: + lr: 0.001 + weight_decay: 0.0001 +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [60, 140, 180] + gamma: 0.2 +trainer: + type: IndDATrainer + args: + epochs: 200 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 200 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/CIFAR100/ResNet56_ensemble.yaml b/configs/deg_all/CIFAR100/ResNet56_ensemble.yaml new file mode 100644 index 0000000..cbb38b9 --- /dev/null +++ b/configs/deg_all/CIFAR100/ResNet56_ensemble.yaml @@ -0,0 +1,35 @@ +name: ResNet56_CIFAR100_deg_ensemble +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 100 + pretrained_path_jpeg: saved/jpeg/SLTrainer/ResNet56-56_CIFAR100/train//model_best.pth + pretrained_path_blur: saved/blur/SLTrainer/ResNet56-56_CIFAR100/train//model_best.pth + pretrained_path_noise: saved/noise/SLTrainer/ResNet56-56_CIFAR100/train//model_best.pth + pretrained_path_saltpepper: saved/saltpepper/SLTrainer/ResNet56-56_CIFAR100/train//model_best.pth +data_loader: + type: DegCIFAR100DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: jpeg +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy_classes +trainer: + type: IndDATrainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/CIFAR100/ResNet56_fused.yaml b/configs/deg_all/CIFAR100/ResNet56_fused.yaml new file mode 100644 index 0000000..fff1540 --- /dev/null +++ b/configs/deg_all/CIFAR100/ResNet56_fused.yaml @@ -0,0 +1,46 @@ +name: ResNet56_CIFAR100_deg_fused +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 100 + pretrained_path: ./saved/combined_deg/SLTrainer/ResNet56-56_CIFAR100_soups/train//model_best.pth # Model_soups: ILIAC - seed 2 +data_loader: + type: DegCIFAR100DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: combined_deg + cutout_apply_clean: true + cutout_apply_deg: true + cutout_method: Cutout + cutout_length: 16 +optimizer: + type: RAdam + args: + lr: 0.001 + weight_decay: 0.0001 +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [30, 70, 90] + gamma: 0.2 +trainer: + type: IndDATrainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/CIFAR100/ResNet56_soups.yaml b/configs/deg_all/CIFAR100/ResNet56_soups.yaml new file mode 100644 index 0000000..0ca749a --- /dev/null +++ b/configs/deg_all/CIFAR100/ResNet56_soups.yaml @@ -0,0 +1,35 @@ +name: ResNet56_CIFAR100_deg +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 100 + pretrained_path_jpeg: saved/jpeg/SLTrainer/ResNet56-56_CIFAR100/train//model_best.pth + pretrained_path_blur: saved/blur/SLTrainer/ResNet56-56_CIFAR100/train//model_best.pth + pretrained_path_noise: saved/noise/SLTrainer/ResNet56-56_CIFAR100/train//model_best.pth + pretrained_path_saltpepper: saved/saltpepper/SLTrainer/ResNet56-56_CIFAR100/train//model_best.pth +data_loader: + type: DegCIFAR100DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: jpeg +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +trainer: + type: SLTrainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/TinyImagenet/ResNet56-56_fusiondistill.yaml b/configs/deg_all/TinyImagenet/ResNet56-56_fusiondistill.yaml new file mode 100644 index 0000000..3b9d0d1 --- /dev/null +++ b/configs/deg_all/TinyImagenet/ResNet56-56_fusiondistill.yaml @@ -0,0 +1,56 @@ +name: ResNet56-56_TinyImagenet +n_gpu: 1 +teacher_model: + type: ResNet56 + args: + deg_flag: clean + num_class: 200 + pretrained_path_jpeg: saved/jpeg/SLTrainer/ResNet56-56_TinyImagenet/train//model_best.pth + pretrained_path_blur: saved/blur/SLTrainer/ResNet56-56_TinyImagenet/train//model_best.pth + pretrained_path_noise: saved/noise/SLTrainer/ResNet56-56_TinyImagenet/train//model_best.pth + pretrained_path_saltpepper: saved/saltpepper/SLTrainer/ResNet56-56_TinyImagenet/train//model_best.pth +student_model: + type: ResNet56 + args: + deg_flag: deg + num_class: 200 + pretrained_path: ./saved/combined_deg/SLTrainer/ResNet56-56_TinyImagenet_soups/train//model_best.pth # Model_soups: ILIAC - seed 1 +data_loader: + type: DegTinyImagenetDataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 1 + deg_type: combined_deg + cutout_method: Cutout + cutout_length: 16 + cutout_apply_clean: false + cutout_apply_deg: true +optimizer: + type: RAdam + args: + lr: 0.001 + weight_decay: 0.0001 +loss: + - supervised_loss: CE + - inheritance_loss: COS +loss_weights: [0.1, [0.0, 0.0, 1.0]] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [30, 70, 90] + gamma: 0.2 +trainer: + type: SLDA_Trainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/TinyImagenet/ResNet56_base_scratch.yaml b/configs/deg_all/TinyImagenet/ResNet56_base_scratch.yaml new file mode 100644 index 0000000..76b448e --- /dev/null +++ b/configs/deg_all/TinyImagenet/ResNet56_base_scratch.yaml @@ -0,0 +1,46 @@ +name: ResNet56_TinyImagenet_deg_scratch +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 200 +data_loader: + type: DegTinyImagenetDataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 1 + deg_type: jpeg + cutout_apply_clean: true + cutout_apply_deg: true + cutout_method: Cutout + cutout_length: 32 +optimizer: + type: SGD + args: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0001 +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [60, 140, 180] + gamma: 0.2 +trainer: + type: IndDATrainer + args: + epochs: 200 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 200 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/TinyImagenet/ResNet56_base_vanilla.yaml b/configs/deg_all/TinyImagenet/ResNet56_base_vanilla.yaml new file mode 100644 index 0000000..32426ea --- /dev/null +++ b/configs/deg_all/TinyImagenet/ResNet56_base_vanilla.yaml @@ -0,0 +1,46 @@ +name: ResNet56_TinyImagenet_deg_vanilla +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 200 + pretrained_path: ./saved/jpeg/IndTrainer/ResNet56_TinyImagenet_clean/train//model_best.pth +data_loader: + type: DegTinyImagenetDataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 1 + deg_type: combined_deg + cutout_apply_clean: true + cutout_apply_deg: true + cutout_method: Cutout + cutout_length: 32 +optimizer: + type: RAdam + args: + lr: 0.001 + weight_decay: 0.0001 +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [60, 140, 180] + gamma: 0.2 +trainer: + type: IndDATrainer + args: + epochs: 200 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 200 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/TinyImagenet/ResNet56_ensemble.yaml b/configs/deg_all/TinyImagenet/ResNet56_ensemble.yaml new file mode 100644 index 0000000..523bef8 --- /dev/null +++ b/configs/deg_all/TinyImagenet/ResNet56_ensemble.yaml @@ -0,0 +1,35 @@ +name: ResNet56_TinyImagenet_deg_ensemble +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 200 + pretrained_path_jpeg: saved/jpeg/SLTrainer/ResNet56-56_TinyImagenet/train//model_best.pth + pretrained_path_blur: saved/blur/SLTrainer/ResNet56-56_TinyImagenet/train//model_best.pth + pretrained_path_noise: saved/noise/SLTrainer/ResNet56-56_TinyImagenet/train//model_best.pth + pretrained_path_saltpepper: saved/saltpepper/SLTrainer/ResNet56-56_TinyImagenet/train//model_best.pth +data_loader: + type: DegTinyImagenetDataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: jpeg +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy_classes +trainer: + type: IndDATrainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/TinyImagenet/ResNet56_fused.yaml b/configs/deg_all/TinyImagenet/ResNet56_fused.yaml new file mode 100644 index 0000000..126154c --- /dev/null +++ b/configs/deg_all/TinyImagenet/ResNet56_fused.yaml @@ -0,0 +1,46 @@ +name: ResNet56_TinyImagenet_deg_fused +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 200 + pretrained_path: ./saved/combined_deg/SLTrainer/ResNet56-56_TinyImagenet_soups/train//model_best.pth # Model_soups: ILIAC - seed 2 +data_loader: + type: DegTinyImagenetDataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 1 + deg_type: combined_deg + cutout_apply_clean: true + cutout_apply_deg: true + cutout_method: Cutout + cutout_length: 16 +optimizer: + type: RAdam + args: + lr: 0.001 + weight_decay: 0.0001 +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [30, 70, 90] + gamma: 0.2 +trainer: + type: IndDATrainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/deg_all/TinyImagenet/ResNet56_soups.yaml b/configs/deg_all/TinyImagenet/ResNet56_soups.yaml new file mode 100644 index 0000000..d20d718 --- /dev/null +++ b/configs/deg_all/TinyImagenet/ResNet56_soups.yaml @@ -0,0 +1,35 @@ +name: ResNet56_TinyImagenet_deg +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: deg + num_class: 200 + pretrained_path_jpeg: saved/jpeg/SLTrainer/ResNet56-56_TinyImagenet/train//model_best.pth + pretrained_path_blur: saved/blur/SLTrainer/ResNet56-56_TinyImagenet/train//model_best.pth + pretrained_path_noise: saved/noise/SLTrainer/ResNet56-56_TinyImagenet/train//model_best.pth + pretrained_path_saltpepper: saved/saltpepper/SLTrainer/ResNet56-56_TinyImagenet/train//model_best.pth +data_loader: + type: DegTinyImagenetDataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: jpeg +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +trainer: + type: SLTrainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/ind/CIFAR10/ResNet56_clean.yaml b/configs/ind/CIFAR10/ResNet56_clean.yaml new file mode 100644 index 0000000..a73a7f4 --- /dev/null +++ b/configs/ind/CIFAR10/ResNet56_clean.yaml @@ -0,0 +1,45 @@ +name: ResNet56_CIFAR10_clean +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: clean + num_class: 10 +data_loader: + type: DegCIFAR10DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: jpeg + cutout_method: Cutout + cutout_length: 16 + cutout_apply_clean: True +optimizer: + type: SGD + args: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0001 +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [60, 120, 160] + gamma: 0.2 +trainer: + type: IndTrainer + args: + epochs: 200 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_clean + early_stop: 200 + tensorboard: true \ No newline at end of file diff --git a/configs/ind/CIFAR100/ResNet56_clean.yaml b/configs/ind/CIFAR100/ResNet56_clean.yaml new file mode 100644 index 0000000..1041cef --- /dev/null +++ b/configs/ind/CIFAR100/ResNet56_clean.yaml @@ -0,0 +1,45 @@ +name: ResNet56_CIFAR100_clean +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: clean + num_class: 100 +data_loader: + type: DegCIFAR100DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 3 + deg_type: jpeg + cutout_method: Cutout + cutout_length: 16 + cutout_apply_clean: True +optimizer: + type: SGD + args: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0005 +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [60, 120, 160] + gamma: 0.2 +trainer: + type: IndTrainer + args: + epochs: 200 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_clean + early_stop: 200 + tensorboard: true \ No newline at end of file diff --git a/configs/ind/TinyImagenet/ResNet56_clean.yaml b/configs/ind/TinyImagenet/ResNet56_clean.yaml new file mode 100644 index 0000000..9731f88 --- /dev/null +++ b/configs/ind/TinyImagenet/ResNet56_clean.yaml @@ -0,0 +1,45 @@ +name: ResNet56_TinyImagenet_clean +n_gpu: 1 +model: + type: ResNet56 + args: + deg_flag: clean + num_class: 200 +data_loader: + type: DegTinyImagenetDataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 5 + deg_type: jpeg + cutout_method: Cutout + cutout_length: 32 + cutout_apply_clean: True +optimizer: + type: SGD + args: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0005 +loss: + - supervised_loss: CE +loss_weights: [1.0] +metrics: +- accuracy +lr_scheduler: + type: MultiStepLR + args: + milestones: [60, 120, 160] + gamma: 0.2 +trainer: + type: IndTrainer + args: + epochs: 200 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_clean + early_stop: 200 + tensorboard: true \ No newline at end of file diff --git a/configs/sl/CIFAR10/ResNet56-56.yaml b/configs/sl/CIFAR10/ResNet56-56.yaml new file mode 100644 index 0000000..d31be68 --- /dev/null +++ b/configs/sl/CIFAR10/ResNet56-56.yaml @@ -0,0 +1,53 @@ +name: ResNet56-56_CIFAR10 +n_gpu: 1 +# random_seed: 0 +teacher_model: + type: ResNet56 + args: + deg_flag: clean + num_class: 10 + pretrained_path: ./saved/jpeg/IndTrainer/ResNet56_CIFAR10_clean/train//model_best.pth +student_model: + type: ResNet56 + args: + deg_flag: deg + num_class: 10 + pretrained_path: ./saved/jpeg/IndTrainer/ResNet56_CIFAR10_clean/train//model_best.pth +data_loader: + type: DegCIFAR10DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: jpeg + cutout_method: Cutout + cutout_length: 16 + cutout_apply_clean: false + cutout_apply_deg: true +optimizer: + type: RAdam + args: + lr: 0.001 + weight_decay: 0.0001 +loss: + - supervised_loss: CE + - inheritance_loss: COS +loss_weights: [1.0, [0.0, 0.0, 1.0]] +metrics: +- accuracy +lr_scheduler: + type: CosineAnnealingLR + args: + T_max: 100 +trainer: + type: SLTrainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/sl/CIFAR100/ResNet56-56.yaml b/configs/sl/CIFAR100/ResNet56-56.yaml new file mode 100644 index 0000000..81c2e60 --- /dev/null +++ b/configs/sl/CIFAR100/ResNet56-56.yaml @@ -0,0 +1,52 @@ +name: ResNet56-56_CIFAR100 +n_gpu: 1 +teacher_model: + type: ResNet56 + args: + deg_flag: clean + num_class: 100 + pretrained_path: ./saved/jpeg/IndTrainer/ResNet56_CIFAR100_clean/train//model_best.pth +student_model: + type: ResNet56 + args: + deg_flag: deg + num_class: 100 + pretrained_path: ./saved/jpeg/IndTrainer/ResNet56_CIFAR100_clean/train//model_best.pth +data_loader: + type: DegCIFAR100DataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: jpeg + cutout_method: Cutout + cutout_length: 16 + cutout_apply_clean: false + cutout_apply_deg: true +optimizer: + type: RAdam + args: + lr: 0.001 + weight_decay: 0.0001 +loss: + - supervised_loss: CE + - inheritance_loss: COS +loss_weights: [1.0, [0.0, 0.0, 1.0]] +metrics: +- accuracy +lr_scheduler: + type: CosineAnnealingLR + args: + T_max: 100 +trainer: + type: SLTrainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/configs/sl/TinyImagenet/ResNet56-56.yaml b/configs/sl/TinyImagenet/ResNet56-56.yaml new file mode 100644 index 0000000..fdf9fff --- /dev/null +++ b/configs/sl/TinyImagenet/ResNet56-56.yaml @@ -0,0 +1,52 @@ +name: ResNet56-56_TinyImagenet +n_gpu: 1 +teacher_model: + type: ResNet56 + args: + deg_flag: clean + num_class: 200 + pretrained_path: ./saved/jpeg/IndTrainer/ResNet56_TinyImagenet_clean/train//model_best.pth +student_model: + type: ResNet56 + args: + deg_flag: deg + num_class: 200 + pretrained_path: ./saved/jpeg/IndTrainer/ResNet56_TinyImagenet_clean/train//model_best.pth +data_loader: + type: DegTinyImagenetDataLoader + args: + data_dir: data/ + batch_size: 128 + shuffle: true + validation_split: 0.0 + num_workers: 2 + deg_type: jpeg + cutout_method: Cutout + cutout_length: 32 + cutout_apply_clean: false + cutout_apply_deg: true +optimizer: + type: RAdam + args: + lr: 0.001 + weight_decay: 0.0001 +loss: + - supervised_loss: CE + - inheritance_loss: COS +loss_weights: [1.0, [0.0, 0.0, 1.0]] +metrics: +- accuracy +lr_scheduler: + type: CosineAnnealingLR + args: + T_max: 100 +trainer: + type: SLTrainer + args: + epochs: 100 + save_dir: saved/ + save_period: 1 + verbosity: 2 + monitor: max val_accuracy_deg + early_stop: 100 + tensorboard: true \ No newline at end of file diff --git a/data_loader/data_loaders.py b/data_loader/data_loaders.py new file mode 100644 index 0000000..27ad594 --- /dev/null +++ b/data_loader/data_loaders.py @@ -0,0 +1,117 @@ +from torchvision import transforms +from base import BaseDataLoader +import argparse +from parse_config import ConfigParser +import data_loader.data_loaders as module_data +from utils.data.datasets import DegCIFAR10Dataset, DegCIFAR100Dataset, DegTinyImagenetDataset + +class DegCIFAR10DataLoader(BaseDataLoader): + """ + CIFAR10 data loader + """ + def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, train=True, + deg_type = 'jpeg', deg_range = None, is_to_tensor = True, is_normalized = True, + transform = None, teacher_transform = None, student_transform = None, train_init_transform = None, + cutout_method = None, cutout_length = None, cutout_apply_clean = True, cutout_apply_deg = True, + cutout_independent = False): + self.data_dir = data_dir + self.cutout_method = cutout_method + if train: + train_init_transform = transforms.Compose([transforms.RandomHorizontalFlip()]) + if is_to_tensor: + if is_normalized: + normalize = transforms.Normalize(mean = (125.3/255.0, 123.0/255.0, 113.9/255.0), + std = (63.0/255.0, 62.1/255.0, 66.7/255.0)) + self.deg_to_tensor = transforms.Compose([transforms.ToTensor(), normalize]) + else: + self.deg_to_tensor = transforms.Compose([transforms.ToTensor()]) + + self.dataset = DegCIFAR10Dataset(data_dir, train, train_init_transform, teacher_transform, student_transform, + deg_type = deg_type, deg_range = deg_range, is_to_tensor = is_to_tensor, + deg_to_tensor = self.deg_to_tensor, cutout_method = cutout_method, + cutout_length = cutout_length, cutout_apply_clean = cutout_apply_clean, + cutout_apply_deg = cutout_apply_deg, cutout_independent = cutout_independent) + super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) + +class DegCIFAR100DataLoader(BaseDataLoader): + """ + CIFAR100 data loader + """ + def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, train=True, + deg_type = 'jpeg', deg_range = None, is_to_tensor = True, is_normalized = True, + transform = None, teacher_transform = None, student_transform = None, train_init_transform = None, + cutout_method = None, cutout_length = None, cutout_apply_clean = True, cutout_apply_deg = True, + cutout_independent = False): + self.data_dir = data_dir + self.cutout_method = cutout_method + if train: + train_init_transform = transforms.Compose([transforms.RandomHorizontalFlip()]) + if is_to_tensor: + if is_normalized: + normalize = transforms.Normalize(mean = [129.3/255.0, 124.1/255.0, 112.4/255.0], + std = [68.2/255.0, 65.4/255.0, 70.4/255.0]) + self.deg_to_tensor = transforms.Compose([transforms.ToTensor(), normalize]) + else: + self.deg_to_tensor = transforms.Compose([transforms.ToTensor()]) + + self.dataset = DegCIFAR100Dataset(data_dir, train, train_init_transform, teacher_transform, student_transform, + deg_type = deg_type, deg_range = deg_range, is_to_tensor = is_to_tensor, + deg_to_tensor = self.deg_to_tensor, cutout_method = cutout_method, + cutout_length = cutout_length, cutout_apply_clean = cutout_apply_clean, + cutout_apply_deg = cutout_apply_deg, cutout_independent = cutout_independent) + super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) + +class DegTinyImagenetDataLoader(BaseDataLoader): + """ + Tiny ImageNet data loader + """ + def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, train=True, + deg_type = 'jpeg', deg_range = None, is_to_tensor = True, is_normalized = True, + transform = None, teacher_transform = None, student_transform = None, train_init_transform = None, + cutout_method = None, cutout_length = None, cutout_apply_clean = True, cutout_apply_deg = True, + cutout_independent = False): + self.data_dir = data_dir + self.cutout_method = cutout_method + if train: + train_init_transform = transforms.Compose([transforms.RandomHorizontalFlip()]) + if is_to_tensor: + if is_normalized: + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + self.deg_to_tensor = transforms.Compose([transforms.ToTensor(), normalize]) + else: + self.deg_to_tensor = transforms.Compose([transforms.ToTensor()]) + + self.dataset = DegTinyImagenetDataset(data_dir, train, train_init_transform, teacher_transform, student_transform, + deg_type = deg_type, deg_range = deg_range, is_to_tensor = is_to_tensor, + deg_to_tensor = self.deg_to_tensor, cutout_method = cutout_method, + cutout_length = cutout_length, cutout_apply_clean = cutout_apply_clean, + cutout_apply_deg = cutout_apply_deg, cutout_independent = cutout_independent) + super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) + +if __name__ == '__main__': + args = argparse.ArgumentParser(description='Testing KD data loaders') + args.add_argument('-c', '--config', default=None, type=str, + help='config file path (default: None)') + args.add_argument('-r', '--resume', default=None, type=str, + help='path to latest checkpoint (default: None)') + args.add_argument('-d', '--device', default=None, type=str, + help='indices of GPUs to enable (default: all)') + + config = ConfigParser.from_args(args) + # setup data_loader instances + data_loader = config.init_obj('data_loader', module_data) + + # (image_clean, image_deg), targets = data_loader.dataset.__getitem__(index=5) + # print("Save images into test_imgs from:") + # print(image_clean.cpu().detach().numpy().shape) + # img = image_clean.cpu().detach().numpy().transpose(1,2,0) + # img = (img - np.min(img)) / (np.max(img) - np.min(img)) + # plt.imsave('./test_img.png', img) + # print('targets: ', targets) + + for batch_idx, (images, targets) in enumerate(data_loader): + (image_clean, image_deg) = images + (labels, levels) = targets + print(labels) + exit() diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..faf9586 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,33 @@ +FROM nvidia/cuda:11.2.0-cudnn8-devel-ubuntu20.04 +MAINTAINER Dinesh Daultani + +# Temp fix for current nvidia key issue +COPY ./cuda-keyring_1.0-1_all.deb cuda-keyring_1.0-1_all.deb +RUN rm /etc/apt/sources.list.d/cuda.list \ + && rm /etc/apt/sources.list.d/nvidia-ml.list \ + && dpkg -i cuda-keyring_1.0-1_all.deb + +ARG DEBIAN_FRONTEND=noninteractive +RUN apt-get update && \ + apt-get install -y wget vim git zip && \ + apt-get install -y sudo software-properties-common systemd-sysv + +# Installing Anaconda +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + /bin/bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda && \ + rm Miniconda3-latest-Linux-x86_64.sh + +ENV PATH /opt/conda/bin:$PATH +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/extras/CUPTI/lib64 + +# Install libraries +RUN conda create -n py38 python=3.8 +SHELL ["/bin/bash", "-c"] +RUN source activate py38 && \ + echo $(python -V) && \ + conda install -y scikit-learn && \ + conda install -y pytorch==1.12.0 torchvision==0.13.0 -c pytorch + +COPY requirements.txt /tmp/ +RUN source activate py38 && \ + pip install -r /tmp/requirements.txt \ No newline at end of file diff --git a/docker/requirements.txt b/docker/requirements.txt new file mode 100644 index 0000000..7d44322 --- /dev/null +++ b/docker/requirements.txt @@ -0,0 +1,19 @@ +numpy +tqdm +tensorboard>=1.14 +sklearn +h5py==2.10.0 +pandas +matplotlib +argparse +logger +pyyaml +imagedegrade +torch-tb-profiler +ipykernel +imageio +torchinfo +openpyxl +seaborn +hiddenlayer +ptflops \ No newline at end of file diff --git a/figures/proposed-arch-diagram.png b/figures/proposed-arch-diagram.png new file mode 100644 index 0000000..d5550df Binary files /dev/null and b/figures/proposed-arch-diagram.png differ diff --git a/logger/__init__.py b/logger/__init__.py new file mode 100644 index 0000000..5f3763b --- /dev/null +++ b/logger/__init__.py @@ -0,0 +1,2 @@ +from .logger import * +from .visualization import * \ No newline at end of file diff --git a/logger/logger.py b/logger/logger.py new file mode 100644 index 0000000..4599fb0 --- /dev/null +++ b/logger/logger.py @@ -0,0 +1,22 @@ +import logging +import logging.config +from pathlib import Path +from utils import read_json + + +def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): + """ + Setup logging configuration + """ + log_config = Path(log_config) + if log_config.is_file(): + config = read_json(log_config) + # modify logging paths based on run config + for _, handler in config['handlers'].items(): + if 'filename' in handler: + handler['filename'] = str(save_dir / handler['filename']) + + logging.config.dictConfig(config) + else: + print("Warning: logging configuration file is not found in {}.".format(log_config)) + logging.basicConfig(level=default_level) diff --git a/logger/logger_config.json b/logger/logger_config.json new file mode 100644 index 0000000..c3e7e02 --- /dev/null +++ b/logger/logger_config.json @@ -0,0 +1,32 @@ + +{ + "version": 1, + "disable_existing_loggers": false, + "formatters": { + "simple": {"format": "%(message)s"}, + "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "simple", + "stream": "ext://sys.stdout" + }, + "info_file_handler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "INFO", + "formatter": "datetime", + "filename": "info.log", + "maxBytes": 10485760, + "backupCount": 20, "encoding": "utf8" + } + }, + "root": { + "level": "INFO", + "handlers": [ + "console", + "info_file_handler" + ] + } +} \ No newline at end of file diff --git a/logger/visualization.py b/logger/visualization.py new file mode 100644 index 0000000..34ef64f --- /dev/null +++ b/logger/visualization.py @@ -0,0 +1,73 @@ +import importlib +from datetime import datetime + + +class TensorboardWriter(): + def __init__(self, log_dir, logger, enabled): + self.writer = None + self.selected_module = "" + + if enabled: + log_dir = str(log_dir) + + # Retrieve vizualization writer. + succeeded = False + for module in ["torch.utils.tensorboard", "tensorboardX"]: + try: + self.writer = importlib.import_module(module).SummaryWriter(log_dir) + succeeded = True + break + except ImportError: + succeeded = False + self.selected_module = module + + if not succeeded: + message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ + "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \ + "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." + logger.warning(message) + + self.step = 0 + self.mode = '' + + self.tb_writer_ftns = { + 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', + 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' + } + self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} + self.timer = datetime.now() + + def set_step(self, step, mode='train'): + self.mode = mode + self.step = step + if step == 0: + self.timer = datetime.now() + else: + duration = datetime.now() - self.timer + self.add_scalar('steps_per_sec', 1 / duration.total_seconds()) + self.timer = datetime.now() + + def __getattr__(self, name): + """ + If visualization is configured to use: + return add_data() methods of tensorboard with additional information (step, tag) added. + Otherwise: + return a blank function handle that does nothing + """ + if name in self.tb_writer_ftns: + add_data = getattr(self.writer, name, None) + + def wrapper(tag, data, *args, **kwargs): + if add_data is not None: + # add mode(train/valid) tag + if name not in self.tag_mode_exceptions: + tag = '{}/{}'.format(tag, self.mode) + add_data(tag, data, self.step, *args, **kwargs) + return wrapper + else: + # default action for returning methods defined in this class, set_step() for instance. + try: + attr = object.__getattr__(name) + except AttributeError: + raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) + return attr diff --git a/losses/__init__.py b/losses/__init__.py new file mode 100644 index 0000000..1a9238a --- /dev/null +++ b/losses/__init__.py @@ -0,0 +1 @@ +from .cosine_embedding import COS \ No newline at end of file diff --git a/losses/cosine_embedding.py b/losses/cosine_embedding.py new file mode 100644 index 0000000..d4b0a01 --- /dev/null +++ b/losses/cosine_embedding.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class COS(nn.Module): + ''' + Cosine Embedding loss + ''' + def __init__(self): + super(COS, self).__init__() + + def forward(self, fm_s, fm_t, dum): + loss = F.cosine_embedding_loss(fm_s.view(fm_s.size(0), -1), + fm_t.view(fm_t.size(0), -1), dum) + return loss \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..4d07374 --- /dev/null +++ b/model/__init__.py @@ -0,0 +1 @@ +from model.backbones.resnet import ResNet20, ResNet56, ResNet110 \ No newline at end of file diff --git a/model/backbones/resnet.py b/model/backbones/resnet.py new file mode 100644 index 0000000..388af10 --- /dev/null +++ b/model/backbones/resnet.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from base.base_model import BaseModel + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels, return_before_act): + super(ResBlock, self).__init__() + self.return_before_act = return_before_act + self.downsample = (in_channels != out_channels) + if self.downsample: + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False) + self.ds = nn.Sequential(*[ + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False), + nn.BatchNorm2d(out_channels) + ]) + else: + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.ds = None + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_channels) + + def forward(self, x): + residual = x + + pout = self.conv1(x) # pout: pre out before activation + pout = self.bn1(pout) + pout = self.relu(pout) + + pout = self.conv2(pout) + pout = self.bn2(pout) + + if self.downsample: + residual = self.ds(x) + + pout += residual + out = self.relu(pout) + + if not self.return_before_act: + return out + else: + return pout, out + + +class ResNet_simple(BaseModel): + def __init__(self, block, num_blocks, num_class = 10, init_weights = True, deg_flag = None, fa = True): + super(ResNet_simple, self).__init__(deg_flag) + self.block = block + self.num_blocks = num_blocks + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(16) + self.relu = nn.ReLU() + + self.res1 = self.make_layer(self.block, self.num_blocks[0], 16, 16) + self.res2 = self.make_layer(self.block, self.num_blocks[1], 16, 32) + self.res3 = self.make_layer(self.block, self.num_blocks[2], 32, 64) + + self.avgpool = nn.AdaptiveAvgPool2d(2) + self.fc = nn.Linear(256, num_class) + + if init_weights: + self._init_weight_layers(self) + + self.num_class = num_class + self.fa = fa + + def make_layer(self, block, num, in_channels, out_channels): # num must >=2 + layers = [block(in_channels, out_channels, False)] + for i in range(num-2): + layers.append(block(out_channels, out_channels, False)) + layers.append(block(out_channels, out_channels, True)) + return nn.Sequential(*layers) + + def forward(self, *x): + x = self.define_input(*x) + pstem = self.conv1(x) # pstem: pre stem before activation + pstem = self.bn1(pstem) + stem = self.relu(pstem) + stem = (pstem, stem) + + rb1 = self.res1(stem[1]) + rb2 = self.res2(rb1[1]) + rb3 = self.res3(rb2[1]) + + feat = self.avgpool(rb3[1]) + feat = feat.view(feat.size(0), -1) + out = self.fc(feat) + + return stem, rb1, rb2, rb3, feat, out + +def ResNet20(**args): + return ResNet_simple(ResBlock, [3,3,3], **args) + +def ResNet56(**args): + return ResNet_simple(ResBlock, [9,9,9], **args) + +def ResNet110(**args): + return ResNet_simple(ResBlock, [18,18,18], **args) diff --git a/model/loss.py b/model/loss.py new file mode 100644 index 0000000..96e6654 --- /dev/null +++ b/model/loss.py @@ -0,0 +1,16 @@ +import torch.nn as nn +from losses import * + +def supervised_loss(method): + if method == 'CE': + loss = nn.CrossEntropyLoss() + else: + raise NotImplementedError + return loss + +def inheritance_loss(method): + if method == 'COS': + loss = COS() + else: + raise NotImplementedError + return loss diff --git a/model/metric.py b/model/metric.py new file mode 100644 index 0000000..448515a --- /dev/null +++ b/model/metric.py @@ -0,0 +1,22 @@ +import torch + +def accuracy(output, target): + """ + This function is used to calculate accuracy + """ + with torch.no_grad(): + pred = torch.argmax(output, dim=1) + assert pred.shape[0] == len(target) + correct = 0 + correct += torch.sum(pred == target).item() + return correct / len(target) + +def accuracy_classes(pred, target): + """ + This function is used to calculate accuracy specifically for ensemble model + """ + with torch.no_grad(): + assert pred.shape[0] == len(target) + correct = 0 + correct += torch.sum(pred == target).item() + return correct / len(target) diff --git a/parse_config.py b/parse_config.py new file mode 100644 index 0000000..88ad528 --- /dev/null +++ b/parse_config.py @@ -0,0 +1,192 @@ +import os +import logging +from pathlib import Path +from functools import reduce, partial +from operator import getitem +from datetime import datetime +from logger import setup_logging +from utils import read_yaml, write_yaml, print_dict +from utils import import_class + +class ConfigParser: + def __init__(self, config, resume=None, modification=None, mode = 'train', run_id=None, dry_run=False): + """ + class to parse configuration yaml file. Handles hyperparameters for training, initializations of modules, checkpoint saving + and logging module. + :param config: Dict containing configurations, hyperparameters for training. contents of `config.yaml` file for example. + :param resume: String, path to the checkpoint being loaded. + :param modification: Dict keychain:value, specifying position values to be replaced from config dict. + :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default + """ + # load config file and apply modification + self._config = _update_config(config, modification) + self.resume = resume + + # set save_dir where trained model and log will be saved. + save_dir = Path(self.config['trainer']['args']['save_dir']) + + exper_name = self.config['name'] + trainer = self.config['trainer']['type'] + deg_type = self.config['data_loader']['args']['deg_type'] + if run_id is None: # use timestamp as default run-id + run_id = datetime.now().strftime(r'%m%d_%H%M%S') + self._save_dir = save_dir / deg_type / trainer / exper_name / mode / run_id + self._log_dir = self._save_dir + + if not dry_run: + # make directory for saving checkpoints and log. + exist_ok = run_id == '' + self.save_dir.mkdir(parents=True, exist_ok=exist_ok) + + # save updated config file to the checkpoint dir + write_yaml(self.config, self.save_dir / 'config.yaml') + + # configure logging module + setup_logging(self.log_dir) + self.log_levels = { + 0: logging.WARNING, + 1: logging.INFO, + 2: logging.DEBUG + } + + @classmethod + def from_args(cls, args, options=''): + """ + Initialize this class from some cli arguments. Used in train, test. + """ + for opt in options: + args.add_argument(*opt.flags, default=None, type=opt.type) + if not isinstance(args, tuple): + args = args.parse_args() + + if args.device is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = args.device + if args.resume is not None: + resume = Path(args.resume) + cfg_fname = resume.parent / 'config.yaml' + else: + msg_no_cfg = "Configuration file need to be specified. Add '-c config.yaml', for example." + assert args.config is not None, msg_no_cfg + resume = None + cfg_fname = Path(args.config) + + config = read_yaml(cfg_fname) + if args.config and resume: + # update new config for fine-tuning + config.update(read_yaml(args.config)) + + # parse custom cli options into dictionary + modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options} + + mode = 'train' if not hasattr(args, 'mode') else args.mode + return cls(config, resume, modification, mode) + + def init_obj(self, name, module, *args, **kwargs): + """ + Finds a function handle with the name given as 'type' in config, and returns the + instance initialized with corresponding arguments given. + + `object = config.init_obj('name', module, a, b=1)` + is equivalent to + `object = module.name(a, b=1)` + """ + module_name = self[name]['type'] + module_args = dict(self[name]['args']) + assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' + module_args.update(kwargs) + return getattr(module, module_name)(*args, **module_args) + + def get_class(self, name, init = True, _class = None): + """ + Finds a function handle with the name given as 'type' in config, and returns the + class with the corresponding arguments given. + + `Class = config.init_class('name', module, a, b=1)` + is equivalent to + `Class = name.module(a, b=1)` + This function can import model and loss classes + """ + module_name = self[name]['type'] + module_args = dict(self[name]['args']) + if _class is None: + Module = '.'.join([name, module_name]) + else: + Module = '.'.join([_class, module_name]) + Module = import_class(Module) + if init: + Module = Module(**module_args) + return Module + + def init_ftn(self, name, module, *args, **kwargs): + """ + Finds a function handle with the name given as 'type' in config, and returns the + function with given arguments fixed with functools.partial. + + `function = config.init_ftn('name', module, a, b=1)` + is equivalent to + `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. + """ + module_name = self[name]['type'] + module_args = dict(self[name]['args']) + assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' + module_args.update(kwargs) + return partial(getattr(module, module_name), *args, **module_args) + + def __getitem__(self, name): + """Access items like ordinary dict.""" + return self.config[name] + + def __contains__(self, name): + """Check item exists like ordinary dict.""" + if name in self.config: + return True + else: + return False + + def __str__(self): + return print_dict(self.config) + + def get_logger(self, name, verbosity=2): + msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys()) + assert verbosity in self.log_levels, msg_verbosity + logger = logging.getLogger(name) + logger.setLevel(self.log_levels[verbosity]) + return logger + + # setting read-only attributes + @property + def config(self): + return self._config + + @property + def save_dir(self): + return self._save_dir + + @property + def log_dir(self): + return self._log_dir + +# helper functions to update config dict with custom cli options +def _update_config(config, modification): + if modification is None: + return config + + for k, v in modification.items(): + if v is not None: + _set_by_path(config, k, v) + return config + +def _get_opt_name(flags): + for flg in flags: + if flg.startswith('--'): + return flg.replace('--', '') + return flags[0].replace('--', '') + +def _set_by_path(tree, keys, value): + """Set a value in a nested object in tree by sequence of keys.""" + keys = keys.split(';') + _get_by_path(tree, keys[:-1])[keys[-1]] = value + +def _get_by_path(tree, keys): + """Access a nested object in tree by sequence of keys.""" + return reduce(getitem, keys, tree) diff --git a/test.py b/test.py new file mode 100644 index 0000000..def444c --- /dev/null +++ b/test.py @@ -0,0 +1,99 @@ +import argparse +import collections +import torch +import data_loader.data_loaders as module_data +import model.metric as module_metric +from parse_config import ConfigParser +from utils.data import degradedimagedata as deg_data +from logger import TensorboardWriter +from utils.util import set_seeds +from utils import prepare_device + +# fix random seeds for reproducibility +set_seeds() + +def main(config): + logger = config.get_logger('test') + logger.info(config) + device, device_ids = prepare_device(config['n_gpu']) + + writer = TensorboardWriter(config.log_dir, logger, + config['trainer']['args']['tensorboard']) + deg_range = deg_data.get_type_range(config['data_loader']['args']['deg_type']) + + # build model architecture + if 'model' in config: + model = config.get_class('model') + else: + model = config.get_class('student_model', _class = 'model') + logger.info(model) + + metric_fns = [getattr(module_metric, met) for met in config['metrics']] + + logger.info('Loading checkpoint: {} ...'.format(config.resume)) + model = model.to(device) + + if len(device_ids) > 1: + model = torch.nn.DataParallel(model, device_ids=device_ids) + checkpoint = torch.load(config.resume) + state_dict = checkpoint['state_dict'] + model.load_state_dict(state_dict) + model.eval() + + for lev in range(deg_range[0],deg_range[1]+1): + # setup data_loader instances + data_loader = getattr(module_data, config['data_loader']['type'])( + config['data_loader']['args']['data_dir'], + batch_size=100, + shuffle=False, + validation_split=0.0, + num_workers=2, + train=False, + deg_type = config['data_loader']['args']['deg_type'], + deg_range = [lev, lev] + ) + total_loss = 0.0 + total_metrics = torch.zeros(len(metric_fns)) + + with torch.no_grad(): + for i, (images, targets) in enumerate(data_loader): + (image_clean, image_deg) = images + (labels, _) = targets + image_clean = image_clean.to(device) + image_deg = image_deg.to(device) + target = labels.to(device) + + _, _, _, _, feat, output = model(image_deg, image_deg) + + batch_size = image_clean.shape[0] + for i, metric in enumerate(metric_fns): + total_metrics[i] += metric(output, target) * batch_size + + n_samples = len(data_loader.sampler) + log = {'deg_level': lev} + log.update({ + met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns) + }) + writer.set_step(lev, mode = 'eval') + for met, val in log.items(): + writer.add_scalar(met, val) + logger.info(log) + + +if __name__ == '__main__': + args = argparse.ArgumentParser(description='Degraded Image Classification - KD') + args.add_argument('-c', '--config', default=None, type=str, + help='config file path (default: None)') + args.add_argument('-r', '--resume', default=None, type=str, + help='path to latest checkpoint (default: None)') + args.add_argument('-d', '--device', default=None, type=str, + help='indices of GPUs to enable (default: all)') + args.add_argument('-m', '--mode', default='eval', type=str, + help='Activate eval mode for config') + # custom cli options to modify configuration from default values given in json file. + CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') + options = [ + CustomArgs(['--dt', '--deg_type'], type=str, target='data_loader;args;deg_type') + ] + config = ConfigParser.from_args(args, options) + main(config) diff --git a/test_ensemble.py b/test_ensemble.py new file mode 100644 index 0000000..340d8bd --- /dev/null +++ b/test_ensemble.py @@ -0,0 +1,141 @@ +import argparse +import collections +import torch +import data_loader.data_loaders as module_data +import model.metric as module_metric +from parse_config import ConfigParser +from utils.data import degradedimagedata as deg_data +from logger import TensorboardWriter +from utils.util import set_seeds +from utils import prepare_device +import copy + +# fix random seeds for reproducibility +set_seeds() + +def main(config): + logger = config.get_logger('test') + logger.info(config) + device, device_ids = prepare_device(config['n_gpu']) + + writer = TensorboardWriter(config.log_dir, logger, + config['trainer']['args']['tensorboard']) + deg_range = deg_data.get_type_range(config['data_loader']['args']['deg_type']) + + # build model architecture + if 'model' in config: + model = config.get_class('model') + else: + model = config.get_class('student_model', _class = 'model') + logger.info(model) + + metric_fns = [getattr(module_metric, met) for met in config['metrics']] + + # logger.info('Loading checkpoint: {} ...'.format(config.resume)) + model = model.to(device) + + # Loading model paths for all deg models + logger.info('Loading checkpoints of below models:') + model_paths = [] + for key, value in config['model'].items(): + if key.startswith('pretrained_path'): + model_paths.append(value) + logger.info(value) + + if len(device_ids) > 1: + model = torch.nn.DataParallel(model, device_ids=device_ids) + checkpoints = [torch.load(path) for path in model_paths] + models_all = [copy.deepcopy(model) for _ in range(len(checkpoints))] + # Loading all models given the model paths for all degradations + for i, model in enumerate(models_all): + model.load_state_dict(checkpoints[i]['state_dict']) + model = model.to(device) + model.eval() + + for lev in range(deg_range[0],deg_range[1]+1): + # setup data_loader instances + data_loader = getattr(module_data, config['data_loader']['type'])( + config['data_loader']['args']['data_dir'], + batch_size=100, + shuffle=False, + validation_split=0.0, + num_workers=2, + train=False, + deg_type = config['data_loader']['args']['deg_type'], + deg_range = [lev, lev] + ) + total_loss = 0.0 + total_metrics = torch.zeros(len(metric_fns)) + + with torch.no_grad(): + for i, (images, targets) in enumerate(data_loader): + (image_clean, image_deg) = images + (labels, _) = targets + image_clean = image_clean.to(device) + image_deg = image_deg.to(device) + target = labels.to(device) + + outputs_all, pred_labels_all = [], [] + for i, model in enumerate(models_all): + _, _, _, _, feat, output = model(image_deg, image_deg) + outputs_all.append(output) + + for output in outputs_all: + pred_labels_all.append(torch.argmax(output, dim=1)) + + # Stack all lists together as tensor + outputs_all = torch.stack(outputs_all) + pred_labels_all = torch.stack(pred_labels_all) + + # Transpose the tensors to apply single image-wise operations + outputs_all = torch.permute(outputs_all, (1, 0, 2)) + pred_labels_all = pred_labels_all.T + # Take the sum of prob and then max of all predictions + outputs_all_sum_max = torch.argmax(outputs_all.sum(dim=1), dim=1) + + ensemble_outputs = [] + # Iterate over each sub-tensor along the first dimension + for i, sub_tensor in enumerate(pred_labels_all): + values, counts = torch.unique(sub_tensor, return_counts=True) + max_count = counts.max() + mode_values = values[counts == max_count] + + # Breaking the pluraity ensemble tie here + if len(mode_values) > 1: + ensemble_outputs.append(outputs_all_sum_max[i]) + else: + ensemble_outputs.append(mode_values[0]) + ensemble_outputs = torch.stack(ensemble_outputs) + + batch_size = image_clean.shape[0] + for i, metric in enumerate(metric_fns): + total_metrics[i] += metric(ensemble_outputs, target) * batch_size + + n_samples = len(data_loader.sampler) + log = {'deg_level': lev} + log.update({ + met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns) + }) + writer.set_step(lev, mode = 'eval') + for met, val in log.items(): + writer.add_scalar(met, val) + logger.info(log) + + +if __name__ == '__main__': + args = argparse.ArgumentParser(description='Degraded Image Classification - KD') + args.add_argument('-c', '--config', default=None, type=str, + help='config file path (default: None)') + args.add_argument('-r', '--resume', default=None, type=str, + help='path to latest checkpoint (default: None)') + args.add_argument('-d', '--device', default=None, type=str, + help='indices of GPUs to enable (default: all)') + args.add_argument('-m', '--mode', default='eval', type=str, + help='Activate eval mode for config') + # custom cli options to modify configuration from default values given in json file. + CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') + options = [ + CustomArgs(['--dt', '--deg_type'], type=str, target='data_loader;args;deg_type') + ] + config = ConfigParser.from_args(args, options) + main(config) diff --git a/train.py b/train.py new file mode 100644 index 0000000..2b81972 --- /dev/null +++ b/train.py @@ -0,0 +1,58 @@ +import argparse +import collections +import data_loader.data_loaders as module_data +import model.metric as module_metric +from parse_config import ConfigParser +from utils.util import set_seeds, set_seeds_prev + +def main(config): + logger = config.get_logger('train') + logger.info(config) + + # setup data_loader instances + train_data_loader = config.init_obj('data_loader', module_data) + valid_data_loader = getattr(module_data, config['data_loader']['type'])( + config['data_loader']['args']['data_dir'], + batch_size=128, + shuffle=False, + validation_split=0.0, + num_workers=2, + train=False, + deg_type = config['data_loader']['args']['deg_type'] + ) + + Trainer = config.get_class('trainer', init = False) + metrics = [getattr(module_metric, met) for met in config['metrics']] + trainer = Trainer(metrics, config=config, + train_data_loader=train_data_loader, + valid_data_loader=valid_data_loader) + + trainer.train() + + +if __name__ == '__main__': + args = argparse.ArgumentParser(description='Degraded Image Classification - KD') + args.add_argument('-c', '--config', default=None, type=str, + help='config file path (default: None)') + args.add_argument('-r', '--resume', default=None, type=str, + help='path to latest checkpoint (default: None)') + args.add_argument('-d', '--device', default=None, type=str, + help='indices of GPUs to enable (default: all)') + + # custom cli options to modify configuration from default values given in json file. + CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') + options = [ + CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'), + CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size'), + CustomArgs(['--dt', '--deg_type'], type=str, target='data_loader;args;deg_type'), + CustomArgs(['--rs', '--random_seed'], type=int, target='random_seed') + ] + config = ConfigParser.from_args(args, options) + + # fix random seeds for reproducibility + if 'random_seed' in config: + set_seeds(config['random_seed']) + else: + # Provides backward compability for previous experiments + set_seeds_prev() + main(config) diff --git a/train_all_deg.py b/train_all_deg.py new file mode 100644 index 0000000..fda38b7 --- /dev/null +++ b/train_all_deg.py @@ -0,0 +1,64 @@ +import argparse +import collections +import data_loader.data_loaders as module_data +import model.metric as module_metric +from parse_config import ConfigParser +from utils.util import set_seeds, set_seeds_prev + +def main(config): + logger = config.get_logger('train') + logger.info(config) + + # setup data_loader instances + degs_all = ['jpeg', 'blur', 'saltpepper', 'noise'] + train_loaders_all = [] + val_loaders_all = [] + prev_deg = config.config['data_loader']['args']['deg_type'] + for deg in degs_all: + config.config['data_loader']['args']['deg_type'] = deg + train_loaders_all.append(config.init_obj('data_loader', module_data)) + val_loaders_all.append(getattr(module_data, config['data_loader']['type'])( + config['data_loader']['args']['data_dir'], + batch_size=128, + shuffle=False, + validation_split=0.0, + num_workers=2, + train=False, + deg_type = config['data_loader']['args']['deg_type'] + )) + config.config['data_loader']['args']['deg_type'] = prev_deg + Trainer = config.get_class('trainer', init = False) + metrics = [getattr(module_metric, met) for met in config['metrics']] + trainer = Trainer(metrics, config=config, + train_loaders_all=train_loaders_all, + val_loaders_all=val_loaders_all) + + trainer.train() + + +if __name__ == '__main__': + args = argparse.ArgumentParser(description='Degraded Image Classification - KD') + args.add_argument('-c', '--config', default=None, type=str, + help='config file path (default: None)') + args.add_argument('-r', '--resume', default=None, type=str, + help='path to latest checkpoint (default: None)') + args.add_argument('-d', '--device', default=None, type=str, + help='indices of GPUs to enable (default: all)') + + # custom cli options to modify configuration from default values given in json file. + CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') + options = [ + CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'), + CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size'), + CustomArgs(['--dt', '--deg_type'], type=str, target='data_loader;args;deg_type'), + CustomArgs(['--rs', '--random_seed'], type=int, target='random_seed') + ] + config = ConfigParser.from_args(args, options) + + # fix random seeds for reproducibility + if 'random_seed' in config: + set_seeds(config['random_seed']) + else: + # Provides backward compability for previous experiments + set_seeds_prev() + main(config) diff --git a/trainer/__init__.py b/trainer/__init__.py new file mode 100644 index 0000000..aa54817 --- /dev/null +++ b/trainer/__init__.py @@ -0,0 +1,4 @@ +from .ind import IndTrainer +from .sl import SLTrainer +from .sl_all_deg import SLDA_Trainer +from .ind_all_deg import IndDATrainer \ No newline at end of file diff --git a/trainer/ind.py b/trainer/ind.py new file mode 100644 index 0000000..66b0d5a --- /dev/null +++ b/trainer/ind.py @@ -0,0 +1,211 @@ +import numpy as np +import torch +from base import BaseTrainer +from utils import inf_loop, MetricTracker +import model.loss as module_loss +import warmup_scheduler + +class IndTrainer(BaseTrainer): + """ + Trainer class for training base model on clean images, i.e., Step-1. + """ + def __init__(self, metric_ftns, config, train_data_loader, valid_data_loader=None, + len_epoch=None): + super().__init__(metric_ftns, config, train_data_loader, valid_data_loader, len_epoch) + self.model = self._build_model(config) + self.criterion = self._load_loss(config) + self.optimizer = self._load_optimizer(self.model, config) + self.lr_scheduler = self._load_scheduler(self.optimizer, config) + self.config = config + + self.log_step = int(np.sqrt(train_data_loader.batch_size)) + train_misc_metrics = ['loss', 'lr'] + valid_misc_metrics = ['loss'] + self.train_metrics = MetricTracker(*train_misc_metrics, + *[m.__name__ + '_' + self.deg_flag for m in self.metric_ftns], + writer=self.writer) + self.valid_metrics = MetricTracker(*valid_misc_metrics, + *[m.__name__ + '_' + self.deg_flag for m in self.metric_ftns], + writer=self.writer) + + def _build_model(self, config): + """ + Building model from the configuration file + + :param config: config file + :return: model with loaded state dict + """ + # build model architecture, then print to console + model = config.get_class('model') + self.logger.info(model) + model = model.to(self.device) + if len(self.device_ids) > 1: + model = torch.nn.DataParallel(model, device_ids=self.device_ids) + return model + + def _load_loss(self, config): + """ + Build model from the configuration file + + :param config: config file + :return: criterion dictionary in the format: {loss_type: loss} + """ + # criterion = getattr(module_loss, config['loss']) + criterion = {type: getattr(module_loss, type)(loss) for losses in config['loss'] \ + for type, loss in losses.items()} + return criterion + + def _load_optimizer(self, model, config): + """ + Load optimizer from the configuration file + + :param model: model for which optimizer is to be initialized + :param config: config file + :return: initialized optimizer + """ + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = config.init_obj('optimizer', torch.optim, trainable_params) + return optimizer + + def _load_scheduler(self, optimizer, config): + """ + Load scheduler from the configuration file + + :param optimizer: optimizer for which scheduler is to be initialized + :param config: config file + :return: initialized scheduler + """ + lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer) + if 'lr_warmup' in config and config['lr_warmup'] is not None: + lr_scheduler = config.init_obj('lr_warmup', warmup_scheduler, optimizer, after_scheduler = lr_scheduler) + return lr_scheduler + + def _train_epoch(self, epoch): + """ + Training logic for an epoch + + :param epoch: Integer, current training epoch. + :return: A log that contains average loss and metric in this epoch. + """ + self.model.train() + self.train_metrics.reset() + for batch_idx, (images, targets) in enumerate(self.train_data_loader): + (image_clean, image_deg) = images + (labels, _) = targets + image_clean = image_clean.to(self.device) + image_deg = image_deg.to(self.device) + target = labels.to(self.device) + + self.optimizer.zero_grad() + + outputs = self.model(image_clean, image_deg) + output = outputs[-1] + loss = self.criterion['supervised_loss'](output, target) + loss.backward() + self.optimizer.step() + + self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) + self.train_metrics.update('loss', loss.item()) + for met in self.metric_ftns: + self.train_metrics.update(met.__name__ + '_' + self.deg_flag, + met(output, target)) + self.train_metrics.update('lr', self.lr_scheduler.get_last_lr()[0]) + if batch_idx % self.log_step == 0: + self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( + epoch, + self._progress(batch_idx), + loss.item())) + + if batch_idx == self.len_epoch: + break + log = self.train_metrics.result() + + if self.do_validation: + self.logger.info('Testing on validation data') + val_log = self._valid_epoch(epoch) + log.update(**{'val_'+k : v for k, v in val_log.items()}) + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + return log + + def _valid_epoch(self, epoch): + """ + Validate after training an epoch + + :param epoch: Integer, current training epoch. + :return: A log that contains information about validation + """ + self.model.eval() + self.valid_metrics.reset() + with torch.no_grad(): + for batch_idx, (images, targets) in enumerate(self.valid_data_loader): + (image_clean, image_deg) = images + (labels, _) = targets + image_clean = image_clean.to(self.device) + image_deg = image_deg.to(self.device) + target = labels.to(self.device) + + outputs = self.model(image_clean, image_deg) + output = outputs[-1] + loss = self.criterion['supervised_loss'](output, target) + + self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') + self.valid_metrics.update('loss', loss.item()) + for met in self.metric_ftns: + self.valid_metrics.update(met.__name__ + '_' + self.deg_flag, + met(output, target)) + + return self.valid_metrics.result() + + def _save_checkpoint(self, epoch, save_best=False): + """ + Saving checkpoints + + :param epoch: current epoch number + :param log: logging information of the epoch + :param save_best: if True, rename the saved checkpoint to 'model_best.pth' + """ + model_name = type(self.model).__name__ + state = { + 'model': model_name, + 'epoch': epoch, + 'state_dict': self.model.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'monitor_best': self.mnt_best, + 'config': self.config + } + filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) + # torch.save(state, filename) + self.logger.info("Saving checkpoint: {} ...".format(filename)) + if save_best: + best_path = str(self.checkpoint_dir / 'model_best.pth') + torch.save(state, best_path) + self.logger.info("Saving current best: model_best.pth ...") + + def _resume_checkpoint(self, config): + """ + Resume from saved checkpoints + + :param resume_path: Checkpoint path to be resumed + """ + resume_path = str(config.resume) + self.logger.info("Loading checkpoint: {} ...".format(resume_path)) + checkpoint = torch.load(resume_path) + self.start_epoch = checkpoint['epoch'] + 1 + self.mnt_best = checkpoint['monitor_best'] + + # load architecture params from checkpoint. + if checkpoint['config']['model'] != self.config['model']: + self.logger.warning("Warning: Architecture configuration given in config file is different from that of " + "checkpoint. This may yield an exception while state_dict is being loaded.") + self.model.load_state_dict(checkpoint['state_dict']) + + # load optimizer state from checkpoint only when optimizer type is not changed. + if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: + self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " + "Optimizer parameters not being resumed.") + else: + self.optimizer.load_state_dict(checkpoint['optimizer']) + + self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) diff --git a/trainer/ind_all_deg.py b/trainer/ind_all_deg.py new file mode 100644 index 0000000..17a83d8 --- /dev/null +++ b/trainer/ind_all_deg.py @@ -0,0 +1,237 @@ +import numpy as np +import torch +from base import BaseTrainer +from utils import inf_loop, MetricTracker +import model.loss as module_loss +import warmup_scheduler +import copy + +class IndDATrainer(BaseTrainer): + """ + Trainer class for training Individual model for combination of several degradation. + This trainer is used for training methods such as: Scratch, Vanilla, and Fused. + """ + def __init__(self, metric_ftns, config, train_loaders_all, val_loaders_all=None, + len_epoch=None): + super().__init__(metric_ftns, config, train_loaders_all[0], val_loaders_all[0], len_epoch) + self.model = self._build_model(config) + self.criterion = self._load_loss(config) + self.optimizer = self._load_optimizer(self.model, config) + self.lr_scheduler = self._load_scheduler(self.optimizer, config) + self.config = config + + self.log_step = int(np.sqrt(train_loaders_all[0].batch_size)) + train_misc_metrics = ['loss', 'lr'] + valid_misc_metrics = ['loss'] + self.train_metrics = MetricTracker(*train_misc_metrics, + *[m.__name__ + '_' + self.deg_flag for m in self.metric_ftns], + writer=self.writer) + self.valid_metrics = MetricTracker(*valid_misc_metrics, + *[m.__name__ + '_' + self.deg_flag for m in self.metric_ftns], + writer=self.writer) + self.train_loaders_all = train_loaders_all + self.val_loaders_all = val_loaders_all + + def _build_model(self, config): + """ + Building model from the configuration file + + :param config: config file + :return: model with loaded state dict + """ + # build model architecture, then print to console + model = config.get_class('model') + self.logger.info(model) + model = model.to(self.device) + if 'pretrained_path' in config['model']: + checkpoint = torch.load(config['model']['pretrained_path']) + model.load_state_dict(checkpoint['state_dict']) + if len(self.device_ids) > 1: + model = torch.nn.DataParallel(model, device_ids=self.device_ids) + return model + + def _load_loss(self, config): + """ + Build model from the configuration file + + :param config: config file + :return: criterion dictionary in the format: {loss_type: loss} + """ + criterion = {type: getattr(module_loss, type)(loss) for losses in config['loss'] \ + for type, loss in losses.items()} + return criterion + + def _load_optimizer(self, model, config): + """ + Load optimizer from the configuration file + + :param model: model for which optimizer is to be initialized + :param config: config file + :return: initialized optimizer + """ + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = config.init_obj('optimizer', torch.optim, trainable_params) + return optimizer + + def _load_scheduler(self, optimizer, config): + """ + Load scheduler from the configuration file + + :param optimizer: optimizer for which scheduler is to be initialized + :param config: config file + :return: initialized scheduler + """ + lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer) + if 'lr_warmup' in config and config['lr_warmup'] is not None: + lr_scheduler = config.init_obj('lr_warmup', warmup_scheduler, optimizer, after_scheduler = lr_scheduler) + return lr_scheduler + + def _train_epoch(self, epoch): + """ + Training logic for an epoch + + :param epoch: Integer, current training epoch. + :return: A log that contains average loss and metric in this epoch. + """ + self.model.train() + self.train_metrics.reset() + for batch_idx, loaders_all in enumerate(zip(*self.train_loaders_all)): + image_clean_all, image_deg_all, labels_all = None, None, None + for loader in loaders_all: + ((image_clean, image_deg), (labels, _)) = loader + if image_clean_all is None: + image_clean_all = copy.deepcopy(image_clean) + image_deg_all = copy.deepcopy(image_deg) + labels_all = copy.deepcopy(labels) + else: + image_clean_all = torch.cat((image_clean_all, image_clean)) + image_deg_all = torch.cat((image_deg_all, image_deg)) + labels_all = torch.cat((labels_all, labels)) + + image_clean_all = image_clean_all.to(self.device) + image_deg_all = image_deg_all.to(self.device) + target = labels_all.to(self.device) + + self.optimizer.zero_grad() + + outputs = self.model(image_clean_all, image_deg_all) + output = outputs[-1] + loss = self.criterion['supervised_loss'](output, target) + loss.backward() + self.optimizer.step() + + self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) + self.train_metrics.update('loss', loss.item()) + for met in self.metric_ftns: + self.train_metrics.update(met.__name__ + '_' + self.deg_flag, + met(output, target)) + self.train_metrics.update('lr', self.lr_scheduler.get_last_lr()[0]) + if batch_idx % self.log_step == 0: + self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( + epoch, + self._progress(batch_idx), + loss.item())) + + if batch_idx == self.len_epoch: + break + log = self.train_metrics.result() + + if self.do_validation: + self.logger.info('Testing on validation data') + val_log = self._valid_epoch(epoch) + log.update(**{'val_'+k : v for k, v in val_log.items()}) + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + return log + + def _valid_epoch(self, epoch): + """ + Validate after training an epoch + + :param epoch: Integer, current training epoch. + :return: A log that contains information about validation + """ + self.model.eval() + self.valid_metrics.reset() + with torch.no_grad(): + for batch_idx, loaders_all in enumerate(zip(*self.val_loaders_all)): + image_clean_all, image_deg_all, labels_all = None, None, None + for loader in loaders_all: + ((image_clean, image_deg), (labels, _)) = loader + if image_clean_all is None: + image_clean_all = copy.deepcopy(image_clean) + image_deg_all = copy.deepcopy(image_deg) + labels_all = copy.deepcopy(labels) + else: + image_clean_all = torch.cat((image_clean_all, image_clean)) + image_deg_all = torch.cat((image_deg_all, image_deg)) + labels_all = torch.cat((labels_all, labels)) + + image_clean_all = image_clean_all.to(self.device) + image_deg_all = image_deg_all.to(self.device) + target = labels_all.to(self.device) + + outputs = self.model(image_clean_all, image_deg_all) + output = outputs[-1] + loss = self.criterion['supervised_loss'](output, target) + + self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') + self.valid_metrics.update('loss', loss.item()) + for met in self.metric_ftns: + self.valid_metrics.update(met.__name__ + '_' + self.deg_flag, + met(output, target)) + + return self.valid_metrics.result() + + def _save_checkpoint(self, epoch, save_best=False): + """ + Saving checkpoints + + :param epoch: current epoch number + :param log: logging information of the epoch + :param save_best: if True, rename the saved checkpoint to 'model_best.pth' + """ + model_name = type(self.model).__name__ + state = { + 'model': model_name, + 'epoch': epoch, + 'state_dict': self.model.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'monitor_best': self.mnt_best, + 'config': self.config + } + filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) + # torch.save(state, filename) + self.logger.info("Saving checkpoint: {} ...".format(filename)) + if save_best: + best_path = str(self.checkpoint_dir / 'model_best.pth') + torch.save(state, best_path) + self.logger.info("Saving current best: model_best.pth ...") + + def _resume_checkpoint(self, config): + """ + Resume from saved checkpoints + + :param resume_path: Checkpoint path to be resumed + """ + resume_path = str(config.resume) + self.logger.info("Loading checkpoint: {} ...".format(resume_path)) + checkpoint = torch.load(resume_path) + self.start_epoch = checkpoint['epoch'] + 1 + self.mnt_best = checkpoint['monitor_best'] + + # load architecture params from checkpoint. + if checkpoint['config']['model'] != self.config['model']: + self.logger.warning("Warning: Architecture configuration given in config file is different from that of " + "checkpoint. This may yield an exception while state_dict is being loaded.") + self.model.load_state_dict(checkpoint['state_dict']) + + # load optimizer state from checkpoint only when optimizer type is not changed. + if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: + self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " + "Optimizer parameters not being resumed.") + else: + self.optimizer.load_state_dict(checkpoint['optimizer']) + + self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) diff --git a/trainer/sl.py b/trainer/sl.py new file mode 100644 index 0000000..0a8df30 --- /dev/null +++ b/trainer/sl.py @@ -0,0 +1,299 @@ +import numpy as np +import torch +from base import BaseTrainer +from utils import inf_loop, MetricTracker +import model.loss as module_loss +import torch.nn as nn +import torch.nn.functional as F + +class SLTrainer(BaseTrainer): + """ + Trainer class for fine-tuning the base model trained on clean images for specific degradation, i.e., step-2 + """ + def __init__(self, metric_ftns, config, train_data_loader, valid_data_loader=None, + len_epoch=None): + super().__init__(metric_ftns, config, train_data_loader, valid_data_loader, len_epoch) + self.teacher, self.student = self._build_model(config) + self.criterion = self._load_loss(config) + self.optimizer = self._load_optimizer(self.student, config) + self.lr_scheduler = self._load_scheduler(self.optimizer, config) + self.config = config + self.loss_names = {type: loss for losses in config['loss'] for type, loss in losses.items()} + self.log_step = int(np.sqrt(train_data_loader.batch_size)) + train_misc_metrics = ['loss', 'sup_loss', 'inh_loss', 'lr'] + valid_misc_metrics = ['loss', 'sup_loss', 'inh_loss'] + self.train_metrics = MetricTracker(*train_misc_metrics, + *[m.__name__ + '_' + self.deg_flag for m in self.metric_ftns], + writer=self.writer) + self.valid_metrics = MetricTracker(*valid_misc_metrics, + *[m.__name__ + '_' + self.deg_flag for m in self.metric_ftns], + writer=self.writer) + + def _build_model(self, config): + """ + Building model from the configuration file + + :param config: config file + :return: model with loaded state dict + """ + # build model architecture, then print to console + teacher = config.get_class('teacher_model', _class = 'model') + student = config.get_class('student_model', _class = 'model') + self.logger.info('Teacher Network: {} \n Student Network: {}'.format(teacher, student)) + self.logger.info("Loading checkpoint for teacher: {} ...".format( + config['teacher_model']['pretrained_path'])) + + teacher = teacher.to(self.device) + student = student.to(self.device) + if len(self.device_ids) > 1: + teacher = torch.nn.DataParallel(teacher, device_ids=self.device_ids) + student = torch.nn.DataParallel(student, device_ids=self.device_ids) + + checkpoint = torch.load(config['teacher_model']['pretrained_path']) + teacher.load_state_dict(checkpoint['state_dict']) + if 'pretrained_path' in config['student_model']: + checkpoint = torch.load(config['student_model']['pretrained_path']) + student.load_state_dict(checkpoint['state_dict']) + # Feezing parameters of teacher + if self.config['student_model']['type'].startswith('ShakePyramidNet'): + for param in teacher.parameters(): + param.requires_grad = False + + return teacher, student + + def _load_loss(self, config): + """ + Build model from the configuration file + + :param config: config file + :return: criterion dictionary in the format: {loss_type: loss} + """ + criterion = {type: getattr(module_loss, type)(loss) for losses in config['loss'] \ + for type, loss in losses.items()} + return criterion + + def _load_optimizer(self, model, config): + """ + Load optimizer from the configuration file + + :param model: model for which optimizer is to be initialized + :param config: config file + :return: initialized optimizer + """ + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = config.init_obj('optimizer', torch.optim, trainable_params) + return optimizer + + def _load_scheduler(self, optimizer, config): + """ + Load scheduler from the configuration file + + :param optimizer: optimizer for which scheduler is to be initialized + :param config: config file + :return: initialized scheduler + """ + lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer) + return lr_scheduler + + def _train_epoch(self, epoch): + """ + Training logic for an epoch + + :param epoch: Integer, current training epoch. + :return: A log that contains average loss and metric in this epoch. + """ + self.teacher.eval() + self.student.train() + self.train_metrics.reset() + for batch_idx, (images, targets) in enumerate(self.train_data_loader): + (image_clean, image_deg) = images + (labels, _) = targets + image_clean = image_clean.to(self.device) + image_deg = image_deg.to(self.device) + target = labels.to(self.device) + if self.loss_names['inheritance_loss'] == 'COS': + dum = torch.ones((image_clean.size(0),)) + dum = dum.to(self.device) + + self.optimizer.zero_grad() + stem_t, rb1_t, rb2_t, rb3_t, t_feat, t_out = self.teacher(image_clean, image_deg) + stem_s, rb1_s, rb2_s, rb3_s, s_feat, s_out = self.student(image_clean, image_deg) + + sup_loss = self.criterion['supervised_loss'](s_out, target) * \ + self.config['loss_weights'][0] + inh_loss = 0 + if self.loss_names['inheritance_loss'] == 'COS': + inh_loss += self.criterion['inheritance_loss'](rb1_s[1], rb1_t[1].detach(), dum) * \ + self.config['loss_weights'][1][0] + inh_loss += self.criterion['inheritance_loss'](rb2_s[1], rb2_t[1].detach(), dum) * \ + self.config['loss_weights'][1][1] + inh_loss += self.criterion['inheritance_loss'](rb3_s[1], rb3_t[1].detach(), dum) * \ + self.config['loss_weights'][1][2] + elif self.loss_names['inheritance_loss'] == 'AT': + inh_loss += self.criterion['inheritance_loss'](rb1_s[1], rb1_t[1].detach()) * \ + self.config['loss_weights'][1][0] + inh_loss += self.criterion['inheritance_loss'](rb2_s[1], rb2_t[1].detach()) * \ + self.config['loss_weights'][1][1] + inh_loss += self.criterion['inheritance_loss'](rb3_s[1], rb3_t[1].detach()) * \ + self.config['loss_weights'][1][2] + elif self.loss_names['inheritance_loss'] == 'KLD': + inh_loss += self.criterion['inheritance_loss'](torch.log(torch.clamp(F.softmax(rb3_s[1],dim=1),1e-10,1.0)), + torch.clamp(F.softmax(rb3_t[1], dim=1),1e-10,1.0)) * \ + self.config['loss_weights'][1][2] + else: + raise NotImplementedError + # inh_loss = inh_loss / (self.config['loss_weights'][1][0] + + # self.config['loss_weights'][1][1] + + # self.config['loss_weights'][1][2]) + loss = sup_loss + inh_loss + loss.backward() + self.optimizer.step() + + self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) + self.train_metrics.update('loss', loss.item()) + self.train_metrics.update('inh_loss', inh_loss.item()) + self.train_metrics.update('sup_loss', sup_loss.item()) + for met in self.metric_ftns: + self.train_metrics.update(met.__name__ + '_' + self.deg_flag, + met(s_out, target)) + self.train_metrics.update('lr', self.lr_scheduler.get_last_lr()[0]) + if batch_idx % self.log_step == 0: + self.logger.debug( + 'Train Epoch: {} {} Loss: {:.6f} Sup Loss: {:.6f} Inh Loss: {:.6f}'.format( + epoch, self._progress(batch_idx), loss.item(), + sup_loss.item(), inh_loss.item())) + # self.writer.add_image('input_clean', make_grid(image_clean.cpu(), nrow=8, normalize=True)) + # self.writer.add_image('input_deg', make_grid(image_deg.cpu(), nrow=8, normalize=True)) + + if batch_idx == self.len_epoch: + break + log = self.train_metrics.result() + + if self.do_validation: + self.logger.info('Testing on validation data') + val_log = self._valid_epoch(epoch) + log.update(**{'val_'+k : v for k, v in val_log.items()}) + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + return log + + def _valid_epoch(self, epoch): + """ + Validate after training an epoch + + :param epoch: Integer, current training epoch. + :return: A log that contains information about validation + """ + self.teacher.eval() + self.student.eval() + self.valid_metrics.reset() + with torch.no_grad(): + for batch_idx, (images, targets) in enumerate(self.valid_data_loader): + (image_clean, image_deg) = images + (labels, _) = targets + image_clean = image_clean.to(self.device) + image_deg = image_deg.to(self.device) + target = labels.to(self.device) + if self.loss_names['inheritance_loss'] == 'COS': + dum = torch.ones((image_clean.size(0),)) + dum = dum.to(self.device) + + self.optimizer.zero_grad() + stem_t, rb1_t, rb2_t, rb3_t, t_feat, t_out = self.teacher(image_clean, image_deg) + stem_s, rb1_s, rb2_s, rb3_s, s_feat, s_out = self.student(image_clean, image_deg) + + sup_loss = self.criterion['supervised_loss'](s_out, target) * \ + self.config['loss_weights'][0] + inh_loss = 0 + if self.loss_names['inheritance_loss'] == 'COS': + inh_loss += self.criterion['inheritance_loss'](rb1_s[1], rb1_t[1].detach(), dum) * \ + self.config['loss_weights'][1][0] + inh_loss += self.criterion['inheritance_loss'](rb2_s[1], rb2_t[1].detach(), dum) * \ + self.config['loss_weights'][1][1] + inh_loss += self.criterion['inheritance_loss'](rb3_s[1], rb3_t[1].detach(), dum) * \ + self.config['loss_weights'][1][2] + elif self.loss_names['inheritance_loss'] == 'AT': + inh_loss += self.criterion['inheritance_loss'](rb1_s[1], rb1_t[1].detach()) * \ + self.config['loss_weights'][1][0] + inh_loss += self.criterion['inheritance_loss'](rb2_s[1], rb2_t[1].detach()) * \ + self.config['loss_weights'][1][1] + inh_loss += self.criterion['inheritance_loss'](rb3_s[1], rb3_t[1].detach()) * \ + self.config['loss_weights'][1][2] + elif self.loss_names['inheritance_loss'] == 'KLD': + inh_loss += self.criterion['inheritance_loss'](torch.log(torch.clamp(F.softmax(rb3_s[1],dim=1),1e-10,1.0)), + torch.clamp(F.softmax(rb3_t[1], dim=1),1e-10,1.0)) * \ + self.config['loss_weights'][1][2] + else: + raise NotImplementedError + # inh_loss = inh_loss / (self.config['loss_weights'][1][0] + + # self.config['loss_weights'][1][1] + + # self.config['loss_weights'][1][2]) + loss = sup_loss + inh_loss + + self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') + self.valid_metrics.update('loss', loss.item()) + self.valid_metrics.update('inh_loss', inh_loss.item()) + self.valid_metrics.update('sup_loss', sup_loss.item()) + for met in self.metric_ftns: + self.valid_metrics.update(met.__name__ + '_' + self.deg_flag, + met(s_out, target)) + # self.writer.add_image('input_clean', make_grid(image_clean.cpu(), nrow=8, normalize=True)) + # self.writer.add_image('input_deg', make_grid(image_deg.cpu(), nrow=8, normalize=True)) + + # add histogram of model parameters to the tensorboard + # for name, p in self.student.named_parameters(): + # self.writer.add_histogram(name, p, bins='auto') + return self.valid_metrics.result() + + def _save_checkpoint(self, epoch, save_best=False): + """ + Saving checkpoints + + :param epoch: current epoch number + :param log: logging information of the epoch + :param save_best: if True, rename the saved checkpoint to 'model_best.pth' + """ + model_name = type(self.student).__name__ + state = { + 'model': model_name, + 'epoch': epoch, + 'state_dict': self.student.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'monitor_best': self.mnt_best, + 'config': self.config + } + filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) + # torch.save(state, filename) + self.logger.info("Saving checkpoint: {} ...".format(filename)) + if save_best: + best_path = str(self.checkpoint_dir / 'model_best.pth') + torch.save(state, best_path) + self.logger.info("Saving current best: model_best.pth ...") + + def _resume_checkpoint(self, config): + """ + Resume from saved checkpoints + + :param resume_path: Checkpoint path to be resumed + """ + resume_path = str(config.resume) + self.logger.info("Loading checkpoint: {} ...".format(resume_path)) + checkpoint = torch.load(resume_path) + self.start_epoch = checkpoint['epoch'] + 1 + self.mnt_best = checkpoint['monitor_best'] + + # load architecture params from checkpoint. + if checkpoint['config']['student_model'] != self.config['student_model']: + self.logger.warning("Warning: Architecture configuration given in config file is different from that of " + "checkpoint. This may yield an exception while state_dict is being loaded.") + self.student.load_state_dict(checkpoint['state_dict']) + + # load optimizer state from checkpoint only when optimizer type is not changed. + if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: + self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " + "Optimizer parameters not being resumed.") + else: + self.optimizer.load_state_dict(checkpoint['optimizer']) + + self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) diff --git a/trainer/sl_all_deg.py b/trainer/sl_all_deg.py new file mode 100644 index 0000000..fb0ec31 --- /dev/null +++ b/trainer/sl_all_deg.py @@ -0,0 +1,316 @@ +import numpy as np +import torch +from torchvision.utils import make_grid +from base import BaseTrainer +from utils import inf_loop, MetricTracker +import model.loss as module_loss +import torch.nn as nn +import torch.nn.functional as F +import copy + +class SLDA_Trainer(BaseTrainer): + """ + Trainer class for training our proposed method FusionDistill based on distillation and fusion, + i.e., Step-4 of our proposed method. + """ + def __init__(self, metric_ftns, config, train_loaders_all, val_loaders_all=None, + len_epoch=None): + super().__init__(metric_ftns, config, train_loaders_all[0], val_loaders_all[0], len_epoch) + self.teachers, self.student = self._build_model(config) + self.criterion = self._load_loss(config) + self.optimizer = self._load_optimizer(self.student, config) + self.lr_scheduler = self._load_scheduler(self.optimizer, config) + self.config = config + self.loss_names = {type: loss for losses in config['loss'] for type, loss in losses.items()} + self.log_step = int(np.sqrt(train_loaders_all[0].batch_size)) + train_misc_metrics = ['loss', 'sup_loss', 'inh_loss', 'lr'] + valid_misc_metrics = ['loss', 'sup_loss', 'inh_loss'] + self.train_metrics = MetricTracker(*train_misc_metrics, + *[m.__name__ + '_' + self.deg_flag for m in self.metric_ftns], + writer=self.writer) + self.valid_metrics = MetricTracker(*valid_misc_metrics, + *[m.__name__ + '_' + self.deg_flag for m in self.metric_ftns], + writer=self.writer) + self.train_loaders_all = train_loaders_all + self.val_loaders_all = val_loaders_all + + + def _build_model(self, config): + """ + Building model from the configuration file + + :param config: config file + :return: model with loaded state dict + """ + # build model architecture, then print to console + teacher = config.get_class('teacher_model', _class = 'model') + student = config.get_class('student_model', _class = 'model') + self.logger.info('Teacher Network: {} \n Student Network: {}'.format(teacher, student)) + + degs_all = ['jpeg', 'blur', 'saltpepper', 'noise'] + model_paths = [config['teacher_model']['pretrained_path_' + deg] for deg in degs_all] + checkpoints = [torch.load(path) for path in model_paths] + teachers = [copy.deepcopy(teacher) for _ in range(len(checkpoints))] + # Feezing parameters of teachers + for i, model in enumerate(teachers): + model.load_state_dict(checkpoints[i]['state_dict']) + for param in model.parameters(): + param.requires_grad = False + model = model.to(self.device) + + student = student.to(self.device) + if 'pretrained_path' in config['student_model']: + checkpoint = torch.load(config['student_model']['pretrained_path']) + student.load_state_dict(checkpoint['state_dict']) + + return teachers, student + + def _load_loss(self, config): + """ + Build model from the configuration file + + :param config: config file + :return: criterion dictionary in the format: {loss_type: loss} + """ + criterion = {type: getattr(module_loss, type)(loss) for losses in config['loss'] \ + for type, loss in losses.items()} + return criterion + + def _load_optimizer(self, model, config): + """ + Load optimizer from the configuration file + + :param model: model for which optimizer is to be initialized + :param config: config file + :return: initialized optimizer + """ + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = config.init_obj('optimizer', torch.optim, trainable_params) + return optimizer + + def _load_scheduler(self, optimizer, config): + """ + Load scheduler from the configuration file + + :param optimizer: optimizer for which scheduler is to be initialized + :param config: config file + :return: initialized scheduler + """ + lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer) + return lr_scheduler + + def _train_epoch(self, epoch): + """ + Training logic for an epoch + + :param epoch: Integer, current training epoch. + :return: A log that contains average loss and metric in this epoch. + """ + for teacher in self.teachers: + teacher.eval() + self.student.train() + self.train_metrics.reset() + for batch_idx, loaders_all in enumerate(zip(*self.train_loaders_all)): + image_clean_all, image_deg_all, labels_all = None, None, None + for loader in loaders_all: + ((image_clean, image_deg), (labels, _)) = loader + if image_clean_all is None: + image_clean_all = copy.deepcopy(image_clean) + image_deg_all = copy.deepcopy(image_deg) + labels_all = copy.deepcopy(labels) + else: + image_clean_all = torch.cat((image_clean_all, image_clean)) + image_deg_all = torch.cat((image_deg_all, image_deg)) + labels_all = torch.cat((labels_all, labels)) + + image_clean_all = image_clean_all.to(self.device) + image_deg_all = image_deg_all.to(self.device) + target = labels_all.to(self.device) + + batch_size = int(image_clean_all.size(0)/4) + if self.loss_names['inheritance_loss'] == 'COS': + dum = torch.ones(batch_size,) + dum = dum.to(self.device) + + self.optimizer.zero_grad() + _, rb1_t_jpeg, rb2_t_jpeg, rb3_t_jpeg, _, _ = self.teachers[0](image_clean_all[:batch_size], + image_deg_all[:batch_size]) + _, rb1_t_blur, rb2_t_blur, rb3_t_blur, _, _ = self.teachers[1](image_clean_all[batch_size:batch_size*2], + image_deg_all[batch_size:batch_size*2]) + _, rb1_t_saltpepper, rb2_t_saltpepper, rb3_t_saltpepper, _, _ = self.teachers[2](image_clean_all[batch_size*2:batch_size*3], + image_deg_all[batch_size*2:batch_size*3]) + _, rb1_t_noise, rb2_t_noise, rb3_t_noise, _, _ = self.teachers[3](image_clean_all[batch_size*3:], + image_deg_all[batch_size*3:]) + + stem_s, rb1_s, rb2_s, rb3_s, s_feat, s_out = self.student(image_clean_all, image_deg_all) + + sup_loss = self.criterion['supervised_loss'](s_out, target) * \ + self.config['loss_weights'][0] + inh_loss = 0 + if self.loss_names['inheritance_loss'] == 'COS': + inh_loss += self.criterion['inheritance_loss'](rb3_s[1][:batch_size], rb3_t_jpeg[1].detach(), dum) * \ + self.config['loss_weights'][1][2] + inh_loss += self.criterion['inheritance_loss'](rb3_s[1][batch_size:batch_size*2], rb3_t_blur[1].detach(), dum) * \ + self.config['loss_weights'][1][2] + inh_loss += self.criterion['inheritance_loss'](rb3_s[1][batch_size*2:batch_size*3], rb3_t_saltpepper[1].detach(), dum) * \ + self.config['loss_weights'][1][2] + inh_loss += self.criterion['inheritance_loss'](rb3_s[1][batch_size*3:], rb3_t_noise[1].detach(), dum) * \ + self.config['loss_weights'][1][2] + else: + raise NotImplementedError + + loss = sup_loss + inh_loss + loss.backward() + self.optimizer.step() + + self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) + self.train_metrics.update('loss', loss.item()) + self.train_metrics.update('inh_loss', inh_loss.item()) + self.train_metrics.update('sup_loss', sup_loss.item()) + for met in self.metric_ftns: + self.train_metrics.update(met.__name__ + '_' + self.deg_flag, + met(s_out, target)) + self.train_metrics.update('lr', self.lr_scheduler.get_last_lr()[0]) + if batch_idx % self.log_step == 0: + self.logger.debug( + 'Train Epoch: {} {} Loss: {:.6f} Sup Loss: {:.6f} Inh Loss: {:.6f}'.format( + epoch, self._progress(batch_idx), loss.item(), + sup_loss.item(), inh_loss.item())) + + if batch_idx == self.len_epoch: + break + log = self.train_metrics.result() + + if self.do_validation: + self.logger.info('Testing on validation data') + val_log = self._valid_epoch(epoch) + log.update(**{'val_'+k : v for k, v in val_log.items()}) + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + return log + + def _valid_epoch(self, epoch): + """ + Validate after training an epoch + + :param epoch: Integer, current training epoch. + :return: A log that contains information about validation + """ + for teacher in self.teachers: + teacher.eval() + self.student.eval() + self.valid_metrics.reset() + with torch.no_grad(): + for batch_idx, loaders_all in enumerate(zip(*self.val_loaders_all)): + image_clean_all, image_deg_all, labels_all = None, None, None + for loader in loaders_all: + ((image_clean, image_deg), (labels, _)) = loader + if image_clean_all is None: + image_clean_all = copy.deepcopy(image_clean) + image_deg_all = copy.deepcopy(image_deg) + labels_all = copy.deepcopy(labels) + else: + image_clean_all = torch.cat((image_clean_all, image_clean)) + image_deg_all = torch.cat((image_deg_all, image_deg)) + labels_all = torch.cat((labels_all, labels)) + + image_clean_all = image_clean_all.to(self.device) + image_deg_all = image_deg_all.to(self.device) + target = labels_all.to(self.device) + + batch_size = int(image_clean_all.size(0)/4) + if self.loss_names['inheritance_loss'] == 'COS': + dum = torch.ones(batch_size,) + dum = dum.to(self.device) + + self.optimizer.zero_grad() + _, rb1_t_jpeg, rb2_t_jpeg, rb3_t_jpeg, _, _ = self.teachers[0](image_clean_all[:batch_size], + image_deg_all[:batch_size]) + _, rb1_t_blur, rb2_t_blur, rb3_t_blur, _, _ = self.teachers[1](image_clean_all[batch_size:batch_size*2], + image_deg_all[batch_size:batch_size*2]) + _, rb1_t_saltpepper, rb2_t_saltpepper, rb3_t_saltpepper, _, _ = self.teachers[2](image_clean_all[batch_size*2:batch_size*3], + image_deg_all[batch_size*2:batch_size*3]) + _, rb1_t_noise, rb2_t_noise, rb3_t_noise, _, _ = self.teachers[3](image_clean_all[batch_size*3:], + image_deg_all[batch_size*3:]) + + stem_s, rb1_s, rb2_s, rb3_s, s_feat, s_out = self.student(image_clean_all, image_deg_all) + + sup_loss = self.criterion['supervised_loss'](s_out, target) * \ + self.config['loss_weights'][0] + inh_loss = 0 + if self.loss_names['inheritance_loss'] == 'COS': + inh_loss += self.criterion['inheritance_loss'](rb3_s[1][:batch_size], rb3_t_jpeg[1].detach(), dum) * \ + self.config['loss_weights'][1][2] + inh_loss += self.criterion['inheritance_loss'](rb3_s[1][batch_size:batch_size*2], rb3_t_blur[1].detach(), dum) * \ + self.config['loss_weights'][1][2] + inh_loss += self.criterion['inheritance_loss'](rb3_s[1][batch_size*2:batch_size*3], rb3_t_saltpepper[1].detach(), dum) * \ + self.config['loss_weights'][1][2] + inh_loss += self.criterion['inheritance_loss'](rb3_s[1][batch_size*3:], rb3_t_noise[1].detach(), dum) * \ + self.config['loss_weights'][1][2] + else: + raise NotImplementedError + + loss = sup_loss + inh_loss + + self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') + self.valid_metrics.update('loss', loss.item()) + self.valid_metrics.update('inh_loss', inh_loss.item()) + self.valid_metrics.update('sup_loss', sup_loss.item()) + for met in self.metric_ftns: + self.valid_metrics.update(met.__name__ + '_' + self.deg_flag, + met(s_out, target)) + + return self.valid_metrics.result() + + def _save_checkpoint(self, epoch, save_best=False): + """ + Saving checkpoints + + :param epoch: current epoch number + :param log: logging information of the epoch + :param save_best: if True, rename the saved checkpoint to 'model_best.pth' + """ + model_name = type(self.student).__name__ + state = { + 'model': model_name, + 'epoch': epoch, + 'state_dict': self.student.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'monitor_best': self.mnt_best, + 'config': self.config + } + filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) + # torch.save(state, filename) + self.logger.info("Saving checkpoint: {} ...".format(filename)) + if save_best: + best_path = str(self.checkpoint_dir / 'model_best.pth') + torch.save(state, best_path) + self.logger.info("Saving current best: model_best.pth ...") + + def _resume_checkpoint(self, config): + """ + Resume from saved checkpoints + + :param resume_path: Checkpoint path to be resumed + """ + resume_path = str(config.resume) + self.logger.info("Loading checkpoint: {} ...".format(resume_path)) + checkpoint = torch.load(resume_path) + self.start_epoch = checkpoint['epoch'] + 1 + self.mnt_best = checkpoint['monitor_best'] + + # load architecture params from checkpoint. + if checkpoint['config']['student_model'] != self.config['student_model']: + self.logger.warning("Warning: Architecture configuration given in config file is different from that of " + "checkpoint. This may yield an exception while state_dict is being loaded.") + self.student.load_state_dict(checkpoint['state_dict']) + + # load optimizer state from checkpoint only when optimizer type is not changed. + if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: + self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " + "Optimizer parameters not being resumed.") + else: + self.optimizer.load_state_dict(checkpoint['optimizer']) + + self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..46d3a15 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +from .util import * diff --git a/utils/data/cutout.py b/utils/data/cutout.py new file mode 100644 index 0000000..e9e4d10 --- /dev/null +++ b/utils/data/cutout.py @@ -0,0 +1,52 @@ +import torch +import numpy as np + +class Cutout(object): + """Randomly mask out one or more patches from an image. + Implementation based on the implementation from: + https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py + Args: + n_holes (int): Number of patches to cut out of each image. + length (int): The length (in pixels) of each square patch. + """ + def __init__(self, n_holes = 1, length = 8): + self.n_holes = n_holes + self.length = length + + def __call__(self, img, mask): + """ + Args: + img (Tensor): Tensor image of size (C, H, W). + mask (Tensor): Mask with n_holes of dimension length x length cut out of it. + Returns: + Tensor: Image with applied mask that contains n_holes of dimension length + x length cut out of it. + """ + img = img * mask + return img + + def get_mask(self, img): + """ + Args: + img (Tensor): Tensor image of size (C, H, W). + Returns: + Tensor: Mask with n_holes of dimension length x length cut out of it. + """ + h = img.size(1) + w = img.size(2) + mask = np.ones((h, w), np.float32) + + for n in range(self.n_holes): + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + return mask diff --git a/utils/data/datasets.py b/utils/data/datasets.py new file mode 100644 index 0000000..0319973 --- /dev/null +++ b/utils/data/datasets.py @@ -0,0 +1,290 @@ +from torchvision.datasets import CIFAR10, CIFAR100 +from utils.data.tiny_imagenet.dataset import TinyImageNetDataset +from torchvision.transforms import Compose, RandomCrop, functional as tvtf +from utils.data import degtransforms, degradedimagedata as deg_data +import torch +from utils.data.cutout import Cutout +import numpy as np +from PIL import Image + +class DegCIFAR10Dataset(CIFAR10): + def __init__(self, data_dir, train = True, train_init_transform = None, teacher_transform = None, + student_transform = None, val_transform = None, download = False, deg_type = 'jpeg', + deg_range = None, deg_list = None, is_to_tensor = True, is_target_to_tensor = True, + deg_to_tensor = None, cutout_method = None, cutout_length = None, + cutout_apply_clean = True, cutout_apply_deg = True, cutout_independent = False): + super().__init__(data_dir, train, train_init_transform, download = download) + self.train = train + self.teacher_transform = teacher_transform + self.student_transform = student_transform + self.deg_type = deg_type + self.deg_range = deg_range + if self.deg_range is None: + self.deg_range = deg_data.get_type_range(self.deg_type) + self.deg_list = deg_list + self.is_to_tensor = is_to_tensor + self.is_target_to_tensor = is_target_to_tensor + self.deg_to_tensor = deg_to_tensor + self.cutout_method = cutout_method + self.cutout_length = cutout_length + self.cutout_apply_clean = cutout_apply_clean + self.cutout_apply_deg = cutout_apply_deg + self.cutout_independent = cutout_independent + if cutout_method == 'Cutout' and self.cutout_length is not None: + self.cutout = Cutout(length = cutout_length) + self.epoch = 0 + self.deg_transform = Compose([degtransforms.DegApplyWithLevel(self.deg_type, self.deg_range, self.deg_list)]) + + def __getitem__(self, index): + """ + degradation & tensor are applied. + """ + clean_img, target = super().__getitem__(index) + orig_clean_img = clean_img.copy() + + if self.train: + if self.teacher_transform: + clean_img, _ = self.teacher_transform(clean_img) + if self.student_transform: + clean_img, _ = self.student_transform(orig_clean_img) + deg_img, deg_lev = self.deg_transform(clean_img) if self.deg_type == 'jpeg' \ + else self.deg_transform(np.asarray(clean_img)) + else: + deg_img, deg_lev = self.deg_transform(clean_img) if self.deg_type == 'jpeg' \ + else self.deg_transform(np.asarray(clean_img)) + else: + deg_img, deg_lev = self.deg_transform(clean_img) if self.deg_type == 'jpeg' \ + else self.deg_transform(np.asarray(clean_img)) + if self.deg_type != 'jpeg': + deg_img = Image.fromarray(np.uint8(deg_img.clip(0, 255))) + + if self.train: + # self.transform does not have the RandomCrop in the training process + # Applying RandomCrop + clean_img = tvtf.pad(clean_img, 4, 0, "constant") + deg_img = tvtf.pad(deg_img, 4, 0, "constant") + i, j, h, w = RandomCrop.get_params(clean_img, output_size=(32, 32)) + clean_img = tvtf.crop(clean_img, i, j, h, w) + deg_img = tvtf.crop(deg_img, i, j, h, w) + + if self.is_to_tensor: + tensor_clean_img = self.deg_to_tensor(clean_img) + tensor_deg_img = self.deg_to_tensor(deg_img) + imgs = (tensor_clean_img, tensor_deg_img) + else: + imgs = (clean_img, deg_img) + + # Applying cutout + if self.train and self.cutout_method is not None: + clean_img, deg_img = imgs + if self.cutout_method == 'Cutout': + if self.cutout_independent: + clean_mask = self.cutout.get_mask(clean_img) + deg_mask = self.cutout.get_mask(deg_img) + else: + clean_mask = self.cutout.get_mask(clean_img) + deg_mask = clean_mask + if self.cutout_apply_clean: + clean_img = self.cutout(clean_img, clean_mask) + if self.cutout_apply_deg: + deg_img = self.cutout(deg_img, deg_mask) + + imgs = (clean_img, deg_img) + + if self.is_target_to_tensor: + deg_lev = degtransforms.normalize_level(self.deg_type, deg_lev) + tensor_target = torch.tensor(target) + tensor_deg_lev = deg_lev + targets = (tensor_target, tensor_deg_lev) + else: + targets = (target, deg_lev) + + return imgs, targets + +class DegCIFAR100Dataset(CIFAR100): + def __init__(self, data_dir, train = True, train_init_transform = None, teacher_transform = None, + student_transform = None, val_transform = None, download = False, deg_type = 'jpeg', + deg_range = None, deg_list = None, is_to_tensor = True, is_target_to_tensor = True, + deg_to_tensor = None, cutout_method = None, cutout_length = None, + cutout_apply_clean = True, cutout_apply_deg = True, cutout_independent = False): + super().__init__(data_dir, train, train_init_transform, download = download) + self.train = train + self.teacher_transform = teacher_transform + self.student_transform = student_transform + self.deg_type = deg_type + self.deg_range = deg_range + if self.deg_range is None: + self.deg_range = deg_data.get_type_range(self.deg_type) + self.deg_list = deg_list + self.is_to_tensor = is_to_tensor + self.is_target_to_tensor = is_target_to_tensor + self.deg_to_tensor = deg_to_tensor + self.cutout_method = cutout_method + self.cutout_length = cutout_length + self.cutout_apply_clean = cutout_apply_clean + self.cutout_apply_deg = cutout_apply_deg + self.cutout_independent = cutout_independent + if cutout_method == 'Cutout' and self.cutout_length is not None: + self.cutout = Cutout(length = cutout_length) + self.epoch = 0 + self.deg_transform = Compose([degtransforms.DegApplyWithLevel(self.deg_type, self.deg_range, self.deg_list)]) + + def __getitem__(self, index): + """ + degradation & tensor are applied. + """ + clean_img, target = super().__getitem__(index) + orig_clean_img = clean_img.copy() + if self.train: + if self.teacher_transform: + clean_img, _ = self.teacher_transform(clean_img) + if self.student_transform: + clean_img, _ = self.student_transform(orig_clean_img) + deg_img, deg_lev = self.deg_transform(clean_img) if self.deg_type == 'jpeg' \ + else self.deg_transform(np.asarray(clean_img)) + else: + deg_img, deg_lev = self.deg_transform(clean_img) if self.deg_type == 'jpeg' \ + else self.deg_transform(np.asarray(clean_img)) + else: + deg_img, deg_lev = self.deg_transform(clean_img) if self.deg_type == 'jpeg' \ + else self.deg_transform(np.asarray(clean_img)) + if self.deg_type != 'jpeg': + deg_img = Image.fromarray(np.uint8(deg_img.clip(0, 255))) + + if self.train: + # self.transform does not have the RandomCrop in the training process + # Applying RandomCrop + clean_img = tvtf.pad(clean_img, 4, 0, "constant") + deg_img = tvtf.pad(deg_img, 4, 0, "constant") + i, j, h, w = RandomCrop.get_params(clean_img, output_size=(32, 32)) + clean_img = tvtf.crop(clean_img, i, j, h, w) + deg_img = tvtf.crop(deg_img, i, j, h, w) + + if self.is_to_tensor: + tensor_clean_img = self.deg_to_tensor(clean_img) + tensor_deg_img = self.deg_to_tensor(deg_img) + imgs = (tensor_clean_img, tensor_deg_img) + else: + imgs = (clean_img, deg_img) + + # Applying cutout + if self.train and self.cutout_method is not None: + clean_img, deg_img = imgs + if self.cutout_method == 'Cutout' or self.cutout_method == 'GridMask': + if self.cutout_independent: + clean_mask = self.cutout.get_mask(clean_img) + deg_mask = self.cutout.get_mask(deg_img) + else: + clean_mask = self.cutout.get_mask(clean_img) + deg_mask = clean_mask + if self.cutout_apply_clean: + clean_img = self.cutout(clean_img, clean_mask) + if self.cutout_apply_deg: + deg_img = self.cutout(deg_img, deg_mask) + + imgs = (clean_img, deg_img) + + if self.is_target_to_tensor: + deg_lev = degtransforms.normalize_level(self.deg_type, deg_lev) + tensor_target = torch.tensor(target) + tensor_deg_lev = deg_lev + targets = (tensor_target, tensor_deg_lev) + else: + targets = (target, deg_lev) + + return imgs, targets + +class DegTinyImagenetDataset(TinyImageNetDataset): + def __init__(self, data_dir, train = True, train_init_transform = None, teacher_transform = None, + student_transform = None, val_transform = None, download = False, deg_type = 'jpeg', + deg_range = None, deg_list = None, is_to_tensor = True, is_target_to_tensor = True, + deg_to_tensor = None, cutout_method = None, cutout_length = None, + cutout_apply_clean = True, cutout_apply_deg = True, cutout_independent = False): + super().__init__(data_dir + 'tiny-imagenet-200/', train, train_init_transform, + download = download) + self.train = train + self.teacher_transform = teacher_transform + self.student_transform = student_transform + self.deg_type = deg_type + self.deg_range = deg_range + if self.deg_range is None: + self.deg_range = deg_data.get_type_range(self.deg_type) + self.deg_list = deg_list + self.is_to_tensor = is_to_tensor + self.is_target_to_tensor = is_target_to_tensor + self.deg_to_tensor = deg_to_tensor + self.cutout_method = cutout_method + self.cutout_length = cutout_length + self.cutout_apply_clean = cutout_apply_clean + self.cutout_apply_deg = cutout_apply_deg + self.cutout_independent = cutout_independent + if cutout_method == 'Cutout' and self.cutout_length is not None: + self.cutout = Cutout(length = cutout_length) + self.epoch = 0 + self.deg_transform = Compose([degtransforms.DegApplyWithLevel(self.deg_type, self.deg_range, self.deg_list)]) + + def __getitem__(self, index): + """ + degradation & tensor are applied. + """ + clean_img, target = super().__getitem__(index) + + orig_clean_img = clean_img.copy() + if self.train: + if self.teacher_transform: + clean_img, _ = self.teacher_transform(clean_img) + if self.student_transform: + clean_img, _ = self.student_transform(orig_clean_img) + deg_img, deg_lev = self.deg_transform(clean_img) if self.deg_type == 'jpeg' \ + else self.deg_transform(np.asarray(clean_img)) + else: + deg_img, deg_lev = self.deg_transform(clean_img) if self.deg_type == 'jpeg' \ + else self.deg_transform(np.asarray(clean_img)) + else: + deg_img, deg_lev = self.deg_transform(clean_img) if self.deg_type == 'jpeg' \ + else self.deg_transform(np.asarray(clean_img)) + if self.deg_type != 'jpeg': + deg_img = Image.fromarray(np.uint8(deg_img.clip(0, 255))) + + if self.train: + # self.transform does not have the RandomCrop in the training process + # Applying RandomCrop + clean_img = tvtf.pad(clean_img, 4, 0, "constant") + deg_img = tvtf.pad(deg_img, 4, 0, "constant") + i, j, h, w = RandomCrop.get_params(clean_img, output_size=(64, 64)) + clean_img = tvtf.crop(clean_img, i, j, h, w) + deg_img = tvtf.crop(deg_img, i, j, h, w) + + if self.is_to_tensor: + tensor_clean_img = self.deg_to_tensor(clean_img) + tensor_deg_img = self.deg_to_tensor(deg_img) + imgs = (tensor_clean_img, tensor_deg_img) + else: + imgs = (clean_img, deg_img) + + # Applying cutout + if self.train and self.cutout_method is not None: + clean_img, deg_img = imgs + if self.cutout_method == 'Cutout' or self.cutout_method == 'GridMask': + if self.cutout_independent: + clean_mask = self.cutout.get_mask(clean_img) + deg_mask = self.cutout.get_mask(deg_img) + else: + clean_mask = self.cutout.get_mask(clean_img) + deg_mask = clean_mask + if self.cutout_apply_clean: + clean_img = self.cutout(clean_img, clean_mask) + if self.cutout_apply_deg: + deg_img = self.cutout(deg_img, deg_mask) + + imgs = (clean_img, deg_img) + + if self.is_target_to_tensor: + deg_lev = degtransforms.normalize_level(self.deg_type, deg_lev) + tensor_target = torch.tensor(target) + tensor_deg_lev = deg_lev + targets = (tensor_target, tensor_deg_lev) + else: + targets = (target, deg_lev) + + return imgs, targets diff --git a/utils/data/degradedimagedata.py b/utils/data/degradedimagedata.py new file mode 100644 index 0000000..911ca1d --- /dev/null +++ b/utils/data/degradedimagedata.py @@ -0,0 +1,75 @@ +import numpy as np + +def get_type_range(degtype): + """ + Get degradation type and range + Args: + degtype (string) : jpeg, noise, blur, saltpepper + """ + if degtype == "jpeg": + deg_range = [1, 101] + elif degtype == "noise": + deg_range = [0, 50] + elif degtype == "blur": + deg_range = [0, 50] + elif degtype == 'saltpepper': + deg_range = [0, 25] + else: + raise NotImplementedError + return deg_range + +def get_minmax_normalizedlevel(deg_type): + """ + Min and Max of normalized degradation levels + Args: + deg_type (string) : degradation type + Returns: + normalized degradation level (tuple) + """ + if deg_type == 'jpeg': + ret_adj, max_l, min_l = 100.0, 101.0, 1.0 + elif deg_type == 'noise': + ret_adj, max_l, min_l = 255.0, 50.0, 0.0 + elif deg_type == 'blur': + ret_adj, max_l, min_l = 100.0, 50.0, 0.0 + elif deg_type == 'saltpepper': + ret_adj, max_l, min_l = 100.0, 5.0, 0.0 + else: + ret_adj, max_l, min_l = 1.0, 1.0, 0.0 + + return min_l/ret_adj, max_l/ret_adj + +def fix_seed_noise_sl(is_fixed): + """ + Fix the seed of Gaussian and Binomial distributions + This is only used for evalution purpose. + If you fix the seed, please do not forget to unfix the seed. + Args: + is_fixed (bool) : True if the seed is fixed + """ + if is_fixed: + np.random.seed(seed=301) + else: + np.random.seed(seed=None) + +def get_type_list(degtype): + """ + Get degradation type and range + Args: + degtype (string) : jpeg, noise, blur, saltpepper + """ + if degtype == "jpeg": + deg_type = "jpeg" + deg_list = [10, 30, 50, 70, 90] + elif degtype == "noise": + deg_type = "noise" + deg_list = [np.sqrt(0.05)*255, np.sqrt(0.1)*255, + np.sqrt(0.15)*255, np.sqrt(0.2)*255, np.sqrt(0.25)*255] + elif degtype == "blur": + deg_type = "blur" + deg_list = [10., 20., 30., 40., 50.] + else: + deg_type = "saltpepper" + deg_list = [5., 10., 15., 20., 25.] + + return deg_type, deg_list \ No newline at end of file diff --git a/utils/data/degtransforms.py b/utils/data/degtransforms.py new file mode 100644 index 0000000..9dd0e92 --- /dev/null +++ b/utils/data/degtransforms.py @@ -0,0 +1,168 @@ +import torch +import numpy as np +import random +from PIL import Image, ImageStat, ImageChops +import imagedegrade.im as degrade # https://github.com/mastnk/imagedegrade +import imagedegrade.np as np_degrade # https://github.com/mastnk/imagedegrade + +random.seed(0) + +def jpegcompresswithclean(img, jpeg_quality): + """ + Apply JPEG distortion to clean images + If JPEG quality factor is in [1, 100], JPEG distortion is applied. + If it is not in [1, 100], clean image will be returned. + Args: + img (PIL image) : clean image + jpeg_quality (int) : JPEG quality factor + Returns: + img (PIL image) : JPEG image or clean image + """ + if (jpeg_quality >= 1) and (jpeg_quality <= 100): + ret_img = degrade.jpeg(img, jpeg_quality) + else: + ret_img = img + + return ret_img + +def degradation_function(deg_type): + """ + Get the pointer of a degradation function from imagedegrade.im + Args: + deg_type (string) : degradtion type + Returns: + ret_func (pinter) : the poiter of your selected degradation function + ret_adj (folat) : the adjsutment of degradation level + """ + if deg_type == 'jpeg': + ret_func = jpegcompresswithclean + ret_adj = 1.0 + elif deg_type == 'noise': + ret_func = np_degrade.noise + ret_adj = 1.0 + elif deg_type == 'blur': + ret_func = np_degrade.blur + ret_adj = 10.0 + elif deg_type == 'saltpepper': + ret_func = np_degrade.saltpepper + ret_adj = 100.0 + else: + msg = 'This degradation is not supported.' + raise LookupError(msg) + + return ret_func, ret_adj + +def normalize_level(deg_type, level): + """ + Normaliza degradation levels + Args: + deg_type (string) : degradation type + level (int or float) : degradation level + Returns: + normalized degradation level (float) + """ + if deg_type == 'jpeg': + ret_adj = 100.0 + elif deg_type == 'noise': + ret_adj = 255.0 + elif deg_type == 'blur': + ret_adj = 10.0 + elif deg_type == 'saltpepper': + ret_adj = 1.0 + else: + ret_adj = 1.0 + ret = np.array([float(level)/ret_adj]) + ret.astype(np.float32) + return ret + +def calc_deglev_rescaledimg(deg_type, img_org, resized_img_org, inp_resized_img_org, inp_resized_img_deg, deg_param): + """ + True degration level for rescaled degraded images + Args: + deg_type (string) : degradation type + img_org (PIL images) : clean image + resized_img_org (PIL images) : resized clean image + inp_resized_img_org (PIL images) : resized clean image input into CNN + inp_resized_img_deg (PIL images) : resized degraded image input into CNN + deg_param (float) : degradation parameter used by degradation operator + Returns: + true degradation level (JPEG and Gaussian, S&P noise: RMSE, Gaussina blur: rescaled std) + """ + if deg_type == 'blur': + true_deglev = deg_param / float(img_org.width) * float(resized_img_org.width) # Rescaling the standard deviation + else: + img_diff = ImageChops.difference(inp_resized_img_org, inp_resized_img_deg) + stat_diff = ImageStat.Stat(img_diff) + mse = sum(stat_diff.sum2) / float(len(stat_diff.count)) / float(stat_diff.count[0]) + true_deglev = np.sqrt(mse) # RMSE based on the maximum intensity 255.0 + if deg_type == 'jpeg': # This operation is necessary for the normalize_level function. + true_deglev *= (100.0 / 255.0) + elif deg_type == 'saltpepper': + true_deglev /= 255.0 + + return true_deglev + +class DegApplyWithLevel(torch.nn.Module): + """ + Data augmentation of degradations + This transform returns not only a degraded image but also a degradation level. + """ + def __init__(self, deg_type, deg_range, deg_list): + """ + deg_range or deg_list are not input at the same time. + Args: + deg_type (string) : degradtion type + deg_range (int, int) : range of degradation levels + deg_list (list) : list of degradation levels + """ + super().__init__() + self.deg_type = deg_type + if deg_range is None and deg_list is None: + msg = 'Both deg_range and deg_list do not have values.' + raise TypeError(msg) + elif (deg_range is not None) and (deg_list is not None): + msg = 'deg_range or deg_list have values.' + raise TypeError(msg) + else: + self.deg_range = deg_range + self.deg_list = deg_list + + self.deg_func, self.deg_adj = degradation_function(deg_type) + + def forward(self, img): + """ + Get a degraded image and a degradation level + Args: + img : clean image (PIL image) + Returns: + degraded image (PIL image) : degraded image + deg_lev (float) : degradation level + """ + if self.deg_range is not None: + deg_lev = random.randint(self.deg_range[0], self.deg_range[1]) + if self.deg_adj > 1.0: + deg_lev = deg_lev / self.deg_adj + else: + deg_lev = random.choice(self.deg_list) + if self.deg_adj > 1.0: + deg_lev = deg_lev / self.deg_adj + + return self.deg_func(img, deg_lev), deg_lev + + def __repr__(self): + return self.__class__.__name__ + '(deg_type={}, deg_range=({},{}))'.format(self.deg_type, self.deg_range[0], self.deg_range[1]) + +class DegradationApply(DegApplyWithLevel): + """ + Data augmentation of degradations + """ + def forward(self, img): + """ + Get a degraded image + Args: + img : clean image (PIL image) + Returns: + degraded image (PIL image) : degraded image + """ + deg_img, deg_lev = super().forward(img) + return deg_img diff --git a/utils/data/tiny_imagenet/dataset.py b/utils/data/tiny_imagenet/dataset.py new file mode 100644 index 0000000..28b0e12 --- /dev/null +++ b/utils/data/tiny_imagenet/dataset.py @@ -0,0 +1,210 @@ +import imageio +from PIL import Image +import numpy as np +import os + +from collections import defaultdict +from torch.utils.data import Dataset + +from tqdm.autonotebook import tqdm + +dir_structure_help = r""" +TinyImageNetPath +├── test +│ └── images +│ ├── test_0.JPEG +│ ├── t... +│ └── ... +├── train +│ ├── n01443537 +│ │ ├── images +│ │ │ ├── n01443537_0.JPEG +│ │ │ ├── n... +│ │ │ └── ... +│ │ └── n01443537_boxes.txt +│ ├── n01629819 +│ │ ├── images +│ │ │ ├── n01629819_0.JPEG +│ │ │ ├── n... +│ │ │ └── ... +│ │ └── n01629819_boxes.txt +│ ├── n... +│ │ ├── images +│ │ │ ├── ... +│ │ │ └── ... +├── val +│ ├── images +│ │ ├── val_0.JPEG +│ │ ├── v... +│ │ └── ... +│ └── val_annotations.txt +├── wnids.txt +└── words.txt +""" + +def download_and_unzip(URL, root_dir): + """ + Please download the dataset from here: http://cs231n.stanford.edu/tiny-imagenet-200.zip + """ + error_message = "Download is not yet implemented. Please, go to {URL} urself." + raise NotImplementedError(error_message.format(URL)) + +def _add_channels(img, total_channels=3): + while len(img.shape) < 3: # third axis is the channels + img = np.expand_dims(img, axis=-1) + while(img.shape[-1]) < 3: + img = np.concatenate([img, img[:, :, -1:]], axis=-1) + return img + +"""Creates a paths datastructure for the tiny imagenet. + +Args: + root_dir: Where the data is located + download: Download if the data is not there + +Members: + label_id: + ids: + nit_to_words: + data_dict: + +""" +class TinyImageNetPaths: + def __init__(self, root_dir, download=False): + if download: + download_and_unzip('http://cs231n.stanford.edu/tiny-imagenet-200.zip', + root_dir) + train_path = os.path.join(root_dir, 'train') + val_path = os.path.join(root_dir, 'val') + test_path = os.path.join(root_dir, 'test') + + wnids_path = os.path.join(root_dir, 'wnids.txt') + words_path = os.path.join(root_dir, 'words.txt') + + self._make_paths(train_path, val_path, test_path, + wnids_path, words_path) + + def _make_paths(self, train_path, val_path, test_path, + wnids_path, words_path): + self.ids = [] + with open(wnids_path, 'r') as idf: + for nid in idf: + nid = nid.strip() + self.ids.append(nid) + self.nid_to_words = defaultdict(list) + with open(words_path, 'r') as wf: + for line in wf: + nid, labels = line.split('\t') + labels = list(map(lambda x: x.strip(), labels.split(','))) + self.nid_to_words[nid].extend(labels) + + self.paths = { + 'train': [], # [img_path, id, nid, box] + 'val': [], # [img_path, id, nid, box] + 'test': [] # img_path + } + + # Get the test paths + self.paths['test'] = list(map(lambda x: os.path.join(test_path, x), + os.listdir(test_path))) + # Get the validation paths and labels + with open(os.path.join(val_path, 'val_annotations.txt')) as valf: + for line in valf: + fname, nid, x0, y0, x1, y1 = line.split() + fname = os.path.join(val_path, 'images', fname) + bbox = int(x0), int(y0), int(x1), int(y1) + label_id = self.ids.index(nid) + self.paths['val'].append((fname, label_id, nid, bbox)) + + # Get the training paths + train_nids = os.listdir(train_path) + for nid in train_nids: + anno_path = os.path.join(train_path, nid, nid+'_boxes.txt') + imgs_path = os.path.join(train_path, nid, 'images') + label_id = self.ids.index(nid) + with open(anno_path, 'r') as annof: + for line in annof: + fname, x0, y0, x1, y1 = line.split() + fname = os.path.join(imgs_path, fname) + bbox = int(x0), int(y0), int(x1), int(y1) + self.paths['train'].append((fname, label_id, nid, bbox)) + +"""Datastructure for the tiny image dataset. + +Args: + root_dir: Root directory for the data + mode: One of "train", "test", or "val" + preload: Preload into memory + load_transform: Transformation to use at the preload time + transform: Transformation to use at the retrieval time + download: Download the dataset + +Members: + tinp: Instance of the TinyImageNetPaths + img_data: Image data + label_data: Label data +""" +class TinyImageNetDataset(Dataset): + def __init__(self, root_dir, train = True, transform=None, download=False, + preload=True, load_transform=None, max_samples=None): + tinp = TinyImageNetPaths(root_dir, download) + if train: + mode = 'train' + else: + mode = 'val' + self.mode = mode + self.label_idx = 1 # from [image, id, nid, box] + self.preload = preload + self.transform = transform + self.transform_results = dict() + + self.IMAGE_SHAPE = (64, 64, 3) + + self.img_data = [] + self.label_data = [] + + self.samples = tinp.paths[mode] + self.samples_num = len(self.samples) + + if self.preload: + load_desc = "Preloading {} data...".format(mode) + self.img_data = [] + self.label_data = [] + for idx in range(self.samples_num): + s = self.samples[idx] + img = Image.open(s[0]) + # Convert to RGB from L, values doesn't change + if img.mode == 'L': + img = img.convert('RGB') + + self.img_data.append(img) + self.label_data.append(s[self.label_idx]) + + if load_transform: + for lt in load_transform: + result = lt(self.img_data, self.label_data) + self.img_data, self.label_data = result[:2] + if len(result) > 2: + self.transform_results.update(result[2]) + + def __len__(self): + return self.samples_num + + def __getitem__(self, idx): + if self.preload: + img = self.img_data[idx] + lbl = self.label_data[idx] + + else: + s = self.samples[idx] + img = Image.open(s[0]) + # Convert to RGB from L, values doesn't change + if img.mode == 'L': + img = img.convert('RGB') + + lbl = (s[self.label_idx]) + + if self.transform: + img = self.transform(img) + + return img, lbl \ No newline at end of file diff --git a/utils/model_soups.py b/utils/model_soups.py new file mode 100644 index 0000000..9b2d638 --- /dev/null +++ b/utils/model_soups.py @@ -0,0 +1,80 @@ +import os +import sys +# Adding the parent directory to the sys path, fix so that this file can be run from utils dir. +script_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.abspath(os.path.join(script_dir, '..')) +sys.path.insert(1, parent_dir) + +import argparse +import torch +import copy +from parse_config import ConfigParser +from utils import read_yaml, write_yaml +from datetime import datetime + +def model_fusion(dataset='CIFAR10'): + """ + This function is used to perform the fusion of several fine-tuned individual degradation models, + i.e., Step-3 of our proposed method. + + Example usage: python utils/model_soups.py --dataset CIFAR10 + It will generate a combined model in the saved/combined_deg/SLTrainer/ResNet56-56_CIFAR10_soups/train/ directory. + """ + # Load the configuration file + config_file = 'configs/deg_all/{}/ResNet56_soups.yaml'.format(dataset.lower()) + config = ConfigParser(read_yaml(config_file), dry_run = True) + + # Get the model from the configuration + model = config.get_class('model') + + # Extract the pretrained model paths from the configuration + model_paths = [] + for key, value in config['model'].items(): + if key.startswith('pretrained_path'): + model_paths.append(value) + + # Load the checkpoints of the pretrained individual degradation models + checkpoints = [torch.load(path) for path in model_paths] + + # Initialize the models and load their state + models_all = [copy.deepcopy(model) for _ in range(len(checkpoints))] + for i, model in enumerate(models_all): + model.load_state_dict(checkpoints[i]['state_dict']) + + # Combine the weights of the model + combined_model = None + global_count = 0 + for model in models_all: + if combined_model is None: + combined_model = copy.deepcopy(model) + else: + for param_q, param_k in zip(model.parameters(), combined_model.parameters()): + param_k.data = (param_k.data * global_count + param_q.data) / (1. + global_count) + global_count += 1 + + # Prepare the checkpoint directory and model state + run_id = datetime.now().strftime(r'%m%d_%H%M%S') + checkpoint_dir = 'saved/combined_deg/SLTrainer/ResNet56-56_{}_soups/train/{}/'.format(dataset, run_id) + model_name = type(combined_model).__name__ + config.config['name'] = config.config['name'] + '_soups' + state = { + 'model': model_name, + 'state_dict': combined_model.state_dict(), + 'config': config + } + + # Save the combined model and the configuration + model_path = checkpoint_dir + 'model_best.pth' + os.makedirs(checkpoint_dir, exist_ok=True) + write_yaml(config.config, checkpoint_dir + 'config.yaml') + torch.save(state, model_path) + print('saved combined model:', model_path) + +if __name__ == "__main__": + # Parse command line arguments + parser = argparse.ArgumentParser(description='Combine Models') + parser.add_argument('--dataset', type=str, default='CIFAR10', help='Dataset to use such as CIFAR100 or TinyImagenet') + args = parser.parse_args() + + # Call the model fusion function + model_fusion(args.dataset) diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000..d307eec --- /dev/null +++ b/utils/util.py @@ -0,0 +1,111 @@ +import json +import yaml +import io +import torch +import pandas as pd +from pathlib import Path +from itertools import repeat +from collections import OrderedDict +import numpy as np +import random + +def ensure_dir(dirname): + dirname = Path(dirname) + if not dirname.is_dir(): + dirname.mkdir(parents=True, exist_ok=False) + +def read_json(fname): + fname = Path(fname) + with fname.open('rt') as handle: + return json.load(handle, object_hook=OrderedDict) + +def write_json(content, fname): + fname = Path(fname) + with fname.open('wt') as handle: + json.dump(content, handle, indent=4, sort_keys=False) + +def read_yaml(fname): + fname = Path(fname) + with open(fname, 'r') as stream: + return yaml.safe_load(stream) + +def write_yaml(content, fname): + fname = Path(fname) + with io.open(fname, 'w', encoding='utf8') as outfile: + yaml.dump(content, outfile, default_flow_style=False, sort_keys=False, + allow_unicode=True, indent=2) + +def print_dict(content): + return yaml.dump(content, indent=2, sort_keys=False) + +def inf_loop(data_loader): + ''' wrapper function for endless data loader. ''' + for loader in repeat(data_loader): + yield from loader + +def import_class(name): + components = name.split('.') + mod = __import__(components[0]) # import return model + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + +def prepare_device(n_gpu_use): + """ + setup GPU device if available. get gpu device indices which are used for DataParallel + """ + n_gpu = torch.cuda.device_count() + if n_gpu_use > 0 and n_gpu == 0: + print("Warning: There\'s no GPU available on this machine," + "training will be performed on CPU.") + n_gpu_use = 0 + if n_gpu_use > n_gpu: + print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are " + "available on this machine.") + n_gpu_use = n_gpu + device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') + list_ids = list(range(n_gpu_use)) + return device, list_ids + +def set_seeds(SEED = 123): + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) # New here <-- + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + np.random.seed(SEED) + random.seed(SEED) + +def set_seeds_prev(SEED = 123): + torch.manual_seed(SEED) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(SEED) + random.seed(SEED) + +def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + +class MetricTracker: + def __init__(self, *keys, writer=None): + self.writer = writer + self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) + self.reset() + + def reset(self): + for col in self._data.columns: + self._data[col].values[:] = 0 + + def update(self, key, value, n=1): + if self.writer is not None: + self.writer.add_scalar(key, value) + self._data.total[key] += value * n + self._data.counts[key] += n + self._data.average[key] = self._data.total[key] / self._data.counts[key] + + def avg(self, key): + return self._data.average[key] + + def result(self): + return dict(self._data.average)