diff --git a/.github/workflows/deploy_pages.yml b/.github/workflows/deploy_pages.yml
index 4101931a..d311f831 100644
--- a/.github/workflows/deploy_pages.yml
+++ b/.github/workflows/deploy_pages.yml
@@ -39,7 +39,7 @@ jobs:
python-version: "3.10"
- name: Install dependencies
run: |
- pip install -r docs/requirements.txt -r requirements.txt
+ pip install -r docs/requirements.txt -r requirements.txt -r requirements-optional.txt
pip install quadprog==0.1.11
- name: Sphinx build
run: |
diff --git a/.gitignore b/.gitignore
index e8b4020f..a1cb24d1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -95,4 +95,9 @@ logs
**/_build/
_autosummary
generated
-val_permutations
\ No newline at end of file
+val_permutations
+
+# Other prepare grid scripts except the example one
+scripts/prepare_grid*
+
+docs/models/*_args.rst
\ No newline at end of file
diff --git a/README.md b/README.md
index 401e133d..affe4d7c 100644
--- a/README.md
+++ b/README.md
@@ -4,10 +4,12 @@
# Mammoth - An Extendible (General) Continual Learning Framework for Pytorch
-Official repository of [Class-Incremental Continual Learning into the eXtended DER-verse](https://arxiv.org/abs/2201.00766) and [Dark Experience for General Continual Learning: a Strong, Simple Baseline](https://papers.nips.cc/paper/2020/hash/b704ea2c39778f07c617f6b7ce480e9e-Abstract.html)
+Official repository of [Class-Incremental Continual Learning into the eXtended DER-verse](https://arxiv.org/abs/2201.00766), [Dark Experience for General Continual Learning: a Strong, Simple Baseline](https://papers.nips.cc/paper/2020/hash/b704ea2c39778f07c617f6b7ce480e9e-Abstract.html), and [Semantic Residual Prompts for Continual Learning](https://arxiv.org/abs/2403.06870)
-Mammoth is a framework for continual learning research. It is designed to be modular, easy to extend, and - most importantly - _easy to debug_.
-Idelly, all the code necessary to run the experiments is included _in the repository_, without needing to check out other repositories or install additional packages.
+Mammoth is a framework for continual learning research. With **40 methods and 21 datasets**, it includes the most complete list competitors and benchmarks for research purposes.
+
+The core idea of Mammoth is that it is designed to be modular, easy to extend, and - most importantly - _easy to debug_.
+Ideally, all the code necessary to run the experiments is included _in the repository_, without needing to check out other repositories or install additional packages.
With Mammoth, nothing is set in stone. You can easily add new models, datasets, training strategies, or functionalities.
@@ -28,13 +30,15 @@ Join our Discord Server for all your Mammoth-related questions → ![Discord Shi
## Setup
+- Install with `pip install -r requirements.txt`. NOTE: Pytorch version >= 2.1.0 is required for scaled_dot_product_attention (see: https://github.com/Lightning-AI/litgpt/issues/763). If you cannot support this requirement, uncomment the lines 136-139 under `scaled_dot_product_attention` in `backbone/vit.py`.
- Use `./utils/main.py` to run experiments.
-- Use argument `--load_best_args` to use the best hyperparameters from the paper.
- New models can be added to the `models/` folder.
- New datasets can be added to the `datasets/` folder.
## Models
+Mammoth currently supports **42** models, with new releases covering the main competitors in literature.
+
- Efficient Lifelong Learning with A-GEM (A-GEM, A-GEM-R - A-GEM with reservoir buffer): `agem`, `agem_r`
- Bias Correction (BiC): `bic`.
- Continual Contrastive Interpolation Consistency (CCIC) - _Requires_ `pip install kornia`: `ccic`.
@@ -62,26 +66,40 @@ Join our Discord Server for all your Mammoth-related questions → ![Discord Shi
- SLCA: Slow Learner with Classifier Alignment for Continual Learning on a Pre-trained Model (SLCA) - _Requires_ `pip install timm==0.9.8`: `slca`.
- Transfer without Forgetting (TwF): `twf`.
- eXtended-DER (X-DER): `xder` (full version), `xder_ce` (X-DER with CE), `xder_rpc` (X-DER with RPC).
+- AttriCLIP: `attriclip`.
+- Slow Learner with Classifier Alignment (SLCA): `slca`.
+- Continual Generative training for Incremental prompt-Learning (CGIL): `cgil`
+- Semantic Two-level Additive Residual Prompt (STAR-Prompt): `starprompt`. Also includes the first-stage only (`first_stage_starprompt`) and second-stage only (`second_stage_starprompt`) versions.
## Datasets
-**NOTE**: Datasets are automatically downloaded in the `data/`.
+**NOTE**: Datasets are automatically downloaded in `data/`.
+- This can be changes by changing the `base_path` function in `utils/conf.py` or using the `--base_path` argument.
+- The `data/` folder should not tracked by git and is craeted automatically if missing.
-- This can be changes by changing the `base_path` function in `utils/conf.py`.
-- The `data/` folder is not tracked by git and is craeted automatically if missing.
+Mammoth includes **21** datasets, covering *toy classification problems* (different versions of MNIST), *standard domains* (CIFAR, Imagenet-R, TinyImagenet, MIT-67), *fine-grained classification domains* (Cars-196, CUB-200), *aerial domains* (EuroSAT-RGB, Resisc45), *medical domains* (CropDisease, ISIC, ChestX).
- Sequential MNIST (_Class-Il / Task-IL_): `seq-mnist`.
+- Permuted MNIST (_Domain-IL_): `perm-mnist`.
+- Rotated MNIST (_Domain-IL_): `rot-mnist`.
+- MNIST-360 (_General Continual Learning_): `mnist-360`.
- Sequential CIFAR-10 (_Class-Il / Task-IL_): `seq-cifar10`.
+- Sequential CIFAR-10 resized 224x224 (ViT version) (_Class-Il / Task-IL_): `seq-cifar10-224`.
+- Sequential CIFAR-10 resized 224x224 (ResNet50 version) (_Class-Il / Task-IL_): `seq-cifar10-224-rs`.
- Sequential Tiny ImageNet (_Class-Il / Task-IL_): `seq-tinyimg`.
- Sequential Tiny ImageNet resized 32x32 (_Class-Il / Task-IL_): `seq-tinyimg-r`.
- Sequential CIFAR-100 (_Class-Il / Task-IL_): `seq-cifar100`.
- Sequential CIFAR-100 resized 224x224 (ViT version) (_Class-Il / Task-IL_): `seq-cifar100-224`.
- Sequential CIFAR-100 resized 224x224 (ResNet50 version) (_Class-Il / Task-IL_): `seq-cifar100-224-rs`.
-- Permuted MNIST (_Domain-IL_): `perm-mnist`.
-- Rotated MNIST (_Domain-IL_): `rot-mnist`.
-- MNIST-360 (_General Continual Learning_): `mnist-360`.
- Sequential CUB-200 (_Class-Il / Task-IL_): `seq-cub200`.
- Sequential ImageNet-R (_Class-Il / Task-IL_): `seq-imagenet-r`.
+- Sequential Cars-196 (_Class-Il / Task-IL_): `seq-cars196`.
+- Sequential RESISC45 (_Class-Il / Task-IL_): `seq-resisc45`.
+- Sequential EuroSAT-RGB (_Class-Il / Task-IL_): `seq-eurosat-rgb`.
+- Sequential ISIC (_Class-Il / Task-IL_): `seq-isic`.
+- Sequential ChestX (_Class-Il / Task-IL_): `seq-chestx`.
+- Sequential MIT-67 (_Class-Il / Task-IL_): `seq-mit67`.
+- Sequential CropDisease (_Class-Il / Task-IL_): `seq-cropdisease`.
## Pretrained backbones
@@ -117,44 +135,161 @@ Join our Discord Server for all your Mammoth-related questions → ![Discord Shi
### Our Papers
-- Dark Experience for General Continual Learning: a Strong, Simple Baseline (**NeurIPS 2020**) [[paper](https://arxiv.org/abs/2004.07211)]
-- Rethinking Experience Replay: a Bag of Tricks for Continual Learning (**ICPR 2020**) [[paper](https://arxiv.org/abs/2010.05595)] [[code](https://github.com/hastings24/rethinking_er)]
-- Class-Incremental Continual Learning into the eXtended DER-verse (**TPAMI 2022**) [[paper](https://arxiv.org/abs/2201.00766)]
-- Effects of Auxiliary Knowledge on Continual Learning (**ICPR 2022**) [[paper](https://arxiv.org/abs/2206.02577)]
-- Transfer without Forgetting (**ECCV 2022**) [[paper](https://arxiv.org/abs/2206.00388)][[code](https://github.com/mbosc/twf)]
-- Continual semi-supervised learning through contrastive interpolation consistency (**PRL 2022**) [[paper](https://arxiv.org/abs/2108.06552)][[code](https://github.com/aimagelab/CSSL)]
-- On the Effectiveness of Lipschitz-Driven Rehearsal in Continual Learning (**NeurIPS 2022**) [[paper](https://arxiv.org/abs/2210.06443)] [[code](https://github.com/aimagelab/lider)]
+
+Dark Experience for General Continual Learning: a Strong, Simple Baseline (NeurIPS 2020) paper
+
+@inproceedings{buzzega2020dark,
+ author = {Buzzega, Pietro and Boschini, Matteo and Porrello, Angelo and Abati, Davide and Calderara, Simone},
+ booktitle = {Advances in Neural Information Processing Systems},
+ editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
+ pages = {15920--15930},
+ publisher = {Curran Associates, Inc.},
+ title = {Dark Experience for General Continual Learning: a Strong, Simple Baseline},
+ volume = {33},
+ year = {2020}
+}
+
+
+
+Rethinking Experience Replay: a Bag of Tricks for Continual Learning (ICPR 2020) paper code
+
+@inproceedings{buzzega2021rethinking,
+ title={Rethinking experience replay: a bag of tricks for continual learning},
+ author={Buzzega, Pietro and Boschini, Matteo and Porrello, Angelo and Calderara, Simone},
+ booktitle={25th International Conference on Pattern Recognition},
+ pages={2180--2187},
+ year={2021},
+ organization={IEEE}
+}
+
+
+Class-Incremental Continual Learning into the eXtended DER-verse (TPAMI 2022) paper
+
+@article{boschini2022class,
+ title={Class-Incremental Continual Learning into the eXtended DER-verse},
+ author={Boschini, Matteo and Bonicelli, Lorenzo and Buzzega, Pietro and Porrello, Angelo and Calderara, Simone},
+ journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
+ year={2022},
+ publisher={IEEE}
+}
+
+
+Effects of Auxiliary Knowledge on Continual Learning (ICPR 2022) paper
+
+@inproceedings{bellitto2022effects,
+ title={Effects of auxiliary knowledge on continual learning},
+ author={Bellitto, Giovanni and Pennisi, Matteo and Palazzo, Simone and Bonicelli, Lorenzo and Boschini, Matteo and Calderara, Simone},
+ booktitle={26th International Conference on Pattern Recognition},
+ pages={1357--1363},
+ year={2022},
+ organization={IEEE}
+}
+
+
+Transfer without Forgetting (ECCV 2022) paper code (Also available here)
+
+@inproceedings{boschini2022transfer,
+ title={Transfer without forgetting},
+ author={Boschini, Matteo and Bonicelli, Lorenzo and Porrello, Angelo and Bellitto, Giovanni and Pennisi, Matteo and Palazzo, Simone and Spampinato, Concetto and Calderara, Simone},
+ booktitle={17th European Conference on Computer Vision},
+ pages={692--709},
+ year={2022},
+ organization={Springer}
+}
+
+
+Continual semi-supervised learning through contrastive interpolation consistency (PRL 2022) paper code (Also available here)
+
+@article{boschini2022continual,
+ title={Continual semi-supervised learning through contrastive interpolation consistency},
+ author={Boschini, Matteo and Buzzega, Pietro and Bonicelli, Lorenzo and Porrello, Angelo and Calderara, Simone},
+ journal={Pattern Recognition Letters},
+ volume={162},
+ pages={9--14},
+ year={2022},
+ publisher={Elsevier}
+}
+
+
+On the Effectiveness of Lipschitz-Driven Rehearsal in Continual Learning (NeurIPS 2022) paper code (Also available here)
+
+@article{bonicelli2022effectiveness,
+ title={On the effectiveness of lipschitz-driven rehearsal in continual learning},
+ author={Bonicelli, Lorenzo and Boschini, Matteo and Porrello, Angelo and Spampinato, Concetto and Calderara, Simone},
+ journal={Advances in Neural Information Processing Systems},
+ volume={35},
+ pages={31886--31901},
+ year={2022}
+}
+
+
+Mask and Compress: Efficient Skeleton-based Action Recognition in Continual Learning (ICPR 2024) paper code
+
+@inproceedings{mosconi2024mask,
+ title={Mask and Compress: Efficient Skeleton-based Action Recognition in Continual Learning},
+ author={Mosconi, Matteo and Sorokin, Andriy and Panariello, Aniello and Porrello, Angelo and Bonato, Jacopo and Cotogni, Marco and Sabetta, Luigi and Calderara, Simone and Cucchiara, Rita},
+ booktitle={International Conference on Pattern Recognition},
+ year={2024}
+}
+
+
+Semantic Residual Prompts for Continual Learning (ECCV 2024) paper
+
+@inproceedings{menabue2024semantic,
+ title={Semantic Residual Prompts for Continual Learning},
+ author={Menabue, Martin and Frascaroli, Emanuele and Boschini, Matteo and Sangineto, Enver and Bonicelli, Lorenzo and Porrello, Angelo and Calderara, Simone},
+ booktitle={18th European Conference on Computer Vision},
+ year={202},
+ organization={Springer}
+}
+
+
+CLIP with Generative Latent Replay: a Strong Baseline for Incremental Learning (BMVC 2024) paper
+
+@inproceedings{heng2022enhancing,
+ title={CLIP with Generative Latent Replay: a Strong Baseline for Incremental Learning},
+ author={Frascaroli, Emanuele and Panariello, Aniello and Buzzega, Pietro and Bonicelli, Lorenzo and Porrello, Angelo and Calderara, Simone},
+ booktitle={35th British Machine Vision Conference},
+ year={2024}
+}
+
+
+
+
### Other Awesome CL works using Mammoth
**_Get in touch if we missed your awesome work!_**
-- Prediction Error-based Classification for Class-Incremental Learning (**ICLR2024**) [[paper](https://arxiv.org/pdf/2305.18806)] [[code](https://github.com/michalzajac-ml/pec)]
-- TriRE: A Multi-Mechanism Learning Paradigm for Continual Knowledge Retention and Promotion (**NeurIPS2023**) [[paper](https://arxiv.org/pdf/2310.08217.pdf)] [[code](https://github.com/NeurAI-Lab/TriRE)]
-- Overcoming Recency Bias of Normalization Statistics in Continual Learning: Balance and Adaptation (**NeurIPS2023**) [[paper](https://arxiv.org/pdf/2310.08855.pdf)] [[code](https://github.com/lvyilin/AdaB2N)]
-- A Unified and General Framework for Continual Learning (**ICLR2024**) [[paper](https://arxiv.org/pdf/2403.13249.pdf)] [[code](https://github.com/joey-wang123/CL-refresh-learning)]
-- Decoupling Learning and Remembering: a Bilevel Memory Framework with Knowledge Projection for Task-Incremental Learning (**CVPR2023**) [[paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Sun_Decoupling_Learning_and_Remembering_A_Bilevel_Memory_Framework_With_Knowledge_CVPR_2023_paper.pdf)] [[code](https://github.com/SunWenJu123/BMKP)]
-- Regularizing Second-Order Influences for Continual Learning (**CVPR2023**) [[paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Sun_Regularizing_Second-Order_Influences_for_Continual_Learning_CVPR_2023_paper.pdf)] [[code](https://github.com/feifeiobama/InfluenceCL)]
-- Sparse Coding in a Dual Memory System for Lifelong Learning (**CVPR2023**) [[paper](https://arxiv.org/pdf/2301.05058.pdf)] [[code](https://github.com/NeurAI-Lab/SCoMMER)]
-- A Unified Approach to Domain Incremental Learning with Memory: Theory and Algorithm (**CVPR2023**) [[paper](https://arxiv.org/pdf/2310.12244.pdf)] [[code](https://github.com/Wang-ML-Lab/unified-continual-learning)]
-- A Multi-Head Model for Continual Learning via Out-of-Distribution Replay (**CVPR2023**) [[paper](https://arxiv.org/pdf/2208.09734.pdf)] [[code](https://github.com/k-gyuhak/MORE)]
-- Preserving Linear Separability in Continual Learning by Backward Feature Projection (**CVPR2023**) [[paper](https://arxiv.org/pdf/2303.14595.pdf)] [[code](https://github.com/rvl-lab-utoronto/BFP)]
-- Complementary Calibration: Boosting General Continual Learning With Collaborative Distillation and Self-Supervision (**TIP2023**) [[paper](https://ieeexplore.ieee.org/document/10002397)] [[code](https://github.com/lijincm/CoCa)]
-- Continual Learning by Modeling Intra-Class Variation (**TMLR2023**) [[paper](https://arxiv.org/abs/2210.05398)] [[code](https://github.com/yulonghui/MOCA)]
-- ConSlide: Asynchronous Hierarchical Interaction Transformer with Breakup-Reorganize Rehearsal for Continual Whole Slide Image Analysis (**ICCV2023**) [[paper](https://openaccess.thecvf.com/content/ICCV2023/papers/Huang_ConSlide_Asynchronous_Hierarchical_Interaction_Transformer_with_Breakup-Reorganize_Rehearsal_for_Continual_ICCV_2023_paper.pdf)] [[code](https://github.com/HKU-MedAI/ConSlide)]
-- CBA: Improving Online Continual Learning via Continual Bias Adaptor (**ICCV2023**) [[paper](https://arxiv.org/pdf/2308.06925.pdf)] [[code](https://github.com/wqza/CBA-online-CL)]
-- Neuro-Symbolic Continual Learning: Knowledge, Reasoning Shortcuts and Concept Rehearsal (**ICML2023**) [[paper](https://arxiv.org/pdf/2302.01242.pdf)] [[code](https://github.com/ema-marconato/NeSy-CL)]
-- Learnability and Algorithm for Continual Learning (**ICML2023**) [[paper](https://arxiv.org/pdf/2306.12646.pdf)] [[code](https://github.com/k-gyuhak/CLOOD)]
-- Pretrained Language Model in Continual Learning: a Comparative Study (**ICLR2022**) [[paper](https://openreview.net/pdf?id=figzpGMrdD)] [[code](https://github.com/wutong8023/PLM4CL)]
-- Representational continuity for unsupervised continual learning (**ICLR2022**) [[paper](https://openreview.net/pdf?id=9Hrka5PA7LW)] [[code](https://github.com/divyam3897/UCL)]
-- Continual Normalization: Rethinking Batch Normalization for Online Continual Learning (**ICLR2022**) [[paper](https://arxiv.org/abs/2203.16102)] [[code](https://github.com/phquang/Continual-Normalization)]
-- Learning Fast, Learning Slow: A General Continual Learning Method based on Complementary Learning System (**ICLR2022**) [[paper](https://arxiv.org/pdf/2201.12604.pdf)] [[code](https://github.com/NeurAI-Lab/CLS-ER)]
-- New Insights on Reducing Abrupt Representation Change in Online Continual Learning (**ICLR2022**) [[paper](https://openreview.net/pdf?id=N8MaByOzUfb)] [[code](https://github.com/pclucas14/AML)]
-- Looking Back on Learned Experiences for Class/Task Incremental Learning (**ICLR2022**) [[paper](https://openreview.net/pdf?id=RxplU3vmBx)] [[code](https://github.com/MozhganPourKeshavarz/Cost-Free-Incremental-Learning)]
-- Task Agnostic Representation Consolidation: a Self-supervised based Continual Learning Approach (**CoLLAs2022**) [[paper](https://arxiv.org/pdf/2207.06267.pdf)] [[code](https://github.com/NeurAI-Lab/TARC)]
-- Consistency is the key to further Mitigating Catastrophic Forgetting in Continual Learning (**CoLLAs2022**) [[paper](https://arxiv.org/pdf/2207.04998.pdf)] [[code](https://github.com/NeurAI-Lab/ConsistencyCL)]
-- Self-supervised models are continual learners (**CVPR2022**) [[paper](https://arxiv.org/abs/2112.04215)] [[code](https://github.com/DonkeyShot21/cassle)]
-- Learning from Students: Online Contrastive Distillation Network for General Continual Learning (**IJCAI2022**) [[paper](https://www.ijcai.org/proceedings/2022/0446.pdf)] [[code](https://github.com/lijincm/OCD-Net)]
+- Gradual Divergence for Seamless Adaptation: A Novel Domain Incremental Learning Method (**ICML 2024**) [[paper](https://arxiv.org/abs/2305.04769)] [[code](https://github.com/NeurAI-Lab/DARE)]
+- AGILE - Mitigating Interference in Incremental Learning through Attention-Guided Rehearsal (**CoLLAs 2024**) [[paper](https://arxiv.org/abs/2405.13978)] [[code](https://github.com/NeurAI-Lab/AGILE)]
+- Interactive Continual Learning (ICL) (**CVPR 2024**) [[paper](https://arxiv.org/abs/2403.02628)] [[code](https://github.com/Biqing-Qi/Interactive-continual-Learning-Fast-and-Slow-Thinking)]
+- Prediction Error-based Classification for Class-Incremental Learning (**ICLR 2024**) [[paper](https://arxiv.org/abs/2305.18806)] [[code](https://github.com/michalzajac-ml/pec)]
+- TriRE: A Multi-Mechanism Learning Paradigm for Continual Knowledge Retention and Promotion (**NeurIPS 2023**) [[paper](https://arxiv.org/abs/2310.08217)] [[code](https://github.com/NeurAI-Lab/TriRE)]
+- Overcoming Recency Bias of Normalization Statistics in Continual Learning: Balance and Adaptation (**NeurIPS 2023**) [[paper](https://arxiv.org/abs/2310.08855)] [[code](https://github.com/lvyilin/AdaB2N)]
+- A Unified and General Framework for Continual Learning (**ICLR 2024**) [[paper](https://arxiv.org/abs/2403.13249)] [[code](https://github.com/joey-wang123/CL-refresh-learning)]
+- Decoupling Learning and Remembering: a Bilevel Memory Framework with Knowledge Projection for Task-Incremental Learning (**CVPR 2023**) [[paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Sun_Decoupling_Learning_and_Remembering_A_Bilevel_Memory_Framework_With_Knowledge_CVPR_2023_paper.pdf)] [[code](https://github.com/SunWenJu123/BMKP)]
+- Regularizing Second-Order Influences for Continual Learning (**CVPR 2023**) [[paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Sun_Regularizing_Second-Order_Influences_for_Continual_Learning_CVPR_2023_paper.pdf)] [[code](https://github.com/feifeiobama/InfluenceCL)]
+- Sparse Coding in a Dual Memory System for Lifelong Learning (**CVPR 2023**) [[paper](https://arxiv.org/abs/2301.05058)] [[code](https://github.com/NeurAI-Lab/SCoMMER)]
+- A Unified Approach to Domain Incremental Learning with Memory: Theory and Algorithm (**CVPR 2023**) [[paper](https://arxiv.org/abs/2310.12244)] [[code](https://github.com/Wang-ML-Lab/unified-continual-learning)]
+- A Multi-Head Model for Continual Learning via Out-of-Distribution Replay (**CVPR 2023**) [[paper](https://arxiv.org/abs/2208.09734)] [[code](https://github.com/k-gyuhak/MORE)]
+- Preserving Linear Separability in Continual Learning by Backward Feature Projection (**CVPR 2023**) [[paper](https://arxiv.org/abs/2303.14595)] [[code](https://github.com/rvl-lab-utoronto/BFP)]
+- Complementary Calibration: Boosting General Continual Learning With Collaborative Distillation and Self-Supervision (**TIP 2023**) [[paper](https://ieeexplore.ieee.org/document/10002397)] [[code](https://github.com/lijincm/CoCa)]
+- Continual Learning by Modeling Intra-Class Variation (**TMLR 2023**) [[paper](https://arxiv.org/abs/2210.05398)] [[code](https://github.com/yulonghui/MOCA)]
+- ConSlide: Asynchronous Hierarchical Interaction Transformer with Breakup-Reorganize Rehearsal for Continual Whole Slide Image Analysis (**ICCV 2023**) [[paper](https://openaccess.thecvf.com/content/ICCV2023/papers/Huang_ConSlide_Asynchronous_Hierarchical_Interaction_Transformer_with_Breakup-Reorganize_Rehearsal_for_Continual_ICCV_2023_paper.pdf)] [[code](https://github.com/HKU-MedAI/ConSlide)]
+- CBA: Improving Online Continual Learning via Continual Bias Adaptor (**ICCV 2023**) [[paper](https://arxiv.org/abs/2308.06925)] [[code](https://github.com/wqza/CBA-online-CL)]
+- Neuro-Symbolic Continual Learning: Knowledge, Reasoning Shortcuts and Concept Rehearsal (**ICML 2023**) [[paper](https://arxiv.org/abs/2302.01242)] [[code](https://github.com/ema-marconato/NeSy-CL)]
+- Learnability and Algorithm for Continual Learning (**ICML 2023**) [[paper](https://arxiv.org/abs/2306.12646)] [[code](https://github.com/k-gyuhak/CLOOD)]
+- Pretrained Language Model in Continual Learning: a Comparative Study (**ICLR 2022**) [[paper](https://openreview.net/pdf?id=figzpGMrdD)] [[code](https://github.com/wutong8023/PLM4CL)]
+- Representational continuity for unsupervised continual learning (**ICLR 2022**) [[paper](https://openreview.net/pdf?id=9Hrka5PA7LW)] [[code](https://github.com/divyam3897/UCL)]
+- Continual Normalization: Rethinking Batch Normalization for Online Continual Learning (**ICLR 2022**) [[paper](https://arxiv.org/abs/2203.16102)] [[code](https://github.com/phquang/Continual-Normalization)]
+- Learning Fast, Learning Slow: A General Continual Learning Method based on Complementary Learning System (**ICLR 2022**) [[paper](https://arxiv.org/abs/2201.12604)] [[code](https://github.com/NeurAI-Lab/CLS-ER)]
+- New Insights on Reducing Abrupt Representation Change in Online Continual Learning (**ICLR 2022**) [[paper](https://openreview.net/pdf?id=N8MaByOzUfb)] [[code](https://github.com/pclucas14/AML)]
+- Looking Back on Learned Experiences for Class/Task Incremental Learning (**ICLR 2022**) [[paper](https://openreview.net/pdf?id=RxplU3vmBx)] [[code](https://github.com/MozhganPourKeshavarz/Cost-Free-Incremental-Learning)]
+- Task Agnostic Representation Consolidation: a Self-supervised based Continual Learning Approach (**CoLLAs 2022**) [[paper](https://arxiv.org/abs/2207.06267)] [[code](https://github.com/NeurAI-Lab/TARC)]
+- Consistency is the key to further Mitigating Catastrophic Forgetting in Continual Learning (**CoLLAs 2022**) [[paper](https://arxiv.org/abs/2207.04998)] [[code](https://github.com/NeurAI-Lab/ConsistencyCL)]
+- Self-supervised models are continual learners (**CVPR 2022**) [[paper](https://arxiv.org/abs/2112.04215)] [[code](https://github.com/DonkeyShot21/cassle)]
+- Learning from Students: Online Contrastive Distillation Network for General Continual Learning (**IJCAI 2022**) [[paper](https://www.ijcai.org/proceedings/2022/0446)] [[code](https://github.com/lijincm/OCD-Net)]
### Contributing
diff --git a/backbone/MNISTMLP.py b/backbone/MNISTMLP.py
index d39bfb3b..a9ac9da3 100644
--- a/backbone/MNISTMLP.py
+++ b/backbone/MNISTMLP.py
@@ -72,8 +72,3 @@ def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor:
return (out, feats)
raise NotImplementedError("Unknown return type")
-
- def to(self, device):
- super().to(device)
- self.device = device
- return self
diff --git a/backbone/ResNetBlock.py b/backbone/ResNetBlock.py
index 56d8f65c..1da7c010 100644
--- a/backbone/ResNetBlock.py
+++ b/backbone/ResNetBlock.py
@@ -112,10 +112,6 @@ def __init__(self, block: BasicBlock, num_blocks: List[int],
self.feature_dim = nf * 8 * block.expansion
- def to(self, device, **kwargs):
- self.device = device
- return super().to(device, **kwargs)
-
def set_return_prerelu(self, enable=True):
self.return_prerelu = enable
for c in self.modules():
diff --git a/backbone/__init__.py b/backbone/__init__.py
index 09467286..3cdd4ec0 100644
--- a/backbone/__init__.py
+++ b/backbone/__init__.py
@@ -65,6 +65,12 @@ class MammothBackbone(nn.Module):
def __init__(self, **kwargs) -> None:
super(MammothBackbone, self).__init__()
+ self.device = torch.device('cpu') if 'device' not in kwargs else kwargs['device']
+
+ def to(self, device, *args, **kwargs):
+ super(MammothBackbone, self).to(device, *args, **kwargs)
+ self.device = device
+ return self
def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor:
"""
diff --git a/backbone/utils/layers.py b/backbone/utils/layers.py
index 55547581..9e7bd998 100644
--- a/backbone/utils/layers.py
+++ b/backbone/utils/layers.py
@@ -6,61 +6,7 @@
import torch.nn as nn
import torch.nn.functional as F
-
-class LoRALayer():
- def __init__(
- self,
- lora_dropout: float,
- ):
- # Optional dropout
- if lora_dropout > 0.:
- self.lora_dropout = nn.Dropout(p=lora_dropout)
- else:
- self.lora_dropout = lambda x: x
-
-
-class LoRALinear(nn.Linear, LoRALayer):
-
- def __init__(
- self,
- in_features: int,
- out_features: int,
- lora_dropout: float = 0.,
- fan_in_fan_out: bool = False,
- **kwargs
- ):
- nn.Linear.__init__(self, in_features, out_features, **kwargs)
- LoRALayer.__init__(self, lora_dropout=lora_dropout)
-
- self.fan_in_fan_out = fan_in_fan_out
- self.weight.requires_grad = False
- self.reset_parameters()
-
- if fan_in_fan_out:
- self.weight.data = self.weight.data.transpose(0, 1)
-
- def reset_parameters(self):
- nn.Linear.reset_parameters(self)
-
- def forward(self, x: torch.Tensor, AB: dict = None):
-
- def T(w):
- return w.transpose(1, 2) if self.fan_in_fan_out else w
-
- result = F.linear(x, T(self.weight), bias=self.bias)
-
- if AB is not None:
- A = None
- if isinstance(AB, dict):
- B = AB['B']
- A = AB.get('A')
- else:
- B = AB
- if A is not None:
- return result + (B @ (A @ x.transpose(1, 2).unsqueeze(1))).sum(1).transpose(1, 2)
- return result + (B @ x.transpose(1, 2).unsqueeze(1)).sum(1).transpose(1, 2)
-
- return result
+from backbone.utils.lora_utils import LoRALayer
class ClipLinear(nn.Linear, LoRALayer):
diff --git a/backbone/utils/lora_utils.py b/backbone/utils/lora_utils.py
new file mode 100644
index 00000000..18e06dc5
--- /dev/null
+++ b/backbone/utils/lora_utils.py
@@ -0,0 +1,208 @@
+import collections.abc
+from itertools import repeat
+from torch import nn
+import torch
+import torch.nn.functional as F
+
+
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+ return parse
+
+
+to_2tuple = _ntuple(2)
+
+
+class LoRALayer():
+ def __init__(
+ self,
+ lora_dropout: float,
+ ):
+ # Optional dropout
+ if lora_dropout > 0.:
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
+ else:
+ self.lora_dropout = lambda x: x
+
+
+class LoRALinear(nn.Linear, LoRALayer):
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ lora_dropout: float = 0.,
+ fan_in_fan_out: bool = False,
+ **kwargs
+ ):
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
+ LoRALayer.__init__(self, lora_dropout=lora_dropout)
+
+ self.fan_in_fan_out = fan_in_fan_out
+ self.weight.requires_grad = False
+ self.reset_parameters()
+
+ if fan_in_fan_out:
+ self.weight.data = self.weight.data.transpose(0, 1)
+
+ def reset_parameters(self):
+ nn.Linear.reset_parameters(self)
+
+ def forward(self, x: torch.Tensor, AB: dict = None):
+
+ def T(w):
+ return w.transpose(1, 2) if self.fan_in_fan_out else w
+
+ result = F.linear(x, T(self.weight), bias=self.bias)
+
+ if AB is not None:
+ A = None
+ if isinstance(AB, dict):
+ B = AB['B']
+ A = AB.get('A')
+ else:
+ B = AB
+ if A is not None:
+ return result + (B @ (A @ x.transpose(1, 2).unsqueeze(1))).sum(1).transpose(1, 2)
+ return result + (B @ x.transpose(1, 2).unsqueeze(1)).sum(1).transpose(1, 2)
+
+ return result
+
+
+class LoRAAttention(nn.Module):
+ """
+ Attention layer as used in Vision Transformer.
+ Adapted to support LoRA-style parameters.
+
+ Args:
+ dim: Number of input channels
+ num_heads: Number of attention heads
+ qkv_bias: If True, add a learnable bias to q, k, v
+ attn_drop: Dropout rate for attention weights
+ proj_drop: Dropout rate after the final projection
+ """
+
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = LoRALinear(dim, dim * 3, 0., bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = LoRALinear(dim, dim, 0.)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, AB: dict = None, **kwargs):
+ """
+ Forward pass of the attention layer.
+ Supports `AB` for LoRA-style parameters (checkout docs for `VisionTransformer.forward`).
+
+ Args:
+ x: Input tensor
+ AB: Dictionary containing LoRA-style parameters for the layer
+ """
+
+ B, N, C = x.shape
+
+ AB_qkv = None
+
+ if AB is not None:
+ AB_qkv = AB.get("qkv")
+
+ qkv = self.qkv(x, AB_qkv)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ # NOTE: flash attention is less debuggable than the original. Use the commented code below if in trouble.
+ if torch.__version__ >= '2.1.0':
+ x = F.scaled_dot_product_attention(q, k, v, scale=self.scale, dropout_p=self.attn_drop.p)
+ else:
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+
+ AB_proj = None
+
+ if AB is not None:
+ AB_proj = AB.get("proj")
+
+ x = self.proj(x, AB_proj)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class LoRAMlp(nn.Module):
+ """
+ MLP as used in Vision Transformer, MLP-Mixer and related networks.
+ Adapted to support LoRA-style parameters.
+ """
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ assert use_conv is False
+
+ self.fc1 = LoRALinear(in_features, hidden_features, bias=bias[0], lora_dropout=0.)
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
+ self.fc2 = LoRALinear(hidden_features, out_features, bias=bias[1], lora_dropout=0.)
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x: torch.Tensor, AB: dict = None, **kwargs):
+ """
+ Forward pass of the MLP layer.
+ Supports `AB` for LoRA-style parameters (checkout docs for `VisionTransformer.forward`).
+
+ Args:
+ x: Input tensor
+ AB: Dictionary containing LoRA-style parameters for the layer
+ """
+ AB_fc1 = None
+ AB_fc2 = None
+
+ if AB is not None:
+ AB_fc1 = AB.get("fc1")
+ AB_fc2 = AB.get("fc2")
+
+ x = self.fc1(x, AB_fc1)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.norm(x)
+ x = self.fc2(x, AB_fc2)
+ x = self.drop2(x)
+
+ return x
diff --git a/backbone/vit.py b/backbone/vit.py
index 9e53fde9..87e48c4e 100644
--- a/backbone/vit.py
+++ b/backbone/vit.py
@@ -49,7 +49,6 @@
import logging
import math
-from collections import OrderedDict
from functools import partial
import torch
@@ -62,33 +61,19 @@
from timm.models._builder import build_model_with_cfg
from timm.models._manipulate import named_apply
-from backbone.utils.layers import LoRALinear, IncrementalClassifier
-
-from itertools import repeat
-import collections.abc
-
+from backbone.utils.layers import IncrementalClassifier
from backbone import MammothBackbone
+from backbone.utils.lora_utils import LoRAAttention, LoRAMlp
+from utils.conf import warn_once
__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this
_logger = logging.getLogger(__name__)
-def _ntuple(n):
- def parse(x):
- if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
- return tuple(x)
- return tuple(repeat(x, n))
- return parse
-
-
-to_2tuple = _ntuple(2)
-
-
class Attention(nn.Module):
"""
Attention layer as used in Vision Transformer.
- Adapted to support LoRA-style parameters.
Args:
dim: Number of input channels
@@ -105,47 +90,39 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.)
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
- self.qkv = LoRALinear(dim, dim * 3, 0., bias=qkv_bias)
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
- self.proj = LoRALinear(dim, dim, 0.)
+ self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
- def forward(self, x, AB: dict = None, **kwargs):
+ def forward(self, x, **kwargs):
"""
Forward pass of the attention layer.
- Supports `AB` for LoRA-style parameters (checkout docs for `VisionTransformer.forward`).
Args:
x: Input tensor
- AB: Dictionary containing LoRA-style parameters for the layer
"""
B, N, C = x.shape
- AB_qkv = None
-
- if AB is not None:
- AB_qkv = AB.get("qkv")
-
- qkv = self.qkv(x, AB_qkv)
+ qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# NOTE: flash attention is less debuggable than the original. Use the commented code below if in trouble.
- x = F.scaled_dot_product_attention(q, k, v, scale=self.scale, dropout_p=self.attn_drop.p)
- # attn = (q @ k.transpose(-2, -1)) * self.scale
- # attn = attn.softmax(dim=-1)
- # attn = self.attn_drop(attn)
- # x = (attn @ v)
+ # check torch version
+ if torch.__version__ >= '2.1.0':
+ x = F.scaled_dot_product_attention(q, k, v, scale=self.scale, dropout_p=self.attn_drop.p)
+ else:
+ warn_once("Torch verison < 2.1.0 detected. Using the original attention code.")
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v)
x = x.transpose(1, 2).reshape(B, N, C)
- AB_proj = None
-
- if AB is not None:
- AB_proj = AB.get("proj")
-
- x = self.proj(x, AB_proj)
+ x = self.proj(x)
x = self.proj_drop(x)
return x
@@ -161,64 +138,6 @@ def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
-class Mlp(nn.Module):
- """
- MLP as used in Vision Transformer, MLP-Mixer and related networks.
- Adapted to support LoRA-style parameters.
- """
-
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- norm_layer=None,
- bias=True,
- drop=0.,
- use_conv=False,
- ):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- bias = to_2tuple(bias)
- drop_probs = to_2tuple(drop)
-
- assert use_conv is False
-
- self.fc1 = LoRALinear(in_features, hidden_features, bias=bias[0], lora_dropout=0.)
- self.act = act_layer()
- self.drop1 = nn.Dropout(drop_probs[0])
- self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
- self.fc2 = LoRALinear(hidden_features, out_features, bias=bias[1], lora_dropout=0.)
- self.drop2 = nn.Dropout(drop_probs[1])
-
- def forward(self, x: torch.Tensor, AB: dict = None, **kwargs):
- """
- Forward pass of the MLP layer.
- Supports `AB` for LoRA-style parameters (checkout docs for `VisionTransformer.forward`).
-
- Args:
- x: Input tensor
- AB: Dictionary containing LoRA-style parameters for the layer
- """
- AB_fc1 = None
- AB_fc2 = None
-
- if AB is not None:
- AB_fc1 = AB.get("fc1")
- AB_fc2 = AB.get("fc2")
-
- x = self.fc1(x, AB_fc1)
- x = self.act(x)
- x = self.drop1(x)
- x = self.norm(x)
- x = self.fc2(x, AB_fc2)
- x = self.drop2(x)
-
- return x
-
-
class Block(nn.Module):
def __init__(
@@ -234,6 +153,7 @@ def __init__(
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
attn_layer=Attention,
+ mlp_layer=Mlp
):
super().__init__()
self.norm1 = norm_layer(dim)
@@ -243,7 +163,7 @@ def __init__(
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
- self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
+ self.mlp = mlp_layer(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@@ -254,7 +174,8 @@ def forward(self, x, **kwargs):
class VisionTransformer(MammothBackbone):
- """ Vision Transformer
+ """ Vision Transformer.
+ This implementation supports LoRA (Layer-wise Relevance Adaptation) parameters if `use_lora=True`.
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
@@ -285,7 +206,9 @@ def __init__(
norm_layer=None,
act_layer=None,
block_fn=Block,
- attn_layer=Attention,
+ attn_layer=None,
+ mlp_layer=None,
+ use_lora=False,
args=None
):
"""
@@ -321,6 +244,8 @@ def __init__(
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.act_layer = act_layer or nn.GELU
+ attn_layer = attn_layer if attn_layer is not None else (Attention if not use_lora else LoRAAttention)
+ mlp_layer = mlp_layer if mlp_layer is not None else (Mlp if not use_lora else LoRAMlp)
self.attn_layer = attn_layer
self.norm_layer = norm_layer
self.num_heads = num_heads
@@ -366,7 +291,8 @@ def __init__(
drop_path=self.dpr[i],
norm_layer=norm_layer,
act_layer=self.act_layer,
- attn_layer=attn_layer
+ attn_layer=attn_layer,
+ mlp_layer=mlp_layer
)
for i in range(depth)])
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
@@ -743,7 +669,7 @@ def vit_base_patch16_224_prompt_prototype(pretrained=False, pretrain_type='in21k
"""
assert pretrain_type in ['in21k', 'in21k_old', 'in21k-ft-in1k'], f"Invalid pretrain_type: {pretrain_type}"
if not pretrained:
- print("WARNING: creating a ViT without pre-trained weights. This is not recommended.")
+ logging.warning("creating a ViT without pre-trained weights. This is not recommended.")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
if kwargs is None:
diff --git a/datasets/perm_mnist.py b/datasets/perm_mnist.py
index 2187d5ba..702f378b 100644
--- a/datasets/perm_mnist.py
+++ b/datasets/perm_mnist.py
@@ -14,7 +14,7 @@
from backbone.MNISTMLP import MNISTMLP
from datasets.transforms.permutation import Permutation
-from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders
+from datasets.utils.continual_dataset import ContinualDataset, fix_class_names_order, store_masked_loaders
from utils.conf import base_path
from datasets.utils import set_default_from_args
@@ -112,3 +112,12 @@ def get_batch_size(self) -> int:
@set_default_from_args('n_epochs')
def get_epochs(self):
return 1
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = MNIST(base_path() + 'MNIST', train=True, download=True).classes
+ classes = [c.split('-')[1].strip() for c in classes]
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return self.class_names
diff --git a/datasets/rot_mnist.py b/datasets/rot_mnist.py
index a346a3bc..1b8c43ec 100644
--- a/datasets/rot_mnist.py
+++ b/datasets/rot_mnist.py
@@ -9,9 +9,10 @@
from backbone.MNISTMLP import MNISTMLP
from datasets.perm_mnist import MyMNIST, MNIST
from datasets.transforms.rotation import Rotation
-from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders
+from datasets.utils.continual_dataset import ContinualDataset, fix_class_names_order, store_masked_loaders
from utils.conf import base_path
from datasets.utils import set_default_from_args
+from torchvision.datasets import MNIST
class RotatedMNIST(ContinualDataset):
@@ -72,3 +73,12 @@ def get_batch_size(self) -> int:
@set_default_from_args('n_epochs')
def get_epochs(self):
return 1
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = MNIST(base_path() + 'MNIST', train=True, download=True).classes
+ classes = [c.split('-')[1].strip() for c in classes]
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return self.class_names
diff --git a/datasets/seq_cars196.py b/datasets/seq_cars196.py
new file mode 100644
index 00000000..66b2daae
--- /dev/null
+++ b/datasets/seq_cars196.py
@@ -0,0 +1,228 @@
+import os
+import sys
+import torch
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+from PIL import Image
+from typing import Tuple
+from tqdm import tqdm
+import json
+try:
+ import deeplake
+except ImportError:
+ raise NotImplementedError("Deeplake not installed. Please install with `pip install deeplake` to use this dataset.")
+
+from datasets.utils import set_default_from_args
+from datasets.utils.continual_dataset import ContinualDataset, fix_class_names_order, store_masked_loaders
+from datasets.transforms.denormalization import DeNormalize
+from utils.conf import base_path
+from torch.utils.data import Dataset
+from torchvision.transforms.functional import InterpolationMode
+from utils.prompt_templates import templates
+from backbone.vit import vit_base_patch16_224_prompt_prototype
+
+
+def load_and_preprocess_cars196(train_str='train', names_only=False) -> Tuple[torch.Tensor, torch.Tensor, dict] | dict:
+ """
+ Loads data from deeplake and preprocesses it to be stored locally.
+
+ Args:
+ train_str (str): 'train' or 'test'.
+ names_only (bool): If True, returns the class names only.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, dict] | dict: If names_only is False, returns a tuple of data, targets, and class_idx_to_name
+ """
+ assert train_str in ['train', 'test'], "train_str must be 'train' or 'test'"
+ ds = deeplake.load(f"hub://activeloop/stanford-cars-{train_str}")
+ loader = ds.pytorch()
+ class_names = ds['car_models'].info['class_names']
+ class_idx_to_name = {i: class_names[i] for i in range(len(class_names))}
+ if names_only:
+ return class_idx_to_name
+
+ # Pre-process dataset
+ data = []
+ targets = []
+ for x in tqdm(loader, desc=f'Pre-processing {train_str} dataset'):
+ img = x['images'][0].permute(2, 0, 1) # load one image at a time
+ if len(img) < 3:
+ img = img.repeat(3, 1, 1) # fix rgb
+ img = MyCars196.PREPROCESSING_TRANSFORM(img) # resize
+ data.append(img)
+ label = x['car_models'][0].item() # get label
+ targets.append(label)
+
+ data = torch.stack(data) # stack all images
+ targets = torch.tensor(targets)
+
+ return data, targets, class_idx_to_name
+
+
+class MyCars196(Dataset):
+ N_CLASSES = 196
+
+ """
+ Overrides the CIFAR100 dataset to change the getitem function.
+ """
+
+ PREPROCESSING_TRANSFORM = transforms.Compose([
+ transforms.Resize(224, interpolation=InterpolationMode.BICUBIC, antialias=True),
+ transforms.CenterCrop(224),
+ ])
+
+ def __init__(self, root, train=True, transform=None,
+ target_transform=None) -> None:
+
+ self.root = root
+ self.train = train
+ self.transform = transform
+ self.target_transform = target_transform
+ self.not_aug_transform = transforms.ToTensor()
+
+ train_str = 'train' if train else 'test'
+ if not os.path.exists(f'{root}/{train_str}_images.pt'):
+ print(f'Preparing {train_str} dataset...', file=sys.stderr)
+ self.load_and_preprocess_dataset(root, train_str)
+ else:
+ print(f"Loading pre-processed {train_str} dataset...", file=sys.stderr)
+ self.data = torch.load(f'{root}/{train_str}_images.pt')
+ self.targets = torch.load(f'{root}/{train_str}_labels.pt')
+
+ self.class_names = MyCars196.get_class_names()
+
+ def load_and_preprocess_dataset(self, root, train_str='train'):
+ self.data, self.targets, class_idx_to_name = load_and_preprocess_cars196(train_str)
+
+ print(f"Saving pre-processed dataset in {root} ({train_str}_images.pt and {train_str}_labels.py)...", file=sys.stderr)
+ if not os.path.exists(root):
+ os.makedirs(root)
+ torch.save(self.data, f'{root}/{train_str}_images.pt')
+ torch.save(self.targets, f'{root}/{train_str}_labels.pt')
+
+ with open(f'{root}/class_names.json', 'wt') as f:
+ json.dump(class_idx_to_name, f, indent=4)
+ print('Done', file=sys.stderr)
+
+ @staticmethod
+ def get_class_names():
+ if not os.path.exists(base_path() + f'cars196/class_names.json'):
+ print("Class names not found, performing pre-processing...")
+ class_idx_to_name = load_and_preprocess_cars196(names_only=True)
+ print('Done', file=sys.stderr)
+ else:
+ with open(base_path() + f'cars196/class_names.json', 'rt') as f:
+ class_idx_to_name = json.load(f)
+ class_names = list(class_idx_to_name.values())
+ return class_names
+
+ def __len__(self):
+ return len(self.targets)
+
+ def __getitem__(self, index: int) -> Tuple[Image.Image, int, Image.Image]:
+ """
+ Gets the requested element from the dataset.
+
+ Args:
+ index: index of the element to be returned
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], self.targets[index]
+
+ img = Image.fromarray(img.permute(1, 2, 0).numpy(), mode='RGB')
+
+ not_aug_img = self.not_aug_transform(img.copy())
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ if not self.train:
+ return img, target
+
+ if hasattr(self, 'logits'):
+ return img, target, not_aug_img, self.logits[index]
+
+ return img, target, not_aug_img
+
+
+class SequentialCars196(ContinualDataset):
+ """
+ Sequential CARS196 Dataset. The images are loaded from deeplake, resized to 224x224, and store locally.
+ """
+
+ NAME = 'seq-cars196'
+ SETTING = 'class-il'
+ N_TASKS = 10
+ N_CLASSES = 196
+ N_CLASSES_PER_TASK = [20] * 9 + [16]
+ MEAN, STD = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
+ SIZE = (224, 224)
+
+ TRANSFORM = transforms.Compose([
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=MEAN, std=STD),
+ ])
+ TEST_TRANSFORM = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=MEAN, std=STD)]) # no transform for test
+
+ def __init__(self, args):
+ super().__init__(args)
+ self.args = args
+
+ def get_data_loaders(self) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
+ train_dataset = MyCars196(base_path() + 'cars196', train=True,
+ transform=self.TRANSFORM)
+ test_dataset = MyCars196(base_path() + 'cars196', train=False,
+ transform=self.TEST_TRANSFORM)
+
+ train, test = store_masked_loaders(train_dataset, test_dataset, self)
+
+ return train, test
+
+ @staticmethod
+ def get_prompt_templates():
+ return templates['cars196']
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = MyCars196.get_class_names()
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return self.class_names
+
+ @staticmethod
+ def get_transform():
+ transform = transforms.Compose(
+ [transforms.ToPILImage(), SequentialCars196.TRANSFORM])
+ return transform
+
+ @staticmethod
+ def get_backbone():
+ return vit_base_patch16_224_prompt_prototype(pretrained=True, num_classes=sum(SequentialCars196.N_CLASSES_PER_TASK))
+
+ @staticmethod
+ def get_loss():
+ return F.cross_entropy
+
+ @staticmethod
+ def get_normalization_transform():
+ return transforms.Normalize(mean=SequentialCars196.MEAN, std=SequentialCars196.STD)
+
+ @staticmethod
+ def get_denormalization_transform():
+ transform = DeNormalize(SequentialCars196.MEAN, SequentialCars196.STD)
+ return transform
+
+ @set_default_from_args('n_epochs')
+ def get_epochs(self):
+ return 50
+
+ @set_default_from_args('batch_size')
+ def get_batch_size(self):
+ return 128
diff --git a/datasets/seq_chestx.py b/datasets/seq_chestx.py
new file mode 100644
index 00000000..c7297197
--- /dev/null
+++ b/datasets/seq_chestx.py
@@ -0,0 +1,178 @@
+import os
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+import numpy as np
+import pickle
+from PIL import Image
+from typing import Tuple
+
+from datasets.utils import set_default_from_args
+from utils import smart_joint
+from utils.conf import base_path
+from datasets.utils.continual_dataset import ContinualDataset, fix_class_names_order, store_masked_loaders
+from datasets.transforms.denormalization import DeNormalize
+from torchvision.transforms.functional import InterpolationMode
+from utils.prompt_templates import templates
+from backbone.vit import vit_base_patch16_224_prompt_prototype
+
+
+class ChestX(Dataset):
+ N_CLASSES = 6
+
+ """
+ To reduce the effect of the severe imbalance in the dataset, we drop the two classes with the smallest and largest amount of samples.
+ """
+ LABELS = [
+ "Cardiomegaly",
+ "Consolidation",
+ "Edema",
+ "Fibrosis",
+ "Pleural Thickening",
+ "Pneumothorax"
+ ]
+
+ """
+ Overrides the ChestX dataset to change the getitem function.
+ """
+
+ def __init__(self, root, train=True, transform=None,
+ target_transform=None, download=False) -> None:
+
+ self.root = root
+ self.train = train
+ self.transform = transform
+ self.target_transform = target_transform
+
+ if not os.path.exists(f'{root}/train_images.pkl'):
+ if download:
+ from onedrivedownloader import download
+
+ print('Downloading dataset')
+ ln = "https://unimore365-my.sharepoint.com/:u:/g/personal/215580_unimore_it/EfmFCiLaGlpFgtAuv0YLpeYBeR54I7YHK75bu_Ex78mADA?e=K8rHpZ"
+ download(ln, filename=smart_joint(root, 'chestx.zip'), unzip=True, unzip_path=root.rstrip('chestx'), clean=True)
+ else:
+ raise FileNotFoundError(f'File not found: {root}/train_images.pkl')
+
+ if train:
+ filename_labels = f'{self.root}/train_labels.pkl'
+ filename_images = f'{self.root}/train_images.pkl'
+ else:
+ filename_labels = f'{self.root}/test_labels.pkl'
+ filename_images = f'{self.root}/test_images.pkl'
+
+ self.not_aug_transform = transforms.ToTensor()
+
+ with open(filename_images, 'rb') as f:
+ self.data = pickle.load(f)
+
+ with open(filename_labels, 'rb') as f:
+ self.targets = pickle.load(f)
+
+ def __len__(self):
+ return len(self.targets)
+
+ def __getitem__(self, index: int) -> Tuple[Image.Image, int, Image.Image]:
+ """
+ Gets the requested element from the dataset.
+ :param index: index of the element to be returned
+ :returns: tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], self.targets[index]
+ img = np.repeat(img[np.newaxis, :, :], 3, axis=0)
+ img = Image.fromarray((img * 255).astype(np.int8).transpose(1, 2, 0), mode='RGB')
+
+ original_img = img.copy()
+
+ not_aug_img = self.not_aug_transform(original_img)
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ if not self.train:
+ return img, target
+
+ if hasattr(self, 'logits'):
+ return img, target, not_aug_img, self.logits[index]
+
+ return img, target, not_aug_img
+
+
+class SequentialChestX(ContinualDataset):
+
+ NAME = 'seq-chestx'
+ SETTING = 'class-il'
+ N_TASKS = 2
+ N_CLASSES = 6
+ N_CLASSES_PER_TASK = 3
+ SIZE = (224, 224)
+ MEAN, STD = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
+ normalize = transforms.Normalize(mean=MEAN,
+ std=STD)
+
+ TRANSFORM = transforms.Compose([
+ transforms.Resize(size=SIZE, interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ normalize,
+ ])
+
+ TEST_TRANSFORM = transforms.Compose([
+ transforms.Resize(size=SIZE, interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ normalize,
+ ])
+
+ def get_data_loaders(self):
+ train_dataset = ChestX(base_path() + 'chestx', train=True,
+ download=True, transform=self.TRANSFORM)
+
+ test_dataset = ChestX(base_path() + 'chestx', train=False, download=True,
+ transform=self.TEST_TRANSFORM)
+
+ train, test = store_masked_loaders(train_dataset, test_dataset, self)
+
+ return train, test
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = fix_class_names_order(ChestX.LABELS, self.args)
+ self.class_names = classes
+ return self.class_names
+
+ @staticmethod
+ def get_prompt_templates():
+ return templates['cifar100']
+
+ @staticmethod
+ def get_transform():
+ return transforms.Compose([transforms.ToPILImage(),
+ SequentialChestX.TRANSFORM])
+
+ @staticmethod
+ def get_backbone():
+ return vit_base_patch16_224_prompt_prototype(pretrained=True, num_classes=SequentialChestX.N_CLASSES)
+
+ @staticmethod
+ def get_loss():
+ return F.cross_entropy
+
+ @staticmethod
+ def get_normalization_transform():
+ return transforms.Normalize(mean=SequentialChestX.MEAN, std=SequentialChestX.STD)
+
+ @staticmethod
+ def get_denormalization_transform():
+ transform = DeNormalize(mean=SequentialChestX.MEAN, std=SequentialChestX.STD)
+ return transform
+
+ @set_default_from_args('n_epochs')
+ def get_epochs(self):
+ return 30
+
+ @set_default_from_args('batch_size')
+ def get_batch_size(self):
+ return 128
diff --git a/datasets/seq_cifar10.py b/datasets/seq_cifar10.py
index 5705cd92..e6253c71 100644
--- a/datasets/seq_cifar10.py
+++ b/datasets/seq_cifar10.py
@@ -12,9 +12,9 @@
from torchvision.datasets import CIFAR10
from backbone.ResNetBlock import resnet18
-from datasets.seq_tinyimagenet import base_path
+from utils.conf import base_path
from datasets.transforms.denormalization import DeNormalize
-from datasets.utils.continual_dataset import (ContinualDataset,
+from datasets.utils.continual_dataset import (ContinualDataset, fix_class_names_order,
store_masked_loaders)
from datasets.utils import set_default_from_args
@@ -143,3 +143,11 @@ def get_epochs(self):
@set_default_from_args('batch_size')
def get_batch_size(self):
return 32
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = CIFAR10(base_path() + 'CIFAR10', train=True, download=True).classes
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return self.class_names
diff --git a/datasets/seq_cifar100.py b/datasets/seq_cifar100.py
index 451f9838..695398e1 100644
--- a/datasets/seq_cifar100.py
+++ b/datasets/seq_cifar100.py
@@ -14,7 +14,7 @@
from backbone.ResNetBlock import resnet18
from datasets.transforms.denormalization import DeNormalize
-from datasets.utils.continual_dataset import (ContinualDataset,
+from datasets.utils.continual_dataset import (ContinualDataset, fix_class_names_order,
store_masked_loaders)
from utils.conf import base_path
from datasets.utils import set_default_from_args
@@ -158,3 +158,11 @@ def get_scheduler(model, args: Namespace, reload_optim=True) -> torch.optim.lr_s
model.opt = model.get_optimizer()
scheduler = torch.optim.lr_scheduler.MultiStepLR(model.opt, [35, 45], gamma=0.1, verbose=False)
return scheduler
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = CIFAR100(base_path() + 'CIFAR100', train=True, download=True).classes
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return self.class_names
diff --git a/datasets/seq_cifar100_224.py b/datasets/seq_cifar100_224.py
index b47758ac..8ef9d4b5 100644
--- a/datasets/seq_cifar100_224.py
+++ b/datasets/seq_cifar100_224.py
@@ -5,14 +5,17 @@
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
+from torchvision.transforms.functional import InterpolationMode
+from torchvision.datasets import CIFAR100
from backbone.vit import vit_base_patch16_224_prompt_prototype
from datasets.seq_cifar100 import TCIFAR100, MyCIFAR100
from datasets.transforms.denormalization import DeNormalize
-from datasets.utils.continual_dataset import (ContinualDataset,
+from datasets.utils.continual_dataset import (ContinualDataset, fix_class_names_order,
store_masked_loaders)
from utils.conf import base_path
from datasets.utils import set_default_from_args
+from utils.prompt_templates import templates
class SequentialCIFAR100224(ContinualDataset):
@@ -40,14 +43,16 @@ class SequentialCIFAR100224(ContinualDataset):
SIZE = (224, 224)
MEAN, STD = (0, 0, 0), (1, 1, 1) # Normalized in [0,1] as in L2P paper
TRANSFORM = transforms.Compose(
- [transforms.Resize(224),
- transforms.RandomCrop(224, padding=28),
- transforms.RandomHorizontalFlip(),
+ [transforms.RandomResizedCrop(224, interpolation=InterpolationMode.BICUBIC),
+ transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)]
)
- TEST_TRANSFORM = transforms.Compose(
- [transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(MEAN, STD)])
+ TEST_TRANSFORM = transforms.Compose([
+ transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN, STD)
+ ])
def get_data_loaders(self) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
transform = self.TRANSFORM
@@ -70,7 +75,7 @@ def get_transform():
return transform
@staticmethod
- def get_backbone(hookme=False):
+ def get_backbone():
return vit_base_patch16_224_prompt_prototype(pretrained=True, num_classes=SequentialCIFAR100224.N_CLASSES_PER_TASK * SequentialCIFAR100224.N_TASKS)
@staticmethod
@@ -89,8 +94,20 @@ def get_denormalization_transform():
@set_default_from_args('n_epochs')
def get_epochs(self):
- return 5
+ return 20
@set_default_from_args('batch_size')
def get_batch_size(self):
return 128
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = CIFAR100(base_path() + 'CIFAR100', train=True, download=True).classes
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return self.class_names
+
+ @staticmethod
+ def get_prompt_templates():
+ return templates['cifar100']
diff --git a/datasets/seq_cifar100_224_rs.py b/datasets/seq_cifar100_224_rs.py
index 366ba966..213abfdf 100644
--- a/datasets/seq_cifar100_224_rs.py
+++ b/datasets/seq_cifar100_224_rs.py
@@ -3,11 +3,12 @@
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
+from torchvision.datasets import CIFAR100
from backbone.ResNetBottleneck import resnet50
from datasets.seq_cifar100 import TCIFAR100, MyCIFAR100
from datasets.transforms.denormalization import DeNormalize
-from datasets.utils.continual_dataset import (ContinualDataset,
+from datasets.utils.continual_dataset import (ContinualDataset, fix_class_names_order,
store_masked_loaders)
from utils.conf import base_path
from datasets.utils import set_default_from_args
@@ -93,3 +94,11 @@ def get_epochs(self):
@set_default_from_args('batch_size')
def get_batch_size(self):
return 32
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = CIFAR100(base_path() + 'CIFAR100', train=True, download=True).classes
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return self.class_names
diff --git a/datasets/seq_cifar10_224.py b/datasets/seq_cifar10_224.py
index 51748e8f..a93f8d66 100644
--- a/datasets/seq_cifar10_224.py
+++ b/datasets/seq_cifar10_224.py
@@ -8,14 +8,15 @@
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
+from torchvision.datasets import CIFAR10
from backbone.vit import vit_base_patch16_224_prompt_prototype
-from datasets.seq_cifar10 import TCIFAR10, MyCIFAR10
-from datasets.seq_tinyimagenet import base_path
+from datasets.seq_cifar10 import TCIFAR10, MyCIFAR10, base_path
from datasets.transforms.denormalization import DeNormalize
-from datasets.utils.continual_dataset import (ContinualDataset,
+from datasets.utils.continual_dataset import (ContinualDataset, fix_class_names_order,
store_masked_loaders)
from datasets.utils import set_default_from_args
+from utils.prompt_templates import templates
class SequentialCIFAR10224(ContinualDataset):
@@ -93,3 +94,15 @@ def get_epochs(self):
@set_default_from_args('batch_size')
def get_batch_size(self):
return 32
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = CIFAR10(base_path() + 'CIFAR10', train=True, download=True).classes
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return self.class_names
+
+ @staticmethod
+ def get_prompt_templates():
+ return templates['cifar100']
diff --git a/datasets/seq_cifar10_224_rs.py b/datasets/seq_cifar10_224_rs.py
index 77c2e817..39d86fa1 100644
--- a/datasets/seq_cifar10_224_rs.py
+++ b/datasets/seq_cifar10_224_rs.py
@@ -8,12 +8,12 @@
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
+from torchvision.datasets import CIFAR10
from backbone.ResNetBottleneck import resnet50
-from datasets.seq_cifar10 import TCIFAR10, MyCIFAR10
-from datasets.seq_tinyimagenet import base_path
+from datasets.seq_cifar10 import TCIFAR10, MyCIFAR10, base_path
from datasets.transforms.denormalization import DeNormalize
-from datasets.utils.continual_dataset import (ContinualDataset,
+from datasets.utils.continual_dataset import (ContinualDataset, fix_class_names_order,
store_masked_loaders)
from datasets.utils import set_default_from_args
@@ -94,3 +94,11 @@ def get_epochs(self):
@set_default_from_args('batch_size')
def get_batch_size(self):
return 32
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = CIFAR10(base_path() + 'CIFAR10', train=True, download=True).classes
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return self.class_names
diff --git a/datasets/seq_cropdisease.py b/datasets/seq_cropdisease.py
new file mode 100644
index 00000000..c089fb54
--- /dev/null
+++ b/datasets/seq_cropdisease.py
@@ -0,0 +1,198 @@
+import json
+import os
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+import numpy as np
+from PIL import Image
+from typing import Tuple
+
+from datasets.utils import set_default_from_args
+from utils import smart_joint
+from utils.conf import base_path
+from datasets.utils.continual_dataset import ContinualDataset, fix_class_names_order, store_masked_loaders
+from datasets.transforms.denormalization import DeNormalize
+from torchvision.transforms.functional import InterpolationMode
+from utils.prompt_templates import templates
+from backbone.vit import vit_base_patch16_224_prompt_prototype
+
+
+class CropDisease(Dataset):
+
+ LABELS = [
+ "Apple___Apple_scab",
+ "Apple___Black_rot",
+ "Apple___healthy",
+ "Blueberry___healthy",
+ "Cherry___Powdery_mildew",
+ "Cherry___healthy",
+ "Corn___Cercospora_leaf_spot Gray_leaf_spot",
+ "Corn___Common_rust",
+ "Corn___Northern_Leaf_Blight",
+ "Corn___healthy",
+ "Grape___Black_rot",
+ "Grape___Esca_(Black_Measles)",
+ "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
+ "Grape___healthy",
+ "Orange___Haunglongbing_(Citrus_greening)",
+ "Peach___Bacterial_spot",
+ "Pepper,_bell___Bacterial_spot",
+ "Pepper,_bell___healthy",
+ "Potato___Early_blight",
+ "Potato___Late_blight",
+ "Raspberry___healthy",
+ "Soybean___healthy",
+ "Squash___Powdery_mildew",
+ "Strawberry___Leaf_scorch",
+ "Strawberry___healthy",
+ "Tomato___Bacterial_spot",
+ "Tomato___Early_blight",
+ "Tomato___Late_blight",
+ "Tomato___Leaf_Mold",
+ "Tomato___Septoria_leaf_spot",
+ "Tomato___Spider_mites Two-spotted_spider_mite",
+ "Tomato___Target_Spot",
+ "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
+ "Tomato___Tomato_mosaic_virus",
+ "Tomato___healthy",
+ ]
+
+ def __init__(self, root, train=True, transform=None,
+ target_transform=None, download=False) -> None:
+
+ self.root = root
+ self.train = train
+ self.transform = transform
+ self.target_transform = target_transform
+
+ self.not_aug_transform = transforms.Compose([
+ transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor()]
+ )
+
+ if download:
+ if os.path.isdir(root) and len(os.listdir(root)) > 0:
+ print('Download not needed, files already on disk.')
+ else:
+ from onedrivedownloader import download
+ ln = "https://unimore365-my.sharepoint.com/:u:/g/personal/215580_unimore_it/EZUaXKQUAVBPrhjHTUdflDEBNu0YiPWrdpAdDhnEU4nD2A?e=GPrCYF"
+ print('Downloading dataset')
+ parent_dir = os.path.dirname(root)
+ download(ln, filename=os.path.join(root, 'cropdisease.tar.gz'), unzip=True, unzip_path=parent_dir, clean=True)
+
+ filename = smart_joint(root, ('train' if train else 'test') + '.json')
+ with open(filename) as f:
+ data_config = json.load(f)
+
+ self.data = np.array([smart_joint(root, 'images', d) for d in data_config['data']])
+ self.targets = np.array(data_config['labels']).astype(np.int16)
+
+ def __len__(self):
+ return len(self.targets)
+
+ def __getitem__(self, index: int) -> Tuple[Image.Image, int, Image.Image]:
+ """
+ Gets the requested element from the dataset.
+ :param index: index of the element to be returned
+ :returns: tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], self.targets[index]
+
+ img = Image.open(img).convert('RGB')
+
+ original_img = img.copy()
+
+ not_aug_img = self.not_aug_transform(original_img)
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ if not self.train:
+ return img, target
+
+ if hasattr(self, 'logits'):
+ return img, target, not_aug_img, self.logits[index]
+
+ return img, target, not_aug_img
+
+
+class SequentialCropDisease(ContinualDataset):
+
+ NAME = 'seq-cropdisease'
+ SETTING = 'class-il'
+ N_TASKS = 7
+ N_CLASSES = 35
+ N_CLASSES_PER_TASK = N_CLASSES // N_TASKS
+ SIZE = (224, 224)
+ MEAN, STD = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
+
+ TRANSFORM = transforms.Compose([
+ transforms.RandomResizedCrop(SIZE, interpolation=InterpolationMode.BICUBIC),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=MEAN, std=STD),
+ ])
+
+ TEST_TRANSFORM = transforms.Compose([
+ transforms.Resize(size=SIZE, interpolation=InterpolationMode.BICUBIC),
+ transforms.CenterCrop(SIZE),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=MEAN, std=STD),
+ ])
+
+ def get_data_loaders(self):
+ train_dataset = CropDisease(base_path() + 'cropdisease', train=True,
+ download=True, transform=self.TRANSFORM)
+ test_dataset = CropDisease(base_path() + 'cropdisease', train=False,
+ download=True, transform=self.TEST_TRANSFORM)
+
+ train, test = store_masked_loaders(train_dataset, test_dataset, self)
+
+ return train, test
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = [x.replace('_', ' ') for x in CropDisease.LABELS] # .split('___')[-1]
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return self.class_names
+
+ @staticmethod
+ def get_prompt_templates():
+ return templates['cifar100']
+
+ @staticmethod
+ def get_transform():
+ transform = transforms.Compose(
+ [transforms.ToPILImage(), SequentialCropDisease.TRANSFORM])
+ return transform
+
+ @staticmethod
+ def get_backbone():
+ num_classes = SequentialCropDisease.N_CLASSES_PER_TASK * SequentialCropDisease.N_TASKS
+ return vit_base_patch16_224_prompt_prototype(pretrained=True, num_classes=num_classes)
+
+ @staticmethod
+ def get_loss():
+ return F.cross_entropy
+
+ @staticmethod
+ def get_normalization_transform():
+ return transforms.Normalize(mean=SequentialCropDisease.MEAN, std=SequentialCropDisease.STD)
+
+ @staticmethod
+ def get_denormalization_transform():
+ transform = DeNormalize(SequentialCropDisease.MEAN, SequentialCropDisease.STD)
+ return transform
+
+ @set_default_from_args('n_epochs')
+ def get_epochs(self):
+ return 5
+
+ @set_default_from_args('batch_size')
+ def get_batch_size(self):
+ return 128
diff --git a/datasets/seq_cub200.py b/datasets/seq_cub200.py
index 07ba9ceb..ce4cb4e0 100644
--- a/datasets/seq_cub200.py
+++ b/datasets/seq_cub200.py
@@ -1,21 +1,19 @@
import os
-from typing import Tuple
-
import numpy as np
import torch
-import torch.nn.functional as F
import torchvision.transforms as transforms
+import torch.nn.functional as F
from PIL import Image
-from torch.utils.data.dataset import Dataset
-
+from typing import Tuple
-from backbone.ResNetBottleneck import resnet50
+from datasets.utils import set_default_from_args
+from datasets.utils.continual_dataset import ContinualDataset, fix_class_names_order, store_masked_loaders
from datasets.transforms.denormalization import DeNormalize
-from datasets.utils.continual_dataset import (ContinualDataset,
- store_masked_loaders)
from utils import smart_joint
from utils.conf import base_path
-from datasets.utils import set_default_from_args
+from torch.utils.data import Dataset
+from torchvision.transforms.functional import InterpolationMode
+from backbone.vit import vit_base_patch16_224_prompt_prototype
class MyCUB200(Dataset):
@@ -24,12 +22,12 @@ class MyCUB200(Dataset):
"""
IMG_SIZE = 224
N_CLASSES = 200
- MEAN, STD = (0.4856, 0.4994, 0.4324), (0.2272, 0.2226, 0.2613)
- TEST_TRANSFORM = transforms.Compose([transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize(MEAN, STD)])
def __init__(self, root, train=True, transform=None,
target_transform=None, download=True) -> None:
- self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
+ self.not_aug_transform = transforms.Compose([
+ transforms.Resize(MyCUB200.IMG_SIZE, interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor()])
self.root = root
self.train = train
self.transform = transform
@@ -53,7 +51,7 @@ def __init__(self, root, train=True, transform=None,
self.segs = data_file['segs']
self._return_segmask = False
- def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]:
+ def __getitem__(self, index: int) -> Tuple[Image.Image, int, Image.Image]:
"""
Gets the requested element from the dataset.
@@ -67,9 +65,8 @@ def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]:
# to return a PIL Image
img = Image.fromarray(img, mode='RGB')
- original_img = img.copy()
- not_aug_img = self.not_aug_transform(original_img)
+ not_aug_img = self.not_aug_transform(img.copy())
if self.transform is not None:
img = self.transform(img)
@@ -81,6 +78,7 @@ def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]:
img, target, not_aug_img]
if self._return_segmask:
+ # TODO: add to the return tuple
raise "Unsupported segmentation output in training set!"
return ret_tuple
@@ -96,7 +94,7 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down
super().__init__(root, train=train, transform=transform,
target_transform=target_transform, download=download)
- def __getitem__(self, index: int, ret_segmask=False) -> Tuple[type(Image), int, type(Image)]:
+ def __getitem__(self, index: int, ret_segmask=False) -> Tuple[Image.Image, int, Image.Image]:
"""
Gets the requested element from the dataset.
@@ -120,6 +118,7 @@ def __getitem__(self, index: int, ret_segmask=False) -> Tuple[type(Image), int,
ret_tuple = [img, target, self.logits[index]] if hasattr(self, 'logits') else [img, target]
if ret_segmask or self._return_segmask:
+ # TODO: does not work with the current implementation
seg = self.segs[index]
seg = Image.fromarray(seg, mode='L')
seg = transforms.ToTensor()(transforms.CenterCrop((MyCUB200.IMG_SIZE, MyCUB200.IMG_SIZE))(seg))[0]
@@ -147,25 +146,23 @@ class SequentialCUB200(ContinualDataset):
N_CLASSES_PER_TASK = 20
N_TASKS = 10
SIZE = (MyCUB200.IMG_SIZE, MyCUB200.IMG_SIZE)
- MEAN, STD = (0.4856, 0.4994, 0.4324), (0.2272, 0.2226, 0.2613)
+ MEAN, STD = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
TRANSFORM = transforms.Compose([
- transforms.Resize(MyCUB200.IMG_SIZE),
- transforms.RandomCrop(MyCUB200.IMG_SIZE, padding=4),
+ transforms.Resize((300, 300), interpolation=InterpolationMode.BICUBIC),
+ transforms.RandomCrop(SIZE),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)])
- TEST_TRANSFORM = MyCUB200.TEST_TRANSFORM
+ TEST_TRANSFORM = transforms.Compose([transforms.Resize(256, interpolation=InterpolationMode.BICUBIC),
+ transforms.CenterCrop(MyCUB200.IMG_SIZE),
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN, STD)])
def get_data_loaders(self) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
- transform = self.TRANSFORM
-
- test_transform = transforms.Compose(
- [transforms.Resize((MyCUB200.IMG_SIZE, MyCUB200.IMG_SIZE)), transforms.ToTensor(), self.get_normalization_transform()])
-
train_dataset = MyCUB200(base_path() + 'CUB200', train=True,
- download=True, transform=transform)
+ download=True, transform=SequentialCUB200.TRANSFORM)
test_dataset = CUB200(base_path() + 'CUB200', train=False,
- download=True, transform=test_transform)
+ download=True, transform=SequentialCUB200.TEST_TRANSFORM)
train, test = store_masked_loaders(
train_dataset, test_dataset, self)
@@ -179,9 +176,9 @@ def get_transform():
return transform
@staticmethod
- def get_backbone(hookme=False):
+ def get_backbone():
num_classes = SequentialCUB200.N_CLASSES_PER_TASK * SequentialCUB200.N_TASKS
- return resnet50(num_classes, pretrained=True)
+ return vit_base_patch16_224_prompt_prototype(pretrained=True, num_classes=num_classes)
@staticmethod
def get_loss():
@@ -189,9 +186,7 @@ def get_loss():
@staticmethod
def get_normalization_transform():
- transform = transforms.Normalize(
- SequentialCUB200.MEAN, SequentialCUB200.STD)
- return transform
+ return transforms.Normalize(SequentialCUB200.MEAN, SequentialCUB200.STD)
@staticmethod
def get_denormalization_transform():
@@ -200,8 +195,219 @@ def get_denormalization_transform():
@set_default_from_args('batch_size')
def get_batch_size(self):
- return 16
+ return 128
@set_default_from_args('n_epochs')
def get_epochs(self):
- return 30
+ return 50
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = fix_class_names_order(CLASS_NAMES, self.args)
+ self.class_names = classes
+ return self.class_names
+
+
+CLASS_NAMES = [
+ 'black footed albatross',
+ 'laysan albatross',
+ 'sooty albatross',
+ 'groove billed ani',
+ 'crested auklet',
+ 'least auklet',
+ 'parakeet auklet',
+ 'rhinoceros auklet',
+ 'brewer blackbird',
+ 'red winged blackbird',
+ 'rusty blackbird',
+ 'yellow headed blackbird',
+ 'bobolink',
+ 'indigo bunting',
+ 'lazuli bunting',
+ 'painted bunting',
+ 'cardinal',
+ 'spotted catbird',
+ 'gray catbird',
+ 'yellow breasted chat',
+ 'eastern towhee',
+ 'chuck will widow',
+ 'brandt cormorant',
+ 'red faced cormorant',
+ 'pelagic cormorant',
+ 'bronzed cowbird',
+ 'shiny cowbird',
+ 'brown creeper',
+ 'american crow',
+ 'fish crow',
+ 'black billed cuckoo',
+ 'mangrove cuckoo',
+ 'yellow billed cuckoo',
+ 'gray crowned rosy finch',
+ 'purple finch',
+ 'northern flicker',
+ 'acadian flycatcher',
+ 'great crested flycatcher',
+ 'least flycatcher',
+ 'olive sided flycatcher',
+ 'scissor tailed flycatcher',
+ 'vermilion flycatcher',
+ 'yellow bellied flycatcher',
+ 'frigatebird',
+ 'northern fulmar',
+ 'gadwall',
+ 'american goldfinch',
+ 'european goldfinch',
+ 'boat tailed grackle',
+ 'eared grebe',
+ 'horned grebe',
+ 'pied billed grebe',
+ 'western grebe',
+ 'blue grosbeak',
+ 'evening grosbeak',
+ 'pine grosbeak',
+ 'rose breasted grosbeak',
+ 'pigeon guillemot',
+ 'california gull',
+ 'glaucous winged gull',
+ 'heermann gull',
+ 'herring gull',
+ 'ivory gull',
+ 'ring billed gull',
+ 'slaty backed gull',
+ 'western gull',
+ 'anna hummingbird',
+ 'ruby throated hummingbird',
+ 'rufous hummingbird',
+ 'green violetear',
+ 'long tailed jaeger',
+ 'pomarine jaeger',
+ 'blue jay',
+ 'florida jay',
+ 'green jay',
+ 'dark eyed junco',
+ 'tropical kingbird',
+ 'gray kingbird',
+ 'belted kingfisher',
+ 'green kingfisher',
+ 'pied kingfisher',
+ 'ringed kingfisher',
+ 'white breasted kingfisher',
+ 'red legged kittiwake',
+ 'horned lark',
+ 'pacific loon',
+ 'mallard',
+ 'western meadowlark',
+ 'hooded merganser',
+ 'red breasted merganser',
+ 'mockingbird',
+ 'nighthawk',
+ 'clark nutcracker',
+ 'white breasted nuthatch',
+ 'baltimore oriole',
+ 'hooded oriole',
+ 'orchard oriole',
+ 'scott oriole',
+ 'ovenbird',
+ 'brown pelican',
+ 'white pelican',
+ 'western wood pewee',
+ 'sayornis',
+ 'american pipit',
+ 'whip poor will',
+ 'horned puffin',
+ 'common raven',
+ 'white necked raven',
+ 'american redstart',
+ 'geococcyx',
+ 'loggerhead shrike',
+ 'great grey shrike',
+ 'baird sparrow',
+ 'black throated sparrow',
+ 'brewer sparrow',
+ 'chipping sparrow',
+ 'clay colored sparrow',
+ 'house sparrow',
+ 'field sparrow',
+ 'fox sparrow',
+ 'grasshopper sparrow',
+ 'harris sparrow',
+ 'henslow sparrow',
+ 'le conte sparrow',
+ 'lincoln sparrow',
+ 'nelson sharp tailed sparrow',
+ 'savannah sparrow',
+ 'seaside sparrow',
+ 'song sparrow',
+ 'tree sparrow',
+ 'vesper sparrow',
+ 'white crowned sparrow',
+ 'white throated sparrow',
+ 'cape glossy starling',
+ 'bank swallow',
+ 'barn swallow',
+ 'cliff swallow',
+ 'tree swallow',
+ 'scarlet tanager',
+ 'summer tanager',
+ 'artic tern',
+ 'black tern',
+ 'caspian tern',
+ 'common tern',
+ 'elegant tern',
+ 'forsters tern',
+ 'least tern',
+ 'green tailed towhee',
+ 'brown thrasher',
+ 'sage thrasher',
+ 'black capped vireo',
+ 'blue headed vireo',
+ 'philadelphia vireo',
+ 'red eyed vireo',
+ 'warbling vireo',
+ 'white eyed vireo',
+ 'yellow throated vireo',
+ 'bay breasted warbler',
+ 'black and white warbler',
+ 'black throated blue warbler',
+ 'blue winged warbler',
+ 'canada warbler',
+ 'cape may warbler',
+ 'cerulean warbler',
+ 'chestnut sided warbler',
+ 'golden winged warbler',
+ 'hooded warbler',
+ 'kentucky warbler',
+ 'magnolia warbler',
+ 'mourning warbler',
+ 'myrtle warbler',
+ 'nashville warbler',
+ 'orange crowned warbler',
+ 'palm warbler',
+ 'pine warbler',
+ 'prairie warbler',
+ 'prothonotary warbler',
+ 'swainson warbler',
+ 'tennessee warbler',
+ 'wilson warbler',
+ 'worm eating warbler',
+ 'yellow warbler',
+ 'northern waterthrush',
+ 'louisiana waterthrush',
+ 'bohemian waxwing',
+ 'cedar waxwing',
+ 'american three toed woodpecker',
+ 'pileated woodpecker',
+ 'red bellied woodpecker',
+ 'red cockaded woodpecker',
+ 'red headed woodpecker',
+ 'downy woodpecker',
+ 'bewick wren',
+ 'cactus wren',
+ 'carolina wren',
+ 'house wren',
+ 'marsh wren',
+ 'rock wren',
+ 'winter wren',
+ 'common yellowthroat'
+]
diff --git a/datasets/seq_cub200_rs.py b/datasets/seq_cub200_rs.py
new file mode 100644
index 00000000..afaac1ce
--- /dev/null
+++ b/datasets/seq_cub200_rs.py
@@ -0,0 +1,87 @@
+"""
+Implements the Sequential CUB200 Dataset, as used in `Transfer without Forgetting `_ (Version with ResNet50 as backbone).
+"""
+
+import torch
+import torchvision.transforms as transforms
+from typing import Tuple
+
+from backbone.ResNetBottleneck import resnet50
+from datasets.seq_cub200 import SequentialCUB200, MyCUB200, CUB200
+from datasets.transforms.denormalization import DeNormalize
+from datasets.utils import set_default_from_args
+from datasets.utils.continual_dataset import store_masked_loaders
+from utils.conf import base_path
+
+
+class MyCUB200RS(MyCUB200):
+ MEAN, STD = (0.4856, 0.4994, 0.4324), (0.2272, 0.2226, 0.2613)
+ TEST_TRANSFORM = transforms.Compose([transforms.Resize(MyCUB200.IMG_SIZE), transforms.ToTensor(), transforms.Normalize(MEAN, STD)])
+
+
+class SequentialCUB200RS(SequentialCUB200):
+ """Sequential CUB200 Dataset. Version with ResNet50 (as in `Transfer without Forgetting`)
+
+ Args:
+ NAME (str): name of the dataset.
+ SETTING (str): setting of the dataset.
+ N_CLASSES_PER_TASK (int): number of classes per task.
+ N_TASKS (int): number of tasks.
+ SIZE (tuple): size of the images.
+ MEAN (tuple): mean of the dataset.
+ STD (tuple): standard deviation of the dataset.
+ TRANSFORM (torchvision.transforms): transformation to apply to the data.
+ TEST_TRANSFORM (torchvision.transforms): transformation to apply to the test data.
+ """
+ NAME = 'seq-cub200-rs'
+ SETTING = 'class-il'
+ N_CLASSES_PER_TASK = 20
+ N_TASKS = 10
+ SIZE = (MyCUB200RS.IMG_SIZE, MyCUB200RS.IMG_SIZE)
+ MEAN, STD = (0.4856, 0.4994, 0.4324), (0.2272, 0.2226, 0.2613)
+ TRANSFORM = transforms.Compose([
+ transforms.Resize(MyCUB200RS.IMG_SIZE),
+ transforms.RandomCrop(MyCUB200RS.IMG_SIZE, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN, STD)])
+ TEST_TRANSFORM = MyCUB200RS.TEST_TRANSFORM
+
+ def get_data_loaders(self) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
+ train_dataset = MyCUB200RS(base_path() + 'CUB200', train=True,
+ download=True, transform=SequentialCUB200RS.TRANSFORM)
+ test_dataset = CUB200(base_path() + 'CUB200', train=False,
+ download=True, transform=SequentialCUB200RS.TEST_TRANSFORM)
+
+ train, test = store_masked_loaders(
+ train_dataset, test_dataset, self)
+
+ return train, test
+
+ @staticmethod
+ def get_transform():
+ transform = transforms.Compose(
+ [transforms.ToPILImage(), SequentialCUB200RS.TRANSFORM])
+ return transform
+
+ @staticmethod
+ def get_backbone():
+ num_classes = SequentialCUB200RS.N_CLASSES_PER_TASK * SequentialCUB200RS.N_TASKS
+ return resnet50(num_classes, pretrained=True)
+
+ @staticmethod
+ def get_normalization_transform():
+ return transforms.Normalize(SequentialCUB200RS.MEAN, SequentialCUB200RS.STD)
+
+ @staticmethod
+ def get_denormalization_transform():
+ transform = DeNormalize(SequentialCUB200RS.MEAN, SequentialCUB200RS.STD)
+ return transform
+
+ @set_default_from_args('batch_size')
+ def get_batch_size(self):
+ return 16
+
+ @set_default_from_args('n_epochs')
+ def get_epochs(self):
+ return 30
diff --git a/datasets/seq_eurosat_rgb.py b/datasets/seq_eurosat_rgb.py
new file mode 100644
index 00000000..c93d8928
--- /dev/null
+++ b/datasets/seq_eurosat_rgb.py
@@ -0,0 +1,204 @@
+import io
+import json
+import logging
+import os
+import sys
+import zipfile
+import pandas as pd
+import requests
+import torch
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+import numpy as np
+from PIL import Image
+from typing import Tuple
+try:
+ from google_drive_downloader import GoogleDriveDownloader as gdd
+except ImportError:
+ raise ImportError("Please install the google_drive_downloader package by running: `pip install googledrivedownloader`")
+
+from datasets.utils import set_default_from_args
+from utils.conf import base_path
+from datasets.utils.continual_dataset import ContinualDataset, fix_class_names_order, store_masked_loaders
+from datasets.transforms.denormalization import DeNormalize
+from torchvision.transforms.functional import InterpolationMode
+from utils.prompt_templates import templates
+from backbone.vit import vit_base_patch16_224_prompt_prototype
+
+
+class MyEuroSat(Dataset):
+
+ def __init__(self, root, split='train', transform=None,
+ target_transform=None) -> None:
+
+ self.root = root
+ self.split = split
+ assert split in ['train', 'test', 'val'], 'Split must be either train, test or val'
+ self.transform = transform
+ self.target_transform = target_transform
+ self.totensor = transforms.ToTensor()
+
+ if not os.path.exists(root + '/DONE'):
+ print('Preparing dataset...', file=sys.stderr)
+ r = requests.get('https://zenodo.org/records/7711810/files/EuroSAT_RGB.zip?download=1')
+ z = zipfile.ZipFile(io.BytesIO(r.content))
+ z.extractall(root)
+ os.system(f'mv {root}/EuroSAT_RGB/* {root}')
+ os.system(f'rmdir {root}/EuroSAT_RGB')
+
+ # create DONE file
+ with open(self.root + '/DONE', 'w') as f:
+ f.write('')
+
+ # downlaod split file form https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/
+ gdd.download_file_from_google_drive(file_id='1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o',
+ dest_path=self.root + '/split.json')
+
+ print('Done', file=sys.stderr)
+
+ self.data_split = pd.DataFrame(json.load(open(self.root + '/split.json', 'r'))[split])
+ self.class_names = self.get_class_names()
+
+ self.data = self.data_split[0].values
+ self.targets = self.data_split[1].values
+
+ @staticmethod
+ def get_class_names():
+ if not os.path.exists(base_path() + f'eurosat/DONE'):
+ gdd.download_file_from_google_drive(file_id='1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o',
+ dest_path=base_path() + 'eurosat/split.json')
+ return pd.DataFrame(json.load(open(base_path() + 'eurosat/split.json', 'r'))['train'])[2].unique()
+
+ def __len__(self):
+ return len(self.targets)
+
+ def __getitem__(self, index: int) -> Tuple[Image.Image, int, Image.Image]:
+ """
+ Gets the requested element from the dataset.
+ :param index: index of the element to be returned
+ :returns: tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], self.targets[index]
+
+ img = Image.open(self.root + '/' + img).convert('RGB')
+
+ not_aug_img = self.totensor(img.copy())
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ if self.split != 'train':
+ return img, target
+
+ if hasattr(self, 'logits'):
+ return img, target, not_aug_img, self.logits[index]
+
+ return img, target, not_aug_img
+
+
+def my_collate_fn(batch):
+ tmp = list(zip(*batch))
+ imgs = torch.stack(tmp[0], dim=0)
+ labels = torch.tensor(tmp[1])
+ if len(tmp) == 2:
+ return imgs, labels
+ not_aug_imgs = tmp[2]
+ not_aug_imgs = torch.stack(not_aug_imgs, dim=0)
+ if len(tmp) == 4:
+ logits = torch.stack(tmp[3], dim=0)
+ return imgs, labels, not_aug_imgs, logits
+ return imgs, labels, not_aug_imgs
+
+
+class SequentialEuroSatRgb(ContinualDataset):
+
+ NAME = 'seq-eurosat-rgb'
+ SETTING = 'class-il'
+ N_TASKS = 5
+ N_CLASSES = 10
+ N_CLASSES_PER_TASK = 2
+ SIZE = (224, 224)
+ MEAN, STD = [0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]
+
+ TRANSFORM = transforms.Compose([
+ transforms.RandomResizedCrop(SIZE[0], scale=(0.08, 1.0), interpolation=InterpolationMode.BICUBIC), # from https://github.dev/KaiyangZhou/Dassl.pytorch defaults
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=MEAN, std=STD),
+ ])
+
+ TEST_TRANSFORM = transforms.Compose([
+ transforms.Resize(SIZE[0], interpolation=InterpolationMode.BICUBIC), # bicubic
+ transforms.CenterCrop(SIZE[0]),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=MEAN, std=STD),
+ ])
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+
+ try:
+ classes = MyEuroSat.get_class_names()
+ except BaseException:
+ logging.warning("dataset not loaded yet -- loading dataset...")
+ MyEuroSat(base_path() + 'eurosat', train=True,
+ transform=None)
+ classes = MyEuroSat.get_class_names()
+
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return self.class_names
+
+ def __init__(self, args):
+ super().__init__(args)
+ self.args = args
+
+ def get_data_loaders(self):
+ train_dataset = MyEuroSat(base_path() + 'eurosat', split='train',
+ transform=self.TRANSFORM)
+ test_dataset = MyEuroSat(base_path() + 'eurosat', split='test',
+ transform=self.TEST_TRANSFORM)
+
+ train, test = store_masked_loaders(train_dataset, test_dataset, self)
+
+ return train, test
+
+ @staticmethod
+ def get_transform():
+ transform = transforms.Compose([transforms.ToPILImage(),
+ SequentialEuroSatRgb.TRANSFORM])
+ return transform
+
+ @staticmethod
+ def get_backbone():
+ return vit_base_patch16_224_prompt_prototype(pretrained=True, num_classes=SequentialEuroSatRgb.N_CLASSES)
+
+ @staticmethod
+ def get_loss():
+ return F.cross_entropy
+
+ @staticmethod
+ def get_normalization_transform():
+ return transforms.Normalize(mean=SequentialEuroSatRgb.MEAN, std=SequentialEuroSatRgb.STD)
+
+ @staticmethod
+ def get_denormalization_transform():
+ transform = DeNormalize(SequentialEuroSatRgb.MEAN, SequentialEuroSatRgb.STD)
+ return transform
+
+ @set_default_from_args('n_epochs')
+ def get_epochs(self):
+ return 5
+
+ @set_default_from_args('batch_size')
+ def get_batch_size(self):
+ return 128
+
+ @staticmethod
+ def get_prompt_templates():
+ return templates['eurosat']
diff --git a/datasets/seq_imagenet_r.py b/datasets/seq_imagenet_r.py
index 21765f2a..87142ee5 100644
--- a/datasets/seq_imagenet_r.py
+++ b/datasets/seq_imagenet_r.py
@@ -1,21 +1,28 @@
+import logging
+try:
+ import requests
+except ImportError as e:
+ logging.error("Please install requests using 'pip install requests'")
+ raise e
+
import os
-from requests import request
import torchvision.transforms as transforms
-from torchvision.models import resnet18
import torch.nn.functional as F
+from torch.utils.data import Dataset
import numpy as np
-from utils.conf import base_path
+import pickle
from PIL import Image
-from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders
from typing import Tuple
-from datasets.transforms.denormalization import DeNormalize
-from torch.utils.data import Dataset
-import torch.nn as nn
+
import yaml
-import pickle
-from torchvision.transforms.functional import InterpolationMode
-from utils.prompt_templates import templates
+
from datasets.utils import set_default_from_args
+from utils import smart_joint
+from utils.conf import base_path
+from datasets.utils.continual_dataset import ContinualDataset, fix_class_names_order, store_masked_loaders
+from datasets.transforms.denormalization import DeNormalize
+from torchvision.transforms.functional import InterpolationMode
+from backbone.vit import vit_base_patch16_224_prompt_prototype
class MyImagenetR(Dataset):
@@ -40,10 +47,10 @@ def __init__(self, root, train=True, transform=None,
# download from https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar
print("Downloading imagenet-r dataset...")
url = 'https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar'
- r = request('GET', url, allow_redirects=True)
+ r = requests.get(url, allow_redirects=True)
if not os.path.exists(self.root):
os.makedirs(self.root)
- print("Saving tar...")
+ print("Writing tar on disk...")
open(self.root + 'imagenet-r.tar', 'wb').write(r.content)
print("Extracting tar...")
os.system('tar -xf ' + self.root + 'imagenet-r.tar -C ' + self.root.rstrip('imagenet-r'))
@@ -74,7 +81,7 @@ def __init__(self, root, train=True, transform=None,
def __len__(self):
return len(self.targets)
- def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]:
+ def __getitem__(self, index: int) -> Tuple[Image.Image, int, Image.Image]:
"""
Gets the requested element from the dataset.
:param index: index of the element to be returned
@@ -110,54 +117,44 @@ class SequentialImagenetR(ContinualDataset):
N_TASKS = 10
N_CLASSES = 200
N_CLASSES_PER_TASK = N_CLASSES // N_TASKS
- normalize = transforms.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0))
+ MEAN, STD = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
SIZE = (224, 224)
TRANSFORM = transforms.Compose([
- transforms.RandomResizedCrop(224, interpolation=InterpolationMode.BICUBIC),
+ transforms.RandomResizedCrop(SIZE[0], interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
- normalize,
+ transforms.Normalize(mean=MEAN, std=STD),
])
- TEST_TRANSFORM = transforms.Compose([
- transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
- transforms.ToTensor(),
- normalize,
- ])
-
- def __init__(self, args):
- super().__init__(args)
- self.args = args
- self.label_to_class_name = self.get_class_names()
+ TEST_TRANSFORM = transforms.Compose([transforms.Resize(size=(256, 256),
+ interpolation=InterpolationMode.BICUBIC),
+ transforms.CenterCrop(SIZE[0]),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=MEAN, std=STD)])
def get_data_loaders(self):
- transform = self.TRANSFORM
-
- test_transform = transforms.Compose(
- [transforms.Resize(size=(256, 256), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), self.normalize])
-
train_dataset = MyImagenetR(base_path() + 'imagenet-r/', train=True,
- download=True, transform=transform)
+ download=True, transform=self.TRANSFORM)
test_dataset = MyImagenetR(base_path() + 'imagenet-r/', train=False,
- download=True, transform=test_transform)
+ download=True, transform=self.TEST_TRANSFORM)
train, test = store_masked_loaders(train_dataset, test_dataset, self)
return train, test
def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+
pwd = os.path.dirname(os.path.abspath(__file__))
with open(pwd + '/imagenet_r_utils/label_to_class_name.pkl', 'rb') as f:
label_to_class_name = pickle.load(f)
class_names = label_to_class_name.values()
class_names = [x.replace('_', ' ') for x in class_names]
- if hasattr(self.args, 'class_order'):
- class_names = [class_names[i] for i in self.class_order]
- return class_names
- @staticmethod
- def get_prompt_templates():
- return templates['imagenet']
+ class_names = fix_class_names_order(class_names, self.args)
+ self.class_names = class_names
+ return self.class_names
@staticmethod
def get_transform():
@@ -166,11 +163,8 @@ def get_transform():
return transform
@staticmethod
- def get_backbone(hookme=False):
- backbone = resnet18()
- num_classes = SequentialImagenetR.N_CLASSES_PER_TASK * SequentialImagenetR.N_TASKS
- backbone.fc = nn.Linear(in_features=512, out_features=num_classes, bias=True)
- return backbone
+ def get_backbone():
+ return vit_base_patch16_224_prompt_prototype(pretrained=True, num_classes=SequentialImagenetR.N_CLASSES)
@staticmethod
def get_loss():
@@ -178,12 +172,11 @@ def get_loss():
@staticmethod
def get_normalization_transform():
- return transforms.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0))
+ return transforms.Normalize(mean=SequentialImagenetR.MEAN, std=SequentialImagenetR.STD)
@staticmethod
def get_denormalization_transform():
- transform = DeNormalize((0, 0, 0),
- (1, 1, 1))
+ transform = DeNormalize(SequentialImagenetR.MEAN, SequentialImagenetR.STD)
return transform
@set_default_from_args('n_epochs')
@@ -192,12 +185,4 @@ def get_epochs(self):
@set_default_from_args('batch_size')
def get_batch_size(self):
- return 32
-
- @staticmethod
- def get_virtual_bn_num():
- return 4
-
- @staticmethod
- def get_n_epochs_first_stage():
- return 50
+ return 128
diff --git a/datasets/seq_isic.py b/datasets/seq_isic.py
new file mode 100644
index 00000000..75433020
--- /dev/null
+++ b/datasets/seq_isic.py
@@ -0,0 +1,170 @@
+import os
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+import numpy as np
+import pickle
+from PIL import Image
+from typing import Tuple
+
+from datasets.utils import set_default_from_args
+from utils import smart_joint
+from utils.conf import base_path
+from datasets.utils.continual_dataset import ContinualDataset, fix_class_names_order, store_masked_loaders
+from datasets.transforms.denormalization import DeNormalize
+from torchvision.transforms.functional import InterpolationMode
+from utils.prompt_templates import templates
+from backbone.vit import vit_base_patch16_224_prompt_prototype
+
+
+class Isic(Dataset):
+ N_CLASSES = 6
+
+ LABELS = ['melanoma',
+ 'basal cell carcinoma',
+ 'actinic keratosis or intraepithelial carcinoma',
+ 'benign keratosis',
+ 'dermatofibroma',
+ 'vascular skin lesion']
+
+ """
+ Overrides the ChestX dataset to change the getitem function.
+ """
+
+ def __init__(self, root, train=True, transform=None,
+ target_transform=None, download=False) -> None:
+
+ self.root = root
+ self.train = train
+ self.transform = transform
+ self.target_transform = target_transform
+
+ split = 'train' if train else 'test'
+ if not os.path.exists(f'{root}/{split}_images.pkl'):
+ if download:
+ ln = 'https://unimore365-my.sharepoint.com/:u:/g/personal/215580_unimore_it/ERM64PkPkFtJhmiUQkVvE64BR900MbIHtJVA_CR4KKhy8A?e=OsrQr5'
+ from onedrivedownloader import download
+ download(ln, filename=smart_joint(root, 'isic.tar.gz'), unzip=True, unzip_path=root.rstrip('isic'), clean=True)
+ else:
+ raise FileNotFoundError(f'File not found: {root}/{split}_images.pkl')
+
+ filename_labels = f'{self.root}/{split}_labels.pkl'
+ filename_images = f'{self.root}/{split}_images.pkl'
+
+ self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
+
+ with open(filename_images, 'rb') as f:
+ self.data = pickle.load(f)
+
+ with open(filename_labels, 'rb') as f:
+ self.targets = pickle.load(f)
+
+ def __len__(self):
+ return len(self.targets)
+
+ def __getitem__(self, index: int) -> Tuple[Image.Image, int, Image.Image]:
+ """
+ Gets the requested element from the dataset.
+ :param index: index of the element to be returned
+ :returns: tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], self.targets[index]
+ img = Image.fromarray((img * 255).astype(np.int8), mode='RGB')
+
+ original_img = img.copy()
+
+ not_aug_img = self.not_aug_transform(original_img)
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ if not self.train:
+ return img, target
+
+ if hasattr(self, 'logits'):
+ return img, target, not_aug_img, self.logits[index]
+
+ return img, target, not_aug_img
+
+
+class SequentialIsic(ContinualDataset):
+
+ NAME = 'seq-isic'
+ SETTING = 'class-il'
+ N_TASKS = 3
+ N_CLASSES_PER_TASK = 2
+ N_CLASSES = 6
+ SIZE = (224, 224)
+ MEAN, STD = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
+
+ TRANSFORM = transforms.Compose([
+ transforms.Resize(256, interpolation=InterpolationMode.BICUBIC),
+ transforms.RandomCrop(SIZE[0]),
+ transforms.RandomHorizontalFlip(0.5),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=MEAN, std=STD),
+ ])
+
+ TEST_TRANSFORM = transforms.Compose([
+ transforms.Resize(size=(256, 256), interpolation=InterpolationMode.BICUBIC),
+ transforms.CenterCrop(SIZE[0]),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=MEAN, std=STD),
+ ])
+
+ def get_data_loaders(self):
+ train_dataset = Isic(base_path() + 'isic', train=True,
+ download=True, transform=self.TRANSFORM)
+
+ test_dataset = Isic(base_path() + 'isic', train=False, download=True,
+ transform=self.TEST_TRANSFORM)
+
+ train, test = store_masked_loaders(train_dataset, test_dataset, self)
+
+ return train, test
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = fix_class_names_order(Isic.LABELS, self.args)
+ self.class_names = classes
+ return self.class_names
+
+ @staticmethod
+ def get_prompt_templates():
+ return templates['cifar100']
+
+ @staticmethod
+ def get_transform():
+ transform = transforms.Compose([
+ transforms.ToPILImage(),
+ SequentialIsic.TRANSFORM])
+ return transform
+
+ @staticmethod
+ def get_backbone():
+ return vit_base_patch16_224_prompt_prototype(pretrained=True, num_classes=SequentialIsic.N_CLASSES)
+
+ @staticmethod
+ def get_loss():
+ return F.cross_entropy
+
+ @staticmethod
+ def get_normalization_transform():
+ return transforms.Normalize(mean=SequentialIsic.MEAN, std=SequentialIsic.STD)
+
+ @staticmethod
+ def get_denormalization_transform():
+ transform = DeNormalize(mean=SequentialIsic.MEAN, std=SequentialIsic.STD)
+ return transform
+
+ @set_default_from_args('n_epochs')
+ def get_epochs(self):
+ return 30
+
+ @set_default_from_args('batch_size')
+ def get_batch_size(self):
+ return 128
diff --git a/datasets/seq_mit67.py b/datasets/seq_mit67.py
new file mode 100644
index 00000000..328f4ed2
--- /dev/null
+++ b/datasets/seq_mit67.py
@@ -0,0 +1,244 @@
+import glob
+import io
+import os
+import tarfile
+import requests
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+import numpy as np
+from PIL import Image
+
+from datasets.utils import set_default_from_args
+from utils.conf import base_path
+from datasets.utils.continual_dataset import ContinualDataset, fix_class_names_order, store_masked_loaders
+from datasets.transforms.denormalization import DeNormalize
+from torchvision.transforms.functional import InterpolationMode
+from utils.prompt_templates import templates
+from backbone.vit import vit_base_patch16_224_prompt_prototype
+
+idx_to_class_names = {
+ 0: 'airport_inside',
+ 1: 'artstudio',
+ 2: 'auditorium',
+ 3: 'bakery',
+ 4: 'bar',
+ 5: 'bathroom',
+ 6: 'bedroom',
+ 7: 'bookstore',
+ 8: 'bowling',
+ 9: 'buffet',
+ 10: 'casino',
+ 11: 'children_room',
+ 12: 'church_inside',
+ 13: 'classroom',
+ 14: 'cloister',
+ 15: 'closet',
+ 16: 'clothingstore',
+ 17: 'computerroom',
+ 18: 'concert_hall',
+ 19: 'corridor',
+ 20: 'deli',
+ 21: 'dentaloffice',
+ 22: 'dining_room',
+ 23: 'elevator',
+ 24: 'fastfood_restaurant',
+ 25: 'florist',
+ 26: 'gameroom',
+ 27: 'garage',
+ 28: 'greenhouse',
+ 29: 'grocerystore',
+ 30: 'gym',
+ 31: 'hairsalon',
+ 32: 'hospitalroom',
+ 33: 'inside_bus',
+ 34: 'inside_subway',
+ 35: 'jewelleryshop',
+ 36: 'kindergarden',
+ 37: 'kitchen',
+ 38: 'laboratorywet',
+ 39: 'laundromat',
+ 40: 'library',
+ 41: 'livingroom',
+ 42: 'lobby',
+ 43: 'locker_room',
+ 44: 'mall',
+ 45: 'meeting_room',
+ 46: 'movietheater',
+ 47: 'museum',
+ 48: 'nursery',
+ 49: 'office',
+ 50: 'operating_room',
+ 51: 'pantry',
+ 52: 'poolinside',
+ 53: 'prisoncell',
+ 54: 'restaurant',
+ 55: 'restaurant_kitchen',
+ 56: 'shoeshop',
+ 57: 'stairscase',
+ 58: 'studiomusic',
+ 59: 'subway',
+ 60: 'toystore',
+ 61: 'trainstation',
+ 62: 'tv_studio',
+ 63: 'videostore',
+ 64: 'waitingroom',
+ 65: 'warehouse',
+ 66: 'winecellar'
+}
+
+
+class MyMIT67(Dataset):
+ NUM_CLASSES = 67
+
+ def __init__(self, root, train=True, download=True, transform=None,
+ target_transform=None) -> None:
+ self.root = os.path.join(base_path(), 'MIT67')
+ self.transform = transform
+ self.train = train
+ self.target_transform = target_transform
+ self.not_aug_transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
+
+ if not os.path.exists(self.root) and download:
+ print('Downloading MIT67 dataset...')
+ if not os.path.exists(self.root):
+ os.makedirs(self.root)
+ train_images_link = 'http://groups.csail.mit.edu/vision/LabelMe/NewImages/indoorCVPR_09.tar'
+ train_labels_link = 'https://web.mit.edu/torralba/www/TrainImages.txt'
+ test_images_link = 'https://web.mit.edu/torralba/www/TestImages.txt'
+ r = requests.get(train_images_link)
+ z = tarfile.open(fileobj=io.BytesIO(r.content))
+ z.extractall(root)
+
+ r = requests.get(train_labels_link)
+ with open(os.path.join(self.root, 'TrainImages.txt'), 'wb') as f:
+ f.write(r.content)
+
+ r = requests.get(test_images_link)
+ with open(os.path.join(self.root, 'TestImages.txt'), 'wb') as f:
+ f.write(r.content)
+ print('MIT67 dataset downloaded')
+ else:
+ print('MIT67 dataset already downloaded')
+
+ folder_targets = {os.path.basename(f[:-1]): i for i, f in enumerate(sorted(glob.glob(os.path.join(self.root, 'Images/*/'))))}
+
+ train_images_path = os.path.join(self.root, 'TrainImages.txt')
+ test_images_path = os.path.join(self.root, 'TestImages.txt')
+
+ if self.train:
+ with open(train_images_path) as f:
+ paths = f.readlines()
+ else:
+ with open(test_images_path) as f:
+ paths = f.readlines()
+ paths = [p.strip() for p in paths]
+ self.data = [os.path.join(self.root, 'Images', p) for p in paths]
+ self.data = np.array(self.data)
+ self.targets = [folder_targets[p.split('/')[0]] for p in paths]
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ target = self.targets[index]
+ img = Image.open(self.data[index])
+ img = img.convert('RGB')
+
+ original_img = img.copy()
+ not_aug_img = self.not_aug_transform(original_img)
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ if not self.train:
+ return img, target
+
+ return img, target, not_aug_img
+
+
+class SequentialMIT67(ContinualDataset):
+
+ NAME = 'seq-mit67'
+ SETTING = 'class-il'
+ N_TASKS = 10
+ N_CLASSES = 67
+ N_CLASSES_PER_TASK = [7] * 7 + [6] * 3
+ SIZE = (224, 224)
+ MEAN = [0.485, 0.456, 0.406]
+ STD = [0.229, 0.224, 0.225]
+ TRANSFORM = transforms.Compose([
+ transforms.Resize(256, interpolation=InterpolationMode.BICUBIC),
+ transforms.RandomCrop(SIZE),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN, STD)
+ ])
+ TEST_TRANSFORM = transforms.Compose([
+ transforms.Resize(256),
+ transforms.CenterCrop(SIZE),
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN, STD)
+ ])
+
+ def get_data_loaders(self):
+ train_dataset = MyMIT67(base_path() + 'MIT67', train=True,
+ download=True, transform=self.TRANSFORM)
+ test_dataset = MyMIT67(base_path() + 'MIT76', train=False,
+ download=True, transform=self.TEST_TRANSFORM)
+
+ train, test = store_masked_loaders(train_dataset, test_dataset, self)
+
+ return train, test
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = list(idx_to_class_names.values())
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return classes
+
+ @staticmethod
+ def get_prompt_templates():
+ return templates['cifar100']
+
+ @staticmethod
+ def get_transform():
+ transform = transforms.Compose([transforms.ToPILImage(),
+ SequentialMIT67.TRANSFORM])
+ return transform
+
+ @staticmethod
+ def get_backbone():
+ return vit_base_patch16_224_prompt_prototype(pretrained=True, num_classes=SequentialMIT67.N_CLASSES)
+
+ @staticmethod
+ def get_loss():
+ return F.cross_entropy
+
+ @staticmethod
+ def get_normalization_transform():
+ return transforms.Normalize(SequentialMIT67.MEAN, SequentialMIT67.STD)
+
+ @staticmethod
+ def get_denormalization_transform():
+ return DeNormalize(SequentialMIT67.MEAN, SequentialMIT67.STD)
+
+ @set_default_from_args('n_epochs')
+ def get_epochs(self):
+ return 50
+
+ @set_default_from_args('batch_size')
+ def get_batch_size(self):
+ return 32
diff --git a/datasets/seq_mnist.py b/datasets/seq_mnist.py
index 9bb5ed2c..6d91a0dc 100644
--- a/datasets/seq_mnist.py
+++ b/datasets/seq_mnist.py
@@ -12,7 +12,7 @@
from torchvision.datasets import MNIST
from backbone.MNISTMLP import MNISTMLP
-from datasets.utils.continual_dataset import (ContinualDataset,
+from datasets.utils.continual_dataset import (ContinualDataset, fix_class_names_order,
store_masked_loaders)
from utils.conf import base_path
from datasets.utils import set_default_from_args
@@ -116,3 +116,12 @@ def get_batch_size(self):
@set_default_from_args('n_epochs')
def get_epochs(self):
return 1
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = MNIST(base_path() + 'MNIST', train=True, download=True).classes
+ classes = [c.split('-')[1].strip() for c in classes]
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return self.class_names
diff --git a/datasets/seq_resisc45.py b/datasets/seq_resisc45.py
new file mode 100644
index 00000000..4ae38dc5
--- /dev/null
+++ b/datasets/seq_resisc45.py
@@ -0,0 +1,207 @@
+import os
+from typing import Tuple
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+import numpy as np
+from PIL import Image
+import yaml
+
+from datasets.utils import set_default_from_args
+from utils import smart_joint
+from utils.conf import base_path
+from datasets.utils.continual_dataset import ContinualDataset, fix_class_names_order, store_masked_loaders
+from datasets.transforms.denormalization import DeNormalize
+from torchvision.transforms.functional import InterpolationMode
+from utils.prompt_templates import templates
+from backbone.vit import vit_base_patch16_224_prompt_prototype
+
+
+class Resisc45(Dataset):
+
+ N_CLASSES = 45
+ LABELS = [
+ 'airplane',
+ 'airport',
+ 'baseball_diamond',
+ 'basketball_court',
+ 'beach',
+ 'bridge',
+ 'chaparral',
+ 'church',
+ 'circular_farmland',
+ 'cloud',
+ 'commercial_area',
+ 'dense_residential',
+ 'desert',
+ 'forest',
+ 'freeway',
+ 'golf_course',
+ 'ground_track_field',
+ 'harbor',
+ 'industrial_area',
+ 'intersection',
+ 'island',
+ 'lake',
+ 'meadow',
+ 'medium_residential',
+ 'mobile_home_park',
+ 'mountain',
+ 'overpass',
+ 'palace',
+ 'parking_lot',
+ 'railway',
+ 'railway_station',
+ 'rectangular_farmland',
+ 'river',
+ 'roundabout',
+ 'runway',
+ 'sea_ice',
+ 'ship',
+ 'snowberg',
+ 'sparse_residential',
+ 'stadium',
+ 'storage_tank',
+ 'tennis_court',
+ 'terrace',
+ 'thermal_power_station',
+ 'wetland',
+ ]
+
+ def __init__(self, root, train=True, transform=None,
+ target_transform=None, download=False) -> None:
+
+ self.root = root
+ self.train = train
+ self.transform = transform
+ self.target_transform = target_transform
+
+ self.not_aug_transform = transforms.Compose([
+ transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor()]
+ )
+
+ if download:
+ if os.path.isdir(root) and len(os.listdir(root)) > 0:
+ print('Download not needed, files already on disk.')
+ else:
+ # download from https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar
+ print("Downloading resisc45 dataset...")
+ ln = 'https://unimore365-my.sharepoint.com/:u:/g/personal/215580_unimore_it/EbxMu5z5HbVIkG9qFCGbg7ABDRZvpBEA8uqVC-Em9HYVug?e=Cfc4Yc'
+ from onedrivedownloader import download
+ download(ln, filename=os.path.join(root, 'resisc45.tar.gz'), unzip=True, unzip_path=root, clean=True)
+ print("Done!")
+
+ if self.train:
+ data_config = yaml.load(open(smart_joint(root, 'resisc45_train.yaml')), Loader=yaml.Loader)
+ else:
+ data_config = yaml.load(open(smart_joint(root, 'resisc45_test.yaml')), Loader=yaml.Loader)
+
+ self.data = np.array([smart_joint(root, d) for d in data_config['data']])
+ self.targets = np.array(data_config['targets']).astype(np.int16)
+
+ def __len__(self):
+ return len(self.targets)
+
+ def __getitem__(self, index: int) -> Tuple[Image.Image, int, Image.Image]:
+ """
+ Gets the requested element from the dataset.
+ :param index: index of the element to be returned
+ :returns: tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], self.targets[index]
+
+ img = Image.open(img).convert('RGB')
+
+ original_img = img.copy()
+
+ not_aug_img = self.not_aug_transform(original_img)
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ if not self.train:
+ return img, target
+
+ if hasattr(self, 'logits'):
+ return img, target, not_aug_img, self.logits[index]
+
+ return img, target, not_aug_img
+
+
+class SequentialResisc45(ContinualDataset):
+
+ NAME = 'seq-resisc45'
+ SETTING = 'class-il'
+ N_TASKS = 9
+ N_CLASSES_PER_TASK = 45 // N_TASKS
+ N_CLASSES = 45
+ SIZE = (224, 224)
+ MEAN, STD = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
+ TRANSFORM = transforms.Compose([
+ transforms.RandomResizedCrop(SIZE[0], interpolation=InterpolationMode.BICUBIC),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN, STD),
+ ])
+
+ TEST_TRANSFORM = transforms.Compose([
+ transforms.Resize(size=(256, 256), interpolation=InterpolationMode.BICUBIC),
+ transforms.CenterCrop(SIZE),
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN, STD)
+ ])
+
+ def get_data_loaders(self):
+ train_dataset = Resisc45(base_path() + 'NWPU-RESISC45', train=True,
+ download=True, transform=self.TRANSFORM)
+ test_dataset = Resisc45(base_path() + 'NWPU-RESISC45', train=False,
+ download=True, transform=self.TEST_TRANSFORM)
+
+ train, test = store_masked_loaders(train_dataset, test_dataset, self)
+
+ return train, test
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = [x.replace('_', ' ') for x in Resisc45.LABELS]
+ classes = fix_class_names_order(classes, self.args)
+ self.class_names = classes
+ return classes
+
+ @staticmethod
+ def get_prompt_templates():
+ return templates['eurosat']
+
+ @staticmethod
+ def get_transform():
+ return transforms.Compose([transforms.ToPILImage(),
+ SequentialResisc45.TRANSFORM])
+
+ @staticmethod
+ def get_backbone():
+ return vit_base_patch16_224_prompt_prototype(pretrained=True, num_classes=SequentialResisc45.N_CLASSES)
+
+ @staticmethod
+ def get_loss():
+ return F.cross_entropy
+
+ @staticmethod
+ def get_normalization_transform():
+ return transforms.Normalize(mean=SequentialResisc45.MEAN, std=SequentialResisc45.STD)
+
+ @staticmethod
+ def get_denormalization_transform():
+ return DeNormalize(mean=SequentialResisc45.MEAN, std=SequentialResisc45.STD)
+
+ @set_default_from_args('n_epochs')
+ def get_epochs(self):
+ return 30
+
+ @set_default_from_args('batch_size')
+ def get_batch_size(self):
+ return 128
diff --git a/datasets/seq_tinyimagenet.py b/datasets/seq_tinyimagenet.py
index ac36d554..89588f93 100644
--- a/datasets/seq_tinyimagenet.py
+++ b/datasets/seq_tinyimagenet.py
@@ -16,7 +16,7 @@
from backbone.ResNetBlock import resnet18
from datasets.transforms.denormalization import DeNormalize
-from datasets.utils.continual_dataset import (ContinualDataset,
+from datasets.utils.continual_dataset import (ContinualDataset, fix_class_names_order,
store_masked_loaders)
from utils import smart_joint
from utils.conf import base_path
@@ -185,3 +185,214 @@ def get_epochs(self):
@set_default_from_args('batch_size')
def get_batch_size(self):
return 32
+
+ def get_class_names(self):
+ if self.class_names is not None:
+ return self.class_names
+ classes = fix_class_names_order(CLASS_NAMES, self.args)
+ self.class_names = classes
+ return self.class_names
+
+
+CLASS_NAMES = [
+ 'egyptian_cat',
+ 'reel',
+ 'volleyball',
+ 'rocking_chair',
+ 'lemon',
+ 'bullfrog',
+ 'basketball',
+ 'cliff',
+ 'espresso',
+ 'plunger',
+ 'parking_meter',
+ 'german_shepherd',
+ 'dining_table',
+ 'monarch',
+ 'brown_bear',
+ 'school_bus',
+ 'pizza',
+ 'guinea_pig',
+ 'umbrella',
+ 'organ',
+ 'oboe',
+ 'maypole',
+ 'goldfish',
+ 'potpie',
+ 'hourglass',
+ 'seashore',
+ 'computer_keyboard',
+ 'arabian_camel',
+ 'ice_cream',
+ 'nail',
+ 'space_heater',
+ 'cardigan',
+ 'baboon',
+ 'snail',
+ 'coral_reef',
+ 'albatross',
+ 'spider_web',
+ 'sea_cucumber',
+ 'backpack',
+ 'labrador_retriever',
+ 'pretzel',
+ 'king_penguin',
+ 'sulphur_butterfly',
+ 'tarantula',
+ 'lesser_panda',
+ 'pop_bottle',
+ 'banana',
+ 'sock',
+ 'cockroach',
+ 'projectile',
+ 'beer_bottle',
+ 'mantis',
+ 'freight_car',
+ 'guacamole',
+ 'remote_control',
+ 'european_fire_salamander',
+ 'lakeside',
+ 'chimpanzee',
+ 'pay-phone',
+ 'fur_coat',
+ 'alp',
+ 'lampshade',
+ 'torch',
+ 'abacus',
+ 'moving_van',
+ 'barrel',
+ 'tabby',
+ 'goose',
+ 'koala',
+ 'bullet_train',
+ 'cd_player',
+ 'teapot',
+ 'birdhouse',
+ 'gazelle',
+ 'academic_gown',
+ 'tractor',
+ 'ladybug',
+ 'miniskirt',
+ 'golden_retriever',
+ 'triumphal_arch',
+ 'cannon',
+ 'neck_brace',
+ 'sombrero',
+ 'gasmask',
+ 'candle',
+ 'desk',
+ 'frying_pan',
+ 'bee',
+ 'dam',
+ 'spiny_lobster',
+ 'police_van',
+ 'ipod',
+ 'punching_bag',
+ 'beacon',
+ 'jellyfish',
+ 'wok',
+ "potter's_wheel",
+ 'sandal',
+ 'pill_bottle',
+ 'butcher_shop',
+ 'slug',
+ 'hog',
+ 'cougar',
+ 'crane',
+ 'vestment',
+ 'dragonfly',
+ 'cash_machine',
+ 'mushroom',
+ 'jinrikisha',
+ 'water_tower',
+ 'chest',
+ 'snorkel',
+ 'sunglasses',
+ 'fly',
+ 'limousine',
+ 'black_stork',
+ 'dugong',
+ 'sports_car',
+ 'water_jug',
+ 'suspension_bridge',
+ 'ox',
+ 'ice_lolly',
+ 'turnstile',
+ 'christmas_stocking',
+ 'broom',
+ 'scorpion',
+ 'wooden_spoon',
+ 'picket_fence',
+ 'rugby_ball',
+ 'sewing_machine',
+ 'steel_arch_bridge',
+ 'persian_cat',
+ 'refrigerator',
+ 'barn',
+ 'apron',
+ 'yorkshire_terrier',
+ 'swimming_trunks',
+ 'stopwatch',
+ 'lawn_mower',
+ 'thatch',
+ 'fountain',
+ 'black_widow',
+ 'bikini',
+ 'plate',
+ 'teddy',
+ 'barbershop',
+ 'confectionery',
+ 'beach_wagon',
+ 'scoreboard',
+ 'orange',
+ 'flagpole',
+ 'american_lobster',
+ 'trolleybus',
+ 'drumstick',
+ 'dumbbell',
+ 'brass',
+ 'bow_tie',
+ 'convertible',
+ 'bighorn',
+ 'orangutan',
+ 'american_alligator',
+ 'centipede',
+ 'syringe',
+ 'go-kart',
+ 'brain_coral',
+ 'sea_slug',
+ 'cliff_dwelling',
+ 'mashed_potato',
+ 'viaduct',
+ 'military_uniform',
+ 'pomegranate',
+ 'chain',
+ 'kimono',
+ 'comic_book',
+ 'trilobite',
+ 'bison',
+ 'pole',
+ 'boa_constrictor',
+ 'poncho',
+ 'bathtub',
+ 'grasshopper',
+ 'walking_stick',
+ 'chihuahua',
+ 'tailed_frog',
+ 'lion',
+ 'altar',
+ 'obelisk',
+ 'beaker',
+ 'bell_pepper',
+ 'bannister',
+ 'bucket',
+ 'magnetic_compass',
+ 'meat_loaf',
+ 'gondola',
+ 'standard_poodle',
+ 'acorn',
+ 'lifeboat',
+ 'binoculars',
+ 'cauliflower',
+ 'african_elephant'
+]
diff --git a/datasets/transforms/denormalization.py b/datasets/transforms/denormalization.py
index c5cea852..edd4e951 100644
--- a/datasets/transforms/denormalization.py
+++ b/datasets/transforms/denormalization.py
@@ -4,6 +4,11 @@
# LICENSE file in the root directory of this source tree.
+import PIL
+import numpy as np
+import torch
+
+
class DeNormalize(object):
def __init__(self, mean, std):
"""
@@ -13,19 +18,36 @@ def __init__(self, mean, std):
mean (list): List of mean values for each channel.
std (list): List of standard deviation values for each channel.
"""
+ if isinstance(mean, (list, tuple)):
+ mean = torch.tensor(mean)
+ elif isinstance(mean, np.ndarray):
+ mean = torch.from_numpy(mean)
+ if isinstance(std, (list, tuple)):
+ std = torch.tensor(std)
+ elif isinstance(std, np.ndarray):
+ std = torch.from_numpy(std)
+
self.mean = mean
self.std = std
- def __call__(self, tensor):
+ def __call__(self, tensor: torch.Tensor | PIL.Image.Image):
"""
Applies denormalization to the input tensor.
Args:
- tensor (Tensor): Tensor image of size (C, H, W) to be denormalized.
+ tensor (Tensor): Tensor of images of size ([B,] C, H, W) to be denormalized.
Returns:
- Tensor: Denormalized image.
+ Tensor: Denormalized tensor.
"""
- for t, m, s in zip(tensor, self.mean, self.std):
- t.mul_(s).add_(m)
- return tensor
+ if isinstance(tensor, PIL.Image.Image):
+ tensor = torch.tensor(np.array(tensor).transpose(2, 0, 1)).float()
+
+ if tensor.ndimension() == 3:
+ tensor = tensor.unsqueeze(0)
+
+ if tensor.device != self.mean.device:
+ self.mean = self.mean.to(tensor.device)
+ self.std = self.std.to(tensor.device)
+
+ return (tensor * self.std[:, None, None]) + self.mean[:, None, None]
diff --git a/datasets/utils/__init__.py b/datasets/utils/__init__.py
index dbf74218..10f2e367 100644
--- a/datasets/utils/__init__.py
+++ b/datasets/utils/__init__.py
@@ -37,7 +37,13 @@ def set_default_from_args(arg_name: str):
DEFAULT_ARGS[caller_name] = {}
def decorator_set_default_from_args(func):
- DEFAULT_ARGS[caller_name][arg_name] = func(None)
+ n_args = len(inspect.signature(func).parameters)
+ if arg_name in DEFAULT_ARGS[caller_name]:
+ raise ValueError(f"Argument `{arg_name}` already has a default value in `{caller_name}`")
+ if n_args == 1: # has self
+ DEFAULT_ARGS[caller_name][arg_name] = func(None)
+ else:
+ DEFAULT_ARGS[caller_name][arg_name] = func()
@functools.wraps(func)
def wrapper(*args):
diff --git a/datasets/utils/continual_dataset.py b/datasets/utils/continual_dataset.py
index 7289ac89..dcaa4f43 100644
--- a/datasets/utils/continual_dataset.py
+++ b/datasets/utils/continual_dataset.py
@@ -4,8 +4,9 @@
# LICENSE file in the root directory of this source tree.
from argparse import Namespace
+import logging
import sys
-from typing import Tuple
+from typing import List, Tuple
import torch
import numpy as np
@@ -16,6 +17,7 @@
from datasets.utils.validation import get_validation_indexes
from utils.conf import create_seeded_dataloader
from datasets.utils import DEFAULT_ARGS
+from utils.prompt_templates import templates
class ContinualDataset(object):
@@ -29,6 +31,8 @@ class ContinualDataset(object):
N_TASKS (int): the number of tasks
N_CLASSES (int): the number of classes
SIZE (Tuple[int]): the size of the dataset
+ AVAIL_SCHEDS (List[str]): the available schedulers
+ class_names (List[str]): list of the class names of the dataset (should be populated by `get_class_names`)
train_loader (DataLoader): the training loader
test_loaders (List[DataLoader]): the test loaders
i (int): the current task
@@ -43,6 +47,7 @@ class ContinualDataset(object):
N_CLASSES: int
SIZE: Tuple[int]
AVAIL_SCHEDS = ['multisteplr']
+ class_names: List[str] = None
def __init__(self, args: Namespace) -> None:
"""
@@ -53,7 +58,6 @@ def __init__(self, args: Namespace) -> None:
"""
self.train_loader = None
self.test_loaders = []
- self.i = 0
self.c_task = -1
self.args = args
if self.SETTING == 'class-il':
@@ -66,15 +70,10 @@ def __init__(self, args: Namespace) -> None:
if not hasattr(self.args, 'class_order'): # set only once
if self.args.seed is not None:
np.random.seed(self.args.seed)
- if isinstance(self.N_CLASSES_PER_TASK, int):
- self.args.class_order = np.random.permutation(self.N_CLASSES_PER_TASK * self.N_TASKS)
- else:
- self.args.class_order = np.random.permutation(sum(self.N_CLASSES_PER_TASK))
-
- if self.args.validation:
- self._c_seed = self.args.seed if self.args.seed is not None else torch.initial_seed()
+ self.args.class_order = np.random.permutation(self.N_CLASSES)
if args.joint:
+ assert self.SETTING in ['class-il', 'task-il'], 'Joint training is only supported for class-il and task'
self.N_CLASSES_PER_TASK = self.N_CLASSES
self.N_TASKS = 1
@@ -100,7 +99,7 @@ def update_default_args(self):
setattr(self.args, k, v)
else:
if getattr(self.args, k) != v:
- print('Warning: {} set to {} instead of {}.'.format(k, getattr(self.args, k), v), file=sys.stderr)
+ logging.warning('{} set to {} instead of {}.'.format(k, getattr(self.args, k), v))
return self.args
@@ -122,6 +121,8 @@ def get_offsets(self, task_idx: int = None):
start_c = self.N_CLASSES_PER_TASK * task_idx if isinstance(self.N_CLASSES_PER_TASK, int) else sum(self.N_CLASSES_PER_TASK[:task_idx])
end_c = self.N_CLASSES_PER_TASK * (task_idx + 1) if isinstance(self.N_CLASSES_PER_TASK, int) else sum(self.N_CLASSES_PER_TASK[:task_idx + 1])
+ assert end_c > start_c, 'End class index must be greater than start class index.'
+
return start_c, end_c
def get_data_loaders(self) -> Tuple[DataLoader, DataLoader]:
@@ -160,7 +161,7 @@ def get_denormalization_transform() -> nn.Module:
def get_scheduler(model, args: Namespace, reload_optim=True) -> torch.optim.lr_scheduler._LRScheduler:
"""
Returns the scheduler to be used for the current dataset.
- If `reload_optim` is True, the optimizer is reloaded from the model. This should be done at least ONCE every task
+ If `reload_optim` is True, the optimizer is reloaded from the model. This should be done at least ONCE every task
to ensure that the learning rate is reset to the initial value.
"""
if args.lr_scheduler is not None:
@@ -197,6 +198,17 @@ def get_minibatch_size(self):
"""Returns the minibatch size to be used for the current dataset."""
return self.get_batch_size()
+ def get_class_names(self) -> List[str]:
+ """Returns the class names for the current dataset."""
+ raise NotImplementedError('The dataset does not implement the method `get_class_names` to get the class names.')
+
+ def get_prompt_templates(self) -> List[str]:
+ """
+ Returns the prompt templates for the current dataset.
+ By default, it returns the ImageNet prompt templates.
+ """
+ return templates['imagenet']
+
def _get_mask_unlabeled(train_dataset, setting: ContinualDataset):
if setting.args.label_perc == 1:
@@ -242,6 +254,9 @@ def store_masked_loaders(train_dataset: Dataset, test_dataset: Dataset,
Returns:
the training and test loaders
"""
+ if setting.SETTING == 'task-il' or setting.SETTING == 'class-il':
+ setting.c_task += 1
+
if not isinstance(train_dataset.targets, np.ndarray):
train_dataset.targets = np.array(train_dataset.targets)
if not isinstance(test_dataset.targets, np.ndarray):
@@ -260,16 +275,18 @@ def store_masked_loaders(train_dataset: Dataset, test_dataset: Dataset,
train_dataset.data = train_dataset.data[train_idxs]
train_dataset.targets = train_dataset.targets[train_idxs]
+ start_c, end_c = setting.get_offsets()
+
if setting.SETTING == 'class-il' or setting.SETTING == 'task-il':
- train_mask = np.logical_and(train_dataset.targets >= setting.i,
- train_dataset.targets < setting.i + setting.N_CLASSES_PER_TASK)
+ train_mask = np.logical_and(train_dataset.targets >= start_c,
+ train_dataset.targets < end_c)
if setting.args.validation_mode == 'current':
- test_mask = np.logical_and(test_dataset.targets >= setting.i,
- test_dataset.targets < setting.i + setting.N_CLASSES_PER_TASK)
+ test_mask = np.logical_and(test_dataset.targets >= start_c,
+ test_dataset.targets < end_c)
elif setting.args.validation_mode == 'complete':
test_mask = np.logical_and(test_dataset.targets >= 0,
- test_dataset.targets < setting.i + setting.N_CLASSES_PER_TASK)
+ test_dataset.targets < end_c)
else:
raise ValueError('Unknown validation mode: {}'.format(setting.args.validation_mode))
@@ -288,7 +305,21 @@ def store_masked_loaders(train_dataset: Dataset, test_dataset: Dataset,
setting.test_loaders.append(test_loader)
setting.train_loader = train_loader
- if setting.SETTING == 'task-il' or setting.SETTING == 'class-il':
- setting.i += setting.N_CLASSES_PER_TASK
- setting.c_task += 1
return train_loader, test_loader
+
+
+def fix_class_names_order(class_names: List[str], args: Namespace) -> List[str]:
+ """
+ Permutes the order of the class names according to the class order specified in the arguments.
+ The order reflects that of `store_masked_loaders`.
+
+ Args:
+ class_names: the list of class names. This should contain all classes in the dataset (not just the current task's ones).
+ args: the command line arguments
+
+ Returns:
+ List[str]: the class names in the correct order
+ """
+ if args.permute_classes:
+ class_names = [class_names[np.where(args.class_order == i)[0][0]] for i in range(len(class_names))]
+ return class_names
diff --git a/datasets/utils/validation.py b/datasets/utils/validation.py
index 3f34ba3c..52ec6e60 100644
--- a/datasets/utils/validation.py
+++ b/datasets/utils/validation.py
@@ -48,10 +48,11 @@ def __getitem__(self, index):
return img, target
+
def get_validation_indexes(validation_size: float, dataset: Dataset, seed=None) -> Tuple[Dataset, Dataset]:
"""
Returns the indexes of train and validation datasets from the given dataset, according to the validation size.
-
+
Args:
validation_size (float): percentage of samples for each class to be used for validation (between 0 and 100)
dataset (Dataset): the dataset to split
@@ -74,12 +75,13 @@ def get_validation_indexes(validation_size: float, dataset: Dataset, seed=None)
idxs = torch.randperm(n_samples, generator=torch.Generator().manual_seed(seed)).numpy()
val_idxs.append(cls_idxs[idxs[:n_samples_val]])
train_idxs.append(cls_idxs[idxs[n_samples_val:]])
-
+
train_idxs = np.concatenate(train_idxs)
val_idxs = np.concatenate(val_idxs)
return train_idxs, val_idxs
+
def get_train_val(train: Dataset, test_transform: nn.Module,
dataset: str, val_perc: float = 0.1):
"""
diff --git a/docs/Makefile b/docs/Makefile
index 7a34cc2c..faeed095 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -15,6 +15,7 @@ help:
.PHONY: help Makefile
clean:
+ rm models/*_args.rst
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
# rm -rf generated
# rm -rf **/generated
diff --git a/docs/_templates/custom-module-template.rst b/docs/_templates/custom-module-template.rst
index bdd36329..6b212bac 100644
--- a/docs/_templates/custom-module-template.rst
+++ b/docs/_templates/custom-module-template.rst
@@ -2,11 +2,14 @@
.. currentmodule:: {{ fullname }}
+.. include:: ../models/{{ name }}_args.rst
+
.. automodule:: {{ fullname }}
{% block attributes %}
{% if attributes %}
- .. rubric:: {{ _('Module Attributes') }}
+ Module Attributes
+ ~~~~~~~~~~~~~~~~~~
{% for item in attributes %}
.. autoattribute:: {{ item }}
@@ -17,7 +20,8 @@
{% block classes %}
{% if classes %}
- .. rubric:: {{ _('Classes') }}
+ Classes
+ ~~~~~~~~
{% for item in classes %}
.. autoclass:: {{ item }}
@@ -30,8 +34,9 @@
{% block functions %}
{% if functions %}
- .. rubric:: {{ _('Functions') }}
-
+ Functions
+ ~~~~~~~~~~
+
{% for item in functions %}
.. autofunction:: {{ item }}
:members:
@@ -47,7 +52,8 @@
{% block exceptions %}
{% if exceptions %}
- .. rubric:: {{ _('Exceptions') }}
+ Exceptions
+ ~~~~~~~~~~
{% for item in exceptions %}
.. autoexception:: {{ item }}
@@ -70,7 +76,7 @@
.. toctree::
:hidden:
{% for item in modules | reorder_modules %}
- {{ item }}
+ {{ item | get_item_name }} <{{ item }}>
{%- endfor %}
{% endif %}
diff --git a/docs/conf.py b/docs/conf.py
index 9fb72e22..92f218bb 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -135,11 +135,12 @@ def get_headling_module(fullname):
fs = open(f'./docs/{ module }/{ name }.rst', 'r').read()
return fs
else:
- if os.path.isdir('./' + '/'.join(paths)):
- name = name.capitalize()
+ ref_name = name
+ if os.path.isdir('./' + '/'.join(paths)): # or os.path.dirname(module).lower() in ['models', 'datasets', 'utils', 'backbone']:
+ name = name.replace('_', ' ')
else:
- name = name
- return f".. _module-{name}:\n{name}\n" + "=" * (len(name) + 1)
+ name = name.upper().replace('_', ' ')
+ return f".. _module-{ref_name}:\n{name}\n" + "=" * (len(name) + 1)
def reorder_modules(modules):
@@ -178,3 +179,4 @@ def parse_toctree_name(item):
FILTERS["has_items"] = has_items
FILTERS["drop_torch_items"] = drop_torch_items
FILTERS["get_attributes"] = get_attributes
+FILTERS["get_item_name"] = lambda x: x.split('.')[-1].lower()
diff --git a/docs/datasets/index.rst b/docs/datasets/index.rst
index 412126f7..ef46f955 100644
--- a/docs/datasets/index.rst
+++ b/docs/datasets/index.rst
@@ -40,11 +40,16 @@ each dataset **must statically define** all the necessary information to run a c
- **get_scheduler** static method (``callable``): returns the learning rate scheduler to use during train. *By default*, it also initializes the optimizer. This prevents errors due to the learning rate being continouosly reduced task after task. This behavior can be changed setting the argument ``reload_optim=False``.
-See :ref:`continual_dataset` for more details or **SequentialCIFAR10** in :ref:`seq_cifar10` for an example.
+.. admonition:: Optional methods to implement:
+ - **get_prompt_templates** (``callable``): returns the prompt templates for the dataset. This method is expected for some methods (e.g., `clip`). *By default*, it returns the ImageNet prompt templates.
+
+ - **get_class_names** (``callable``): returns the class names for the dataset. This method is not implemented by default, but is expected for some methods (e.g., `clip`). The method *should* populate the **class_names** attribute of the dataset to cache the result and call the ``fix_class_names_order`` method to ensure that the class names are in the correct order.
+
+See :ref:`module-continual_dataset` for more details or **SequentialCIFAR10** in :ref:`module-seq_cifar10` for an example.
.. note::
Datasets are downloaded by default in the **data** folder. You can change this
- default location by setting the **base_path** function in :ref:`conf`.
+ default location by setting the **base_path** function in :ref:`module-conf`.
.. _settings:
@@ -90,9 +95,9 @@ This is done with the **set_default_from_args** decorator, which takes the name
Steps to create a new dataset:
------------------------------
-All datasets must inherit from the **ContinualDataset** class, which is defined in :ref:`continual_dataset`. The only
-exception are datasets that follow the `general-continual` setting, which inherit from the **GCLDataset** class, (defined in :ref:`gcl_dataset`).
-These classes provide some useful methods to create data loaders and store masked data loaders for continual learning experiments. See more in section :ref:`Utils`.
+All datasets must inherit from the **ContinualDataset** class, which is defined in :ref:`module-continual_dataset`. The only
+exception are datasets that follow the `general-continual` setting, which inherit from the **GCLDataset** class, (defined in :ref:`module-gcl_dataset`).
+These classes provide some useful methods to create data loaders and store masked data loaders for continual learning experiments. See more in section :ref:`module-utils`.
1. Create a new file in the `datasets` folder, e.g. ``my_dataset.py``.
@@ -102,7 +107,7 @@ These classes provide some useful methods to create data loaders and store maske
.. tip::
For convenience, most datasets are initially created with all classes and then masked appropriately by the **store_masked_loaders** function.
- For example, in :ref:`seq_cifar10` the **get_data_loaders** function of **SequentialCIFAR10** dataset first inizializes the **MyCIFAR10** and **TCIFAR10**
+ For example, in :ref:`module-seq_cifar10` the **get_data_loaders** function of **SequentialCIFAR10** dataset first inizializes the **MyCIFAR10** and **TCIFAR10**
datasets with train and test data for all classes respectively, and then masks the data loaders to return only the data for the current task.
.. important::
@@ -118,7 +123,7 @@ Utils
- **get_data_loaders**: This function should take care of downloading the dataset if necessary, make sure that it contains samples and labels for
**only** the current task (you can use the **store_masked_loaders** function), and create the data loaders.
-- **store_masked_loaders**: This function is defined in :ref:`continual_dataset` and takes care of masking the data loaders to return only the data for the current task.
+- **store_masked_loaders**: This function is defined in :ref:`module-continual_dataset` and takes care of masking the data loaders to return only the data for the current task.
It is used by most datasets to create the data loaders for each task.
- If the ``--permute_classes`` flag is set to ``1``, it also applies the appropriate permutation to the classes before splitting the data.
diff --git a/docs/getting_started/checkpoints.rst b/docs/getting_started/checkpoints.rst
index dc38b9a6..bdca76b6 100644
--- a/docs/getting_started/checkpoints.rst
+++ b/docs/getting_started/checkpoints.rst
@@ -3,10 +3,19 @@ Load and save checkpoints
Loading and saving checkpoints is handeled automatically in :ref:`module-training` by supplying the ``--savecheck`` and ``--loadcheck`` arguments.
-For example, to save a checkpoint after training, simply run the following command:
+For example, to save a checkpoint after the end of the last task, simply run the following command:
.. code-block:: python
- python utils/main.py --savecheck=1 --model=sgd --dataset=seq-cifar10 --lr=0.1
+ python utils/main.py --savecheck=last --model=sgd --dataset=seq-cifar10 --lr=0.1
+
+Other options for ``--savecheck`` are:
+
+- ``last``: save the checkpoint after **the last task**.
+- ``task``: save the checkpoint after **each task**.
+
+.. note::
+
+ The ``last`` and ``task`` options have the same effect when training with ``--joint``.
This will save the checkpoint in the ``checkpoints`` folder. To load the checkpoint, simply run the following command:
.. code-block:: python
diff --git a/docs/getting_started/fast_training.rst b/docs/getting_started/fast_training.rst
index 95d2f05b..56c5c17e 100644
--- a/docs/getting_started/fast_training.rst
+++ b/docs/getting_started/fast_training.rst
@@ -17,7 +17,7 @@ Mammoth provides a number of optimizations to speed up training. These are **dis
- It may not give a significant speedup for small models.
Distributed training
-====================
+--------------------
Mammoth supports distributed training via `DataParallel `_. To use it, simply pass the `--distributed=dp` argument to ``utils/main.py``. This will automatically use all available GPUs on the machine using the **make_dp** function in :ref:`module-distributed`.
diff --git a/docs/getting_started/index.rst b/docs/getting_started/index.rst
index d9ba6c76..ce9fc110 100644
--- a/docs/getting_started/index.rst
+++ b/docs/getting_started/index.rst
@@ -35,3 +35,4 @@ Mammoth includes a few tests to ensure that the code is working as expected for
pytest --verbose tests
+The tests are quite long, as they evaluate most of the functionality of Mammoth. The estimated runtime is around 1 hour on a RTX 4080 GPU.
diff --git a/docs/how_to_run/starprompt.rst b/docs/how_to_run/starprompt.rst
new file mode 100644
index 00000000..787fa355
--- /dev/null
+++ b/docs/how_to_run/starprompt.rst
@@ -0,0 +1,98 @@
+How to run STAR-Prompt
+======================
+
+.. important::
+
+ You can find the complete paper at this `link `_. The hyperparameters reported in Tab. D and E are the ones used to obtain the results in the paper. Here we report the best hyperparameters we found for each dataset after a more thorough search. The results are very similar to the ones reported in the paper.
+
+STAR-Prompt
+-----------
+
+The most important hyperparameters for STAR-Prompt are a combination of those of the first and second stage (detailed below). The most important ones are:
+
+- ``lambda_ortho_first_stage``: the weight of the orthogonality loss for the first stage.
+- ``lambda_ortho_second_stage``: the weight of the orthogonality loss for the second stage.
+- ``learning_rate_gr_first_stage``: the learning rate of the Generative Replay for the first stage.
+- ``learning_rate_gr_second_stage``: the learning rate of the Generative Replay for the second stage.
+- ``num_epochs_gr_first_stage``: the number of epochs for the Generative Replay for the first stage.
+- ``num_epochs_gr_second_stage``: the number of epochs for the Generative Replay for the second stage.
+
+The best configurations can be found in the tables below by merging the tables of the first and second stage. The only difference is that the number of epochs for the first stage is set as ``--first_stage_epochs`` (by default, is set as ``--n_epochs``).
+
+.. note::
+
+ In the paper we report the results with 3 different choices of random seeds: ``1993``, ``1996``, and ``1997``. We to not report the seed in the commands below for brevity but the seed can be set with the ``--seed`` argument.
+
+First stage only
+~~~~~~~~~~~~~~~~
+
+In the following we report the commands to run the *first stage* of STAR-Prompt on the different datasets.
+
+The most important Hyperparameters are:
+
+* ``lambda_ortho_first_stage``: the weight of the orthogonality loss. :math:`\lambda` in the main paper (Alg 1, Tab D, E).
+* ``learning_rate_gr_first_stage``: the learning rate of the Generative Replay. :math:`lr` in the main paper (Alg 1, Tab D, E).
+* ``num_epochs_gr_first_stage``: the number of epochs for the Generative Replay. :math:`E_1` in the main paper (Alg 1, Tab D, E).
+
+Other hyperparameters such as ``gr_mog_n_iters`` and ``num_monte_carlo_gr`` have a much smaller impact. Here are reported the best configurations, but the default ones already give pretty much the same results.
+
+.. list-table:: Hyperparameter table
+ :header-rows: 1
+
+ * - Dataset
+ - Command
+ * - EuroSAT-RGB
+ - ``--model=first_stage_starprompt --lr=0.002 --n_epochs=5 --gr_mog_n_iters_first_stage=200 --lambda_ortho_first_stage=30 --dataset=seq-eurosat-rgb``
+ * - CropDisease
+ - ``--model=first_stage_starprompt --lr=0.002 --n_epochs=5 --lambda_ortho_first_stage=30 --learning_rate_gr_first_stage=0.01 --dataset=seq-cropdisease``
+ * - Resisc45
+ - ``--model=first_stage_starprompt --lr=0.002 --n_epochs=30 --lambda_ortho_first_stage=10 --dataset=seq-resisc45``
+ * - CIFAR-100
+ - ``--model=first_stage_starprompt --lr=0.002 --n_epochs=20 --lambda_ortho_first_stage=10 --num_monte_carlo_gr_first_stage=1 --dataset=seq-cifar100-224``
+ * - Imagenet-R
+ - ``--model=first_stage_starprompt --lr=0.002 --n_epochs=20 --gr_mog_n_iters_first_stage=200 --lambda_ortho_first_stage=30 --dataset=seq-imagenet-r``
+ * - ISIC
+ - ``--model=first_stage_starprompt --lr=0.002 --n_epochs=30 --lambda_ortho_first_stage=5 --num_epochs_gr_first_stage=50 --learning_rate_gr_first_stage=0.01 --dataset=seq-isic``
+ * - ChestX
+ - ``--model=first_stage_starprompt --lr=0.002 --n_epochs=10 --lambda_ortho_first_stage=30 --dataset=seq-chestx``
+ * - CUB-200
+ - ``--model=first_stage_starprompt --lr=0.002 --n_epochs=50 --lambda_ortho_first_stage=30 --num_epochs_gr_first_stage=50 --num_monte_carlo_gr_first_stage=5 --dataset=seq-cub200``
+ * - Cars-196
+ - ``--model=first_stage_starprompt --lr=0.002 --n_epochs=50 --lambda_ortho_first_stage=30 --learning_rate_gr_first_stage=0.01 --dataset=seq-cars196``
+
+Second stage only
+~~~~~~~~~~~~~~~~~
+
+The *second stage* of STAR-Prompt can take either the class-specific embeddings learned during the first stage or the pre-existing templates of CLIP. This is controlled by the ``--keys_ckpt_path`` argument. If supplied (see :ref:`module-second_stage_starprompt`), it will load the pre-trained embeddings from the first stage. If not supplied, it will use the pre-existing templates of CLIP. The most important Hyperparameters are:
+
+* ``lambda_ortho_second_stage``: the weight of the orthogonality loss. :math:`\lambda` in the main paper (Alg 1, Tab D, E).
+* ``learning_rate_gr_first_stage``: the learning rate of the Generative Replay. :math:`lr` in the main paper (Alg 1, Tab D, E).
+* ``num_epochs_gr_second_stage``: the number of epochs for the Generative Replay. :math:`E_2` in the main paper (Alg 1, Tab D, E).
+
+.. important::
+
+ Remember to set the ``--keys_ckpt_path`` argument to the path of the checkpoint of the first stage. Otherwise, the second stage will not be able to load the class-specific embeddings and will use the pre-existing templates of CLIP.
+
+.. list-table:: Hyperparameter table
+ :header-rows: 1
+
+ * - Dataset
+ - Command
+ * - ISIC
+ - ``--model=second_stage_starprompt --lr=0.001 --optimizer=adam --n_epochs=30 --num_epochs_gr_second_stage=50 --num_monte_carlo_gr_second_stage=5 --learning_rate_gr_second_stage=0.01 --dataset=seq-isic --lambda_ortho_second_stage=50 --keys_ckpt_path=``
+ * - CUB-200
+ - ``--model=second_stage_starprompt --dataset=seq-cub200 --n_epochs=50 --lr=0.001 --optimizer=adam --lambda_ortho_second_stage=30 --learning_rate_gr_second_stage=0.01 --num_monte_carlo_gr_second_stage=5``
+ * - Imagenet-R
+ - ``--model=second_stage_starprompt --optimizer=adam --dataset=seq-imagenet-r --batch_size=16 --n_epochs=5 --lr=0.001 --lambda_ortho_second_stage=10 --learning_rate_gr_second_stage=0.001``
+ * - CIFAR-100
+ - ``--model=second_stage_starprompt --dataset=seq-cifar100-224 --n_epochs=20 --lr=0.001 --optimizer=adam --lambda_ortho_second_stage=2 --learning_rate_gr_second_stage=0.001``
+ * - ChestX
+ - ``--model=second_stage_starprompt --dataset=seq-chestx --n_epochs=30 --lr=0.001 --optimizer=adam --lambda_ortho_second_stage=5 --learning_rate_gr_second_stage=0.05 --num_monte_carlo_gr_second_stage=1``
+ * - CropDisease
+ - ``--model=second_stage_starprompt --optimizer=adam --dataset=seq-cropdisease --lr=0.001 --lambda_ortho_second_stage=5 --learning_rate_gr_second_stage=0.001 --num_monte_carlo_gr_second_stage=5 --num_epochs_gr_second_stage=10``
+ * - Cars-196
+ - ``--model=second_stage_starprompt --dataset=seq-cars196 --n_epochs=50 --lr=0.001 --optimizer=adam --lambda_ortho_second_stage=10 --learning_rate_gr_second_stage=0.01``
+ * - Resisc45
+ - ``--model=second_stage_starprompt --lr=0.001 --optimizer=adam --dataset=seq-resisc45 --n_epochs=30 --lambda_ortho_second_stage=5 --learning_rate_gr_second_stage=0.01 --num_monte_carlo_gr_second_stage=1 --num_epochs_gr_second_stage=50``
+ * - Cars-196
+ - ``--model=second_stage_starprompt --num_monte_carlo_gr_second_stage=2 --optimizer=adam --dataset=seq-eurosat-rgb --lr=0.001 --lambda_ortho_second_stage=5.0 --learning_rate_gr_second_stage=0.1``
\ No newline at end of file
diff --git a/docs/how_to_upgrade/index.rst b/docs/how_to_upgrade/index.rst
new file mode 100644
index 00000000..ef341da8
--- /dev/null
+++ b/docs/how_to_upgrade/index.rst
@@ -0,0 +1,29 @@
+Upgrading to the new Mammoth
+============================
+
+The new Mammoth is almost a complete rewrite of the old Mammoth. The new Mammoth is faster, more efficient (thanks to the ``--code_optimization``), more stable (thanks to tests), supports more validations strategies and settings, and includes more methods and datasets.
+
+Models
+------
+To upgrade your model to the new Mammoth, you need to take some care:
+
+- The *Continual Model* already supports widely used properties such as `current_task`, `n_tasks`, `num_classes`. Check the documentation in :ref:`module-continual_model` for more information.
+- The *get_parser* has been moved **inside** the model. This is to make it easier to automatically load the arguments of a model in the case of automated parsing. This is easy to fix, just move the `get_parser` function inside the model class and make it a static method. *NOTE*: you do not need to add `add_experiment_args` and `add_management_args` to the get_parser function. These are automatically added.
+- The *observe* function should follow the new signature: `def observe(inputs, labels, not_aug_inputs, epoch=None) -> dict|float`. If a `dict` is returned, it should contain at least the `loss` key. All other values will be logged in WandB (if available).
+
+Datasets
+--------
+The datasets had only some minor changes. Just ensure to defined for each dataset the following properties:
+
+- `NAME`: the name of the dataset. This will be used to load the dataset with `--dataset=`.
+- `SETTING`: the setting supported by the dataset. See :ref:`module-datasets` for more information.
+- `N_CLASSES_PER_TASK`: the number of classes per task. This can be either a single value or a list of values (one for each task).
+- `N_TASKS`: the number of tasks.
+- `N_CLASSES`: if missing, it will be computed from `N_CLASSES_PER_TASK` and `N_TASKS`.
+- `SIZE`: the size of each input dimension (*i.e.*, height and width as a tuple for images).
+- `MEAN` and `STD` for normalization.
+- `TRANSFORM`: the train transform.
+- `TEST_TRANSFORM`: the test transform.
+
+Take a look at :ref:`module-seq_cifar10` for more information on how to define a dataset.
+
diff --git a/docs/index.rst b/docs/index.rst
index dbdeee1a..335f85a0 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -39,4 +39,12 @@
generated/backbone.rst
generated/utils.rst
+.. toctree::
+ :glob:
+ :maxdepth: 1
+ :hidden:
+ :caption: How to run:
+
+ how_to_run/starprompt.rst
+
.. include:: readme.rst
diff --git a/docs/models/index.rst b/docs/models/index.rst
index 6ba50f33..7a8dd413 100644
--- a/docs/models/index.rst
+++ b/docs/models/index.rst
@@ -69,7 +69,7 @@ Besides the **observe** and **forward** methods, the **ContinualModel** provides
Automatic attributes
~~~~~~~~~~~~~~~~~~~~
-The base class **ContinualModel** provides a few properties that are automatically set during the incremental training (see :ref:`continual_model` for more details). The most important attributes are:
+The base class **ContinualModel** provides a few properties that are automatically set during the incremental training (see :ref:`module-continual_model` for more details). The most important attributes are:
.. admonition:: Task-related attributes:
@@ -112,7 +112,7 @@ The base class **ContinualModel** provides a few properties that are automatical
- **args**: the arguments passed to the framework.
.. note::
- The automatic conversion between `PIL `_ and `kornia `_ is handeled by the **to_kornia_transform** function in :ref:`kornia_utils`, which converts (*most*) PIL transforms to kornia transforms. However, not all the transforms are supported, and thus this function *may not be always available*. If you want to use a custom transform, you have to extend the **to_kornia_transform** function.
+ The automatic conversion between `PIL `_ and `kornia `_ is handeled by the **to_kornia_transform** function in :ref:`module-kornia_utils`, which converts (*most*) PIL transforms to kornia transforms. However, not all the transforms are supported, and thus this function *may not be always available*. If you want to use a custom transform, you have to extend the **to_kornia_transform** function.
Model parameters
~~~~~~~~~~~~~~~~~
diff --git a/docs/readme.rst b/docs/readme.rst
index 8bad0be3..8773ac8d 100644
--- a/docs/readme.rst
+++ b/docs/readme.rst
@@ -9,10 +9,12 @@ Welcome to Mammoth's documentation!
Mammoth - An Extendible (General) Continual Learning Framework for Pytorch
==========================================================================
-Official repository of `Class-Incremental Continual Learning into the eXtended DER-verse `_ and `Dark Experience for General Continual Learning: a Strong, Simple Baseline `_
+Official repository of `Class-Incremental Continual Learning into the eXtended DER-verse `_, `Dark Experience for General Continual Learning: a Strong, Simple Baseline `_, and `Semantic Residual Prompts for Continual Learning `_.
-Mammoth is a framework for continual learning research. It is designed to be modular, easy to extend, and - most importantly - *easy to debug*.
-Idelly, all the code necessary to run the experiments is included *in the repository*, without needing to check out other repositories or install additional packages.
+Mammoth is a framework for continual learning research. With **40 methods and 21 datasets**, it includes the most complete list competitors and benchmarks for research purposes.
+
+The core idea of Mammoth is that it is designed to be modular, easy to extend, and - most importantly - *easy to debug*.
+Ideally, all the code necessary to run the experiments is included *in the repository*, without needing to check out other repositories or install additional packages.
With Mammoth, nothing is set in stone. You can easily add new models, datasets, training strategies, or functionalities.
@@ -58,61 +60,27 @@ With Mammoth, nothing is set in stone. You can easily add new models, datasets,
Setup
-----
+- Install with ``pip install -r requirements.txt``.
- Use ``./utils/main.py`` to run experiments.
-- Use argument ``--load_best_args`` to use the best hyperparameters from the paper.
- New models can be added to the ``models/`` folder.
- New datasets can be added to the ``datasets/`` folder.
+.. note::
+ **Pytorch version >=2.1.0 is required for scaled_dot_product_attention** (see: https://github.com/Lightning-AI/litgpt/issues/763). If you cannot support this version, the slower base version (see `backbone/vit.py`).
+
Models
------
-- Efficient Lifelong Learning with A-GEM: (A-GEM), and A-GEM with Reservoir buffer (A-GEM-R)
-- Bias Correction (BiC)
-- Continual Contrastive Interpolation Consistency (CCIC) - *Requires* ``pip install kornia``
-- CODA-Prompt: COntinual Decomposed Attention-based Prompting for Rehearsal-Free Continual Learning (CODA-Prompt) - *Requires* ``pip install timm==0.9.8``
-- Dark Experience Replay (DER)
-- Dark Experience Replay++ (DER++)
-- DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning (DualPrompt) - *Requires* ``pip install timm==0.9.8``
-- Experience Replay (ER)
-- online Elastic Weight Consolidation (oEWC)
-- Function Distance Regularization (FDR)
-- Greedy Sampler and Dumb Learner (GDumb)
-- Gradient Episodic Memory (GEM) - *Unavailable on windows*
-- Greedy gradient-based Sample Selection (GSS)
-- Hindsight Anchor Learning (HAL)
-- Incremental Classifier and Representation Learning (iCaRL)
-- Joint for `General Continual`` setting (JointGCL)
-- Learning to Prompt (L2P) - *Requires* ``pip install timm==0.9.8``
-- LiDER (on DER++, iCaRL, GDumb, and ER-ACE)
-- Learning a Unified Classifier Incrementally via Rebalancing (LUCIR)
-- Learning without Forgetting (LwF)
-- Meta-Experience Replay (MER)
-- Progressive Neural Networks (PNN)
-- Regular Polytope Classifier (RPC)
-- Synaptic Intelligence (SI)
-- SLCA: Slow Learner with Classifier Alignment for Continual Learning on a Pre-trained Model (SLCA) - *Requires* ``pip install timm==0.9.8``
-- Transfer without Forgetting (TwF)
-- eXtended-DER (X-DER)
+Mammoth currently supports **more than 40 models**, with new releases covering the main competitors in literature.
Datasets
--------
-**NOTE**: Datasets are automatically downloaded in the ``data/``.
-- This can be changed by changing the ``base_path`` function in ``utils/conf.py``.
-- The ``data/`` folder is not tracked by git and is created automatically if missing.
-
-- Sequential MNIST (*Class-Il / Task-IL*)
-- Sequential CIFAR-10 (*Class-Il / Task-IL*)
-- Sequential Tiny ImageNet (*Class-Il / Task-IL*)
-- Sequential Tiny ImageNet resized 32x32 (*Class-Il / Task-IL*)
-- Sequential CIFAR-100 (*Class-Il / Task-IL*)
-- Sequential CIFAR-100 resized 224x224 (ViT version) (*Class-Il / Task-IL*)
-- Sequential CIFAR-100 resized 224x224 (ResNet50 version) (*Class-Il / Task-IL*)
-- Permuted MNIST (*Domain-IL*)
-- Rotated MNIST (*Domain-IL*)
-- MNIST-360 (*General Continual Learning*)
-- Sequential CUB-200 (*Class-Il / Task-IL*)
-- Sequential ImageNet-R (*Class-Il / Task-IL*)
+**NOTE**: Datasets are automatically downloaded in ``data/``.
+- This can be changes by changing the ``base_path`` function in ``utils/conf.py`` or using the ``--base_path`` argument.
+- The ``data/`` folder should not tracked by git and is craeted automatically if missing.
+
+Mammoth includes **21** datasets, covering *toy classification problems* (different versions of MNIST), *standard domains* (CIFAR, Imagenet-R, TinyImagenet, MIT-67), *fine-grained classification domains* (Cars-196, CUB-200), *aerial domains* (EuroSAT-RGB, Resisc45), *medical domains* (CropDisease, ISIC, ChestX).
Pretrained backbones
--------------------
@@ -120,96 +88,4 @@ Pretrained backbones
- `ResNet18 on cifar100 `_
- `ResNet18 on TinyImagenet resized (seq-tinyimg-r) `_
- `ResNet50 on ImageNet (pytorch version) `_
-- `ResNet18 on SVHN `_
-
-Citing these works
-------------------
-
-.. code-block:: bibtex
-
- @article{boschini2022class,
- title={Class-Incremental Continual Learning into the eXtended DER-verse},
- author={Boschini, Matteo and Bonicelli, Lorenzo and Buzzega, Pietro and Porrello, Angelo and Calderara, Simone},
- journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
- year={2022},
- publisher={IEEE}
- }
-
- @inproceedings{buzzega2020dark,
- author = {Buzzega, Pietro and Boschini, Matteo and Porrello, Angelo and Abati, Davide and Calderara, Simone},
- booktitle = {Advances in Neural Information Processing Systems},
- editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
- pages = {15920--15930},
- publisher = {Curran Associates, Inc.},
- title = {Dark Experience for General Continual Learning: a Strong, Simple Baseline},
- volume = {33},
- year = {2020}
- }
-
-Awesome Papers using Mammoth
-----------------------------
-
-Our Papers
-~~~~~~~~~~~
-
-- `Dark Experience for General Continual Learning: a Strong, Simple Baseline (NeurIPS 2020) `_
-- `Rethinking Experience Replay: a Bag of Tricks for Continual Learning (ICPR 2020) `_ (`code `_)
-- `Class-Incremental Continual Learning into the eXtended DER-verse (TPAMI 2022) `_
-- `Effects of Auxiliary Knowledge on Continual Learning (ICPR 2022) `_
-- `Transfer without Forgetting (ECCV 2022) `_ (`code `_)
-- `Continual semi-supervised learning through contrastive interpolation consistency (PRL 2022) `_ (`code `_)
-- `On the Effectiveness of Lipschitz-Driven Rehearsal in Continual Learning (NeurIPS 2022) `_ (`code `_)
-
-Other Awesome CL works using Mammoth
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-.. important::
-
- **Get in touch if we missed your awesome work!**
-
-`- Prediction Error-based Classification for Class-Incremental Learning (ICLR2024) <(https://arxiv.org/pdf/2305.18806)>`_ (`code <(https://github.com/michalzajac-ml/pec)>`_)
-`- TriRE: A Multi-Mechanism Learning Paradigm for Continual Knowledge Retention and Promotion (NeurIPS2023) <(https://arxiv.org/pdf/2310.08217.pdf)>`_ (`code <(https://github.com/NeurAI-Lab/TriRE)>`_)
-`- Overcoming Recency Bias of Normalization Statistics in Continual Learning: Balance and Adaptation (NeurIPS2023) <(https://arxiv.org/pdf/2310.08855.pdf)>`_ (`code <(https://github.com/lvyilin/AdaB2N)>`_)
-`- A Unified and General Framework for Continual Learning (ICLR2024) <(https://arxiv.org/pdf/2403.13249.pdf)>`_ (`code <(https://github.com/joey-wang123/CL-refresh-learning)>`_)
-`- Decoupling Learning and Remembering: a Bilevel Memory Framework with Knowledge Projection for Task-Incremental Learning (CVPR2023) <(https://openaccess.thecvf.com/content/CVPR2023/papers/Sun_Decoupling_Learning_and_Remembering_A_Bilevel_Memory_Framework_With_Knowledge_CVPR_2023_paper.pdf)>`_ (`code <(https://github.com/SunWenJu123/BMKP)>`_)
-`- Regularizing Second-Order Influences for Continual Learning (CVPR2023) <(https://openaccess.thecvf.com/content/CVPR2023/papers/Sun_Regularizing_Second-Order_Influences_for_Continual_Learning_CVPR_2023_paper.pdf)>`_ (`code <(https://github.com/feifeiobama/InfluenceCL)>`_)
-`- Sparse Coding in a Dual Memory System for Lifelong Learning (CVPR2023) <(https://arxiv.org/pdf/2301.05058.pdf)>`_ (`code <(https://github.com/NeurAI-Lab/SCoMMER)>`_)
-`- A Unified Approach to Domain Incremental Learning with Memory: Theory and Algorithm (CVPR2023) <(https://arxiv.org/pdf/2310.12244.pdf)>`_ (`code <(https://github.com/Wang-ML-Lab/unified-continual-learning)>`_)
-`- A Multi-Head Model for Continual Learning via Out-of-Distribution Replay (CVPR2023) <(https://arxiv.org/pdf/2208.09734.pdf)>`_ (`code <(https://github.com/k-gyuhak/MORE)>`_)
-`- Preserving Linear Separability in Continual Learning by Backward Feature Projection (CVPR2023) <(https://arxiv.org/pdf/2303.14595.pdf)>`_ (`code <(https://github.com/rvl-lab-utoronto/BFP)>`_)
-`- Complementary Calibration: Boosting General Continual Learning With Collaborative Distillation and Self-Supervision (TIP2023) <(https://ieeexplore.ieee.org/document/10002397)>`_ (`code <(https://github.com/lijincm/CoCa)>`_)
-`- Continual Learning by Modeling Intra-Class Variation (TMLR2023) <(https://arxiv.org/abs/2210.05398)>`_ (`code <(https://github.com/yulonghui/MOCA)>`_)
-`- ConSlide: Asynchronous Hierarchical Interaction Transformer with Breakup-Reorganize Rehearsal for Continual Whole Slide Image Analysis (ICCV2023) <(https://openaccess.thecvf.com/content/ICCV2023/papers/Huang_ConSlide_Asynchronous_Hierarchical_Interaction_Transformer_with_Breakup-Reorganize_Rehearsal_for_Continual_ICCV_2023_paper.pdf)>`_ (`code <(https://github.com/HKU-MedAI/ConSlide)>`_)
-`- CBA: Improving Online Continual Learning via Continual Bias Adaptor (ICCV2023) <(https://arxiv.org/pdf/2308.06925.pdf)>`_ (`code <(https://github.com/wqza/CBA-online-CL)>`_)
-`- Neuro-Symbolic Continual Learning: Knowledge, Reasoning Shortcuts and Concept Rehearsal (ICML2023) <(https://arxiv.org/pdf/2302.01242.pdf)>`_ (`code <(https://github.com/ema-marconato/NeSy-CL)>`_)
-`- Pretrained Language Model in Continual Learning: a Comparative Study (ICLR2022) <(https://openreview.net/pdf?id=figzpGMrdD)>`_ (`code <(https://github.com/wutong8023/PLM4CL)>`_)
-`- Representational continuity for unsupervised continual learning (ICLR2022) <(https://openreview.net/pdf?id=9Hrka5PA7LW)>`_ (`code <(https://github.com/divyam3897/UCL)>`_)
-`- Continual Normalization: Rethinking Batch Normalization for Online Continual Learning (ICLR2022) <(https://arxiv.org/abs/2203.16102)>`_ (`code <(https://github.com/phquang/Continual-Normalization)>`_)
-`- Learning Fast, Learning Slow: A General Continual Learning Method based on Complementary Learning System (ICLR2022) <(https://arxiv.org/pdf/2201.12604.pdf)>`_ (`code <(https://github.com/NeurAI-Lab/CLS-ER)>`_)
-`- New Insights on Reducing Abrupt Representation Change in Online Continual Learning (ICLR2022) <(https://openreview.net/pdf?id=N8MaByOzUfb)>`_ (`code <(https://github.com/pclucas14/AML)>`_)
-`- Looking Back on Learned Experiences for Class/Task Incremental Learning (ICLR2022) <(https://openreview.net/pdf?id=RxplU3vmBx)>`_ (`code <(https://github.com/MozhganPourKeshavarz/Cost-Free-Incremental-Learning)>`_)
-`- Task Agnostic Representation Consolidation: a Self-supervised based Continual Learning Approach (CoLLAs2022) <(https://arxiv.org/pdf/2207.06267.pdf)>`_ (`code <(https://github.com/NeurAI-Lab/TARC)>`_)
-`- Consistency is the key to further Mitigating Catastrophic Forgetting in Continual Learning (CoLLAs2022) <(https://arxiv.org/pdf/2207.04998.pdf)>`_ (`code <(https://github.com/NeurAI-Lab/ConsistencyCL)>`_)
-`- Self-supervised models are continual learners (CVPR2022) <(https://arxiv.org/abs/2112.04215)>`_ (`code <(https://github.com/DonkeyShot21/cassle)>`_)
-`- Learning from Students: Online Contrastive Distillation Network for General Continual Learning (IJCAI2022) <(https://www.ijcai.org/proceedings/2022/0446.pdf)>`_ (`code <(https://github.com/lijincm/OCD-Net)>`_)
-
-Contributing
-------------
-
-Pull requests welcome!
-
-Please use `autopep8` with parameters:
-
-- `--aggressive`
-- `--max-line-length=200`
-- `--ignore=E402`
-
-Previous versions
------------------
-
-If you're interested in a version of this repo that only includes the original code for `Dark Experience for General Continual Learning: a Strong, Simple Baseline `_ or `Class-Incremental Continual Learning into the eXtended DER-verse `_, please use the following tags:
-
-- `neurips2020 `_ for DER (NeurIPS 2020).
-
-- `tpami2023 `_ for X-DER (TPAMI 2023).
-
+- `ResNet18 on SVHN `_
\ No newline at end of file
diff --git a/docs/utils/args.rst b/docs/utils/args.rst
index 528dd912..30c9d7e3 100644
--- a/docs/utils/args.rst
+++ b/docs/utils/args.rst
@@ -9,148 +9,116 @@ Arguments
*Arguments used to define the experiment settings.*
-**\-\-dataset** :
- *Help*: Which dataset to perform experiments on.
+**\-\-dataset** : str
+ *Help*: Which dataset to perform experiments on.
- - Default: None
+ - *Default*: ``None``
+ - *Choices*: ``seq-tinyimg, seq-mit67, seq-cars196, seq-cifar100-224-rs, seq-cifar100-224, seq-chestx, seq-cifar10-224-rs, mnist-360, seq-cropdisease, seq-eurosat-rgb, seq-imagenet-r, seq-cifar100, seq-cifar10-224, perm-mnist, seq-cub200, seq-cifar10, rot-mnist, seq-resisc45, seq-mnist, seq-isic, seq-tinyimg-r``
+**\-\-model** : custom_str_underscore
+ *Help*: Model name.
- - Choices: mnist-360, perm-mnist, rot-mnist, seq-cifar10, seq-cifar100, seq-cifar100-224, seq-cifar100-224-rs, seq-cifar10-224, seq-cifar10-224-rs, seq-cub200, seq-imagenet-r, seq-mnist, seq-tinyimg, seq-tinyimg-r
-**\-\-model** :
- *Help*: Model name.
+ - *Default*: ``None``
+ - *Choices*: ``joint-gcl, second-stage-starprompt, lwf-mc, gdumb-lider, ewc-on, xder, hal, sgd, si, first-stage-starprompt, icarl, lucir, fdr, icarl-lider, derpp, der, derpp-lider, gem, bic, attriclip, starprompt, coda-prompt, clip, pnn, er-ace, xder-ce, dualprompt, twf, mer, er-ace-lider, gdumb, l2p, ccic, slca, agem-r, rpc, xder-rpc, gss, lwf, cgil, er, agem``
+**\-\-lr** : float
+ *Help*: Learning rate.
- - Default: None
+ - *Default*: ``None``
+**\-\-batch_size** : int
+ *Help*: Batch size.
- - Choices: agem, agem-r, bic, ccic, coda-prompt, der, derpp, derpp-lider, dualprompt, er, er-ace, er-ace-lider, ewc-on, fdr, gdumb, gdumb-lider, gem, gss, hal, icarl, icarl-lider, joint-gcl, l2p, lucir, lwf, lwf-mc, mer, pnn, rpc, sgd, si, slca, twf, xder, xder-ce, xder-rpc
-**\-\-lr** :
- *Help*: Learning rate.
+ - *Default*: ``None``
+**\-\-label_perc** : float
+ *Help*: Percentage in (0-1] of labeled examples per task.
- - Default: None
+ - *Default*: ``1``
+**\-\-joint** : int
+ *Help*: Train model on Joint (single task)?
- - Choices:
-**\-\-batch_size** :
- *Help*: Batch size.
+ - *Default*: ``0``
+ - *Choices*: ``0, 1``
+**\-\-eval_future** : int
+ *Help*: Evaluate future tasks?
- - Default: None
-
- - Choices:
-**\-\-label_perc** :
- *Help*: Percentage in (0-1] of labeled examples per task.
-
- - Default: 1
-
- - Choices:
-**\-\-joint** :
- *Help*: Train model on Joint (single task)?
-
- - Default: 0
-
- - Choices: 0, 1
+ - *Default*: ``0``
+ - *Choices*: ``0, 1``
.. rubric:: Validation and fitting arguments
*Arguments used to define the validation strategy and the method used to fit the model.*
-**\-\-validation** :
- *Help*: Percentage of samples FOR EACH CLASS drawn from the training set to build the validation set.
-
- - Default: None
-
- - Choices:
-**\-\-validation_mode** :
- *Help*: Mode used for validation. Must be used in combination with `validation` argument. Possible values: - `current`: uses only the current task for validation (default). - `complete`: uses data from both current and past tasks for validation.
-
- - Default: current
-
- - Choices: complete, current
-**\-\-fitting_mode** :
- *Help*: Strategy used for fitting the model. Possible values: - `epochs`: fits the model for a fixed number of epochs (default). NOTE: this option is controlled by the `n_epochs` argument. - `iters`: fits the model for a fixed number of iterations. NOTE: this option is controlled by the `n_iters` argument. - `early_stopping`: fits the model until early stopping criteria are met. This option requires a validation set (see `validation` argument). The early stopping criteria are: if the validation loss does not decrease for `early_stopping_patience` epochs, the training stops.
-
- - Default: epochs
-
- - Choices: epochs, iters, time, early_stopping
-**\-\-early_stopping_patience** :
- *Help*: Number of epochs to wait before stopping the training if the validation loss does not decrease. Used only if `fitting_mode=early_stopping`.
+**\-\-validation** : float
+ *Help*: Percentage of samples FOR EACH CLASS drawn from the training set to build the validation set.
- - Default: 5
+ - *Default*: ``None``
+**\-\-validation_mode** : str
+ *Help*: Mode used for validation. Must be used in combination with `validation` argument. Possible values: - `current`: uses only the current task for validation (default). - `complete`: uses data from both current and past tasks for validation.
- - Choices:
-**\-\-early_stopping_metric** :
- *Help*: Metric used for early stopping. Used only if `fitting_mode=early_stopping`.
+ - *Default*: ``current``
+ - *Choices*: ``complete, current``
+**\-\-fitting_mode** : str
+ *Help*: Strategy used for fitting the model. Possible values: - `epochs`: fits the model for a fixed number of epochs (default). NOTE: this option is controlled by the `n_epochs` argument. - `iters`: fits the model for a fixed number of iterations. NOTE: this option is controlled by the `n_iters` argument. - `early_stopping`: fits the model until early stopping criteria are met. This option requires a validation set (see `validation` argument). The early stopping criteria are: if the validation loss does not decrease for `early_stopping_patience` epochs, the training stops.
- - Default: loss
+ - *Default*: ``epochs``
+ - *Choices*: ``epochs, iters, time, early_stopping``
+**\-\-early_stopping_patience** : int
+ *Help*: Number of epochs to wait before stopping the training if the validation loss does not decrease. Used only if `fitting_mode=early_stopping`.
- - Choices: loss, accuracy
-**\-\-early_stopping_freq** :
- *Help*: Frequency of validation evaluation. Used only if `fitting_mode=early_stopping`.
+ - *Default*: ``5``
+**\-\-early_stopping_metric** : str
+ *Help*: Metric used for early stopping. Used only if `fitting_mode=early_stopping`.
- - Default: 1
+ - *Default*: ``loss``
+ - *Choices*: ``loss, accuracy``
+**\-\-early_stopping_freq** : int
+ *Help*: Frequency of validation evaluation. Used only if `fitting_mode=early_stopping`.
- - Choices:
-**\-\-early_stopping_epsilon** :
- *Help*: Minimum improvement required to consider a new best model. Used only if `fitting_mode=early_stopping`.
+ - *Default*: ``1``
+**\-\-early_stopping_epsilon** : float
+ *Help*: Minimum improvement required to consider a new best model. Used only if `fitting_mode=early_stopping`.
- - Default: 1e-06
+ - *Default*: ``1e-06``
+**\-\-n_epochs** : int
+ *Help*: Number of epochs. Used only if `fitting_mode=epochs`.
- - Choices:
-**\-\-n_epochs** :
- *Help*: Number of epochs. Used only if `fitting_mode=epochs`.
+ - *Default*: ``None``
+**\-\-n_iters** : int
+ *Help*: Number of iterations. Used only if `fitting_mode=iters`.
- - Default: None
-
- - Choices:
-**\-\-n_iters** :
- *Help*: Number of iterations. Used only if `fitting_mode=iters`.
-
- - Default: None
-
- - Choices:
+ - *Default*: ``None``
.. rubric:: Optimizer and learning rate scheduler arguments
*Arguments used to define the optimizer and the learning rate scheduler.*
-**\-\-optimizer** :
- *Help*: Optimizer.
-
- - Default: sgd
-
- - Choices: sgd, adam, adamw
-**\-\-optim_wd** :
- *Help*: optimizer weight decay.
-
- - Default: 0.0
+**\-\-optimizer** : str
+ *Help*: Optimizer.
- - Choices:
-**\-\-optim_mom** :
- *Help*: optimizer momentum.
+ - *Default*: ``sgd``
+ - *Choices*: ``sgd, adam, adamw``
+**\-\-optim_wd** : float
+ *Help*: optimizer weight decay.
- - Default: 0.0
+ - *Default*: ``0.0``
+**\-\-optim_mom** : float
+ *Help*: optimizer momentum.
- - Choices:
-**\-\-optim_nesterov** :
- *Help*: optimizer nesterov momentum.
+ - *Default*: ``0.0``
+**\-\-optim_nesterov** : int
+ *Help*: optimizer nesterov momentum.
- - Default: 0
+ - *Default*: ``0``
+**\-\-lr_scheduler** : str
+ *Help*: Learning rate scheduler.
- - Choices:
-**\-\-lr_scheduler** :
- *Help*: Learning rate scheduler.
+ - *Default*: ``None``
+**\-\-lr_milestones** : int
+ *Help*: Learning rate scheduler milestones (used if `lr_scheduler=multisteplr`).
- - Default: None
+ - *Default*: ``[]``
+**\-\-sched_multistep_lr_gamma** : float
+ *Help*: Learning rate scheduler gamma (used if `lr_scheduler=multisteplr`).
- - Choices:
-**\-\-lr_milestones** :
- *Help*: Learning rate scheduler milestones (used if `lr_scheduler=multisteplr`).
-
- - Default: []
-
- - Choices:
-**\-\-sched_multistep_lr_gamma** :
- *Help*: Learning rate scheduler gamma (used if `lr_scheduler=multisteplr`).
-
- - Default: 0.1
-
- - Choices:
+ - *Default*: ``0.1``
.. rubric:: MANAGEMENT ARGS
@@ -158,151 +126,118 @@ Arguments
*Generic arguments to manage the experiment reproducibility, logging, debugging, etc.*
-**\-\-seed** :
- *Help*: The random seed. If not provided, a random seed will be used.
-
- - Default: None
+**\-\-seed** : int
+ *Help*: The random seed. If not provided, a random seed will be used.
- - Choices:
-**\-\-permute_classes** :
- *Help*: Permute classes before splitting into tasks? This applies the seed before permuting if the `seed` argument is present.
+ - *Default*: ``None``
+**\-\-permute_classes** : int
+ *Help*: Permute classes before splitting into tasks? This applies the seed before permuting if the `seed` argument is present.
- - Default: 0
+ - *Default*: ``1``
+ - *Choices*: ``0, 1``
+**\-\-base_path** : str
+ *Help*: The base path where to save datasets, logs, results.
- - Choices: 0, 1
-**\-\-base_path** :
- *Help*: The base path where to save datasets, logs, results.
+ - *Default*: ``./data/``
+**\-\-device** : str
+ *Help*: The device (or devices) available to use for training. More than one device can be specified by separating them with a comma. If not provided, the code will use the least used GPU available (if there are any), otherwise the CPU. MPS is supported and is automatically used if no GPU is available and MPS is supported. If more than one GPU is available, Mammoth will use the least used one if `--distributed=no`.
- - Default: ./data/
+ - *Default*: ``None``
+**\-\-notes** : str
+ *Help*: Helper argument to include notes for this run. Example: distinguish between different versions of a model and allow separation of results
- - Choices:
-**\-\-notes** :
- *Help*: Helper argument to include notes for this run. Example: distinguish between different versions of a model and allow separation of results
+ - *Default*: ``None``
+**\-\-eval_epochs** : int
+ *Help*: Perform inference on validation every `eval_epochs` epochs. If not provided, the model is evaluated ONLY at the end of each task.
- - Default: None
+ - *Default*: ``None``
+**\-\-non_verbose** : int
+ *Help*: Make progress bars non verbose
- - Choices:
-**\-\-eval_epochs** :
- *Help*: Perform inference on validation every `eval_epochs` epochs. If not provided, the model is evaluated ONLY at the end of each task.
+ - *Default*: ``0``
+ - *Choices*: ``0, 1``
+**\-\-disable_log** : int
+ *Help*: Disable logging?
- - Default: None
+ - *Default*: ``0``
+ - *Choices*: ``0, 1``
+**\-\-num_workers** : int
+ *Help*: Number of workers for the dataloaders (default=infer from number of cpus).
- - Choices:
-**\-\-non_verbose** :
- *Help*: Make progress bars non verbose
+ - *Default*: ``None``
+**\-\-enable_other_metrics** : int
+ *Help*: Enable computing additional metrics: forward and backward transfer.
- - Default: 0
+ - *Default*: ``0``
+ - *Choices*: ``0, 1``
+**\-\-debug_mode** : int
+ *Help*: Run only a few training steps per epoch. This also disables logging on wandb.
- - Choices: 0, 1
-**\-\-disable_log** :
- *Help*: Disable logging?
+ - *Default*: ``0``
+ - *Choices*: ``0, 1``
+**\-\-inference_only** : int
+ *Help*: Perform inference only for each task (no training).
- - Default: 0
+ - *Default*: ``0``
+ - *Choices*: ``0, 1``
+**\-\-code_optimization** : int
+ *Help*: Optimization level for the code.0: no optimization.1: Use TF32, if available.2: Use BF16, if available.3: Use BF16 and `torch.compile`. BEWARE: torch.compile may break your code if you change the model after the first run! Use with caution.
- - Choices: 0, 1
-**\-\-num_workers** :
- *Help*: Number of workers for the dataloaders (default=infer from number of cpus).
+ - *Default*: ``0``
+ - *Choices*: ``0, 1, 2, 3``
+**\-\-distributed** : str
+ *Help*: Enable distributed training?
- - Default: None
+ - *Default*: ``no``
+ - *Choices*: ``no, dp, ddp``
+**\-\-savecheck** : str
+ *Help*: Save checkpoint every `task` or at the end of the training (`last`).
- - Choices:
-**\-\-enable_other_metrics** :
- *Help*: Enable computing additional metrics: forward and backward transfer.
+ - *Default*: ``None``
+ - *Choices*: ``last, task``
+**\-\-loadcheck** : str
+ *Help*: Path of the checkpoint to load (.pt file for the specific task)
- - Default: 0
+ - *Default*: ``None``
+**\-\-ckpt_name** : str
+ *Help*: (optional) checkpoint save name.
- - Choices: 0, 1
-**\-\-debug_mode** :
- *Help*: Run only a few training steps per epoch. This also disables logging on wandb.
+ - *Default*: ``None``
+**\-\-start_from** : int
+ *Help*: Task to start from
- - Default: 0
+ - *Default*: ``None``
+**\-\-stop_after** : int
+ *Help*: Task limit
- - Choices: 0, 1
-**\-\-inference_only** :
- *Help*: Perform inference only for each task (no training).
-
- - Default: 0
-
- - Choices: 0, 1
-**\-\-code_optimization** :
- *Help*: Optimization level for the code.0: no optimization.1: Use TF32, if available.2: Use BF16, if available.3: Use BF16 and `torch.compile`. BEWARE: torch.compile may break your code if you change the model after the first run! Use with caution.
-
- - Default: 0
-
- - Choices: 0, 1, 2, 3
-**\-\-distributed** :
- *Help*: Enable distributed training?
-
- - Default: no
-
- - Choices: no, dp, ddp
-**\-\-savecheck** :
- *Help*: Save checkpoint?
-
- - Default: 0
-
- - Choices: 0, 1
-**\-\-loadcheck** :
- *Help*: Path of the checkpoint to load (.pt file for the specific task)
-
- - Default: None
-
- - Choices:
-**\-\-ckpt_name** :
- *Help*: (optional) checkpoint save name.
-
- - Default: None
-
- - Choices:
-**\-\-start_from** :
- *Help*: Task to start from
-
- - Default: None
-
- - Choices:
-**\-\-stop_after** :
- *Help*: Task limit
-
- - Default: None
-
- - Choices:
+ - *Default*: ``None``
.. rubric:: Wandb arguments
*Arguments to manage logging on Wandb.*
-**\-\-wandb_name** :
- *Help*: Wandb name for this run. Overrides the default name (`args.model`).
+**\-\-wandb_name** : str
+ *Help*: Wandb name for this run. Overrides the default name (`args.model`).
- - Default: None
+ - *Default*: ``None``
+**\-\-wandb_entity** : str
+ *Help*: Wandb entity
- - Choices:
-**\-\-wandb_entity** :
- *Help*: Wandb entity
+ - *Default*: ``None``
+**\-\-wandb_project** : str
+ *Help*: Wandb project name
- - Default: None
-
- - Choices:
-**\-\-wandb_project** :
- *Help*: Wandb project name
-
- - Default: mammoth
-
- - Choices:
+ - *Default*: ``None``
.. rubric:: REEHARSAL-ONLY ARGS
-**\-\-buffer_size** :
- *Help*: The size of the memory buffer.
-
- - Default: None
-
- - Choices:
+**\-\-buffer_size** : int
+ *Help*: The size of the memory buffer.
-**\-\-minibatch_size** :
- *Help*: The batch size of the memory buffer.
+ - *Default*: ``None``
- - Default: None
+**\-\-minibatch_size** : int
+ *Help*: The batch size of the memory buffer.
- - Choices:
+ - *Default*: ``None``
diff --git a/docs/utils/index.rst b/docs/utils/index.rst
index cafeab82..9bc706f0 100644
--- a/docs/utils/index.rst
+++ b/docs/utils/index.rst
@@ -40,12 +40,36 @@ Other arguments such as the size of the training batch and the number of epochs
.. note::
To ease hyper-parameter tuning, all boolean arguments follow the convention: ``--=1`` for ``True`` and ``--=0`` for ``False``.
+Reproducibility
+~~~~~~~~~~~~~~~~
+
+By default, the library does not guarantee reproducibility and seeds are set randomly. However, this can be changed by setting the seed manually.
+
+For example, to run the `er` model on the `seq-cifar10` dataset with a seed of `42`, run the following command:
+
+.. code-block:: bash
+
+ python utils/main.py --dataset seq-cifar10 --model der --buffer_size 500 --lr 0.03 --seed 42
+
+Setting the seed affects:
+
+- The random number generators in `numpy`, `torch`, and `random`.
+- The seed for all GPUs (if available). See `PyTorch's docs `_ for more informations.
+- The order of the classes in each task (and the order of the tasks themselves).
+- The random number generators in the data loaders.
+
+We do not set ``torch.use_deterministic_algorithms(True)`` by default, as it can slow down the training process and in our tests does not seem to affect results too much. However, it can be set manually in the `main.py` script if desired.
+
+.. important::
+
+ Setting the seed also influences **the order and the classes** present in each task. While this is desired in most cases, it can be disabled by setting the `permute_classes` argument to `0`.
+
Other useful arguments
~~~~~~~~~~~~~~~~~~~~~~
* ``--debug_mode``: If set to ``1``, the model will run for only a few iterations per each epoch and will disable WandB logging. This is useful for debugging.
-* ``--num_workers**: The number of workers to use for the data loaders. If set to ``0``, the data loaders will run in the main process. This is useful for debugging.
+* ``--num_workers``: The number of workers to use for the data loaders. If set to ``0``, the data loaders will run in the main process. This is useful for debugging.
* ``--seed``: The seed to use for the random number generators. If this is not set, the seed will be randomly generated.
diff --git a/models/__init__.py b/models/__init__.py
index a100ba9d..d297f396 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -6,7 +6,7 @@
import os
import sys
from argparse import Namespace
-from typing import List
+from typing import Dict, List
from torch import nn
import importlib
import inspect
@@ -69,16 +69,16 @@ def get_model_class(args: Namespace) -> ContinualModel:
return names[model_name]
-def get_model_names() -> List[str]:
+def get_model_names() -> Dict[str, ContinualModel]:
"""
- Return the list of the available continual model names.
+ Return the available continual model names and classes.
Returns:
- the list of the available continual model names
+ A dictionary containing the names of the available continual models and their classes.
"""
def _get_names():
- names = {}
+ names: Dict[str, ContinualModel] = {}
for model_name, model in get_all_models().items():
try:
mod = importlib.import_module('models.' + model)
@@ -88,7 +88,6 @@ def _get_names():
names[c.NAME.replace('_', '-')] = c
except Exception as e:
warn_once("Error in model", model)
- warn_once("\t-", e)
names[model.replace('_', '-')] = e
return names
diff --git a/models/attriclip.py b/models/attriclip.py
new file mode 100644
index 00000000..852b5d98
--- /dev/null
+++ b/models/attriclip.py
@@ -0,0 +1,81 @@
+
+
+from utils.args import *
+from models.utils.continual_model import ContinualModel
+
+from datasets import get_dataset
+import wandb
+from models.attriclip_utils.model import CoOp
+from models.attriclip_utils.utils import cosine_loss
+from utils.conf import get_device
+
+
+class Attriclip(ContinualModel):
+ NAME = 'attriclip'
+ COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
+
+ @staticmethod
+ def get_parser() -> ArgumentParser:
+ parser = ArgumentParser(description='Continual Learning via'
+ ' Progressive Neural Networks.')
+ parser.add_argument("--num_prompt", type=int, default=10, help='num_prompt')
+ parser.add_argument("--text_prompt", type=int, default=3, help='text_prompt')
+ parser.add_argument('--freeze_clip', type=int, default=1, help='freeze_clip')
+ parser.add_argument("--virtual_bs_n", type=int, default=1, help="virtual batch size iterations")
+ return parser
+
+ def __init__(self, backbone, loss, args, transform):
+ self.seq_dataset = get_dataset(args)
+ self.device = get_device()
+ self.class_names = self.seq_dataset.get_class_names()
+ backbone = CoOp(self.device, False, False, args)
+ offset_1, offset_2 = self.seq_dataset.get_offsets(0)
+ cur_class_names = self.class_names[offset_1:offset_2]
+ backbone.init_model(class_names=cur_class_names, text_key=backbone.text_key, text_prompt=backbone.text_prompt)
+ super().__init__(backbone, loss, args, transform)
+
+ def begin_task(self, dataset):
+ self.offset_1, self.offset_2 = self.seq_dataset.get_offsets(self.current_task)
+ self.per_epoch_steps = len(dataset.train_loader)
+ cur_class_names = self.class_names[self.offset_1:self.offset_2]
+ self.net.init_model(class_names=cur_class_names, text_key=self.net.text_key, text_prompt=self.net.text_prompt)
+ self.opt, self.custom_scheduler = self.net.get_optimizer(self.per_epoch_steps)
+ self.net.model.eval()
+ self.old_epoch = 0
+ self.idx = 0
+ self.iteration = 0
+ self.opt.zero_grad()
+
+ def observe(self, inputs, labels, not_aug_inputs, epoch=0):
+ if self.old_epoch != epoch:
+ self.idx = 0
+ self.old_epoch = epoch
+ labels = labels.long()
+
+ log_dict = {}
+ log_dict['lr'] = self.opt.param_groups[0]['lr']
+
+ cur_iter_idx = epoch * self.per_epoch_steps + self.idx
+ self.custom_scheduler.step(cur_iter_idx)
+
+ output, ima_feat, key_choose, loss_m = self.net.model(inputs)
+ loss_main = self.loss(output, labels - self.offset_1)
+ loss_k = cosine_loss(ima_feat, key_choose)
+ loss = loss_main + 0.7 * loss_k + 0.3 * loss_m
+
+ self.opt.zero_grad()
+ loss.backward()
+ self.opt.step()
+
+ self.idx += 1
+ self.iteration += 1
+
+ if not self.args.nowand:
+ wandb.log(log_dict)
+
+ return loss.item()
+
+ def forward(self, x):
+ test_classes = self.class_names[:self.offset_2]
+ logits = self.net.model(x, test_classes, test=True)
+ return logits[:, :self.offset_2]
diff --git a/models/attriclip_utils/clip/bpe_simple_vocab_16e6.txt.gz b/models/attriclip_utils/clip/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 00000000..7b5088a5
Binary files /dev/null and b/models/attriclip_utils/clip/bpe_simple_vocab_16e6.txt.gz differ
diff --git a/models/attriclip_utils/clip/clip.py b/models/attriclip_utils/clip/clip.py
new file mode 100644
index 00000000..ba6d7add
--- /dev/null
+++ b/models/attriclip_utils/clip/clip.py
@@ -0,0 +1,226 @@
+import os
+import hashlib
+import urllib
+import warnings
+from typing import Any, Union, List
+
+from PIL import Image
+import torch
+from tqdm import tqdm
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+
+from .model import build_model
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+
+try:
+ from torchvision.transforms import InterpolationMode
+ BICUBIC = InterpolationMode.BICUBIC
+except ImportError:
+ BICUBIC = Image.BICUBIC
+
+if torch.__version__.split(".") < ["1", "7", "1"]:
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
+
+__all__ = ["available_models", "load", "tokenize"]
+_tokenizer = _Tokenizer()
+
+_MODELS = {
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
+}
+
+
+def _download(url: str, root: str):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split("/")[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(total=int(source.info().get("content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+def _transform(n_px):
+ return Compose([
+ Resize(n_px, interpolation=BICUBIC),
+ CenterCrop(n_px),
+ lambda image: image.convert("RGB"),
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+
+def available_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list(_MODELS.keys())
+
+
+def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
+ """Load a CLIP model
+
+ Parameters
+ ----------
+ name:str
+ A model name listed by `clip.available_models()", or the path to a model checkpoint containing the state_dict
+
+ device : Union[str, torch.device]
+ The device to put the loaded model
+
+ jit: bool
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
+
+ download_root: str
+ path to download the model files; by default, it uses "~/.cache/clip"
+
+ Returns
+ -------
+ model: torch.nn.Module
+ The CLIP model
+
+ preprocess : Callable[[PIL.Image], torch.Tensor]
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+ """
+ if name in _MODELS:
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
+ elif os.path.isfile(name):
+ model_path = name
+ else:
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
+
+ try:
+ # loading JIT archive
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ state_dict = None
+ except RuntimeError:
+ # loading saved state dict
+ if jit:
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
+ jit = False
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ if not jit:
+ model = build_model(state_dict or model.state_dict()).to(device)
+ if str(device) == "cpu":
+ model.float()
+ return model, _transform(model.visual.input_resolution)
+
+ # patch the device names
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [n for n in device_holder.graph.findAllNodes("prim: :Constant") if "Device" in repr(n)][-1]
+
+ def patch_device(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if str(device) == "cpu":
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+
+ model.float()
+
+ return model, _transform(model.input_resolution.item())
+
+
+def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
+ """
+ Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ---------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ truncate:bool
+ whether to truncate the text in case its encoding is longer than the context length
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+ """
+
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ if truncate:
+ tokens = tokens[:context_length]
+ tokens[-1] = eot_token
+ else:
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
diff --git a/models/attriclip_utils/clip/clip_2.py b/models/attriclip_utils/clip/clip_2.py
new file mode 100644
index 00000000..ac97085a
--- /dev/null
+++ b/models/attriclip_utils/clip/clip_2.py
@@ -0,0 +1,226 @@
+import os
+import hashlib
+import urllib
+import warnings
+from typing import Any, Union, List
+
+from PIL import Image
+import torch
+from tqdm import tqdm
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+
+from .model_2 import build_model
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+import pdb
+try:
+ from torchvision.transforms import InterpolationMode
+ BICUBIC = InterpolationMode.BICUBIC
+except ImportError:
+ BICUBIC = Image.BICUBIC
+
+if torch.__version__.split(".") < ["1", "7", "1"]:
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
+
+__all__ = ["available_models", "load", "tokenize"]
+_tokenizer = _Tokenizer()
+
+_MODELS = {
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
+}
+
+
+def _download(url: str, root: str):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split("/")[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(total=int(source.info().get("content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+def _transform(n_px):
+ return Compose([
+ Resize(n_px, interpolation=BICUBIC),
+ CenterCrop(n_px),
+ lambda image: image.convert("RGB"),
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+
+def available_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list(_MODELS.keys())
+
+
+def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
+ """Load a CLIP model
+
+ Parameters
+ ----------
+ name:str
+ A model name listed by `clip.available_models()", or the path to a model checkpoint containing the state_dict
+
+ device : Union[str, torch.device]
+ The device to put the loaded model
+
+ jit: bool
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
+
+ download_root: str
+ path to download the model files; by default, it uses "~/.cache/clip"
+
+ Returns
+ -------
+ model: torch.nn.Module
+ The CLIP model
+
+ preprocess : Callable[[PIL.Image], torch.Tensor]
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+ """
+ # pdb.set_trace()
+ if name in _MODELS:
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
+ elif os.path.isfile(name):
+ model_path = name
+ else:
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
+ try:
+ # loading JIT archive
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ state_dict = None
+ except RuntimeError:
+ # loading saved state dict
+ if jit:
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
+ jit = False
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ if not jit:
+ model = build_model(state_dict or model.state_dict()).to(device)
+ if str(device) == "cpu":
+ model.float()
+ return model, _transform(model.visual.input_resolution)
+
+ # patch the device names
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [n for n in device_holder.graph.findAllNodes("prim: :Constant") if "Device" in repr(n)][-1]
+
+ def patch_device(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if str(device) == "cpu":
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+
+ model.float()
+
+ return model, _transform(model.input_resolution.item())
+
+
+def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
+ """
+ Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ---------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ truncate:bool
+ whether to truncate the text in case its encoding is longer than the context length
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+ """
+
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ if truncate:
+ tokens = tokens[:context_length]
+ tokens[-1] = eot_token
+ else:
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
diff --git a/models/attriclip_utils/clip/model.py b/models/attriclip_utils/clip/model.py
new file mode 100644
index 00000000..d5636f7a
--- /dev/null
+++ b/models/attriclip_utils/clip/model.py
@@ -0,0 +1,444 @@
+from collections import OrderedDict
+from typing import Tuple, Union
+import pdb
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu(self.bn1(self.conv1(x)))
+ out = self.relu(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW ->(HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x, key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+
+ return x[0]
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.input_resolution = input_resolution
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.avgpool = nn.AvgPool2d(2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ def stem(x):
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
+ x = self.relu(bn(conv(x)))
+ x = self.avgpool(x)
+ return x
+
+ x = x.type(self.conv1.weight.dtype)
+ x = stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model))
+ ]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+ self.use_gradient_checkpoint = False
+
+ def forward(self, x: torch.Tensor):
+
+ if self.use_gradient_checkpoint:
+ for layer_module in self.resblocks:
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module), x)
+ return x
+ else:
+ return self.resblocks(x)
+
+
+class VisionTransformer(nn.Module):
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
+ self.ln_pre = LayerNorm(width)
+
+ self.transformer = Transformer(width, layers, heads)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ def forward(self, x: torch.Tensor):
+ x = self.conv1(x) # shape =[*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape =:[*,grid **2, width]
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape=[*,grid**2+1,width]
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+
+class CLIP(nn.Module):
+ def __init__(self,
+ embed_dim: int,
+ # vision
+ image_resolution: int,
+ vision_layers: Union[Tuple[int, int, int, int], int],
+ vision_width: int,
+ vision_patch_size: int,
+ # text
+ context_length: int,
+ vocab_size: int,
+ transformer_width: int,
+ transformer_heads: int,
+ transformer_layers: int
+ ):
+ super().__init__()
+
+ self.context_length = context_length
+
+ if isinstance(vision_layers, (tuple, list)):
+ vision_heads = vision_width * 32 // 64
+ self.visual = ModifiedResNet(
+ layers=vision_layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ input_resolution=image_resolution,
+ width=vision_width
+ )
+ else:
+ vision_heads = vision_width // 64
+ self.visual = VisionTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_width,
+ layers=vision_layers,
+ heads=vision_heads,
+ output_dim=embed_dim
+ )
+
+ self.transformer = Transformer(
+ width=transformer_width,
+ layers=transformer_layers,
+ heads=transformer_heads,
+ attn_mask=self.build_attention_mask()
+ )
+
+ self.vocab_size = vocab_size
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
+ self.ln_final = LayerNorm(transformer_width)
+
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ self.initialize_parameters()
+
+ def initialize_parameters(self):
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+
+ if isinstance(self.visual, ModifiedResNet):
+ if self.visual.attnpool is not None:
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ attn_std = self.transformer.width ** -0.5
+ fc_std = (2 * self.transformer.width) ** -0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ @property
+ def dtype(self):
+ return self.visual.conv1.weight.dtype
+
+ def encode_image(self, image):
+ return self.visual(image.type(self.dtype))
+
+ def encode_text(self, text):
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+ return x
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image)
+ text_features = self.encode_text(text)
+
+ # normalized features
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_image = logit_scale * image_features @ text_features.t()
+ logits_per_text = logits_per_image.t()
+
+ # shape = [global_batch_size, global_batch_size]
+ return logits_per_image, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+ if isinstance(l, nn.MultiheadAttention):
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict):
+ vit = "visual.proj" in state_dict
+
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_resolution = vision_patch_size * grid_size
+ else:
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+ vision_layers = tuple(counts)
+ # pdb.set_trace()
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+ vision_patch_size = None
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ image_resolution = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+
+ model = CLIP(
+ embed_dim,
+ image_resolution, vision_layers, vision_width, vision_patch_size,
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ if key in state_dict:
+ del state_dict[key]
+
+ convert_weights(model)
+ model.load_state_dict(state_dict)
+ return model.eval()
diff --git a/models/attriclip_utils/clip/model_2.py b/models/attriclip_utils/clip/model_2.py
new file mode 100644
index 00000000..9a85092c
--- /dev/null
+++ b/models/attriclip_utils/clip/model_2.py
@@ -0,0 +1,450 @@
+from collections import OrderedDict
+from typing import Tuple, Union
+import pdb
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+import pdb
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu(self.bn1(self.conv1(x)))
+ out = self.relu(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW ->(HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x, key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+
+ return x[0]
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.input_resolution = input_resolution
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.avgpool = nn.AvgPool2d(2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ def stem(x):
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
+ x = self.relu(bn(conv(x)))
+ x = self.avgpool(x)
+ return x
+
+ x = x.type(self.conv1.weight.dtype)
+ x = stem(x)
+ x = self.layer1(x)
+ # pdb.set_trace()
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model))
+ ]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ # pdb.set_trace()
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+ self.use_gradient_checkpoint = False
+
+ def forward(self, x: torch.Tensor):
+ # pdb.set_trace()
+ if self.use_gradient_checkpoint:
+ for layer_module in self.resblocks:
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module), x)
+ return x
+ else:
+ return self.resblocks(x)
+
+
+class VisionTransformer(nn.Module):
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) # [rid**2+1, width]
+ self.ln_pre = LayerNorm(width)
+
+ self.transformer = Transformer(width, layers, heads)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ def forward(self, x: torch.Tensor):
+ # pdb.set_trace()
+ x = self.conv1(x) # shape =[*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape =:[*,grid **2, width]
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape=[*,grid**2+1,width]
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+
+class CLIP(nn.Module):
+ def __init__(self,
+ embed_dim: int,
+ # vision
+ image_resolution: int,
+ vision_layers: Union[Tuple[int, int, int, int], int],
+ vision_width: int,
+ vision_patch_size: int,
+ # text
+ context_length: int,
+ vocab_size: int,
+ transformer_width: int,
+ transformer_heads: int,
+ transformer_layers: int
+ ):
+ super().__init__()
+
+ self.context_length = context_length
+ # pdb.set_trace()
+ if isinstance(vision_layers, (tuple, list)):
+ vision_heads = vision_width * 32 // 64
+ self.visual = ModifiedResNet(
+ layers=vision_layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ input_resolution=image_resolution,
+ width=vision_width
+ )
+ else:
+ vision_heads = vision_width // 64
+ self.visual = VisionTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_width,
+ layers=vision_layers,
+ heads=vision_heads,
+ output_dim=embed_dim
+ )
+
+ self.transformer = Transformer(
+ width=transformer_width,
+ layers=transformer_layers,
+ heads=transformer_heads,
+ attn_mask=self.build_attention_mask()
+ )
+
+ self.vocab_size = vocab_size
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
+ self.ln_final = LayerNorm(transformer_width)
+
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ self.initialize_parameters()
+
+ def initialize_parameters(self):
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+
+ if isinstance(self.visual, ModifiedResNet):
+ if self.visual.attnpool is not None:
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ attn_std = self.transformer.width ** -0.5
+ fc_std = (2 * self.transformer.width) ** -0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ @property
+ def dtype(self):
+ return self.visual.conv1.weight.dtype
+
+ def encode_image(self, image):
+ return self.visual(image.type(self.dtype))
+
+ def encode_text(self, text):
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+ return x
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image)
+ text_features = self.encode_text(text)
+
+ # normalized features
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_image = logit_scale * image_features @ text_features.t()
+ logits_per_text = logits_per_image.t()
+
+ # shape = [global_batch_size, global_batch_size]
+ return logits_per_image, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+ if isinstance(l, nn.MultiheadAttention):
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict):
+ vit = "visual.proj" in state_dict
+
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_resolution = vision_patch_size * grid_size
+
+ else:
+ # pdb.set_trace()
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+ vision_layers = tuple(counts)
+ # pdb.set_trace()
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+ vision_patch_size = None
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ image_resolution = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+ # pdb.set_trace()
+ model = CLIP(
+ embed_dim,
+ image_resolution, vision_layers, vision_width, vision_patch_size,
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
+ ) # 768, 224, 24, 1024, 14, 77, vocab_size 49408, transformer_width 768, 12, 12
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ if key in state_dict:
+ del state_dict[key]
+
+ convert_weights(model)
+ model.load_state_dict(state_dict)
+ return model.eval()
diff --git a/models/attriclip_utils/clip/simple_tokenizer.py b/models/attriclip_utils/clip/simple_tokenizer.py
new file mode 100644
index 00000000..513b98d4
--- /dev/null
+++ b/models/attriclip_utils/clip/simple_tokenizer.py
@@ -0,0 +1,132 @@
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-s byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on....
+ """
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord(";"), ord("-") + 1)) + list(range(ord("@"), ord("y") + 1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152 - 256 - 2 + 1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v + '' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + (token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ''
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except BaseException:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
diff --git a/models/attriclip_utils/model.py b/models/attriclip_utils/model.py
new file mode 100644
index 00000000..519b3ae2
--- /dev/null
+++ b/models/attriclip_utils/model.py
@@ -0,0 +1,280 @@
+import torch
+import torch.nn as nn
+
+from copy import deepcopy
+
+from models.attriclip_utils.clip.clip_2 import load, tokenize
+from models.attriclip_utils.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
+_tokenizer = _Tokenizer()
+import time
+from models.attriclip_utils.utils import build_cosine_scheduler
+
+
+class PromptLearner(nn.Module):
+ def __init__(self, device, args, class_names, clip_model, text_prompt, n_ctx=12, prompt_pos=2):
+ super().__init__()
+ self.device = device
+ ctx_dim = clip_model.ln_final.weight.shape[0]
+ dtype = clip_model.dtype
+ self.clip_model = clip_model
+ self.args = args
+ n_cls = len(class_names)
+ self.dtype = dtype
+
+ prompt_prefix = ' '.join(['x'] * n_ctx * self.args.text_prompt)
+ prompts = [prompt_prefix + ' ' + name + '.' for name in class_names] # xxxxxx classe
+ classnames = [name.replace('_', ' ') for name in class_names]
+ self.name_lens = [len(_tokenizer.encode(name)) for name in class_names]
+ self.prompt_pos = prompt_pos
+
+ self.text_prompt = text_prompt
+ tokenized_prompts = torch.cat([tokenize(p) for p in prompts]) # conversione frase testuale in numeri
+ self.tokenized_prompts = tokenized_prompts # token
+ with torch.no_grad():
+ embedding = clip_model.token_embedding(tokenized_prompts.to(self.device)).type(self.dtype)
+ self.register_buffer('token_prefix', embedding[:, :1, :]) # prende token del SOS (start of sequence)
+ self.register_buffer('token_suffix', embedding[:, 1 + (n_ctx * self.args.text_prompt):, :]) # prende token CLS, EOS
+
+ nc_prompts = [prompt_prefix + '.'] # xxxxxxxxxxxxxxxxxxxxx.
+ nc_tokenized_prompts = torch.cat([tokenize(p) for p in nc_prompts]) # conversione della frase senza la classe
+ self.nc_tokenized_prompts = nc_tokenized_prompts
+ with torch.no_grad():
+ embedding = clip_model.token_embedding(nc_tokenized_prompts.to(self.device)).type(self.dtype)
+ self.register_buffer('nc_token_prefix', embedding[:, :1, :])
+ self.register_buffer('nc_token_suffix', embedding[:, 1 + n_ctx:, :])
+
+ self.n_cls = n_cls
+ self.n_ctx = n_ctx
+ self.ctx_dim = ctx_dim
+
+ def forward(self, indices, test_class=None, infer=False):
+ if test_class is not None:
+ prompt_prefix = ' '.join(['x'] * self.n_ctx * self.args.text_prompt)
+ prompts = [prompt_prefix + ' ' + name + '.' for name in test_class]
+ self.name_lens = [len(_tokenizer.encode(name)) for name in test_class]
+
+ self.prompt_pos = self.prompt_pos
+
+ tokenized_prompts = torch.cat([tokenize(p) for p in prompts])
+ self.tokenized_prompts = tokenized_prompts
+ with torch.no_grad():
+ embedding = self.clip_model.token_embedding(tokenized_prompts.to(self.device)).type(self.dtype)
+ self.register_buffer('token_prefix', embedding[:, :1, :]) # SOS, [n_cls, 1, ctx_dim]
+ self.register_buffer('token_suffix', embedding[:, 1 + (self.n_ctx * self.args.text_prompt):, :]) # CLS, EOS, [n_cls, -1, ctx_dim]
+ self.n_cls = len(test_class)
+ batch = indices.shape[0]
+ ctx = self.text_prompt[indices].view(batch, self.n_ctx * self.args.text_prompt, self.ctx_dim)
+ tokenized_prompts = self.tokenized_prompts.view(self.n_cls, -1)
+ n_cls = self.n_cls
+
+ if self.prompt_pos == 2:
+ prefix = self.token_prefix.unsqueeze(0).repeat(batch, 1, 1, 1)
+ suffix = self.token_suffix.unsqueeze(0).repeat(batch, 1, 1, 1)
+ ctx = ctx.unsqueeze(1).repeat(1, n_cls, 1, 1)
+ prompts = torch.cat([prefix, ctx, suffix], dim=2)
+ elif self.prompt_pos == 1:
+ prompts = []
+ half_n_ctx = self.n_ctx // 2
+ for i in range(n_cls):
+ name_len = self.name_lens[i]
+ prefix_i = self.token_prefix[i:i + 1, :, :].unsqueeze(1)
+ class_i = self.token_suffix[i:i + 1, :name_len, :].unsqueeze(1)
+ suffix_i = self.token_suffix[i:i + 1, name_len:, :].unsqueeze(1)
+ ctx_i_half1 = ctx[:, :half_n_ctx, :].unsqueeze(0)
+ ctx_i_half2 = ctx[:, half_n_ctx:, :].unsqueeze(0)
+ prompt = torch.cat([prefix_i, ctx_i_half1, class_i, ctx_i_half2, suffix_i], dim=2)
+ prompts.append(prompt)
+ prompts = torch.cat(prompts, dim=0)
+ elif self.prompt_pos == 0:
+ prompts = []
+ for i in range(self.n_cls):
+ name_len = self.name_lens[i]
+ prefix_i = self.token_prefix[i:i + 1, :, :].unsqueeze(1)
+ class_i = self.token_suffix[i:i + 1, :name_len, :].unsqueeze(1)
+ suffix_i = self.token_suffix[i:i + 1, name_len:, :].unsqueeze(1)
+ ctx_i = ctx.unsqueeze(0)
+ prompt = torch.cat([prefix_i, class_i, ctx_i, suffix_i], dim=2)
+ prompts.append(prompt)
+ prompts = torch.cat(prompts, dim=0)
+
+ prompts = prompts.squeeze(2).view(batch * self.n_cls, -1, self.ctx_dim)
+ tokenized_prompts = tokenized_prompts.unsqueeze(0).repeat(batch, 1, 1).view(batch * self.n_cls, -1)
+ self.prompts = prompts
+ self.prompts_token = tokenized_prompts
+ if infer:
+ return prompts, tokenized_prompts
+ else:
+ nc_prompts, nc_tokenized_prompts = self.only_prefix()
+ return prompts, tokenized_prompts, nc_prompts, nc_tokenized_prompts
+
+ def only_prefix(self):
+ ctx = self.text_prompt
+ prompt_size = ctx.shape[0]
+ nc_tokenized_prompts = self.nc_tokenized_prompts.repeat(prompt_size, 1)
+ prefix = self.nc_token_prefix.repeat(prompt_size, 1, 1)
+ suffix = self.nc_token_suffix.repeat(prompt_size, 1, 1)
+ nc_prompts = torch.cat([prefix, ctx, suffix], dim=1)
+ return nc_prompts, nc_tokenized_prompts
+
+
+class TextEncoder(nn.Module):
+ def __init__(self, clip_model):
+ super().__init__()
+ self.transformer = clip_model.transformer
+ self.positional_embedding = clip_model.positional_embedding
+ self.ln_final = clip_model.ln_final
+ self.text_projection = clip_model.text_projection
+ self.dtype = clip_model.dtype
+
+ def forward(self, x, tokenized_prompts):
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2)
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2)
+ x = self.ln_final(x).type(self.dtype)
+ x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
+ return x
+
+
+class CLIP(nn.Module):
+ def __init__(self, device, args, class_names, clip_model, text_key, text_prompt, n_ctx=12):
+ super().__init__()
+ self.n_class = len(class_names)
+ self.device = device
+ self.args = args
+
+ # text enoder
+ self.text_encoder = TextEncoder(clip_model)
+ # if torch.cuda.device_count() > 1:
+ # self.text_encoder = nn.DataParallel(self.text_encoder)
+
+ self.prompt_learner = PromptLearner(self.device, self.args, class_names, clip_model, text_prompt, n_ctx=n_ctx)
+ self.text_key = text_key
+ # image encoder
+ self.image_encoder = clip_model.visual
+ self.logit_scale = clip_model.logit_scale
+
+ def forward(self, image, test_class=None, test=False):
+
+ with torch.no_grad():
+ image_features = self.image_encoder(image.type(self.dtype))
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+ image_features = image_features.detach()
+
+ if test:
+ n_test = len(test_class)
+ probability = image_features @ self.text_key.t()
+ _, indices = probability.topk(k=min(self.args.text_prompt, probability.shape[1]), dim=1, largest=True)
+ text_prompt, tokenized_prompts = self.prompt_learner(indices, test_class, test)
+ text_features = self.text_encoder(text_prompt, tokenized_prompts)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+ logit_scale = self.logit_scale.exp()
+ text_features = text_features.view(image_features.shape[0], n_test, -1)
+ image_features = image_features.unsqueeze(1)
+ logit_scale = self.logit_scale.exp()
+ logits = logit_scale * (image_features * text_features).sum(-1)
+ return logits
+
+ else:
+ n_class = self.n_class
+ probability = image_features @ self.text_key.t()
+ _, indices = probability.topk(k=min(self.args.text_prompt, probability.shape[1]), dim=1, largest=True)
+ key_choose = self.text_key[indices]
+ text_prompt, tokenized_prompts, nc_prompts, nc_tokenized_prompts = self.prompt_learner(indices)
+ text_features = self.text_encoder(text_prompt, tokenized_prompts)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+ text_features = text_features.view(image_features.shape[0], n_class, -1)
+ image_features = image_features.unsqueeze(1)
+ logit_scale = self.logit_scale.exp()
+ logits = logit_scale * (image_features * text_features).sum(-1)
+
+ nc_text_features = self.text_encoder(nc_prompts, nc_tokenized_prompts)
+ nc_text_features = nc_text_features / nc_text_features.norm(dim=-1, keepdim=True)
+ dis = nc_text_features @ nc_text_features.permute(1, 0)
+ loss_m = dis[~torch.eye(self.args.num_prompt, dtype=torch.bool, device=self.device)].abs().mean()
+
+ return logits, image_features, key_choose, loss_m
+
+ @property
+ def dtype(self):
+ return self.image_encoder.conv1.weight.dtype
+
+
+class CoOp:
+ def __init__(self, device, prev_key, prev_prompt, args, n_ctx=12, use_float32=False, use_grad_checkpoint=False, keep=False):
+ super().__init__()
+ self.device = device
+ self.args = args
+ clip_model, _ = load('ViT-L/14')
+ clip_model.eval()
+ if use_float32:
+ clip_model.float()
+
+ if self.args.freeze_clip:
+ for param in clip_model.parameters():
+ param.requires_grad = False
+
+ self.clip_model = clip_model
+ self.use_grad_checkpoint = use_grad_checkpoint
+ self.num_prompt = args.num_prompt
+ self.n_ctx = n_ctx
+ self.lr = args.lr * args.batch_size / 20
+ self.wd = args.optim_wd
+ self.epochs = args.n_epochs
+ self.train_batch = args.batch_size
+ self.args = args
+ dtype = clip_model.dtype
+ self.dtype = dtype
+ # prompt learner
+ ctx_dim = clip_model.ln_final.weight.shape[0]
+ text_key = torch.empty(self.num_prompt, ctx_dim, dtype=self.dtype).to(self.device)
+ nn.init.normal_(text_key, std=0.02)
+ text_prompt = torch.empty(self.num_prompt, n_ctx, ctx_dim, dtype=self.dtype).to(self.device)
+ nn.init.normal_(text_prompt, std=0.02)
+ if keep == True:
+ self.text_key = nn.Parameter(prev_key)
+ self.text_prompt = nn.Parameter(prev_prompt)
+ else:
+ self.text_key = nn.Parameter(text_key)
+ self.text_prompt = nn.Parameter(text_prompt)
+
+ def init_model(self, class_names, text_key, text_prompt):
+
+ self.n_class = len(class_names)
+ clip_model = deepcopy(self.clip_model)
+
+ self.model = CLIP(self.device, self.args, class_names, clip_model, text_key, text_prompt, self.n_ctx)
+ if self.use_grad_checkpoint:
+ try:
+ self.model.text_encoder.transformer.use_gradient_checkpoint = True
+ except BaseException:
+ self.model.text_encoder.module.transformer.use_gradient_checkpoint = True
+
+ def get_optimizer(self, per_epoch_steps):
+ Other_params = [param for name, param in self.model.named_parameters() if 'text_key' in name]
+ param_dict = [{'params': [p for p in self.model.prompt_learner.parameters() if p.requires_grad]},
+ {'params': Other_params}]
+
+ optimizer = torch.optim.SGD(param_dict, lr=self.lr, weight_decay=self.wd)
+ scheduler = build_cosine_scheduler(
+ optimizer,
+ lr=self.lr,
+ total_step=self.epochs * per_epoch_steps)
+
+ return optimizer, scheduler
+
+ @property
+ def training(self):
+ return self.model.training
+
+ def train(self, mode=True):
+ self.model.train(mode)
+
+ def eval(self):
+ self.model.eval()
+
+ def to(self, device):
+ self.model.to(device)
+
+ def parameters(self):
+ return self.model.parameters()
diff --git a/models/attriclip_utils/utils.py b/models/attriclip_utils/utils.py
new file mode 100644
index 00000000..57291d9c
--- /dev/null
+++ b/models/attriclip_utils/utils.py
@@ -0,0 +1,64 @@
+import numpy as np
+
+
+def cosine_schedule_warmup(total_step, value, final_value=0, warmup_step=0, warmup_value=0):
+ if warmup_step > 0:
+ warmup_schedule = np.linspace(warmup_value, value, warmup_step + 2)[1:-1]
+ else:
+ warmup_schedule = np.array([])
+ steps = np.arange(total_step - warmup_step)
+ schedule = final_value + 0.5 * (value - final_value) * (1 + np.cos(np.pi * steps / len(steps)))
+ schedule = np.concatenate((warmup_schedule, schedule))
+ assert len(schedule) == total_step
+ return schedule
+
+
+class build_cosine_scheduler:
+ def __init__(self, optimizer, lr, total_step, lr_warmup_step=0):
+ init_lr = 0
+ final_lr = lr * 1e-3
+ self.lrs = cosine_schedule_warmup(total_step, lr, final_lr, lr_warmup_step, init_lr)
+ self.optimizer = optimizer
+
+ def step(self, idx):
+ lr = self.lrs[idx]
+ for i, param_group in enumerate(self.optimizer.param_groups):
+ param_group["lr"] = lr
+ self.lr = lr
+
+
+class build_bicosine_scheduler:
+ def __init__(self, optimizer, lr, total_step, lr_warmup_step=0):
+ lr_promt = lr[0]
+ lr_conv = lr[1]
+ init_lr = 0
+ final_lr_promt = lr_promt * 1e-3
+ final_lr_conv = lr_conv * 1e-3
+ self.lrs_prompt = cosine_schedule_warmup(total_step, lr_promt, final_lr_promt, lr_warmup_step, init_lr)
+ self.lrs_conv = cosine_schedule_warmup(total_step, lr_conv, final_lr_conv, lr_warmup_step, init_lr)
+ self.optimizer = optimizer
+
+ def step(self, idx):
+ lr_promt = self.lrs_prompt[idx]
+ lr_conv = self.lrs_conv[idx]
+ for i, param_group in enumerate(self.optimizer.param_groups):
+ # pdb.set_trace()
+ if i == 0:
+ param_group["lr"] = lr_conv
+ else:
+ param_group["lr"] = lr_promt
+ self.lr_conv = lr_conv
+ self.lr_prompt = lr_promt
+
+
+def cosine_loss(q, k):
+ # pdb.set_trace()
+ q = q.repeat(1, k.shape[1], 1)
+ # k = k.squeeze(1)
+ # q = q/q.norm(dim=-1)
+ k_norm = k.norm(dim=-1, keepdim=True)
+ # pdb.set_trace()
+ # k_norm = k.norm(dim=-1).unsqueeze(1).repeat(1,k.shape[1])
+ k = k / k_norm
+ cos = ((q * k) / (k.shape[0] * k.shape[1])).sum()
+ return 1 - cos
diff --git a/models/cgil.py b/models/cgil.py
new file mode 100644
index 00000000..85fa76ef
--- /dev/null
+++ b/models/cgil.py
@@ -0,0 +1,88 @@
+from argparse import ArgumentParser
+
+import torch
+
+from datasets import get_dataset
+from datasets.utils.continual_dataset import ContinualDataset
+from models.cgil_utils.cgil_utils import Model
+from models.utils.future_model import FutureModel
+
+
+class IncrementalCoopVAE(FutureModel):
+ NAME = 'cgil'
+ COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
+
+ @staticmethod
+ def get_parser() -> ArgumentParser:
+ parser = ArgumentParser()
+
+ parser.add_argument("--backbone", type=str, default='ViT-L/14', help="Clip backbone")
+ parser.add_argument("--learning_rate_alignment", type=float, default=0.05, help="Learning rate for GR.")
+ parser.add_argument("--optim_alignment", type=str, default='adamw', choices=('sgd', 'adam', 'adamw'), help="Optimizer for GR.")
+ parser.add_argument("--optim_alignment_wd", type=float, default=0, help="Weight decay for GR.")
+ parser.add_argument("--lambda_ortho_first_stage", type=float, default=1, help="Orthogonality loss coefficient for coop")
+ parser.add_argument("--num_epochs_alignment", type=int, default=30, help="Num. of epochs for GR.")
+ parser.add_argument("--batch_size_alignment", type=int, default=128, help="Batch size for alignment.")
+ parser.add_argument('--gr_mog_n_components', type=int, default=5, help="Number of components for GR with MOG.")
+ parser.add_argument('--gr_mog_n_iters', type=int, default=500, help="Number of EM iterations during fit for GR with MOG.")
+ parser.add_argument('--gr_vae_hidden_dim', type=int, default=512, help="Hidden dimension for GR with VAE.")
+ parser.add_argument('--gr_vae_latent_dim', type=int, default=256, help="Latent dimension for GR with VAE.")
+ parser.add_argument('--gr_vae_n_iters', type=int, default=500, help="Number of iterations for GR with VAE.")
+ parser.add_argument('--train_only_current_prompts', type=int, default=0, choices=(0, 1), help="Train only current prompts.")
+ parser.add_argument('--align_with_ortholoss', type=int, default=0, choices=(0, 1), help="Align with orthogonality loss.")
+ parser.add_argument('--lr_vae', type=float, default=2e-4, help="Learning rate for VAE.")
+ parser.add_argument('--general_context', type=int, default=0, help="Use general context (number of contexts created).")
+ parser.add_argument('--generated_context', type=int, default=0, help="Use generated context.")
+ parser.add_argument('--cocoop', type=int, default=0, help="Use image embedding to generate context.")
+ parser.add_argument('--combo_context', type=int, default=1, help="Use both generated and prompt context.")
+ parser.add_argument('--n_context', type=int, default=1, help="Use both generated and prompt context.")
+ parser.add_argument("--g_models", type=str, default='vae', choices=('vae', 'mog', 'gauss', "diffusion"), help="Generative model to use for alignment")
+
+ return parser
+
+ def __init__(self, backbone, loss, args, transform):
+ args.n_epochs = 0
+
+ if args.debug_mode:
+ args.num_epochs_alignment = 1
+ args.gr_mog_n_iters = 1
+ args.gr_vae_n_iters = 10
+
+ backbone = Model(args, num_classes=get_dataset(args).N_CLASSES)
+ super().__init__(backbone, loss, args, transform)
+
+ # REMOVE ALL TRACK RUNNING STATS FROM CLIP
+ for m in self.net.modules():
+ if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
+ m.track_running_stats = False
+
+ def end_task(self, dataset: ContinualDataset) -> None:
+
+ self.net.prompter.update_statistics(dataset)
+
+ self.net.prompter.align()
+
+ self.net.prompter.current_task += 1
+
+ def begin_task(self, dataset: ContinualDataset) -> None:
+
+ self.change_transform(dataset)
+
+ self.old_epoch = 0
+ self.iteration = 0
+
+ torch.cuda.empty_cache()
+
+ def change_transform(self, dataset: ContinualDataset) -> None:
+ dataset.train_loader.dataset.transform = self.net.prompter.clip_preprocess
+ dataset.test_loaders[-1].dataset.transform = self.net.prompter.clip_preprocess
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ logits = self.net(x, train=False)
+ return logits[:, :self.n_seen_classes]
+
+ def future_forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.net.future_forward(x)
+
+ def observe(self, *args, **kwargs):
+ return 0
diff --git a/models/cgil_utils/README.md b/models/cgil_utils/README.md
new file mode 100644
index 00000000..12f67b9f
--- /dev/null
+++ b/models/cgil_utils/README.md
@@ -0,0 +1,39 @@
+# How to Run
+`cd` into the root directory of mammoth.
+
+It may be required to add the root directory to the python path. if you are using bash, you can do this by running the following command:
+```bash
+export PYTHONPATH=$PYTHONPATH:/path/to/mammoth
+```
+
+In the [[paper](https://arxiv.org/abs/2407.15793)] we employ the following hyperparameters for the different datasets accross the seeds `1992`, `1996` and `1997` (we report only the seed `1992` for brevity). The hyperparameters are the same for the other seeds:
+
+- **Imagenet-R**
+
+```bash
+python utils/main.py --dataset=seq-imagenet-r --model=cgil --backbone=ViT-L/14 --lr=0.01 --g_models=vae --optim_alignment=adamw --learning_rate_alignment=0.05 --eval_future=1 --seed=1992 --combo_context=1 --gr_vae_n_iters=500 --num_epochs_alignment=60
+```
+
+- **Cars-196**
+
+```bash
+python utils/main.py --dataset=seq-cars196 --model=cgil --backbone=ViT-L/14 --lr=0.01 --g_models=vae --optim_alignment=adamw --learning_rate_alignment=0.03 --eval_future=1 --seed=1992 --combo_context=1 --gr_vae_n_iters=500 --num_epochs_alignment=60
+```
+
+- **CUB-200**
+
+```bash
+python utils/main.py --dataset=seq-cub200 --model=cgil --backbone=ViT-L/14 --lr=0.01 --g_models=vae --optim_alignment=adamw --learning_rate_alignment=0.01 --eval_future=1 --seed=1992 --combo_context=1 --gr_vae_n_iters=500 --num_epochs_alignment=60
+```
+
+- **EuroSAT-RGB**
+
+```bash
+python utils/main.py --dataset=seq-eurosat-rgb --model=cgil --backbone=ViT-L/14 --lr=0.01 --g_models=vae --optim_alignment=adamw --learning_rate_alignment=0.03 --eval_future=1 --seed=1992 --combo_context=1 --gr_vae_n_iters=500 --num_epochs_alignment=150
+```
+
+- **ISIC**
+
+```bash
+python utils/main.py --dataset=seq-isic --model=cgil --backbone=ViT-L/14 --lr=0.01 --g_models=vae --optim_alignment=adamw --learning_rate_alignment=0.05 --eval_future=1 --seed=1992 --combo_context=1 --gr_vae_n_iters=750 --num_epochs_alignment=150
+```
diff --git a/models/cgil_utils/cgil_utils.py b/models/cgil_utils/cgil_utils.py
new file mode 100644
index 00000000..270fe869
--- /dev/null
+++ b/models/cgil_utils/cgil_utils.py
@@ -0,0 +1,395 @@
+import os
+from pathlib import Path
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from tqdm import tqdm, trange
+
+from datasets import get_dataset
+from datasets.utils.continual_dataset import ContinualDataset
+from models.cgil_utils.diffusion import DiffusionCA
+from models.cgil_utils.generative_replay import (FeaturesDataset, Gaussian,
+ MixtureOfGaussiansModel)
+from models.cgil_utils.vae import VariationalAutoEncoderModel
+from utils.conf import get_device
+
+try:
+ import clip
+except ImportError:
+ raise ImportError("Please install the CLIP package by running: pip install git+https://github.com/openai/CLIP.git")
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+
+class TextEncoder(torch.nn.Module):
+ def __init__(self, clip_model):
+ super().__init__()
+ self.transformer = clip_model.transformer
+ self.positional_embedding = clip_model.positional_embedding
+ self.ln_final = clip_model.ln_final
+ self.text_projection = clip_model.text_projection
+ self.dtype = clip_model.dtype
+
+ def forward(self, x: torch.Tensor, tokenized_prompts: torch.Tensor) -> torch.Tensor:
+ x = x + self.positional_embedding.type(self.dtype)
+ x = self.transformer(x.permute(1, 0, 2)).permute(1, 0, 2)
+ x = self.ln_final(x).type(self.dtype)
+ x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
+ return x
+
+
+class Model(torch.nn.Module):
+ def __init__(self, args, num_classes: int):
+ super().__init__()
+ self.args = args
+ self.num_classes = num_classes
+ self.device = get_device()
+
+ self.prompter = Prompter(args)
+
+ def train(self, mode=True):
+ super().train(False)
+ self.prompter.train(False)
+
+ return self
+
+ def forward(self, x: torch.Tensor, train: bool = True) -> torch.Tensor:
+ clip_out = self.prompter.get_query(x)
+
+ keys = self.prompter.get_keys(train=train, image_embeds=clip_out)
+
+ return self.prompter.get_clip_logits(clip_out, keys)
+
+ def future_forward(self, x: torch.Tensor) -> torch.Tensor:
+ clip_out = self.prompter.get_query(x)
+ trained_keys = self.prompter.get_keys(train=False, image_embeds=clip_out)
+ untrained_keys = self.prompter.just_text_features[len(trained_keys):]
+ keys = torch.cat((trained_keys, untrained_keys), dim=0)
+
+ return self.prompter.get_clip_logits(clip_out, keys)
+
+
+class Prompter(torch.nn.Module):
+ def __init__(self, args):
+ super().__init__()
+ self.args = args
+ self.device = get_device()
+
+ self.seq_dataset = get_dataset(self.args)
+ self.num_classes = self.seq_dataset.N_CLASSES
+
+ self.clip_model, self.clip_preprocess = clip.load(args.backbone, self.device)
+ self.clip_model = self.clip_model.float()
+
+ for p in self.clip_model.parameters():
+ p.requires_grad = False
+
+ self.current_task = 0
+ self.class_names = self.seq_dataset.get_class_names()
+ self.setup_text_prompting()
+ self.clip_logit_scale = self.clip_model.logit_scale
+
+ embed_dim = self.clip_model.visual.output_dim
+
+ if self.args.g_models == 'gauss':
+ self.distributions = torch.nn.ModuleList([Gaussian(embed_dim) for _ in range(self.num_classes)]).to(self.device)
+ elif self.args.g_models == 'mog':
+ self.distributions = torch.nn.ModuleList([MixtureOfGaussiansModel(embed_dim, n_components=self.args.gr_mog_n_components,
+ n_iters=self.args.gr_mog_n_iters)
+ for _ in range(self.num_classes)]).to(self.device)
+ elif self.args.g_models == 'vae':
+ self.distributions = torch.nn.ModuleList([VariationalAutoEncoderModel(input_dim=embed_dim,
+ hidden_dim=self.args.gr_vae_hidden_dim,
+ latent_dim=self.args.gr_vae_latent_dim,
+ lr=self.args.lr_vae,
+ n_iters=self.args.gr_vae_n_iters,
+ class_idx=i)
+ for i in range(self.num_classes)])
+ elif self.args.g_models == 'diffusion':
+ self.distributions = torch.nn.ModuleList([DiffusionCA(embed_dim,
+ self.device,
+ target="img",
+ num_hidden=5,
+ hidden_dim=self.args.gr_vae_latent_dim,
+ n_iters=self.args.gr_vae_n_iters,
+ class_idx=i) for i in range(self.num_classes)])
+
+ def compute_ortho_loss(self) -> torch.Tensor:
+ """Computes the orthogonality loss between the prompt parameters.
+
+ Returns:
+ torch.Tensor: The orthogonality loss.
+ """
+ offset_1, offset_2 = self.seq_dataset.get_offsets(self.current_task)
+
+ if not self.args.train_only_current_prompts:
+ coop_p = torch.cat([getattr(self, f'prompt_parameters_{i}') for i in range(0, offset_2)], dim=0)
+ I = torch.eye(coop_p.shape[0], device=self.device, dtype=coop_p.dtype)
+ ortho_loss_coop = (coop_p @ coop_p.t() - I).pow(2).mean()
+ else:
+ cur_coop_p = torch.cat([getattr(self, f'prompt_parameters_{i}') for i in range(offset_1, offset_2)], dim=0).unsqueeze(1)
+ if self.current_task > 0:
+ past_coop_p = torch.cat([getattr(self, f'prompt_parameters_{i}').detach() for i in range(offset_1)], dim=0)
+ ortho_loss_coop = (torch.matmul(cur_coop_p.permute(1, 0, 2), past_coop_p.permute(1, 2, 0))**2).mean()
+
+ return ortho_loss_coop
+
+ @torch.no_grad()
+ def build_features_dataset(self) -> torch.utils.data.DataLoader:
+ """Builds a dataset of features and labels for the alignment task using the current distributions.
+
+ Returns:
+ torch.utils.data.DataLoader: The dataloader for the alignment task.
+ """
+
+ labels, features = [], []
+
+ for _ti in range(self.current_task + 1):
+
+ prev_t_size, cur_t_size = self.seq_dataset.get_offsets(_ti)
+
+ for class_idx in range(prev_t_size, cur_t_size):
+
+ curr_dist = self.distributions[class_idx]
+ prev_train = curr_dist.training
+ curr_dist.eval()
+ curr_dist = curr_dist.to(self.device)
+ current_samples = curr_dist(256)
+ curr_dist.train(prev_train)
+ curr_dist = curr_dist.to('cpu')
+ features.append(current_samples)
+ labels.append(torch.ones((256)) * class_idx)
+
+ features = torch.cat(features, dim=0).detach()
+ labels = torch.cat(labels, dim=0).long()
+
+ return torch.utils.data.DataLoader(FeaturesDataset(features, labels),
+ batch_size=self.args.batch_size_alignment, shuffle=True, num_workers=0)
+
+ def train_alignment_epoch(self, optim: torch.optim.Optimizer) -> None:
+ """Trains the alignment task for one epoch.
+
+ Args:
+ optim (torch.optim.Optimizer): The optimizer to use for training.
+ """
+ offset_1, offset_2 = self.seq_dataset.get_offsets(self.current_task)
+
+ data_loader = self.build_features_dataset()
+
+ for i, (image_features, labels) in enumerate(data_loader):
+ if self.args.debug_mode and i > 3:
+ break
+ optim.zero_grad()
+
+ image_features, labels = image_features.to(self.device, dtype=self.clip_model.dtype), labels.to(self.device)
+ image_features = F.normalize(image_features, dim=-1)
+ if self.args.train_only_current_prompts and self.current_task > 0:
+ with torch.no_grad():
+ past_keys = self.compute_keys(0, offset_1, image_features)
+ cur_keys = self.compute_keys(offset_1, offset_2, image_features)
+ text_features = torch.cat((past_keys.detach(), cur_keys), dim=0)
+ else:
+ text_features = self.compute_keys(0, offset_2, image_features)
+
+ text_features = F.normalize(text_features, dim=-1)
+
+ if self.args.generated_context and self.args.cocoop:
+ text_features = text_features.reshape(image_features.shape[0], -1, text_features.shape[-1])
+ image_features = image_features.unsqueeze(1)
+ clip_logits = (text_features * image_features).sum(-1)
+ else:
+ clip_logits = torch.einsum('bd,cd->bc', image_features, text_features)
+ clip_logits = clip_logits * self.clip_logit_scale.exp()
+ loss = F.cross_entropy(clip_logits, labels)
+
+ wandb_log = {'alignment_loss_ce': loss.item()}
+
+ if self.args.align_with_ortholoss and not self.args.generated_context:
+ ortho_loss = self.compute_ortho_loss()
+ loss += self.args.lambda_ortho_first_stage * ortho_loss
+ wandb_log['alignment_loss_ortho'] = ortho_loss.item()
+
+ wandb_log['alignment_loss'] = loss.item()
+ if wandb.run:
+ wandb.log(wandb_log)
+
+ loss.backward()
+ optim.step()
+
+ def align(self) -> None:
+ """Trains the alignment task for the current task."""
+ offset_1, offset_2 = self.seq_dataset.get_offsets(self.current_task)
+ if not self.args.train_only_current_prompts:
+ offset_1 = 0
+
+ if self.args.generated_context:
+ parameters = self.context_generator.parameters()
+ elif self.args.combo_context:
+ parameters = [getattr(self, f'prompt_parameters_{i}') for i in range(offset_1, offset_2)]
+ parameters += self.context_generator.parameters()
+ elif self.args.general_context == 0:
+ parameters = [getattr(self, f'prompt_parameters_{i}') for i in range(offset_1, offset_2)]
+ else:
+ parameters = [self.prompt_parameters]
+
+ if self.args.optim_alignment == 'sgd':
+ optim = torch.optim.SGD(lr=self.args.learning_rate_alignment, params=parameters, momentum=0.0, weight_decay=0.0)
+ elif self.args.optim_alignment == 'adam':
+ optim = torch.optim.Adam(lr=self.args.learning_rate_alignment, params=parameters, weight_decay=0.0)
+ elif self.args.optim_alignment == 'adamw':
+ optim = torch.optim.AdamW(lr=self.args.learning_rate_alignment, params=parameters, weight_decay=self.args.optim_alignment_wd)
+ else:
+ raise ValueError(f'Invalid optimizer: {self.args.optim_alignment}')
+
+ for _ in trange(self.args.num_epochs_alignment, desc=f'Alignment Task {self.current_task}', unit='epoch'):
+ self.train_alignment_epoch(optim)
+
+ @torch.no_grad()
+ def update_statistics(self, dataset: ContinualDataset) -> None:
+ """Fit the distributions to the features of the current task.
+
+ Args:
+ dataset (ContinualDataset): The dataset to use for updating the statistics.
+ """
+ offset_1, offset_2 = dataset.get_offsets(self.current_task)
+
+ features_dict = {i: [] for i in range(offset_1, offset_2)}
+
+ was_training = self.training
+ self.eval()
+
+ Path('./cache').mkdir(parents=True, exist_ok=True)
+ backbone = self.args.backbone.replace('/', '_')
+ cache_path = Path(f'./cache/{dataset.NAME}_{self.current_task}_{backbone}_features.pt')
+ if dataset.args.seed is not None:
+ cache_path = Path(f'./cache/{dataset.NAME}_{self.current_task}_seed_{dataset.args.seed}_{backbone}_features.pt')
+
+ if cache_path.exists():
+ features_dict = torch.load(cache_path)
+ print(f'Loaded cached features from {cache_path}')
+ else:
+ with tqdm(total=len(dataset.train_loader), desc='Updating statistics for first stage Generative Replay') as pbar:
+ for i, data in enumerate(dataset.train_loader):
+ if self.args.debug_mode == 1 and i > 3 and min([len(v) for v in features_dict.values()]) > self.args.gr_mog_n_components:
+ break
+ inputs, labels = data[0], data[1]
+ inputs, labels = inputs.to(self.device), labels.to(self.device).long()
+
+ clip_query = self.get_query(inputs)
+
+ for class_idx in labels.unique():
+ features_dict[int(class_idx)].append(clip_query[labels == class_idx])
+
+ pbar.update(1)
+ if not self.args.debug_mode:
+ torch.save(features_dict, cache_path)
+
+ for class_idx in range(offset_1, offset_2):
+ features_class_idx = torch.cat(features_dict[class_idx], dim=0)
+ self.distributions[class_idx].fit(features_class_idx.to(self.device))
+
+ if was_training:
+ self.train()
+
+ def compute_keys(self, start: int, end: int, image_embeds=None):
+ prefix = self.token_prefix[start:end]
+ suffix = self.token_suffix[start:end]
+ tokenized_prompts = self.tokenized_prompts[start:end]
+ if self.args.generated_context:
+ if self.args.cocoop:
+ ctx = self.context_generator(image_embeds).unsqueeze(1).unsqueeze(1).expand(-1, end - start, -1, -1).reshape(-1, 1, image_embeds.shape[-1])
+ prefix = prefix.unsqueeze(0).expand(image_embeds.shape[0], -1, -1, -1).reshape(-1, prefix.shape[1], prefix.shape[2])
+ suffix = suffix.unsqueeze(0).expand(image_embeds.shape[0], -1, -1, -1).reshape(-1, suffix.shape[1], suffix.shape[2])
+ tokenized_prompts = tokenized_prompts.unsqueeze(0).expand(image_embeds.shape[0], -1, -1).reshape(-1, self.tokenized_prompts.shape[-1])
+ else:
+ ctx = self.context_generator(self.just_text_features[start:end]).unsqueeze(1)
+ elif self.args.combo_context:
+ ctx = torch.cat([getattr(self, f'prompt_parameters_{i}') for i in range(start, end)], dim=0)
+ ctx = torch.cat([ctx, self.context_generator(self.just_text_features[start:end]).unsqueeze(1)], dim=1)
+ elif self.args.general_context == 0:
+ ctx = torch.cat([getattr(self, f'prompt_parameters_{i}') for i in range(start, end)], dim=0)
+ else:
+ ctx = self.prompt_parameters.unsqueeze(0).expand(end - start, -1, -1)
+ prompts = torch.cat((prefix, ctx, suffix), dim=1)
+ keys = self.text_encoder(prompts.to(self.clip_model.dtype), tokenized_prompts)
+ keys = F.normalize(keys, dim=-1)
+ return keys
+
+ def get_keys(self, train: bool = True, image_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
+ task_id = self.current_task if train else self.current_task - 1
+ offset_1, offset_2 = self.seq_dataset.get_offsets(task_id)
+ if train and self.current_task > 0:
+ with torch.no_grad():
+ past_keys = self.compute_keys(0, offset_1, image_embeds)
+ cur_keys = self.compute_keys(offset_1, offset_2, image_embeds)
+ keys = torch.cat((past_keys.detach(), cur_keys), dim=0)
+ else:
+ keys = self.compute_keys(0, offset_2, image_embeds)
+ return keys
+
+ def setup_text_prompting(self) -> None:
+ """Setup the text prompting for the model."""
+ self.text_encoder = TextEncoder(self.clip_model)
+
+ text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in self.class_names]).to(self.device)
+ with torch.no_grad():
+ self.just_text_tokens = self.clip_model.token_embedding(text_inputs)
+ self.just_text_features = self.clip_model.encode_text(text_inputs)
+
+ if self.args.generated_context or self.args.combo_context:
+ in_dim = self.just_text_features.shape[-1]
+ out_dim = self.clip_model.token_embedding.weight.shape[1]
+ self.context_generator = torch.nn.Sequential(
+ torch.nn.Linear(in_dim, in_dim),
+ torch.nn.BatchNorm1d(in_dim),
+ torch.nn.SELU(True),
+ torch.nn.Linear(in_dim, in_dim),
+ torch.nn.BatchNorm1d(in_dim),
+ torch.nn.SELU(True),
+ torch.nn.Linear(in_dim, out_dim),
+ torch.nn.BatchNorm1d(out_dim),
+ torch.nn.SELU(True),
+ torch.nn.Linear(out_dim, out_dim),
+ ).to(self.device)
+
+ n_ctx = max(self.args.n_context, 1)
+ if self.args.combo_context:
+ n_ctx += 1
+ prefix = " ".join(["X"] * n_ctx)
+ text_prompts = [prefix + " " + name + "." for name in self.class_names]
+ tokenized_prompts = torch.cat([clip.tokenize(p) for p in text_prompts], dim=0).to(self.device)
+ self.tokenized_prompts = tokenized_prompts
+
+ with torch.no_grad():
+ embedding = self.clip_model.token_embedding(tokenized_prompts).type(self.clip_model.dtype)
+ self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
+ self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS
+
+ if not self.args.generated_context:
+ if not self.args.general_context:
+ for i in range(self.num_classes):
+ prompt_parameter = torch.empty(1, self.args.n_context, self.clip_model.token_embedding.weight.shape[1], device=self.device, dtype=torch.float32)
+ torch.nn.init.normal_(prompt_parameter, std=0.02)
+ self.register_parameter(f"prompt_parameters_{i}", torch.nn.Parameter(prompt_parameter))
+ else:
+ prompt_parameter = torch.empty(self.args.n_context, self.clip_model.token_embedding.weight.shape[1], device=self.device, dtype=torch.float32)
+ torch.nn.init.normal_(prompt_parameter, std=0.02)
+ self.prompt_parameters = torch.nn.Parameter(prompt_parameter)
+
+ @torch.no_grad()
+ def get_query(self, x: torch.Tensor) -> torch.Tensor:
+ return self.clip_model.encode_image(x)
+
+ def get_clip_logits(self, clip_out: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
+ image_features = F.normalize(clip_out, dim=-1)
+ if self.args.generated_context and self.args.cocoop:
+ keys = keys.reshape(image_features.shape[0], -1, keys.shape[-1])
+ image_features = image_features.unsqueeze(1)
+ clip_logits = (keys * image_features).sum(-1)
+ else:
+ clip_logits = torch.einsum('bd,cd->bc', image_features, keys)
+ clip_logits = clip_logits * self.clip_logit_scale.exp()
+ return clip_logits
diff --git a/models/cgil_utils/diffusion.py b/models/cgil_utils/diffusion.py
new file mode 100644
index 00000000..2aed193e
--- /dev/null
+++ b/models/cgil_utils/diffusion.py
@@ -0,0 +1,195 @@
+import math
+from dataclasses import dataclass
+
+import torch
+import tqdm
+
+
+@dataclass
+class NoiseSchedule:
+ num_steps: int
+ beta_t: torch.Tensor
+ alpha_t: torch.Tensor
+ alpha_bar_t: torch.Tensor
+ img_weight: torch.Tensor
+ noise_weight: torch.Tensor
+
+ def init(
+ self,
+ num_steps: int,
+ beta_t: torch.Tensor,
+ alpha_t: torch.Tensor,
+ alpha_bar_t: torch.Tensor,
+ img_weight: torch.Tensor,
+ noise_weight: torch.Tensor,
+ ) -> None:
+ self.num_steps = num_steps
+ self.beta_t = beta_t
+ self.alpha_t = alpha_t
+ self.alpha_bar_t = alpha_bar_t
+ self.img_weight = img_weight
+ self.noise_weight = noise_weight
+
+ def to(self, device: torch.device):
+ self.beta_t = self.beta_t.to(device)
+ self.alpha_t = self.alpha_t.to(device)
+ self.alpha_bar_t = self.alpha_bar_t.to(device)
+ self.img_weight = self.img_weight.to(device)
+ self.noise_weight = self.noise_weight.to(device)
+ return self
+
+
+def get_cosine_schedule(num_steps: int, s: float = 0, exp: int = 2):
+ alpha_bar_t = torch.cos(
+ (torch.linspace(1, num_steps, steps=num_steps) / num_steps + s)
+ / (1 + s)
+ * math.pi
+ / 2
+ ).pow(exp)
+ beta_t = 1 - alpha_bar_t / (alpha_bar_t.roll(1))
+ beta_t[0] = max(2 * beta_t[1] - beta_t[2], 0)
+ alpha_t = 1 - beta_t
+ img_weight = torch.sqrt(alpha_bar_t)
+ noise_weight = torch.sqrt(1 - alpha_bar_t)
+ return NoiseSchedule(
+ num_steps, beta_t, alpha_t, alpha_bar_t, img_weight, noise_weight
+ )
+
+
+def sinusoidal_embedding(
+ index: torch.Tensor,
+ embedding_dim: int,
+ num_training_steps: int,
+ device: torch.device,
+) -> torch.Tensor:
+ assert len(index.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(num_training_steps) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=device) * -emb)
+ emb = index.unsqueeze(1) * emb.unsqueeze(0).to(device)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ return emb
+
+
+class MLPDiffusion(torch.nn.Module):
+ def __init__(self, embed_dim, hidden_dim, num_hidden, num_steps, device) -> None:
+ super().__init__()
+ self.silu = torch.nn.SiLU()
+
+ self.fc_input = torch.nn.Linear(embed_dim, hidden_dim)
+ self.hidden_layers = torch.nn.ModuleList(
+ [torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(num_hidden)]
+ )
+ self.fc_output = torch.nn.Linear(hidden_dim, embed_dim)
+ self.fc_emb1 = torch.nn.Linear(hidden_dim, hidden_dim)
+ self.fc_emb2 = torch.nn.Linear(hidden_dim, hidden_dim)
+
+ self.num_steps = num_steps
+ self.device = device
+ self.hidden_dim = hidden_dim
+
+ def forward(self, x, timestep):
+ t_emb = sinusoidal_embedding(timestep, self.hidden_dim, self.num_steps, self.device)
+ t_emb = self.silu(self.fc_emb1(t_emb))
+ t_emb = self.fc_emb2(t_emb)
+
+ x = self.silu(self.fc_input(x))
+ for layer in self.hidden_layers:
+ x = self.silu(layer(x) + t_emb)
+ return self.fc_output(x)
+
+
+class DiffusionCA(torch.nn.Module):
+
+ @staticmethod
+ def q_function(x, schedule, timestep=None):
+ if timestep is None:
+ timestep = torch.randint(low=0, high=schedule.num_steps, size=(x.shape[0],))
+
+ noise = torch.randn_like(x) * torch.randn_like(x)
+
+ noise_weight = schedule.noise_weight[timestep].reshape(-1, 1)
+ img_weight = schedule.img_weight[timestep].reshape(-1, 1)
+
+ return x * img_weight + noise * noise_weight, noise, timestep
+
+ def __init__(self, embed_dim, device, hidden_dim=256, num_hidden=1, diffusion_steps=32, greediness=2e-3, target="img", lr=1e-3, n_iters=1000, class_idx=0):
+ super().__init__()
+ self.n_iters = n_iters
+ self.lr = lr
+ self.class_idx = class_idx
+ self.embed_dim = embed_dim
+ self.hidden_dim = hidden_dim
+ self.device = device
+ self.schedule = get_cosine_schedule(diffusion_steps).to(device)
+ self.net = MLPDiffusion(embed_dim, hidden_dim, num_hidden, self.schedule.num_steps, device)
+ self.net = self.net.to(device)
+ self.signal_to_noise_ratio = torch.log(self.schedule.img_weight / self.schedule.noise_weight).clamp(1)
+ self.pred_weight = torch.linspace(1, 1 + greediness, steps=self.schedule.num_steps).to(device)
+ self.target = target
+ self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=lr, weight_decay=1e-2)
+
+ @torch.enable_grad()
+ def fit(self, x):
+ self.min = torch.min(x, dim=0).values
+ self.max = torch.max(x, dim=0).values
+ x = (x - self.min) / (self.max - self.min)
+
+ self.mu = torch.mean(x, dim=0)
+ self.std = torch.std(x, dim=0)
+ x = (x - self.mu) / self.std
+
+ ds = torch.utils.data.TensorDataset(x)
+ self.train()
+ loader = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
+ iters = 0
+ with tqdm.trange(self.n_iters, desc=f"Training Diffusion [{self.class_idx}]") as pbar:
+ for epoch in pbar:
+ for data in loader:
+ self.optimizer.zero_grad()
+ inputs = data[0]
+ noisy_inputs, noise, timestep = self.q_function(inputs, self.schedule)
+ outputs = self.net(noisy_inputs, timestep.to(self.device) + 1)
+ if self.target == "noise":
+ loss = torch.mean((outputs - noise) ** 2)
+ else:
+ loss = (self.signal_to_noise_ratio[timestep] * torch.mean((outputs - inputs) ** 2, -1)).mean()
+
+ loss.backward()
+ self.optimizer.step()
+ pbar.set_postfix(loss=loss.item())
+ iters += 1
+ if iters >= self.n_iters:
+ break
+ if iters >= self.n_iters:
+ break
+ pbar.update(1)
+ self.eval()
+
+ def sample(self, n_sample, resample_period=1):
+ x = torch.randn(n_sample, self.embed_dim).to(self.device)
+ noise = torch.randn_like(x)
+ if self.target == "noise":
+ for i in reversed(range(self.schedule.num_steps)):
+ out = self.net(x, torch.tensor([i + 1]).to(self.device))
+ x -= self.schedule.beta_t[i] / torch.sqrt(1 - self.schedule.alpha_bar_t[i]) * out
+ if i > 0:
+ x += noise * torch.sqrt(self.schedule.beta_t[i])
+ else:
+ for i in reversed(range(self.schedule.num_steps)):
+ if i % resample_period == 0:
+ noise = torch.randn_like(x)
+ x = self.net(x, torch.tensor([i + 1]).to(self.device))
+ if i > 0:
+ x = x * self.schedule.img_weight[i - 1] + noise * self.schedule.noise_weight[i - 1]
+
+ x = x * self.std + self.mu
+ x = x * (self.max - self.min) + self.min
+ return x
+
+ def forward(self, n_sample):
+ return self.sample(n_sample)
diff --git a/models/cgil_utils/generative_replay.py b/models/cgil_utils/generative_replay.py
new file mode 100644
index 00000000..c4ca8420
--- /dev/null
+++ b/models/cgil_utils/generative_replay.py
@@ -0,0 +1,560 @@
+from math import pi
+
+import numpy as np
+import torch
+
+
+class FeaturesDataset(torch.utils.data.Dataset):
+
+ def __init__(self, X, y):
+ self.X, self.y = X, y
+
+ def __getitem__(self, idx):
+ return self.X[idx], self.y[idx]
+
+ def __len__(self):
+ return len(self.X)
+
+
+class Gaussian(torch.nn.Module):
+
+ def __init__(self, embed_dim):
+ super(Gaussian, self).__init__()
+ self.embed_dim = embed_dim
+ self.register_buffer("mean", torch.zeros(embed_dim))
+ self.register_buffer("std", torch.ones(embed_dim))
+
+ def fit(self, x):
+ self.std, self.mean = torch.std_mean(x, dim=0)
+
+ def sample(self, n_sample, scale_mean):
+ return torch.distributions.normal.Normal(scale_mean * self.mean, self.std).sample((n_sample,))
+
+ def forward(self, n_sample, scale_mean: float = 1.0):
+ return self.sample(n_sample, scale_mean)
+
+
+class MixtureOfGaussiansModel(torch.nn.Module):
+
+ def __init__(self, embed_dim, n_components: int = 3, n_iters: int = 100):
+ super().__init__()
+ self.n_iters = n_iters
+ self.gm = GaussianMixture(n_components, embed_dim, covariance_type='diag')
+
+ def fit(self, x):
+ self.gm.fit(x, n_iter=self.n_iters)
+
+ def sample(self, n_sample):
+ return self.gm.sample(n_sample)[0]
+
+ def forward(self, n_sample):
+ return self.sample(n_sample)
+
+
+def calculate_matmul_n_times(n_components, mat_a, mat_b):
+ """
+ Calculate matrix product of two matrics with mat_a[0] >= mat_b[0].
+ Bypasses torch.matmul to reduce memory footprint.
+ args:
+ mat_a: torch.Tensor (n, k, 1, d)
+ mat_b: torch.Tensor (1, k, d, d)
+ """
+ res = torch.zeros(mat_a.shape).to(mat_a.device)
+
+ for i in range(n_components):
+ mat_a_i = mat_a[:, i, :, :].squeeze(-2)
+ mat_b_i = mat_b[0, i, :, :].squeeze()
+ res[:, i, :, :] = mat_a_i.mm(mat_b_i).unsqueeze(1)
+
+ return res
+
+
+def calculate_matmul(mat_a, mat_b):
+ """
+ Calculate matrix product of two matrics with mat_a[0] >= mat_b[0].
+ Bypasses torch.matmul to reduce memory footprint.
+ args:
+ mat_a: torch.Tensor (n, k, 1, d)
+ mat_b: torch.Tensor (n, k, d, 1)
+ """
+ assert mat_a.shape[-2] == 1 and mat_b.shape[-1] == 1
+ return torch.sum(mat_a.squeeze(-2) * mat_b.squeeze(-1), dim=2, keepdim=True)
+
+
+class GaussianMixture(torch.nn.Module):
+ """
+ Fits a mixture of k=1,..,K Gaussians to the input data (K is supplied via n_components).
+ Input tensors are expected to be flat with dimensions (n: number of samples, d: number of features).
+ The model then extends them to (n, 1, d).
+ The model parametrization (mu, sigma) is stored as (1, k, d),
+ probabilities are shaped (n, k, 1) if they relate to an individual sample,
+ or (1, k, 1) if they assign membership probabilities to one of the mixture components.
+ """
+
+ def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, init_params="kmeans", mu_init=None,
+ var_init=None):
+ """
+ Initializes the model and brings all tensors into their required shape.
+ The class expects data to be fed as a flat tensor in (n, d).
+ The class owns:
+ x: torch.Tensor (n, 1, d)
+ mu: torch.Tensor (1, k, d)
+ var: torch.Tensor (1, k, d) or (1, k, d, d)
+ pi: torch.Tensor (1, k, 1)
+ covariance_type: str
+ eps: float
+ init_params: str
+ log_likelihood: float
+ n_components: int
+ n_features: int
+ args:
+ n_components: int
+ n_features: int
+ options:
+ mu_init: torch.Tensor (1, k, d)
+ var_init: torch.Tensor (1, k, d) or (1, k, d, d)
+ covariance_type: str
+ eps: float
+ init_params: str
+ """
+ super(GaussianMixture, self).__init__()
+
+ self.n_components = n_components
+ self.n_features = n_features
+
+ self.mu_init = mu_init
+ self.var_init = var_init
+ self.eps = eps
+
+ self.log_likelihood = -np.inf
+
+ self.covariance_type = covariance_type
+ self.init_params = init_params
+
+ assert self.covariance_type in ("full", "diag")
+ assert self.init_params in ("kmeans", "random")
+
+ self._init_params()
+
+ def _init_params(self):
+ if self.mu_init is not None:
+ assert self.mu_init.size() == (1, self.n_components,
+ self.n_features), "Input mu_init does not have required tensor dimensions (1, %i, %i)" % (
+ self.n_components, self.n_features)
+ # (1, k, d)
+ self.mu = torch.nn.Parameter(self.mu_init, requires_grad=False)
+ else:
+ self.mu = torch.nn.Parameter(torch.randn(1, self.n_components, self.n_features), requires_grad=False)
+
+ if self.covariance_type == "diag":
+ if self.var_init is not None:
+ # (1, k, d)
+ assert self.var_init.size() == (1, self.n_components,
+ self.n_features), "Input var_init does not have required tensor dimensions (1, %i, %i)" % (
+ self.n_components, self.n_features)
+ self.var = torch.nn.Parameter(self.var_init, requires_grad=False)
+ else:
+ self.var = torch.nn.Parameter(torch.ones(1, self.n_components, self.n_features), requires_grad=False)
+ elif self.covariance_type == "full":
+ if self.var_init is not None:
+ # (1, k, d, d)
+ assert self.var_init.size() == (1, self.n_components, self.n_features,
+ self.n_features), "Input var_init does not have required tensor dimensions (1, %i, %i, %i)" % (
+ self.n_components, self.n_features, self.n_features)
+ self.var = torch.nn.Parameter(self.var_init, requires_grad=False)
+ else:
+ self.var = torch.nn.Parameter(
+ torch.eye(self.n_features).reshape(1, 1, self.n_features, self.n_features).repeat(1,
+ self.n_components,
+ 1, 1),
+ requires_grad=False
+ )
+
+ # (1, k, 1)
+ self.pi = torch.nn.Parameter(torch.Tensor(1, self.n_components, 1), requires_grad=False).fill_(
+ 1. / self.n_components)
+ self.params_fitted = False
+
+ def check_size(self, x):
+ if len(x.size()) == 2:
+ # (n, d) --> (n, 1, d)
+ x = x.unsqueeze(1)
+
+ return x
+
+ def bic(self, x):
+ """
+ Bayesian information criterion for a batch of samples.
+ args:
+ x: torch.Tensor (n, d) or (n, 1, d)
+ returns:
+ bic: float
+ """
+ x = self.check_size(x)
+ n = x.shape[0]
+
+ # Free parameters for covariance, means and mixture components
+ free_params = self.n_features * self.n_components + self.n_features + self.n_components - 1
+
+ bic = -2. * self.__score(x, as_average=False).mean() * n + free_params * np.log(n)
+
+ return bic
+
+ def fit(self, x, delta=1e-3, n_iter=100, warm_start=False):
+ """
+ Fits model to the data.
+ args:
+ x: torch.Tensor (n, d) or (n, k, d)
+ options:
+ delta: float
+ n_iter: int
+ warm_start: bool
+ """
+ if not warm_start and self.params_fitted:
+ self._init_params()
+
+ x = self.check_size(x)
+
+ if self.init_params == "kmeans" and self.mu_init is None:
+ mu = self.get_kmeans_mu(x, n_centers=self.n_components)
+ self.mu.data = mu
+
+ i = 0
+ j = np.inf
+
+ while (i <= n_iter) and (j >= delta):
+
+ log_likelihood_old = self.log_likelihood
+ mu_old = self.mu
+ var_old = self.var
+
+ self.__em(x)
+ self.log_likelihood = self.__score(x)
+
+ if torch.isinf(self.log_likelihood.abs()) or torch.isnan(self.log_likelihood):
+ device = self.mu.device
+ # When the log-likelihood assumes unbound values, reinitialize model
+ self.__init__(self.n_components,
+ self.n_features,
+ covariance_type=self.covariance_type,
+ mu_init=self.mu_init,
+ var_init=self.var_init,
+ eps=self.eps)
+ for p in self.parameters():
+ p.data = p.data.to(device)
+ if self.init_params == "kmeans":
+ self.mu.data, = self.get_kmeans_mu(x, n_centers=self.n_components)
+
+ i += 1
+ j = self.log_likelihood - log_likelihood_old
+
+ if j <= delta:
+ # When score decreases, revert to old parameters
+ self.__update_mu(mu_old)
+ self.__update_var(var_old)
+
+ self.params_fitted = True
+
+ def predict(self, x, probs=False):
+ """
+ Assigns input data to one of the mixture components by evaluating the likelihood under each.
+ If probs=True returns normalized probabilities of class membership.
+ args:
+ x: torch.Tensor (n, d) or (n, 1, d)
+ probs: bool
+ returns:
+ p_k: torch.Tensor (n, k)
+ (or)
+ y: torch.LongTensor (n)
+ """
+ x = self.check_size(x)
+
+ weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)
+
+ if probs:
+ p_k = torch.exp(weighted_log_prob)
+ return torch.squeeze(p_k / (p_k.sum(1, keepdim=True)))
+ else:
+ return torch.squeeze(torch.max(weighted_log_prob, 1)[1].type(torch.LongTensor))
+
+ def predict_proba(self, x):
+ """
+ Returns normalized probabilities of class membership.
+ args:
+ x: torch.Tensor (n, d) or (n, 1, d)
+ returns:
+ y: torch.LongTensor (n)
+ """
+ return self.predict(x, probs=True)
+
+ def sample(self, n):
+ """
+ Samples from the model.
+ args:
+ n: int
+ returns:
+ x: torch.Tensor (n, d)
+ y: torch.Tensor (n)
+ """
+ counts = torch.distributions.multinomial.Multinomial(total_count=n, probs=self.pi.squeeze()).sample()
+ x = torch.empty(0, device=counts.device)
+ y = torch.cat([torch.full([int(sample)], j, device=counts.device) for j, sample in enumerate(counts)])
+
+ # Only iterate over components with non-zero counts
+ for k in torch.arange(self.n_components, device=counts.device)[counts > 0]:
+ if self.covariance_type == "diag":
+ x_k = self.mu[0, k] + torch.randn(int(counts[k]), self.n_features, device=x.device) * torch.sqrt(
+ self.var[0, k])
+ elif self.covariance_type == "full":
+ d_k = torch.distributions.multivariate_normal.MultivariateNormal(self.mu[0, k], self.var[0, k])
+ x_k = torch.stack([d_k.sample() for _ in range(int(counts[k]))])
+
+ x = torch.cat((x, x_k), dim=0)
+
+ return x, y
+
+ def score_samples(self, x):
+ """
+ Computes log-likelihood of samples under the current model.
+ args:
+ x: torch.Tensor (n, d) or (n, 1, d)
+ returns:
+ score: torch.LongTensor (n)
+ """
+ x = self.check_size(x)
+
+ score = self.__score(x, as_average=False)
+ return score
+
+ def _estimate_log_prob(self, x):
+ """
+ Returns a tensor with dimensions (n, k, 1), which indicates the log-likelihood that samples belong to the k-th Gaussian.
+ args:
+ x: torch.Tensor (n, d) or (n, 1, d)
+ returns:
+ log_prob: torch.Tensor (n, k, 1)
+ """
+ x = self.check_size(x)
+
+ if self.covariance_type == "full":
+ mu = self.mu
+ var = self.var
+
+ precision = torch.inverse(var)
+ d = x.shape[-1]
+
+ log_2pi = d * np.log(2. * pi)
+
+ log_det = self._calculate_log_det(precision)
+
+ x_mu_T = (x - mu).unsqueeze(-2)
+ x_mu = (x - mu).unsqueeze(-1)
+
+ x_mu_T_precision = calculate_matmul_n_times(self.n_components, x_mu_T, precision)
+ x_mu_T_precision_x_mu = calculate_matmul(x_mu_T_precision, x_mu)
+
+ return -.5 * (log_2pi - log_det + x_mu_T_precision_x_mu)
+
+ elif self.covariance_type == "diag":
+ mu = self.mu
+ prec = torch.rsqrt(self.var)
+
+ log_p = torch.sum((mu * mu + x * x - 2 * x * mu) * prec, dim=2, keepdim=True)
+ log_det = torch.sum(torch.log(prec), dim=2, keepdim=True)
+
+ return -.5 * (self.n_features * np.log(2. * pi) + log_p - log_det)
+
+ def _calculate_log_det(self, var):
+ """
+ Calculate log determinant in log space, to prevent overflow errors.
+ args:
+ var: torch.Tensor (1, k, d, d)
+ """
+ log_det = torch.empty(size=(self.n_components,)).to(var.device)
+
+ for k in range(self.n_components):
+ log_det[k] = 2 * torch.log(torch.diagonal(torch.linalg.cholesky(var[0, k]))).sum()
+
+ return log_det.unsqueeze(-1)
+
+ def _e_step(self, x):
+ """
+ Computes log-responses that indicate the (logarithmic) posterior belief (sometimes called responsibilities) that a data point was generated by one of the k mixture components.
+ Also returns the mean of the mean of the logarithms of the probabilities (as is done in sklearn).
+ This is the so-called expectation step of the EM-algorithm.
+ args:
+ x: torch.Tensor (n, d) or (n, 1, d)
+ returns:
+ log_prob_norm: torch.Tensor (1)
+ log_resp: torch.Tensor (n, k, 1)
+ """
+ x = self.check_size(x)
+
+ weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)
+
+ log_prob_norm = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True)
+ log_resp = weighted_log_prob - log_prob_norm
+
+ return torch.mean(log_prob_norm), log_resp
+
+ def _m_step(self, x, log_resp):
+ """
+ From the log-probabilities, computes new parameters pi, mu, var (that maximize the log-likelihood). This is the maximization step of the EM-algorithm.
+ args:
+ x: torch.Tensor (n, d) or (n, 1, d)
+ log_resp: torch.Tensor (n, k, 1)
+ returns:
+ pi: torch.Tensor (1, k, 1)
+ mu: torch.Tensor (1, k, d)
+ var: torch.Tensor (1, k, d)
+ """
+ x = self.check_size(x)
+
+ resp = torch.exp(log_resp)
+
+ pi = torch.sum(resp, dim=0, keepdim=True) + self.eps
+ mu = torch.sum(resp * x, dim=0, keepdim=True) / pi
+
+ if self.covariance_type == "full":
+ eps = (torch.eye(self.n_features) * self.eps).to(x.device)
+ var = torch.sum((x - mu).unsqueeze(-1).matmul((x - mu).unsqueeze(-2)) * resp.unsqueeze(-1), dim=0,
+ keepdim=True) / torch.sum(resp, dim=0, keepdim=True).unsqueeze(-1) + eps
+
+ elif self.covariance_type == "diag":
+ x2 = (resp * x * x).sum(0, keepdim=True) / pi
+ mu2 = mu * mu
+ xmu = (resp * mu * x).sum(0, keepdim=True) / pi
+ var = x2 - 2 * xmu + mu2 + self.eps
+
+ pi = pi / x.shape[0]
+
+ return pi, mu, var
+
+ def __em(self, x):
+ """
+ Performs one iteration of the expectation-maximization algorithm by calling the respective subroutines.
+ args:
+ x: torch.Tensor (n, 1, d)
+ """
+ _, log_resp = self._e_step(x)
+ pi, mu, var = self._m_step(x, log_resp)
+
+ self.__update_pi(pi)
+ self.__update_mu(mu)
+ self.__update_var(var)
+
+ def __score(self, x, as_average=True):
+ """
+ Computes the log-likelihood of the data under the model.
+ args:
+ x: torch.Tensor (n, 1, d)
+ sum_data: bool
+ returns:
+ score: torch.Tensor (1)
+ (or)
+ per_sample_score: torch.Tensor (n)
+
+ """
+ weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)
+ per_sample_score = torch.logsumexp(weighted_log_prob, dim=1)
+
+ if as_average:
+ return per_sample_score.mean()
+ else:
+ return torch.squeeze(per_sample_score)
+
+ def __update_mu(self, mu):
+ """
+ Updates mean to the provided value.
+ args:
+ mu: torch.FloatTensor
+ """
+ assert mu.size() in ((self.n_components, self.n_features), (1, self.n_components,
+ self.n_features)), "Input mu does not have required tensor dimensions (%i, %i) or (1, %i, %i)" % (
+ self.n_components, self.n_features, self.n_components, self.n_features)
+
+ if mu.size() == (self.n_components, self.n_features):
+ self.mu = mu.unsqueeze(0)
+ elif mu.size() == (1, self.n_components, self.n_features):
+ self.mu.data = mu
+
+ def __update_var(self, var):
+ """
+ Updates variance to the provided value.
+ args:
+ var: torch.FloatTensor
+ """
+ if self.covariance_type == "full":
+ assert var.size() in ((self.n_components, self.n_features, self.n_features), (
+ 1, self.n_components, self.n_features,
+ self.n_features)), "Input var does not have required tensor dimensions (%i, %i, %i) or (1, %i, %i, %i)" % (
+ self.n_components, self.n_features, self.n_features, self.n_components, self.n_features, self.n_features)
+
+ if var.size() == (self.n_components, self.n_features, self.n_features):
+ self.var = var.unsqueeze(0)
+ elif var.size() == (1, self.n_components, self.n_features, self.n_features):
+ self.var.data = var
+
+ elif self.covariance_type == "diag":
+ assert var.size() in ((self.n_components, self.n_features), (1, self.n_components,
+ self.n_features)), "Input var does not have required tensor dimensions (%i, %i) or (1, %i, %i)" % (
+ self.n_components, self.n_features, self.n_components, self.n_features)
+
+ if var.size() == (self.n_components, self.n_features):
+ self.var = var.unsqueeze(0)
+ elif var.size() == (1, self.n_components, self.n_features):
+ self.var.data = var
+
+ def __update_pi(self, pi):
+ """
+ Updates pi to the provided value.
+ args:
+ pi: torch.FloatTensor
+ """
+ assert pi.size() == (1, self.n_components, 1), "Input pi does not have required tensor dimensions (%i, %i, %i)" % (1, self.n_components, 1)
+
+ self.pi.data = pi
+
+ def get_kmeans_mu(self, x, n_centers, init_times=50, min_delta=1e-3):
+ """
+ Find an initial value for the mean. Requires a threshold min_delta for the k-means algorithm to stop iterating.
+ The algorithm is repeated init_times often, after which the best centerpoint is returned.
+ args:
+ x: torch.FloatTensor (n, d) or (n, 1, d)
+ init_times: init
+ min_delta: int
+ """
+ if len(x.size()) == 3:
+ x = x.squeeze(1)
+ x_min, x_max = x.min(), x.max()
+ x = (x - x_min) / (x_max - x_min)
+
+ min_cost = np.inf
+
+ for i in range(init_times):
+ center_idxs = torch.from_numpy(np.random.choice(np.arange(x.shape[0]), size=n_centers, replace=False)).to(x.device)
+ tmp_center = x[center_idxs, ...]
+ l2_dis = torch.norm((x.unsqueeze(1).repeat(1, n_centers, 1) - tmp_center), p=2, dim=2)
+ l2_cls = torch.argmin(l2_dis, dim=1)
+
+ cost = 0
+ for c in range(n_centers):
+ cost += torch.norm(x[l2_cls == c] - tmp_center[c], p=2, dim=1).mean()
+
+ if cost < min_cost:
+ min_cost = cost
+ center = tmp_center
+
+ delta = np.inf
+
+ while delta > min_delta:
+ l2_dis = torch.norm((x.unsqueeze(1).repeat(1, n_centers, 1) - center), p=2, dim=2)
+ l2_cls = torch.argmin(l2_dis, dim=1)
+ center_old = center.clone()
+
+ for c in range(n_centers):
+ center[c] = x[l2_cls == c].mean(dim=0)
+
+ delta = torch.norm((center_old - center), dim=1).max()
+
+ return (center.unsqueeze(0) * (x_max - x_min) + x_min)
diff --git a/models/cgil_utils/vae.py b/models/cgil_utils/vae.py
new file mode 100644
index 00000000..a9662820
--- /dev/null
+++ b/models/cgil_utils/vae.py
@@ -0,0 +1,178 @@
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import tqdm
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+
+class Encoder(nn.Module):
+ def __init__(self, input_dim, hidden_dim, latent_dim):
+ '''
+ Args:
+ input_dim: A integer indicating the size of input dimension.
+ hidden_dim: A integer indicating the size of hidden dimension.
+ latent_dim: A integer indicating the latent dimension.
+ '''
+ super(Encoder, self).__init__()
+ self.latent_dim = latent_dim
+ self.encoder = nn.Sequential(
+ nn.Linear(input_dim, hidden_dim),
+ nn.BatchNorm1d(hidden_dim),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.BatchNorm1d(hidden_dim),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Linear(hidden_dim, 2 * latent_dim),
+ )
+
+ def forward(self, x: torch.Tensor):
+ hidden = self.encoder(x)
+ z_mu, z_logvar = hidden[:, :self.latent_dim], hidden[:, self.latent_dim:]
+ return z_mu, z_logvar
+
+
+class Decoder(nn.Module):
+ def __init__(self, output_dim: int, hidden_dim: int, latent_dim: int):
+ '''
+ Args:
+ latent_dim: A integer indicating the latent size.
+ hidden_dim: A integer indicating the size of hidden dimension.
+ output_dim: A integer indicating the output dimension.
+ '''
+ super(Decoder, self).__init__()
+ self.decoder = nn.Sequential(
+ nn.Linear(latent_dim, hidden_dim),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Linear(hidden_dim, output_dim),
+ )
+
+ def forward(self, x: torch.Tensor):
+ return self.decoder(x)
+
+
+class VariationalAutoEncoderModel(torch.nn.Module):
+
+ def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int,
+ lr: float, class_idx: int, n_iters: int = 100) -> None:
+ super().__init__()
+ self.n_iters = n_iters
+ self.lr = lr
+ self.class_idx = class_idx
+ self.vae = VariationalAutoEncoder(input_dim, hidden_dim, latent_dim, self.n_iters)
+
+ def fit(self, x: torch.Tensor) -> None:
+ self.vae.fit(x, n_iters=self.n_iters, lr=self.lr, class_idx=self.class_idx)
+
+ def sample(self, n_sample: int) -> torch.Tensor:
+ return self.vae.sample(n_sample)[0]
+
+ def forward(self, n_sample: int) -> torch.Tensor:
+ return self.sample(n_sample)
+
+ def reconstruct(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ x = self.vae.normalize(x)
+ x_rec, z_mu, z_logvar = self.vae(x)
+ x_rec = self.vae.denormalize(x_rec)
+ return x_rec, z_mu, z_logvar
+
+
+class VariationalAutoEncoder(nn.Module):
+ def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int, n_iters: int) -> None:
+ super().__init__()
+ self.n_iters = n_iters
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ self.latent_dim = latent_dim
+ self.elbo = ELBO()
+ self.register_buffer("mu", torch.zeros(input_dim))
+ self.register_buffer("std", torch.ones(input_dim))
+ self.register_buffer("min", torch.zeros(input_dim))
+ self.register_buffer("max", torch.ones(input_dim))
+ self.enc = Encoder(input_dim, hidden_dim, latent_dim)
+ self.dec = Decoder(input_dim, hidden_dim, latent_dim)
+
+ def reparameterization_trick(self, z_mu: torch.Tensor, z_logvar: torch.Tensor) -> torch.Tensor:
+ return torch.randn_like(z_mu) * torch.exp(z_logvar) + z_mu
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ z_mu, z_logvar = self.enc(x)
+ z_post = self.reparameterization_trick(z_mu, z_logvar)
+
+ return self.dec(z_post), z_mu, z_logvar
+
+ def sample(self, num_samples: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = next(self.parameters()).device
+ z = torch.randn(num_samples, self.enc.latent_dim, device=device)
+
+ x = self.dec(z)
+ x = self.denormalize(x)
+ return x, z
+
+ @torch.enable_grad()
+ def fit(self, x: torch.Tensor, n_iters: int, lr: float, class_idx: int) -> None:
+ self.min = torch.min(x, dim=0).values
+ self.max = torch.max(x, dim=0).values
+ x = (x - self.min) / (self.max - self.min)
+
+ self.mu = torch.mean(x, dim=0)
+ self.std = torch.std(x, dim=0)
+ x = (x - self.mu) / self.std
+ optimizer = torch.optim.Adam(self.parameters(), lr=lr)
+ sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=n_iters // 10, T_mult=2)
+ self.train()
+ loader = torch.utils.data.DataLoader(x, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
+ with tqdm.trange(n_iters, desc=f"Training VAE [{class_idx}]") as t:
+ for _ in t:
+ for batch in loader:
+ if len(batch) == 1:
+ continue
+ optimizer.zero_grad()
+ predicted, z_mu, z_logvar = self.forward(batch)
+ loss = self.elbo(batch, predicted, z_mu, z_logvar) / len(batch)
+ loss.backward()
+ optimizer.step()
+ t.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])
+ sched.step()
+ self.eval()
+
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
+ x = (x - self.min) / (self.max - self.min)
+ return (x - self.mu) / self.std
+
+ def denormalize(self, x: torch.Tensor) -> torch.Tensor:
+ x = x * self.std + self.mu
+ return x * (self.max - self.min) + self.min
+
+
+def gaussian_nll(mu: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.gaussian_nll_loss(x, mu, torch.ones_like(mu), full=True, reduction='sum')
+
+
+class ELBO(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def compute_rec_error(self, x: torch.Tensor, x_rec: torch.Tensor):
+ return gaussian_nll(x_rec, x)
+
+ def compute_kl(self, z_mu: torch.Tensor, z_logvar: torch.Tensor):
+ return 0.5 * torch.sum(torch.exp(z_logvar) + z_mu**2 - 1.0 - z_logvar)
+
+ def forward(self, x: torch.Tensor, x_rec: torch.Tensor,
+ z_mu: torch.Tensor, z_logvar: torch.Tensor) -> torch.Tensor:
+
+ recon_loss = self.compute_rec_error(x, x_rec)
+
+ kl_loss = self.compute_kl(z_mu, z_logvar)
+
+ if wandb.run:
+ wandb.log({"reconstruction loss": recon_loss, "kl loss": kl_loss})
+
+ return recon_loss + kl_loss
diff --git a/models/clip.py b/models/clip.py
new file mode 100644
index 00000000..b82423f4
--- /dev/null
+++ b/models/clip.py
@@ -0,0 +1,127 @@
+"""
+Adaptation of OpenAI's CLIP.
+Requires:
+- pip install git+https://github.com/openai/CLIP.git
+
+.. note::
+ Checkpoints are loaded from the OpenAI repository.
+ * RN50: "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"
+ * RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"
+ * RN50x4: "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"
+ * RN50x16: "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"
+ * RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"
+ * ViT-B/32: "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
+ * ViT-B/16: "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"
+ * ViT-L/14: "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"
+ * ViT-L/14@336px: "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"
+"""
+
+import torch
+import torch.nn as nn
+try:
+ import clip
+except ImportError:
+ raise ImportError("Please install the CLIP package by running: pip install git+https://github.com/openai/CLIP.git")
+
+from datasets.utils.continual_dataset import ContinualDataset
+from models.utils.continual_model import ContinualModel
+from utils.args import ArgumentParser
+from utils.conf import get_device
+
+
+class FinalModel(nn.Module):
+ @torch.no_grad()
+ def __init__(self, clip_model, dataset: ContinualDataset, args) -> None:
+ super().__init__()
+ self.dataset = dataset
+ self.clip_model = clip_model
+ self.args = args
+
+ self.classes = self.dataset.get_class_names()
+ if args.use_templates:
+ templates = self.dataset.get_prompt_templates()
+ text_inputs = []
+ for t in templates:
+ t_inputs = torch.cat([clip.tokenize(t.format(c)) for c in self.classes]).to(get_device())
+ t_inputs = self.clip_model.encode_text(t_inputs)
+ t_inputs /= t_inputs.norm(dim=-1, keepdim=True) # double normalization if use templates is expected (see https://github.dev/KaiyangZhou/CoOp)
+ text_inputs.append(t_inputs)
+ self.text_features = torch.stack(text_inputs).mean(0)
+ else:
+ text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in self.classes]).to(get_device())
+ self.text_features = self.clip_model.encode_text(text_inputs)
+
+ self.text_features /= self.text_features.norm(dim=-1, keepdim=True) # double normalization if use templates is expected
+ self.task_id = 0
+
+ @torch.no_grad()
+ def forward(self, x):
+ image_features = self.clip_model.encode_image(x)
+ text_features = self.text_features
+
+ image_features /= image_features.norm(dim=-1, keepdim=True)
+ similarity = (100.0 * (image_features @ text_features.T)).softmax(dim=-1)
+
+ return similarity
+
+
+class CLIP(ContinualModel):
+ NAME = 'clip'
+ COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
+
+ @staticmethod
+ def get_parser() -> ArgumentParser:
+ parser = ArgumentParser(description='STATIC Continual Learning with CLIP')
+ parser.add_argument('--clip_backbone', type=str, default='ViT-L/14',
+ choices=list(clip.available_models()),
+ help='Backbone architecture for CLIP')
+ parser.add_argument('--save_predictions', type=int, choices=[0, 1], default=0,
+ help='Whether to save predictions of the TRAINING set after each task')
+ parser.add_argument('--use_templates', type=int, choices=[0, 1], default=0,
+ help='Whether to use prompt templates for CLIP. NOTE: Datasets NEED to have a `get_prompt_templates` method implemented.')
+ return parser
+
+ def __init__(self, backbone, loss, args, transform):
+ backbone, clip_transform = clip.load(args.clip_backbone, device=get_device())
+ n_epochs = 1 if args.save_predictions else 0
+ if args.n_epochs != n_epochs:
+ print(f"CLIP is a STATIC model, setting n_epochs to {n_epochs}")
+ args.n_epochs = n_epochs
+ super().__init__(backbone, loss, args, transform)
+
+ self.net = FinalModel(self.net, self.dataset, args)
+ self.clip_transform = clip_transform
+
+ self.predictions = []
+ self.original_labels = []
+
+ def begin_task(self, dataset):
+ dataset.test_loaders[-1].dataset.transform = self.clip_transform
+ if self.args.save_predictions:
+ dataset.train_loader.dataset.transform = self.clip_transform
+
+ if self.current_task != 0:
+ self.net.task_id += 1
+
+ self.eval()
+
+ def end_task(self, dataset: ContinualDataset) -> None:
+ if self.args.save_predictions:
+ self.predictions = torch.cat(self.predictions, dim=0).cpu()
+ self.original_labels = torch.cat(self.original_labels, dim=0).cpu()
+ torch.save((self.predictions, self.original_labels), f'predictions_{self.args.dataset}_{self.current_task}.pt')
+ print(f"Predictions saved for task {self.current_task} in 'predictions_{self.args.dataset}_{self.current_task}.pt'")
+ self.predictions = []
+ self.original_labels = []
+ return super().end_task(dataset)
+
+ def observe(self, inputs, labels, not_aug_inputs, epoch=None):
+ if self.args.save_predictions:
+ with torch.no_grad():
+ self.predictions.append(self.net(inputs))
+ self.original_labels.append(labels)
+ return 0
+
+ @torch.no_grad()
+ def forward(self, x):
+ return self.net(x)[:, :self.n_seen_classes]
diff --git a/models/coda_prompt.py b/models/coda_prompt.py
index e43ac038..2a0a35c5 100644
--- a/models/coda_prompt.py
+++ b/models/coda_prompt.py
@@ -6,7 +6,7 @@
The backbone is a ViT-B/16 pretrained on Imagenet 21k and finetuned on ImageNet 1k.
"""
-import timm
+import logging
from utils.args import *
from models.utils.continual_model import ContinualModel
import torch
@@ -32,7 +32,7 @@ def get_parser() -> ArgumentParser:
def __init__(self, backbone, loss, args, transform):
del backbone
print("-" * 20)
- print(f"WARNING: CODA-Prompt USES A CUSTOM BACKBONE: `vit_base_patch16_224`.")
+ logging.warning(f"CODA-Prompt USES A CUSTOM BACKBONE: `vit_base_patch16_224`.")
print("Pretrained on Imagenet 21k and finetuned on ImageNet 1k.")
print("-" * 20)
diff --git a/models/coda_prompt_utils/__init__.py b/models/coda_prompt_utils/__init__.py
index 702830f6..d820c886 100644
--- a/models/coda_prompt_utils/__init__.py
+++ b/models/coda_prompt_utils/__init__.py
@@ -1,3 +1,72 @@
"""
This package contains utility functions for the CoDA Prompt model. Implements a custom version of ViT to add prompt parameters.
"""
+
+import copy
+
+import torch
+
+
+def gram_schmidt(vv, start_c, end_c, return_in_parameter=True):
+ """
+ Code for this function is modified from:
+ https://github.com/legendongary/pytorch-gram-schmidt/blob/master/gram_schmidt.py
+
+ Perform Gram-Schmidt orthogonalization on the input matrix vv.
+ """
+
+ def projection(u, v):
+ denominator = (u * u).sum()
+
+ if denominator < 1e-8:
+ return None
+ else:
+ return (v * u).sum() / denominator * u
+
+ # check if the tensor is 3D and flatten the last two dimensions if necessary
+ is_3d = len(vv.shape) == 3
+ if is_3d:
+ shape_2d = copy.deepcopy(vv.shape)
+ vv = vv.view(vv.shape[0], -1)
+
+ # swap rows and columns
+ vv = vv.T
+
+ # process matrix size
+ uu = torch.zeros_like(vv, device=vv.device)
+
+ if start_c > 0:
+ uu[:, 0:start_c] = vv[:, 0:start_c].clone()
+
+ for k in range(start_c, end_c):
+ redo = True
+ while redo:
+ redo = False
+ vk = torch.randn_like(vv[:, k]).to(vv.device)
+ uk = 0
+ for j in range(0, k):
+ if not redo:
+ uj = uu[:, j].clone()
+ proj = projection(uj, vk)
+ if proj is None:
+ redo = True
+ print('restarting!!!')
+ else:
+ uk = uk + proj
+ if not redo:
+ uu[:, k] = vk - uk
+ for k in range(start_c, end_c):
+ uk = uu[:, k].clone()
+ uu[:, k] = uk / (uk.norm())
+
+ # undo swapping of rows and columns
+ uu = uu.T
+
+ # return from 2D
+ if is_3d:
+ uu = uu.view(shape_2d)
+
+ if return_in_parameter:
+ return torch.nn.Parameter(uu)
+
+ return uu
diff --git a/models/coda_prompt_utils/model.py b/models/coda_prompt_utils/model.py
index 676c1fa7..4c54a512 100644
--- a/models/coda_prompt_utils/model.py
+++ b/models/coda_prompt_utils/model.py
@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
from backbone.vit import create_vision_transformer
+from models.coda_prompt_utils import gram_schmidt
from models.coda_prompt_utils.vit import VisionTransformer
-import copy
class CodaPrompt(nn.Module):
@@ -14,6 +14,8 @@ def __init__(self, emb_d, n_tasks, prompt_param, key_dim=768):
self.n_tasks = n_tasks
self._init_smart(emb_d, prompt_param)
+ pt = int(self.e_pool_size / (self.n_tasks))
+
# e prompt init
for e in self.e_layers:
# for model saving/loading simplicity, we init the full paramaters here
@@ -27,9 +29,9 @@ def __init__(self, emb_d, n_tasks, prompt_param, key_dim=768):
p = tensor_prompt(self.e_pool_size, e_l, emb_d)
k = tensor_prompt(self.e_pool_size, self.key_d)
a = tensor_prompt(self.e_pool_size, self.key_d)
- p = self.gram_schmidt(p)
- k = self.gram_schmidt(k)
- a = self.gram_schmidt(a)
+ p = gram_schmidt(p, start_c=0, end_c=pt)
+ k = gram_schmidt(k, start_c=0, end_c=pt)
+ a = gram_schmidt(a, start_c=0, end_c=pt)
setattr(self, f'e_p_{e}', p)
setattr(self, f'e_k_{e}', k)
setattr(self, f'e_a_{e}', a)
@@ -52,81 +54,20 @@ def process_task_count(self):
#
# in the original paper, we used ortho init at the start - this modification is more
# fair in the spirit of continual learning and has little affect on performance
- #
- # code for this function is modified from:
- # https://github.com/legendongary/pytorch-gram-schmidt/blob/master/gram_schmidt.py
+ pt = int(self.e_pool_size / (self.n_tasks))
+ s = int(self.task_count * pt)
+ f = int((self.task_count + 1) * pt)
for e in self.e_layers:
K = getattr(self, f'e_k_{e}')
A = getattr(self, f'e_a_{e}')
P = getattr(self, f'e_p_{e}')
- k = self.gram_schmidt(K)
- a = self.gram_schmidt(A)
- p = self.gram_schmidt(P)
+ k = gram_schmidt(K, s, f)
+ a = gram_schmidt(A, s, f)
+ p = gram_schmidt(P, s, f)
setattr(self, f'e_p_{e}', p)
setattr(self, f'e_k_{e}', k)
setattr(self, f'e_a_{e}', a)
- # code for this function is modified from:
- # https://github.com/legendongary/pytorch-gram-schmidt/blob/master/gram_schmidt.py
- def gram_schmidt(self, vv):
-
- def projection(u, v):
- denominator = (u * u).sum()
-
- if denominator < 1e-8:
- return None
- else:
- return (v * u).sum() / denominator * u
-
- # check if the tensor is 3D and flatten the last two dimensions if necessary
- is_3d = len(vv.shape) == 3
- if is_3d:
- shape_2d = copy.deepcopy(vv.shape)
- vv = vv.view(vv.shape[0], -1)
-
- # swap rows and columns
- vv = vv.T
-
- # process matrix size
- nk = vv.size(1)
- uu = torch.zeros_like(vv, device=vv.device)
-
- # get starting point
- pt = int(self.e_pool_size / (self.n_tasks))
- s = int(self.task_count * pt)
- f = int((self.task_count + 1) * pt)
- if s > 0:
- uu[:, 0:s] = vv[:, 0:s].clone()
- for k in range(s, f):
- redo = True
- while redo:
- redo = False
- vk = torch.randn_like(vv[:, k]).to(vv.device)
- uk = 0
- for j in range(0, k):
- if not redo:
- uj = uu[:, j].clone()
- proj = projection(uj, vk)
- if proj is None:
- redo = True
- print('restarting!!!')
- else:
- uk = uk + proj
- if not redo:
- uu[:, k] = vk - uk
- for k in range(s, f):
- uk = uu[:, k].clone()
- uu[:, k] = uk / (uk.norm())
-
- # undo swapping of rows and columns
- uu = uu.T
-
- # return from 2D
- if is_3d:
- uu = uu.view(shape_2d)
-
- return torch.nn.Parameter(uu)
-
def forward(self, x_querry, l, x_block, train=False, task_id=None):
# e prompts
diff --git a/models/coda_prompt_utils/vit.py b/models/coda_prompt_utils/vit.py
index 412c93cd..f038ed4f 100644
--- a/models/coda_prompt_utils/vit.py
+++ b/models/coda_prompt_utils/vit.py
@@ -6,7 +6,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from functools import partial
from timm.models.layers import trunc_normal_, DropPath
@@ -51,11 +50,13 @@ def forward(self, x, prompt=None):
k = torch.cat((pk, k), dim=2)
v = torch.cat((pv, v), dim=2)
- x = F.scaled_dot_product_attention(q, k, v, scale=self.scale, dropout_p=self.attn_drop.p)
- # attn = (q @ k.transpose(-2, -1)) * self.scale
- # attn = attn.softmax(dim=-1)
- # attn = self.attn_drop(attn)
- # x = (attn @ v)
+ if torch.__version__ >= '2.1.0':
+ x = F.scaled_dot_product_attention(q, k, v, scale=self.scale, dropout_p=self.attn_drop.p)
+ else:
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
diff --git a/models/dualprompt.py b/models/dualprompt.py
index 7e711fb2..e840ace4 100644
--- a/models/dualprompt.py
+++ b/models/dualprompt.py
@@ -6,6 +6,7 @@
The backbone is a ViT-B/16 pretrained on Imagenet 21k and finetuned on ImageNet 1k.
"""
+import logging
import torch
from models.dualprompt_utils.model import Model
@@ -70,7 +71,7 @@ def get_parser() -> ArgumentParser:
def __init__(self, backbone, loss, args, transform):
del backbone
print("-" * 20)
- print(f"WARNING: DualPrompt USES A CUSTOM BACKBONE: `vit_base_patch16_224`.")
+ logging.warning(f"DualPrompt USES A CUSTOM BACKBONE: `vit_base_patch16_224`.")
print("Pretrained on Imagenet 21k and finetuned on ImageNet 1k.")
print("-" * 20)
diff --git a/models/dualprompt_utils/vision_transformer.py b/models/dualprompt_utils/vision_transformer.py
index 9dd510b9..f485fd6b 100644
--- a/models/dualprompt_utils/vision_transformer.py
+++ b/models/dualprompt_utils/vision_transformer.py
@@ -9,7 +9,7 @@
from timm.models.helpers import named_apply
from timm.models.layers import trunc_normal_
-from backbone.vit import Attention, create_vision_transformer, VisionTransformer as MammothVP, get_init_weights_vit
+from backbone.vit import LoRAAttention, create_vision_transformer, VisionTransformer as MammothVP, get_init_weights_vit
from models.dualprompt_utils.prompt import EPrompt
from models.dualprompt_utils.attention import PreT_Attention
@@ -25,10 +25,10 @@ def __init__(
use_e_prompt=False, e_prompt_layer_idx=None, use_prefix_tune_for_e_prompt=False, same_key_value=False, args=None, **kwargs):
if not (use_g_prompt or use_e_prompt):
- attn_layer = Attention
+ attn_layer = LoRAAttention
elif not (use_prefix_tune_for_g_prompt or use_prefix_tune_for_e_prompt):
# Prompt tunning
- attn_layer = Attention
+ attn_layer = LoRAAttention
else:
# Prefix tunning
attn_layer = PreT_Attention
diff --git a/models/first_stage_starprompt.py b/models/first_stage_starprompt.py
new file mode 100644
index 00000000..014e01cd
--- /dev/null
+++ b/models/first_stage_starprompt.py
@@ -0,0 +1,138 @@
+import logging
+import os
+import sys
+import torch
+from argparse import ArgumentParser
+try:
+ import clip
+except ImportError:
+ raise ImportError("Please install the CLIP package by running: pip install git+https://github.com/openai/CLIP.git (requires also `huggingface-hub`)")
+
+from models.utils.continual_model import ContinualModel
+from models.star_prompt_utils.first_stage_model import Model
+
+
+class FirstStageStarprompt(ContinualModel):
+ NAME = 'first_stage_starprompt'
+ COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
+
+ net: Model
+
+ @staticmethod
+ def get_parser() -> ArgumentParser:
+ parser = ArgumentParser()
+
+ frozen_group = parser.add_argument_group('Frozen hyperparameters')
+ frozen_group.add_argument("--virtual_bs_n", type=int, default=1, help="Virtual batch size iterations")
+ frozen_group.add_argument('--gr_mog_n_iters', '--gr_mog_n_iters_first_stage', dest='gr_mog_n_iters_first_stage',
+ type=int, default=500, help="Number of EM iterations during fit for GR with MOG.")
+ frozen_group.add_argument('--gr_mog_n_components', type=int, default=5,
+ help="Number of components for Generative Replay with MOG.")
+ frozen_group.add_argument("--enable_gr", type=int, default=1, choices=[0, 1],
+ help="Enable Generative Replay.")
+ frozen_group.add_argument('--batch_size_gr', type=int, default=128,
+ help="Batch size for Generative Replay.")
+ frozen_group.add_argument('--num_samples_gr', type=int, default=256,
+ help="Number of samples for Generative Replay.")
+
+ # Tunable hyperparameters
+ tunable_group = parser.add_argument_group('Tunable hyperparameters')
+ tunable_group.add_argument("--num_monte_carlo_gr", "--num_monte_carlo_gr_first_stage", dest="num_monte_carlo_gr_first_stage",
+ type=int, default=2, help="How many times to sample from the dataset for Generative Replay")
+ tunable_group.add_argument("--learning_rate_gr", "--learning_rate_gr_first_stage", dest="learning_rate_gr_first_stage",
+ type=float, default=0.05, help="Learning rate for Generative Replay.")
+ tunable_group.add_argument("--lambda_ortho_first_stage", type=float, default=30,
+ help="Orthogonality loss coefficient for coop")
+ tunable_group.add_argument("--num_epochs_gr", "--num_epochs_gr_first_stage", dest="num_epochs_gr_first_stage",
+ type=int, default=10, help="Num. of epochs for Generative Replay.")
+
+ # Useful flags
+ parser.add_argument("--save_first_stage_keys", type=int, default=1,
+ choices=[0, 1], help="save text encoder outputs")
+
+ # Backbone arguments
+ parser.add_argument("--clip_backbone", type=str, default='ViT-L/14', help="CLIP backbone architecture",
+ choices=clip.available_models())
+
+ return parser
+
+ def __init__(self, backbone, loss, args, transform):
+ logging.warning("The first stage of STAR-Prompt ignores the backbone as it uses CLIP")
+ del backbone
+
+ super().__init__(None, loss, args, transform)
+ self.net = Model(args, num_classes=self.num_classes, dataset=self.dataset, device=self.device)
+ self.opt = self.get_optimizer()
+
+ # REMOVE ALL TRACK RUNNING STATS FROM CLIP
+ for m in self.net.modules():
+ if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
+ m.track_running_stats = False
+
+ self.eye = torch.eye(self.num_classes).to(self.device)
+
+ def end_task(self, dataset):
+ if hasattr(self, 'opt'):
+ self.opt.zero_grad(set_to_none=True)
+ delattr(self, 'opt')
+
+ # Generative replay
+ if self.args.enable_gr:
+ self.net.prompter.update_statistics(dataset, self.current_task)
+ self.net.prompter.align(self.current_task)
+
+ if self.current_task == (self.n_tasks - 1) and self.args.save_first_stage_keys:
+ print('Saving text encoder outputs... ', end='', file=sys.stderr)
+ te_outputs = self.net.prompter.compute_keys(0, self.num_classes)
+ os.makedirs('./coop_keys', exist_ok=True)
+ st = {
+ 'keys': te_outputs,
+ 'args': self.args,
+ }
+ fname = f'./coop_keys/coop_keys_{self.current_task}_{self.args.conf_jobnum}.pt'
+ torch.save(st, fname)
+ print('Saved text-encoder keys in:', fname, file=sys.stderr)
+
+ def get_parameters(self):
+ return [v for k, v in self.net.named_parameters() if 'prompt_parameters' in k]
+
+ def begin_task(self, dataset):
+ # Disable transforms and set normalization as CLIP's preprocessing
+ dataset.train_loader.dataset.transform = self.net.prompter.clip_preprocess
+ dataset.test_loaders[-1].dataset.transform = self.net.prompter.clip_preprocess
+
+ if hasattr(self, 'opt'):
+ self.opt.zero_grad(set_to_none=True)
+ delattr(self, 'opt')
+
+ self.opt = self.get_optimizer()
+
+ torch.cuda.empty_cache()
+
+ def forward(self, x):
+ logits = self.net(x, cur_classes=self.n_seen_classes)
+ return logits[:, :self.n_seen_classes]
+
+ def observe(self, inputs, labels, not_aug_inputs, epoch=None):
+ loss = torch.tensor(0.).to(self.device)
+
+ stream_inputs, stream_labels = inputs, labels.long()
+ clip_logits = self.net(stream_inputs, frozen_past_classes=self.n_past_classes, cur_classes=self.n_seen_classes)
+
+ # compute clip loss
+ clip_logits[:, :self.n_past_classes] = -float('inf')
+ loss_clip = self.loss(clip_logits[:, :self.n_seen_classes], stream_labels)
+
+ loss += loss_clip
+
+ loss_ortho_coop = self.net.prompter.compute_ortho_loss(frozen_past_classes=self.n_past_classes, cur_classes=self.n_seen_classes)
+ loss += self.args.lambda_ortho_first_stage * loss_ortho_coop
+
+ if self.epoch_iteration == 0:
+ self.opt.zero_grad()
+ (loss / self.args.virtual_bs_n).backward()
+ if (self.epoch_iteration > 0 or self.args.virtual_bs_n == 1) and self.epoch_iteration % self.args.virtual_bs_n == 0:
+ self.opt.step()
+ self.opt.zero_grad()
+
+ return loss.item()
diff --git a/models/l2p.py b/models/l2p.py
index 026976b2..3f50d208 100644
--- a/models/l2p.py
+++ b/models/l2p.py
@@ -84,10 +84,9 @@ def observe(self, inputs, labels, not_aug_inputs, epoch=None):
logits = outputs['logits']
# here is the trick to mask out classes of non-current tasks
- offset_1, offset_2 = self._compute_offsets(self.current_task)
- logits[:, :offset_1] = -float('inf')
+ logits[:, :self.n_past_classes] = -float('inf')
- loss = self.loss(logits[:, :offset_2], labels)
+ loss = self.loss(logits[:, :self.n_seen_classes], labels)
if self.args.pull_constraint and 'reduce_sim' in outputs:
loss = loss - self.args.pull_constraint_coeff * outputs['reduce_sim']
@@ -102,8 +101,4 @@ def get_parameters(self):
return [p for n, p in self.net.model.named_parameters() if 'prompt' in n or 'head' in n]
def forward(self, x):
- if self.current_task > 0:
- _, offset_2 = self._compute_offsets(self.current_task - 1)
- else:
- offset_2 = self.N_CLASSES
- return self.net(x)[:, :offset_2]
+ return self.net(x)[:, :self.n_seen_classes]
diff --git a/models/second_stage_starprompt.py b/models/second_stage_starprompt.py
new file mode 100644
index 00000000..4e927d69
--- /dev/null
+++ b/models/second_stage_starprompt.py
@@ -0,0 +1,321 @@
+import torch
+from copy import deepcopy
+from torch.utils.data import TensorDataset
+from tqdm import tqdm
+from argparse import ArgumentParser
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+from utils.augmentations import RepeatedTransform
+from utils.conf import create_seeded_dataloader
+from utils.schedulers import CosineSchedule
+from models.utils.continual_model import ContinualModel
+from models.star_prompt_utils.second_stage_model import Model
+from models.star_prompt_utils.generative_replay import Gaussian, MixtureOfGaussiansModel
+
+
+class SecondStageStarprompt(ContinualModel):
+ NAME = 'second_stage_starprompt'
+ COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
+
+ @staticmethod
+ def get_parser() -> ArgumentParser:
+ parser = ArgumentParser(description='Second-stage of StarPrompt. Requires the keys saved from the first stage.')
+
+ frozen_group = parser.add_argument_group('Frozen hyperparameters')
+ frozen_group.add_argument("--virtual_bs_n", type=int, default=1,
+ help="virtual batch size iterations")
+ frozen_group.add_argument("--enable_data_aug_query", type=int, default=1, choices=[0, 1],
+ help="Use default transform with data aug to generate the CLIP's response?")
+ frozen_group.add_argument("--use_clip_preprocess_eval", type=int, default=0, choices=[0, 1],
+ help="Use CLIP's transform during eval instead of the default test transform?")
+ frozen_group.add_argument("--ortho_split_val", type=int, default=0)
+ frozen_group.add_argument('--gr_mog_n_iters', '--gr_mog_n_iters_second_stage', dest='gr_mog_n_iters_second_stage',
+ type=int, default=500, help="Number of EM iterations during fit for GR with MOG.")
+ frozen_group.add_argument('--gr_mog_n_components', type=int, default=5,
+ help="Number of components for GR with MOG.")
+ frozen_group.add_argument('--batch_size_gr', type=int, default=128,
+ help="Batch size for Generative Replay.")
+ frozen_group.add_argument('--num_samples_gr', type=int, default=256,
+ help="Number of samples for Generative Replay.")
+ frozen_group.add_argument('--prefix_tuning_prompt_len', type=int, default=5,
+ help="Prompt length for prefix tuning. Used only if `--prompt_mode==concat`.")
+
+ ablation_group = parser.add_argument_group('Ablations hyperparameters')
+ ablation_group.add_argument('--gr_model', type=str, default='mog', choices=['mog', 'gaussian'],
+ help="Type of distribution model for Generative Replay. "
+ "- `mog`: Mixture of Gaussian. "
+ "- `gaussian`: Single Gaussian distribution.")
+ ablation_group.add_argument("--enable_gr", type=int, default=1, choices=[0, 1],
+ help="Enable Generative Replay.")
+ ablation_group.add_argument('--statc_keys_use_templates', type=int, default=1, choices=[0, 1],
+ help="Use templates for the second stage if no keys are loaded.")
+ ablation_group.add_argument('--prompt_mode', type=str, default='residual', choices=['residual', 'concat'],
+ help="Prompt type for the second stage. "
+ "- `residual`: STAR-Prompt style prompting. "
+ "- `concat`: Prefix-Tuning style prompting.")
+ ablation_group.add_argument("--enable_confidence_modulation", type=int, default=1, choices=[0, 1],
+ help="Enable confidence modulation with CLIP similarities (Eq. 5 of the main paper)?")
+
+ # Tunable hyperparameters
+ tunable_group = parser.add_argument_group('Tunable hyperparameters')
+ tunable_group.add_argument("--lambda_ortho_second_stage", type=float, default=10,
+ help="orthogonality loss coefficient")
+ tunable_group.add_argument("--num_monte_carlo_gr", "--num_monte_carlo_gr_second_stage", dest="num_monte_carlo_gr_second_stage",
+ type=int, default=1, help="how many times to sample from the dataset for alignment")
+ tunable_group.add_argument("--num_epochs_gr", "--num_epochs_gr_second_stage", dest="num_epochs_gr_second_stage",
+ type=int, default=10, help="Num. of epochs for GR.")
+ tunable_group.add_argument("--learning_rate_gr", "--learning_rate_gr_second_stage", dest="learning_rate_gr_second_stage",
+ type=float, default=0.001, help="Learning rate for GR.")
+
+ # Very important parameter
+ parser.add_argument('--keys_ckpt_path', type=str,
+ help="Path for first-stage keys. "
+ "The keys can be saved by runninng `first_stage_starprompt` with `--save_first_stage_keys=1`."
+ "This can be:"
+ "- A path to a checkpoint file (.pt) containing ONLY THE FIRST STAGE KEYS."
+ "- A path to the checkpoint made by `first_stage_starprompt`"
+ "- The job-id (`conf_jobnum`) of the `first_stage_starprompt` run that made the keys."
+ "- A JSON file containing the job-id (`conf_jobnum`) of the `first_stage_starprompt` run that made the keys."
+ "The JSON is expected to contain an entry for each dataset and seed: `{dataset: {seed: job-id}}`.")
+
+ return parser
+
+ net: Model
+
+ def __init__(self, backbone, loss, args, transform):
+ super().__init__(backbone, loss, args, transform)
+
+ self.net = Model(args,
+ backbone=self.net,
+ dataset=self.dataset,
+ num_classes=self.num_classes,
+ device=self.device)
+
+ # REMOVE ALL TRACK RUNNING STATS FROM CLIP
+ for m in self.net.modules():
+ if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
+ m.track_running_stats = False
+
+ embed_dim = self.net.vit.embed_dim
+
+ self.distributions = torch.nn.ModuleList([self._get_dist(embed_dim)
+ for _ in range(self.num_classes)]).to(self.device)
+ self.classifier_state_dict = None
+
+ def _get_dist(self, embed_dim):
+ assert self.args.gr_model in ['mog', 'gaussian'], f"Invalid GR model: {self.args.gr_model}"
+
+ if self.args.gr_model == 'mog':
+ return MixtureOfGaussiansModel(embed_dim, n_components=self.args.gr_mog_n_components,
+ n_iters=self.args.gr_mog_n_iters_second_stage)
+ else:
+ return Gaussian(embed_dim)
+
+ def norm(self, t):
+ return torch.norm(t, p=2, dim=-1, keepdim=True) + 1e-7
+
+ @torch.no_grad()
+ def create_features_dataset(self):
+
+ labels, features = [], []
+
+ for _ti in range(self.current_task + 1):
+
+ prev_t_size, cur_t_size = self.compute_offsets(_ti)
+
+ for class_idx in range(prev_t_size, cur_t_size):
+ current_samples = self.distributions[class_idx](self.args.num_samples_gr)
+ features.append(current_samples)
+ labels.append(torch.ones(self.args.num_samples_gr) * class_idx)
+
+ features = torch.cat(features, dim=0)
+ labels = torch.cat(labels, dim=0).long()
+
+ return create_seeded_dataloader(self.args, TensorDataset(features, labels),
+ batch_size=self.args.batch_size_gr,
+ shuffle=True, num_workers=0)
+
+ def train_alignment_epoch(self, classifier: torch.nn.Module, optim: torch.optim.Optimizer, epoch: int):
+
+ dl = self.create_features_dataset()
+
+ with tqdm(enumerate(dl), total=len(dl), desc=f'GR second stage epoch {epoch + 1}/{self.args.num_epochs_gr_second_stage}', leave=False) as pbar:
+ for i, (x, labels) in pbar:
+ optim.zero_grad()
+ x, labels = x.to(self.device, dtype=torch.float32), labels.to(self.device)
+
+ logits = classifier(x)
+
+ logits = logits[:, :self.n_seen_classes]
+
+ norm = self.norm(logits)
+ logits = logits / (0.1 * norm)
+
+ loss = self.loss(logits, labels)
+ loss.backward()
+ optim.step()
+
+ if not self.args.nowand:
+ assert wandb is not None, "wandb is not installed."
+ wandb.log({'ca_loss': loss.item(), 'ca_lr': optim.param_groups[0]['lr']})
+ pbar.set_postfix({'loss': loss.item()})
+
+ def align(self):
+
+ classifier = deepcopy(self.net.vit.head)
+
+ optim = torch.optim.SGD(lr=self.args.learning_rate_gr_second_stage,
+ params=classifier.parameters(),
+ momentum=0.0,
+ weight_decay=0.0)
+
+ num_epochs = self.args.num_epochs_gr_second_stage + (5 * self.current_task)
+
+ for e in range(num_epochs):
+ self.train_alignment_epoch(classifier, optim, e)
+
+ self.net.vit.head.weight.data.copy_(classifier.weight.data)
+ self.net.vit.head.bias.data.copy_(classifier.bias.data)
+
+ @torch.no_grad()
+ def update_statistics(self, dataset):
+
+ features_dict = {i: [] for i in range(self.n_past_classes, self.n_seen_classes)}
+
+ self.net.eval()
+
+ with tqdm(total=self.args.num_monte_carlo_gr_second_stage * len(dataset.train_loader), desc='GR update statistics') as pbar:
+ for _ in range(self.args.num_monte_carlo_gr_second_stage):
+ for i, data in enumerate(dataset.train_loader):
+ if self.args.debug_mode and i > 3 and min([len(v) for k, v in features_dict.items()]) > self.args.gr_mog_n_components:
+ break
+
+ x, labels = data[0], data[1]
+ x, labels = x.to(self.device), labels.to(self.device, dtype=torch.long)
+ x, query_x = x[:, 0], x[:, 1]
+ if self.args.enable_data_aug_query:
+ query_x = None
+ features = self.net(x, query_x=query_x, return_features=True, cur_classes=self.n_seen_classes, frozen_past_classes=self.n_past_classes)
+ features = features[:, 0]
+
+ for class_idx in labels.unique():
+ features_dict[int(class_idx)].append(features[labels == class_idx])
+
+ pbar.update(1)
+
+ for class_idx in range(self.n_past_classes, self.n_seen_classes):
+ features_class_idx = torch.cat(features_dict[class_idx], dim=0)
+ self.distributions[class_idx].fit(features_class_idx.to(self.device))
+
+ def backup(self):
+ print(f"BACKUP: Task - {self.current_task} - classes from "
+ f"{self.n_past_classes} - to {self.n_seen_classes}")
+ self.classifier_state_dict = deepcopy(self.net.vit.head.state_dict())
+
+ def recall(self):
+ print(f"RECALL: Task - {self.current_task} - classes from "
+ f"{self.n_past_classes} - to {self.n_seen_classes}")
+
+ if self.current_task == 0 or self.args.enable_gr == 0:
+ return
+
+ assert self.classifier_state_dict
+
+ self.net.vit.head.weight.data.copy_(self.classifier_state_dict['weight'].data)
+ self.net.vit.head.bias.data.copy_(self.classifier_state_dict['bias'].data)
+
+ def end_task(self, dataset):
+ if hasattr(self, 'opt'):
+ del self.opt # free up some vram
+
+ if self.args.enable_gr:
+ self.update_statistics(dataset)
+ self.backup()
+
+ if self.current_task > 0:
+ self.align()
+
+ def get_parameters(self):
+ return [p for p in self.net.parameters() if p.requires_grad]
+
+ def get_scheduler(self):
+ return CosineSchedule(self.opt, K=self.args.n_epochs)
+
+ def begin_task(self, dataset):
+ if self.args.permute_classes:
+ if hasattr(self.net.prompter, 'old_args') and self.net.prompter.old_args is not None:
+ assert self.args.seed == self.net.prompter.old_args.seed
+ assert (self.args.class_order == self.net.prompter.old_args.class_order).all()
+
+ dataset.train_loader.dataset.transform = RepeatedTransform([dataset.train_loader.dataset.transform, self.net.prompter.clip_preprocess])
+ dataset.test_loaders[-1].dataset.transform = RepeatedTransform([dataset.test_loaders[-1].dataset.transform, self.net.prompter.clip_preprocess])
+
+ # NOTE: Remove these comments if you want to check if the keys are loaded correctly and results are the same as the first stage
+ # tot_data, tot_corr = 0, 0
+ # for i, ts in enumerate(dataset.test_loaders):
+ # task_tot, task_corr = 0, 0
+ # for data in ts:
+ # inputs, labels = data[0], data[1]
+ # inputs, labels = inputs[:, 1].to(self.device), labels.to(self.device) # only clip-preprocessed input
+ # queries = self.net.prompter.get_query(inputs)
+ # queries = torch.nn.functional.normalize(queries, dim=-1)
+ # logits = torch.einsum('bd,cd->bc', queries, self.net.prompter.keys.type(self.net.prompter.clip_model.dtype))
+ # task_corr += (logits.argmax(dim=-1) == labels).sum().item()
+ # task_tot += labels.shape[0]
+ # print(f"CLIP on TASK {i+1}: {task_corr / task_tot}")
+ # tot_corr += task_corr
+ # tot_data += task_tot
+ # print(f"AVG CLIP ON TASKS: {tot_corr / tot_data}") # the avg of the avg != the avg of the total
+
+ # For later GR
+ self.recall()
+
+ if hasattr(self, 'opt'):
+ del self.opt
+
+ self.opt = self.get_optimizer()
+ self.scheduler = self.get_scheduler()
+
+ def forward(self, x):
+ x, query_x = x[:, 0], x[:, 1] # from repeated transform
+ if self.args.use_clip_preprocess_eval == 0:
+ query_x = None
+ logits = self.net(x, query_x=query_x, cur_classes=self.n_seen_classes)
+ logits = logits[:, :self.n_seen_classes]
+ return logits
+
+ def observe(self, inputs, labels, not_aug_inputs, epoch=None):
+ stream_inputs, stream_labels = inputs, labels
+ stream_inputs, query_stream_inputs = stream_inputs[:, 0], stream_inputs[:, 1]
+ if self.args.enable_data_aug_query:
+ query_stream_inputs = None
+ stream_logits = self.net(stream_inputs, query_x=query_stream_inputs, cur_classes=self.n_seen_classes, frozen_past_classes=self.n_past_classes)
+
+ # Compute accuracy on current training batch for logging
+ with torch.no_grad():
+ stream_preds = stream_logits[:, :self.n_seen_classes].argmax(dim=1)
+ stream_acc = (stream_preds == stream_labels).sum().item() / stream_labels.shape[0]
+
+ # mask old classes
+ stream_logits[:, :self.n_past_classes] = -float('inf')
+ loss = self.loss(stream_logits[:, :self.n_seen_classes], stream_labels)
+
+ loss_ortho = self.net.prompter.compute_ortho_loss(frozen_past_classes=self.n_past_classes, cur_classes=self.n_seen_classes)
+ loss += self.args.lambda_ortho_second_stage * loss_ortho
+
+ if self.epoch_iteration == 0:
+ self.opt.zero_grad()
+
+ (loss / self.args.virtual_bs_n).backward()
+ # loss.backward()
+ if (self.epoch_iteration > 0 or self.args.virtual_bs_n == 1) and \
+ self.epoch_iteration % self.args.virtual_bs_n == 0:
+ self.opt.step()
+ self.opt.zero_grad()
+
+ return {'loss': loss.item(),
+ 'stream_accuracy': stream_acc}
diff --git a/models/slca.py b/models/slca.py
index 9647e187..5d9fec9c 100644
--- a/models/slca.py
+++ b/models/slca.py
@@ -11,7 +11,6 @@
from utils.args import *
from models.utils.continual_model import ContinualModel
-import timm
import torch
from utils.conf import get_device
from models.slca_utils.slca import SLCA_Model
diff --git a/models/star_prompt_utils/end_to_end_model.py b/models/star_prompt_utils/end_to_end_model.py
new file mode 100644
index 00000000..6707fd4e
--- /dev/null
+++ b/models/star_prompt_utils/end_to_end_model.py
@@ -0,0 +1,327 @@
+from copy import deepcopy
+from typing import Tuple
+import torch
+from torch import nn
+from torch.utils.data import TensorDataset
+
+from tqdm import tqdm
+try:
+ import wandb
+except ImportError:
+ wandb = None
+try:
+ from clip.model import convert_weights
+except ImportError:
+ raise ImportError("Please install the CLIP package by running: pip install git+https://github.com/openai/CLIP.git (requires also `huggingface-hub`)")
+
+
+from utils.conf import create_seeded_dataloader
+from datasets.utils.continual_dataset import ContinualDataset
+from models.star_prompt_utils.first_stage_model import Model as FirstStageModel
+from models.star_prompt_utils.second_stage_model import Model as SecondStageModel
+from models.star_prompt_utils.generative_replay import Gaussian, MixtureOfGaussiansModel
+
+
+class STARPromptModel(nn.Module):
+ first_stage: FirstStageModel
+ second_stage: SecondStageModel
+
+ def __init__(self, args, backbone: nn.Module, num_classes: int, dataset: ContinualDataset, device='cpu'):
+ super().__init__()
+ self.args = args
+ self.num_classes = num_classes
+ self.device = device
+ self.dataset = dataset
+ self.first_stage = FirstStageModel(args=args, num_classes=num_classes, dataset=dataset, device=device)
+
+ # REMOVE ALL TRACK RUNNING STATS FROM CLIP
+ for m in self.first_stage.modules():
+ if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
+ m.track_running_stats = False
+
+ self.second_stage = SecondStageModel(args=args, num_classes=num_classes,
+ dataset=dataset, backbone=backbone,
+ clip_model=self.first_stage.prompter.clip_model,
+ clip_preprocess=self.first_stage.prompter.clip_preprocess,
+ device=device)
+
+ embed_dim = self.second_stage.vit.embed_dim
+
+ self.second_stage_distributions = torch.nn.ModuleList([self._get_dist(embed_dim)
+ for _ in range(self.num_classes)]).to(self.device)
+ self.classifier_state_dict = None
+
+ def _get_dist(self, embed_dim):
+ assert self.args.gr_model in ['mog', 'gaussian'], f"Invalid GR model: {self.args.gr_model}"
+
+ if self.args.gr_model == 'mog':
+ return MixtureOfGaussiansModel(embed_dim, n_components=self.args.gr_mog_n_components,
+ n_iters=self.args.gr_mog_n_iters_second_stage)
+ else:
+ return Gaussian(embed_dim)
+
+ @torch.no_grad()
+ def update_keys(self, start_c: int, end_c: int):
+ print('Updating keys for second stage...')
+ first_stage_keys = self.first_stage.prompter.compute_keys(start_c, end_c)
+ self.second_stage.prompter.set_keys(first_stage_keys, start_c, end_c)
+
+ def forward(self, x: torch.Tensor, cur_classes: int, frozen_past_classes=0, return_query=False) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Compute the complete forward pass of STAR-Prompt.
+ This assumes that the keys are already pre-computed.
+
+ Args:
+ x: The input tensor.
+ cur_classes: The number of current classes.
+ frozen_past_classes: The number of past classes.
+ return_query: Whether to return the query tensor with the output.
+ """
+ return self.second_stage(x, cur_classes=cur_classes, frozen_past_classes=frozen_past_classes, return_query=return_query)
+
+ def train(self, mode: bool = True):
+ self.first_stage.train(mode)
+ self.second_stage.train(mode)
+
+ def to(self, device, *args, **kwargs):
+ super().to(device, *args, **kwargs)
+ self.first_stage.to(device, *args, **kwargs)
+ self.second_stage.to(device, *args, **kwargs)
+ self.device = device
+
+ return self
+
+ @torch.no_grad()
+ def eval_first_stage_on_task(self, dataset: ContinualDataset, n_seen_classes: int) -> torch.Tensor:
+ """
+ Compute and return the accuracy on each task so far.
+ """
+ was_training = self.first_stage.training
+ self.first_stage.eval()
+ all_accs = []
+ with tqdm(total=sum([len(test_loader) for test_loader in dataset.test_loaders]), desc='Eval first stage on seen tasks') as pbar:
+ for t, test_loader in enumerate(dataset.test_loaders):
+ total = 0
+ correct = 0
+ for inputs, labels in test_loader:
+ inputs, labels = inputs.to(self.device), labels.to(self.device, dtype=torch.long)
+ logits = self.first_stage(inputs, cur_classes=n_seen_classes)[:, :n_seen_classes]
+ _, predicted = torch.max(logits, 1)
+ total += labels.size(0)
+ correct += (predicted == labels).sum().item()
+ pbar.update(1)
+ all_accs.append(correct / total)
+ self.first_stage.train(was_training)
+ return torch.tensor(all_accs)
+
+ def norm(self, t):
+ return torch.norm(t, p=2, dim=-1, keepdim=True) + 1e-7
+
+ @torch.no_grad()
+ def create_features_dataset(self, current_task: int):
+
+ labels, features = [], []
+
+ for _ti in range(current_task + 1):
+
+ prev_t_size, cur_t_size = self.dataset.get_offsets(_ti)
+
+ for class_idx in range(prev_t_size, cur_t_size):
+ current_samples = self.second_stage_distributions[class_idx](self.args.num_samples_gr)
+ features.append(current_samples)
+ labels.append(torch.ones(self.args.num_samples_gr) * class_idx)
+
+ features = torch.cat(features, dim=0)
+ labels = torch.cat(labels, dim=0).long()
+
+ return create_seeded_dataloader(self.args, TensorDataset(features, labels),
+ batch_size=self.args.batch_size_gr,
+ shuffle=True, num_workers=0)
+
+ def train_alignment_epoch(self, classifier: torch.nn.Module, optim: torch.optim.Optimizer, n_seen_classes: int, current_task: int, loss_fn):
+
+ dl = self.create_features_dataset(current_task)
+
+ with tqdm(enumerate(dl), total=len(dl), desc='GR epoch') as pbar:
+ for i, (x, labels) in pbar:
+ optim.zero_grad()
+ x, labels = x.to(self.device, dtype=torch.float32), labels.to(self.device)
+
+ logits = classifier(x)
+
+ logits = logits[:, :n_seen_classes]
+
+ norm = self.norm(logits)
+ logits = logits / (0.1 * norm)
+
+ loss = loss_fn(logits, labels)
+ loss.backward()
+ optim.step()
+
+ if not self.args.nowand:
+ assert wandb is not None, "wandb is not installed."
+ wandb.log({'ca_loss': loss.item(), 'ca_lr': optim.param_groups[0]['lr']})
+ pbar.set_postfix({'loss': loss.item()})
+
+ def align(self, current_task: int, n_seen_classes: int, loss_fn):
+
+ classifier = deepcopy(self.second_stage.vit.head)
+
+ optim = torch.optim.SGD(lr=self.args.learning_rate_gr_second_stage,
+ params=classifier.parameters(),
+ momentum=0.0,
+ weight_decay=0.0)
+
+ num_epochs = self.args.num_epochs_gr_second_stage + (5 * current_task)
+
+ for e in range(num_epochs):
+ self.train_alignment_epoch(classifier, optim, n_seen_classes=n_seen_classes, current_task=current_task, loss_fn=loss_fn)
+
+ self.second_stage.vit.head.weight.data.copy_(classifier.weight.data)
+ self.second_stage.vit.head.bias.data.copy_(classifier.bias.data)
+
+ @torch.no_grad()
+ def update_statistics(self, dataset: ContinualDataset, n_past_classes: int, n_seen_classes: int):
+
+ features_dict = {i: [] for i in range(n_past_classes, n_seen_classes)}
+
+ self.second_stage.eval()
+
+ with tqdm(total=self.args.num_monte_carlo_gr_second_stage * len(dataset.train_loader), desc='GR update statistics') as pbar:
+ for _ in range(self.args.num_monte_carlo_gr_second_stage):
+ for i, data in enumerate(dataset.train_loader):
+ if self.args.debug_mode and i > 3 and min([len(v) for k, v in features_dict.items()]) > self.args.gr_mog_n_components:
+ break
+
+ x, labels = data[0], data[1]
+ x, labels = x.to(self.device), labels.to(self.device, dtype=torch.long)
+ features = self.second_stage(x, return_features=True, cur_classes=n_seen_classes, frozen_past_classes=n_past_classes)
+ features = features[:, 0]
+
+ for class_idx in labels.unique():
+ features_dict[int(class_idx)].append(features[labels == class_idx])
+
+ pbar.update(1)
+
+ for class_idx in range(n_past_classes, n_seen_classes):
+ features_class_idx = torch.cat(features_dict[class_idx], dim=0)
+ self.second_stage_distributions[class_idx].fit(features_class_idx.to(self.device))
+
+ def backup(self, current_task: int, n_past_classes: int, n_seen_classes: int):
+ print(f"BACKUP: Task - {current_task} - classes from "
+ f"{n_past_classes} - to {n_seen_classes}")
+ self.classifier_state_dict = deepcopy(self.second_stage.vit.head.state_dict())
+
+ def recall_classifier_second_stage(self, current_task: int, n_past_classes: int, n_seen_classes: int):
+ print(f"RECALL: Task - {current_task} - classes from "
+ f"{n_past_classes} - to {n_seen_classes}")
+
+ if current_task == 0 or self.args.enable_gr == 0:
+ return
+
+ assert self.classifier_state_dict
+
+ self.second_stage.vit.head.weight.data.copy_(self.classifier_state_dict['weight'].data)
+ self.second_stage.vit.head.bias.data.copy_(self.classifier_state_dict['bias'].data)
+
+ @torch.enable_grad()
+ def train_first_stage_on_task(self, dataset: ContinualDataset, current_task: int, n_past_classes: int, n_seen_classes: int, loss_fn):
+ """
+ Train the first stage on the current task.
+
+ Args:
+ dataset: The continual dataset for the current task, containing both train and test (validation) set.
+ current_task: The current task index.
+ n_past_classes: The number of past classes.
+ n_seen_classes: The number of seen classes.
+ loss_fn: The loss function.
+ """
+ print("Starting training of first stage on task", current_task)
+ # BEGIN-TASK
+ old_train_transform = dataset.train_loader.dataset.transform
+ old_test_transform = dataset.test_loaders[-1].dataset.transform
+
+ # use CLIP's preprocessing
+ dataset.train_loader.dataset.transform = self.first_stage.prompter.clip_preprocess
+ dataset.test_loaders[-1].dataset.transform = self.first_stage.prompter.clip_preprocess
+
+ convert_weights(self.first_stage.prompter.clip_model) # convert weights to float16 during training for speedup
+ self.first_stage.prompter.text_encoder.dtype = torch.float16
+ was_training = self.first_stage.training
+ self.first_stage.train()
+
+ first_stage_params = [v for k, v in self.first_stage.named_parameters() if 'prompt_parameters' in k]
+ if self.args.first_stage_optim == 'sgd':
+ opt = torch.optim.SGD(first_stage_params, lr=self.args.first_stage_lr, momentum=self.args.first_stage_momentum,
+ weight_decay=self.args.first_stage_weight_decay)
+ else:
+ opt = torch.optim.Adam(first_stage_params, lr=self.args.first_stage_lr,
+ weight_decay=self.args.first_stage_weight_decay)
+
+ # MINI TRAINING LOOP FOR CURRENT TASK
+ with tqdm(total=self.args.first_stage_epochs * len(dataset.train_loader), desc='First stage training') as pbar:
+ for epoch in range(self.args.first_stage_epochs):
+ for i, data in enumerate(dataset.train_loader):
+ if self.args.debug_mode and i > 3:
+ break
+ inputs, labels = data[0].to(self.device), data[1].to(self.device, dtype=torch.long)
+ loss = torch.tensor(0.).to(self.device)
+
+ opt.zero_grad()
+ # Check cur and past classes
+ clip_logits = self.first_stage(inputs, frozen_past_classes=n_past_classes, cur_classes=n_seen_classes)
+
+ # compute clip loss
+ clip_logits[:, :n_past_classes] = -float('inf')
+ loss_clip = loss_fn(clip_logits[:, :n_seen_classes], labels)
+
+ loss += loss_clip
+
+ loss_ortho_coop = self.first_stage.prompter.compute_ortho_loss(frozen_past_classes=n_past_classes, cur_classes=n_seen_classes)
+ loss += self.args.lambda_ortho_first_stage * loss_ortho_coop
+
+ if i == 0:
+ opt.zero_grad()
+ (loss / self.args.virtual_bs_n).backward()
+ if (i > 0 or self.args.virtual_bs_n == 1) and i % self.args.virtual_bs_n == 0:
+ opt.step()
+ opt.zero_grad()
+
+ if not self.args.nowand:
+ assert wandb is not None, "wandb is not installed."
+ wandb.log({'first_stage_loss': loss.item(),
+ 'first_stage_lr': opt.param_groups[0]['lr'],
+ 'first_stage_epoch': epoch,
+ 'first_stage_loss_clip': loss_clip.item(),
+ 'first_stage_loss_ortho': loss_ortho_coop.item(),
+ 'first_stage_iteration': i})
+
+ pbar.update(1)
+ pbar.set_postfix({'loss': loss.item()})
+
+ # END-TASK
+ opt.zero_grad(set_to_none=True)
+ del opt
+ torch.cuda.empty_cache()
+
+ # Generative replay after end of task
+ if self.args.enable_gr:
+ self.first_stage.prompter.update_statistics(dataset, current_task)
+ self.first_stage.prompter.align(current_task)
+
+ cur_acc = self.eval_first_stage_on_task(dataset, n_seen_classes)
+ print(f'First stage accuracy: {[acc.item() for acc in cur_acc]}')
+ print(f'\tAverage: {cur_acc.mean().item():.4f}')
+ if not self.args.nowand:
+ assert wandb is not None, "wandb is not installed."
+ log_dict = {f'first_stage_acc_{i}': acc.item() for i, acc in enumerate(cur_acc)}
+ log_dict['first_stage_acc'] = cur_acc.mean().item()
+ wandb.log(log_dict)
+
+ # restore original transforms
+ dataset.train_loader.dataset.transform = old_train_transform
+ dataset.test_loaders[-1].dataset.transform = old_test_transform
+
+ self.first_stage.prompter.clip_model.float() # convert back to float32
+ self.first_stage.prompter.text_encoder.dtype = torch.float32
+ self.first_stage.train(was_training)
diff --git a/models/star_prompt_utils/first_stage_model.py b/models/star_prompt_utils/first_stage_model.py
new file mode 100644
index 00000000..6489a949
--- /dev/null
+++ b/models/star_prompt_utils/first_stage_model.py
@@ -0,0 +1,269 @@
+import math
+from typing import List
+import torch.nn.functional as F
+import torch
+from torch.utils.data import TensorDataset
+from tqdm import tqdm
+
+from utils.conf import create_seeded_dataloader
+try:
+ import clip
+except ImportError:
+ raise ImportError("Please install the CLIP package by running: pip install git+https://github.com/openai/CLIP.git")
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+from datasets.utils.continual_dataset import ContinualDataset
+from models.star_prompt_utils.generative_replay import MixtureOfGaussiansModel
+
+
+class TextEncoder(torch.nn.Module):
+ def __init__(self, clip_model):
+ super().__init__()
+ self.transformer = clip_model.transformer
+ self.positional_embedding = clip_model.positional_embedding
+ self.ln_final = clip_model.ln_final
+ self.text_projection = clip_model.text_projection
+ self.dtype = clip_model.dtype
+
+ def forward(self, x, tokenized_prompts):
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2)
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2)
+ x = self.ln_final(x).type(self.dtype)
+ x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
+ return x
+
+
+class Prompter(torch.nn.Module):
+
+ distributions: List[MixtureOfGaussiansModel]
+ token_suffix: torch.Tensor
+ token_prefix: torch.Tensor
+
+ def __init__(self, args, num_classes: int, dataset: ContinualDataset, device='cpu'):
+ super().__init__()
+ self.args = args
+ self.num_classes = num_classes
+ self.dataset = dataset
+ self.device = device
+
+ self.clip_model, self.clip_preprocess = clip.load(args.clip_backbone, self.device)
+
+ for p in self.clip_model.parameters():
+ p.requires_grad = False
+
+ self.class_names = dataset.get_class_names()
+ self.setup_text_prompting()
+ self.clip_logit_scale = self.clip_model.logit_scale
+
+ embed_dim = self.clip_model.visual.output_dim
+ self.distributions = torch.nn.ModuleList([MixtureOfGaussiansModel(embed_dim, n_components=self.args.gr_mog_n_components,
+ n_iters=self.args.gr_mog_n_iters_first_stage)
+ for _ in range(self.num_classes)]).to(self.device)
+
+ def compute_ortho_loss(self, cur_classes: int, frozen_past_classes=0) -> torch.Tensor:
+
+ # (num_classes, 1, clip_size)
+ cur_coop_p = self.prompt_parameters[frozen_past_classes:cur_classes]
+ ortho_loss_coop = torch.tensor(0.0, device=self.device)
+ if frozen_past_classes > 0:
+ past_coop_p = self.prompt_parameters[:frozen_past_classes].detach()
+ ortho_loss_coop = (torch.matmul(cur_coop_p.permute(1, 0, 2), past_coop_p.permute(1, 2, 0))**2).mean()
+
+ return ortho_loss_coop
+
+ @torch.no_grad()
+ def create_features_dataset(self, current_task: int):
+
+ labels, features = [], []
+
+ for _ti in range(current_task + 1):
+
+ prev_t_size, cur_t_size = self.dataset.get_offsets(_ti)
+
+ for class_idx in range(prev_t_size, cur_t_size):
+
+ current_samples = self.distributions[class_idx](self.args.num_samples_gr)
+ features.append(current_samples)
+ labels.append(torch.ones((self.args.num_samples_gr)) * class_idx)
+
+ features = torch.cat(features, dim=0)
+ labels = torch.cat(labels, dim=0).long()
+ return create_seeded_dataloader(self.args, TensorDataset(features, labels), num_workers=0, batch_size=self.args.batch_size_gr, shuffle=True)
+
+ def train_alignment_epoch(self, optim: torch.optim.Optimizer, current_task: int, epoch: int = 0):
+ offset_1, offset_2 = self.dataset.get_offsets(current_task)
+
+ dl = self.create_features_dataset(current_task)
+
+ with tqdm(enumerate(dl), total=len(dl), desc=f'GR first stage epoch {epoch + 1}/{self.args.num_epochs_gr_first_stage}', leave=False) as pbar:
+ for i, (image_features, labels) in pbar:
+ if self.args.debug_mode and i > 3:
+ break
+ optim.zero_grad()
+
+ image_features, labels = image_features.to(self.device, dtype=self.clip_model.dtype), labels.to(self.device)
+ image_features = torch.nn.functional.normalize(image_features, dim=-1)
+
+ text_features = self.compute_keys(0, offset_2)
+
+ text_features = torch.cat((text_features[:offset_1].detach(), text_features[offset_1:offset_2]), dim=0)
+ text_features = torch.nn.functional.normalize(text_features, dim=-1)
+
+ clip_logits = torch.einsum('bd,cd->bc', image_features, text_features)
+ clip_logits = clip_logits * self.clip_logit_scale.exp()
+ loss = F.cross_entropy(clip_logits, labels)
+
+ assert not math.isnan(loss.item())
+
+ loss.backward()
+ optim.step()
+
+ pbar.set_postfix({'loss': loss.item()})
+
+ if not self.args.nowand:
+ assert wandb is not None, "wandb is not installed."
+ wandb.log({'ca_loss': loss.item(), 'ca_lr': optim.param_groups[0]['lr']})
+
+ def align(self, current_task: int):
+ optim = torch.optim.SGD(lr=self.args.learning_rate_gr_first_stage, params=[self.prompt_parameters],
+ momentum=0.0, weight_decay=0.0)
+
+ for e in range(self.args.num_epochs_gr_first_stage):
+ self.train_alignment_epoch(optim, current_task=current_task, epoch=e)
+
+ @torch.no_grad()
+ def update_statistics(self, dataset: ContinualDataset, current_task: int):
+ offset_1, offset_2 = dataset.get_offsets(current_task)
+
+ features_dict = {i: [] for i in range(offset_1, offset_2)}
+
+ was_training = self.training
+ self.eval()
+
+ with tqdm(total=self.args.num_monte_carlo_gr_first_stage * len(dataset.train_loader),
+ desc='Updating statistics for first stage Generative Replay') as pbar:
+ for _ in range(self.args.num_monte_carlo_gr_first_stage):
+ for i, data in enumerate(dataset.train_loader):
+ if self.args.debug_mode == 1 and i > 3 and min([len(v) for k, v in features_dict.items()]) > self.args.gr_mog_n_components:
+ break
+ inputs, labels = data[0].to(self.device), data[1].to(self.device, dtype=torch.long)
+
+ if len(inputs.shape) == 5:
+ inputs = inputs[:, 1]
+ clip_query = self.get_query(inputs)
+
+ for class_idx in labels.unique():
+ features_dict[int(class_idx)].append(clip_query[labels == class_idx])
+
+ pbar.update(1)
+
+ for class_idx in range(offset_1, offset_2):
+ features_class_idx = torch.cat(features_dict[class_idx], dim=0)
+ self.distributions[class_idx].fit(features_class_idx.to(self.device))
+
+ if was_training:
+ self.train()
+
+ def compute_keys(self, start: int, end: int):
+ """
+ Compute the text-encoder features the CoOp way, but separately for each class.
+ """
+ ctx = self.prompt_parameters[start:end]
+ prefix = self.token_prefix[start:end]
+ suffix = self.token_suffix[start:end]
+ prompts = torch.cat((prefix, ctx, suffix), dim=1)
+ tokenized_prompts = self.tokenized_prompts[start:end]
+ keys = self.text_encoder(prompts.to(self.clip_model.dtype), tokenized_prompts)
+ keys = torch.nn.functional.normalize(keys, dim=-1)
+ return keys
+
+ def get_keys(self, cur_classes: int, frozen_past_classes=0) -> torch.Tensor:
+ """
+ Compute the text-encoder features for classes from 0 to `cur_classes`.
+ Features of classes before `frozen_past_classes` are frozen.
+ """
+ if frozen_past_classes > 0:
+ with torch.no_grad():
+ past_keys = self.compute_keys(0, frozen_past_classes)
+ cur_keys = self.compute_keys(frozen_past_classes, cur_classes)
+ keys = torch.cat((past_keys.detach(), cur_keys), dim=0)
+ else:
+ keys = self.compute_keys(0, cur_classes)
+ return keys
+
+ def setup_text_prompting(self):
+ """
+ Initialize a singly prompt (length 1) for each class.
+ """
+ self.text_encoder = TextEncoder(self.clip_model)
+
+ text_prompts = ["X " + name + "." for name in self.class_names]
+ tokenized_prompts = torch.cat([clip.tokenize(p) for p in text_prompts], dim=0).to(self.device)
+ self.tokenized_prompts = tokenized_prompts
+
+ with torch.no_grad():
+ embedding = self.clip_model.token_embedding(tokenized_prompts).type(self.clip_model.dtype)
+ self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
+ self.register_buffer("token_suffix", embedding[:, 2:, :]) # CLS, EOS
+
+ prompt_parameters = torch.empty(self.num_classes, 1, self.clip_model.token_embedding.weight.shape[1], device=self.device, dtype=torch.float32)
+ torch.nn.init.normal_(prompt_parameters, std=0.02)
+ self.prompt_parameters = torch.nn.Parameter(prompt_parameters)
+
+ @torch.no_grad()
+ def get_query(self, x):
+ clip_out = self.clip_model.encode_image(x)
+ assert not torch.isnan(clip_out).any()
+ return clip_out
+
+ def get_clip_logits(self, clip_out, keys):
+ image_features = torch.nn.functional.normalize(clip_out, dim=-1)
+ clip_logits = torch.einsum('bd,cd->bc', image_features, keys)
+ clip_logits = clip_logits * self.clip_logit_scale.exp()
+ return clip_logits
+
+
+class Model(torch.nn.Module):
+ prompter: Prompter
+
+ def __init__(self, args, num_classes: int, dataset: ContinualDataset, device='cpu'):
+ super().__init__()
+ self.args = args
+ self.num_classes = num_classes
+ self.device = device
+
+ self.prompter = Prompter(args, num_classes=num_classes, dataset=dataset, device=device)
+
+ def to(self, device, *args, **kwargs):
+ super().to(device, *args, **kwargs)
+ self.prompter.to(device, *args, **kwargs)
+ self.device = device
+
+ return self
+
+ def train(self, mode=True):
+ super().train(False)
+ self.prompter.train(False)
+
+ return self
+
+ def forward(self, x: torch.Tensor, cur_classes: int, return_query=False, frozen_past_classes=0) -> torch.Tensor:
+ """
+ Compute the logits for the current task.
+ Logits of classes before `frozen_past_classes` are frozen.
+
+ If `return_query` is True, return the CLIP's visual encoder output instead of the logits.
+ """
+ clip_out = self.prompter.get_query(x)
+ if return_query:
+ return clip_out
+
+ keys = self.prompter.get_keys(frozen_past_classes=frozen_past_classes, cur_classes=cur_classes)
+ clip_logits = self.prompter.get_clip_logits(clip_out, keys)
+
+ return clip_logits
diff --git a/models/star_prompt_utils/generative_replay.py b/models/star_prompt_utils/generative_replay.py
new file mode 100644
index 00000000..d9c93aa3
--- /dev/null
+++ b/models/star_prompt_utils/generative_replay.py
@@ -0,0 +1,554 @@
+"""
+Adaptation of https://github.com/ldeecke/gmm-torch.
+Copyright (c) 2019 Lucas Deecke.
+
+Licensed under the MIT License.
+"""
+
+import math
+import torch
+import numpy as np
+
+from math import pi
+
+
+class MixtureOfGaussiansModel(torch.nn.Module):
+
+ def __init__(self, embed_dim: int, n_components: int = 3, n_iters: int = 100):
+ super().__init__()
+ self.n_iters = n_iters
+ self.gm = GaussianMixture(n_components, embed_dim, covariance_type='diag')
+
+ def fit(self, x):
+ x = x.type(torch.float64)
+ tries = 0
+ while tries < 10:
+ self.gm.fit(x, n_iter=self.n_iters)
+ if self.gm.log_likelihood > -np.inf:
+ break
+ self.gm.to(x.device)
+ tries += 1
+
+ assert (self.gm.var > 0).all(), "Variance is not positive"
+ assert self.gm.log_likelihood > -np.inf, "Log-likelihood is not finite"
+
+ def sample(self, n_sample):
+ return self.gm.sample(n_sample)[0]
+
+ def forward(self, n_sample, *args, **kwargs):
+ return self.sample(n_sample)
+
+
+class Gaussian(torch.nn.Module):
+
+ def __init__(self, embed_dim):
+ super(Gaussian, self).__init__()
+ self.embed_dim = embed_dim
+ self.register_buffer("mean", torch.zeros(embed_dim))
+ self.register_buffer("std", torch.ones(embed_dim))
+
+ def fit(self, x):
+ self.std, self.mean = torch.std_mean(x, dim=0)
+
+ def sample(self, n_sample, scale_mean):
+ return torch.distributions.normal.Normal(scale_mean * self.mean, self.std).sample((n_sample,))
+
+ def forward(self, n_sample, scale_mean: float = 1.0):
+ return self.sample(n_sample, scale_mean)
+
+
+def calculate_matmul_n_times(n_components, mat_a, mat_b):
+ """
+ Calculate matrix product of two matrics with mat_a[0] >= mat_b[0].
+ Bypasses torch.matmul to reduce memory footprint.
+ args:
+ mat_a: torch.Tensor (n, k, 1, d)
+ mat_b: torch.Tensor (1, k, d, d)
+ """
+ res = torch.zeros(mat_a.shape).to(mat_a.device)
+
+ for i in range(n_components):
+ mat_a_i = mat_a[:, i, :, :].squeeze(-2)
+ mat_b_i = mat_b[0, i, :, :].squeeze()
+ res[:, i, :, :] = mat_a_i.mm(mat_b_i).unsqueeze(1)
+
+ return res
+
+
+def calculate_matmul(mat_a, mat_b):
+ """
+ Calculate matrix product of two matrics with mat_a[0] >= mat_b[0].
+ Bypasses torch.matmul to reduce memory footprint.
+ args:
+ mat_a: torch.Tensor (n, k, 1, d)
+ mat_b: torch.Tensor (n, k, d, 1)
+ """
+ assert mat_a.shape[-2] == 1 and mat_b.shape[-1] == 1
+ return torch.sum(mat_a.squeeze(-2) * mat_b.squeeze(-1), dim=2, keepdim=True)
+
+
+class GaussianMixture(torch.nn.Module):
+ """
+ Fits a mixture of k=1,..,K Gaussians to the input data (K is supplied via n_components).
+ Input tensors are expected to be flat with dimensions (n: number of samples, d: number of features).
+ The model then extends them to (n, 1, d).
+ The model parametrization (mu, sigma) is stored as (1, k, d),
+ probabilities are shaped (n, k, 1) if they relate to an individual sample,
+ or (1, k, 1) if they assign membership probabilities to one of the mixture components.
+ """
+
+ def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, init_params="kmeans", mu_init=None,
+ var_init=None):
+ """
+ Initializes the model and brings all tensors into their required shape.
+ The class expects data to be fed as a flat tensor in (n, d).
+ The class owns:
+ x: torch.Tensor (n, 1, d)
+ mu: torch.Tensor (1, k, d)
+ var: torch.Tensor (1, k, d) or (1, k, d, d)
+ pi: torch.Tensor (1, k, 1)
+ covariance_type: str
+ eps: float
+ init_params: str
+ log_likelihood: float
+ n_components: int
+ n_features: int
+ args:
+ n_components: int
+ n_features: int
+ options:
+ mu_init: torch.Tensor (1, k, d)
+ var_init: torch.Tensor (1, k, d) or (1, k, d, d)
+ covariance_type: str
+ eps: float
+ init_params: str
+ """
+ super(GaussianMixture, self).__init__()
+
+ self.n_components = n_components
+ self.n_features = n_features
+
+ self.mu_init = mu_init
+ self.var_init = var_init
+ self.eps = eps
+
+ self.log_likelihood = -np.inf
+
+ self.covariance_type = covariance_type
+ self.init_params = init_params
+
+ assert self.covariance_type in ["full", "diag"]
+ assert self.init_params in ["kmeans", "random"]
+
+ self._init_params()
+
+ def _init_params(self):
+ if self.mu_init is not None:
+ assert self.mu_init.size() == (1, self.n_components,
+ self.n_features), "Input mu_init does not have required tensor dimensions (1, %i, %i)" % (
+ self.n_components, self.n_features)
+ # (1, k, d)
+ self.mu = torch.nn.Parameter(self.mu_init, requires_grad=False)
+ else:
+ self.mu = torch.nn.Parameter(torch.randn(1, self.n_components, self.n_features), requires_grad=False)
+
+ if self.covariance_type == "diag":
+ if self.var_init is not None:
+ # (1, k, d)
+ assert self.var_init.size() == (1, self.n_components,
+ self.n_features), "Input var_init does not have required tensor dimensions (1, %i, %i)" % (
+ self.n_components, self.n_features)
+ self.var = torch.nn.Parameter(self.var_init, requires_grad=False)
+ else:
+ self.var = torch.nn.Parameter(torch.ones(1, self.n_components, self.n_features), requires_grad=False)
+ elif self.covariance_type == "full":
+ if self.var_init is not None:
+ # (1, k, d, d)
+ assert self.var_init.size() == (1, self.n_components, self.n_features,
+ self.n_features), "Input var_init does not have required tensor dimensions (1, %i, %i, %i)" % (
+ self.n_components, self.n_features, self.n_features)
+ self.var = torch.nn.Parameter(self.var_init, requires_grad=False)
+ else:
+ self.var = torch.nn.Parameter(
+ torch.eye(self.n_features).reshape(1, 1, self.n_features, self.n_features).repeat(1,
+ self.n_components,
+ 1, 1),
+ requires_grad=False
+ )
+
+ # (1, k, 1)
+ self.pi = torch.nn.Parameter(torch.Tensor(1, self.n_components, 1), requires_grad=False).fill_(
+ 1. / self.n_components)
+ self.params_fitted = False
+
+ def check_size(self, x):
+ if len(x.size()) == 2:
+ # (n, d) --> (n, 1, d)
+ x = x.unsqueeze(1)
+
+ return x
+
+ def fit(self, x, delta=1e-3, n_iter=100, warm_start=False):
+ """
+ Fits model to the data.
+ args:
+ x: torch.Tensor (n, d) or (n, k, d)
+ options:
+ delta: float
+ n_iter: int
+ warm_start: bool
+ """
+ if not warm_start and self.params_fitted:
+ self._init_params()
+
+ x = self.check_size(x)
+
+ if self.init_params == "kmeans" and self.mu_init is None:
+ mu = self.get_kmeans_mu(x, n_centers=self.n_components)
+ self.mu.data = mu
+
+ for p in self.parameters():
+ p.data = p.data.to(x.device)
+
+ i = 0
+ j = np.inf
+
+ while (i <= n_iter) and (j >= delta):
+
+ log_likelihood_old = self.log_likelihood
+ mu_old = self.mu
+ var_old = self.var
+
+ self.__em(x)
+ self.log_likelihood = self.__score(x)
+
+ if torch.isinf(self.log_likelihood.abs()) or torch.isnan(self.log_likelihood):
+ device = self.mu.device
+ # When the log-likelihood assumes unbound values, reinitialize model
+ self.__init__(self.n_components,
+ self.n_features,
+ covariance_type=self.covariance_type,
+ mu_init=self.mu_init,
+ var_init=self.var_init,
+ eps=self.eps)
+ for p in self.parameters():
+ p.data = p.data.to(device)
+ if self.init_params == "kmeans":
+ self.mu.data = self.get_kmeans_mu(x, n_centers=self.n_components)[0]
+
+ i += 1
+ j = self.log_likelihood - log_likelihood_old
+ j = np.inf if math.isnan(j) else j
+
+ if j <= delta:
+ # When score decreases, revert to old parameters
+ self.__update_mu(mu_old)
+ self.__update_var(var_old)
+
+ self.params_fitted = True
+
+ def predict(self, x, probs=False):
+ """
+ Assigns input data to one of the mixture components by evaluating the likelihood under each.
+ If probs=True returns normalized probabilities of class membership.
+ args:
+ x: torch.Tensor (n, d) or (n, 1, d)
+ probs: bool
+ returns:
+ p_k: torch.Tensor (n, k)
+ (or)
+ y: torch.LongTensor (n)
+ """
+ x = self.check_size(x)
+
+ weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)
+
+ if probs:
+ p_k = torch.exp(weighted_log_prob)
+ return torch.squeeze(p_k / (p_k.sum(1, keepdim=True)))
+ else:
+ return torch.squeeze(torch.max(weighted_log_prob, 1)[1].type(torch.LongTensor))
+
+ def predict_proba(self, x):
+ """
+ Returns normalized probabilities of class membership.
+ args:
+ x: torch.Tensor (n, d) or (n, 1, d)
+ returns:
+ y: torch.LongTensor (n)
+ """
+ return self.predict(x, probs=True)
+
+ def sample(self, n):
+ """
+ Samples from the model.
+ args:
+ n: int
+ returns:
+ x: torch.Tensor (n, d)
+ y: torch.Tensor (n)
+ """
+ counts = torch.distributions.multinomial.Multinomial(total_count=n, probs=self.pi.squeeze()).sample()
+ x = torch.empty(0, device=counts.device)
+ y = torch.cat([torch.full([int(sample)], j, device=counts.device) for j, sample in enumerate(counts)])
+
+ # Only iterate over components with non-zero counts
+ for k in torch.arange(self.n_components, device=counts.device)[counts > 0]:
+ if self.covariance_type == "diag":
+ x_k = self.mu[0, k] + torch.randn(int(counts[k]), self.n_features, device=x.device) * torch.sqrt(
+ self.var[0, k])
+ elif self.covariance_type == "full":
+ d_k = torch.distributions.multivariate_normal.MultivariateNormal(self.mu[0, k], self.var[0, k])
+ x_k = torch.stack([d_k.sample() for _ in range(int(counts[k]))])
+
+ x = torch.cat((x, x_k), dim=0)
+
+ return x, y
+
+ def score_samples(self, x):
+ """
+ Computes log-likelihood of samples under the current model.
+ args:
+ x: torch.Tensor (n, d) or (n, 1, d)
+ returns:
+ score: torch.LongTensor (n)
+ """
+ x = self.check_size(x)
+
+ score = self.__score(x, as_average=False)
+ return score
+
+ def _estimate_log_prob(self, x):
+ """
+ Returns a tensor with dimensions (n, k, 1), which indicates the log-likelihood that samples belong to the k-th Gaussian.
+ args:
+ x: torch.Tensor (n, d) or (n, 1, d)
+ returns:
+ log_prob: torch.Tensor (n, k, 1)
+ """
+ x = self.check_size(x)
+
+ if self.covariance_type == "full":
+ mu = self.mu
+ var = self.var
+
+ precision = torch.inverse(var)
+ d = x.shape[-1]
+
+ log_2pi = d * np.log(2. * pi)
+
+ log_det = self._calculate_log_det(precision)
+
+ x_mu_T = (x - mu).unsqueeze(-2)
+ x_mu = (x - mu).unsqueeze(-1)
+
+ x_mu_T_precision = calculate_matmul_n_times(self.n_components, x_mu_T, precision)
+ x_mu_T_precision_x_mu = calculate_matmul(x_mu_T_precision, x_mu)
+
+ return -.5 * (log_2pi - log_det + x_mu_T_precision_x_mu)
+
+ elif self.covariance_type == "diag":
+ mu = self.mu
+ prec = torch.rsqrt(self.var)
+
+ log_p = torch.sum((mu * mu + x * x - 2 * x * mu) * prec, dim=2, keepdim=True)
+ log_det = torch.sum(torch.log(prec), dim=2, keepdim=True)
+
+ return -.5 * (self.n_features * np.log(2. * pi) + log_p - log_det)
+
+ def _calculate_log_det(self, var):
+ """
+ Calculate log determinant in log space, to prevent overflow errors.
+ args:
+ var: torch.Tensor (1, k, d, d)
+ """
+ log_det = torch.empty(size=(self.n_components,)).to(var.device)
+
+ for k in range(self.n_components):
+ log_det[k] = 2 * torch.log(torch.diagonal(torch.linalg.cholesky(var[0, k]))).sum()
+
+ return log_det.unsqueeze(-1)
+
+ def _e_step(self, x):
+ """
+ Computes log-responses that indicate the (logarithmic) posterior belief (sometimes called responsibilities) that a data point was generated by one of the k mixture components.
+ Also returns the mean of the mean of the logarithms of the probabilities (as is done in sklearn).
+ This is the so-called expectation step of the EM-algorithm.
+ args:
+ x: torch.Tensor (n, d) or (n, 1, d)
+ returns:
+ log_prob_norm: torch.Tensor (1)
+ log_resp: torch.Tensor (n, k, 1)
+ """
+ x = self.check_size(x)
+
+ weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)
+
+ log_prob_norm = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True)
+ log_resp = weighted_log_prob - log_prob_norm
+
+ return torch.mean(log_prob_norm), log_resp
+
+ def _m_step(self, x, log_resp):
+ """
+ From the log-probabilities, computes new parameters pi, mu, var (that maximize the log-likelihood). This is the maximization step of the EM-algorithm.
+ args:
+ x: torch.Tensor (n, d) or (n, 1, d)
+ log_resp: torch.Tensor (n, k, 1)
+ returns:
+ pi: torch.Tensor (1, k, 1)
+ mu: torch.Tensor (1, k, d)
+ var: torch.Tensor (1, k, d)
+ """
+ x = self.check_size(x)
+
+ resp = torch.exp(log_resp)
+
+ pi = torch.sum(resp, dim=0, keepdim=True) + self.eps
+ mu = torch.sum(resp * x, dim=0, keepdim=True) / pi
+
+ if self.covariance_type == "full":
+ eps = (torch.eye(self.n_features) * self.eps).to(x.device)
+ var = torch.sum((x - mu).unsqueeze(-1).matmul((x - mu).unsqueeze(-2)) * resp.unsqueeze(-1), dim=0,
+ keepdim=True) / torch.sum(resp, dim=0, keepdim=True).unsqueeze(-1) + eps
+
+ elif self.covariance_type == "diag":
+ x2 = (resp * x * x).sum(0, keepdim=True) / pi
+ mu2 = mu * mu
+ xmu = (resp * mu * x).sum(0, keepdim=True) / pi
+ var = x2 - 2 * xmu + mu2 + self.eps
+
+ pi = pi / x.shape[0]
+
+ return pi, mu, var
+
+ def __em(self, x):
+ """
+ Performs one iteration of the expectation-maximization algorithm by calling the respective subroutines.
+ args:
+ x: torch.Tensor (n, 1, d)
+ """
+ _, log_resp = self._e_step(x)
+ pi, mu, var = self._m_step(x, log_resp)
+
+ self.__update_pi(pi)
+ self.__update_mu(mu)
+ self.__update_var(var)
+
+ def __score(self, x, as_average=True):
+ """
+ Computes the log-likelihood of the data under the model.
+ args:
+ x: torch.Tensor (n, 1, d)
+ sum_data: bool
+ returns:
+ score: torch.Tensor (1)
+ (or)
+ per_sample_score: torch.Tensor (n)
+
+ """
+ weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)
+ per_sample_score = torch.logsumexp(weighted_log_prob, dim=1)
+
+ if as_average:
+ return per_sample_score.mean()
+ else:
+ return torch.squeeze(per_sample_score)
+
+ def __update_mu(self, mu):
+ """
+ Updates mean to the provided value.
+ args:
+ mu: torch.FloatTensor
+ """
+ assert mu.size() in [(self.n_components, self.n_features), (1, self.n_components,
+ self.n_features)], "Input mu does not have required tensor dimensions (%i, %i) or (1, %i, %i)" % (
+ self.n_components, self.n_features, self.n_components, self.n_features)
+
+ if mu.size() == (self.n_components, self.n_features):
+ self.mu = mu.unsqueeze(0)
+ elif mu.size() == (1, self.n_components, self.n_features):
+ self.mu.data = mu
+
+ def __update_var(self, var):
+ """
+ Updates variance to the provided value.
+ args:
+ var: torch.FloatTensor
+ """
+ if self.covariance_type == "full":
+ assert var.size() in [(self.n_components, self.n_features, self.n_features), (
+ 1, self.n_components, self.n_features,
+ self.n_features)], "Input var does not have required tensor dimensions (%i, %i, %i) or (1, %i, %i, %i)" % (
+ self.n_components, self.n_features, self.n_features, self.n_components, self.n_features, self.n_features)
+
+ if var.size() == (self.n_components, self.n_features, self.n_features):
+ self.var = var.unsqueeze(0)
+ elif var.size() == (1, self.n_components, self.n_features, self.n_features):
+ self.var.data = var
+
+ elif self.covariance_type == "diag":
+ assert var.size() in [(self.n_components, self.n_features), (1, self.n_components,
+ self.n_features)], "Input var does not have required tensor dimensions (%i, %i) or (1, %i, %i)" % (
+ self.n_components, self.n_features, self.n_components, self.n_features)
+
+ if var.size() == (self.n_components, self.n_features):
+ self.var = var.unsqueeze(0)
+ elif var.size() == (1, self.n_components, self.n_features):
+ self.var.data = var
+
+ def __update_pi(self, pi):
+ """
+ Updates pi to the provided value.
+ args:
+ pi: torch.FloatTensor
+ """
+ assert pi.size() in [
+ (1, self.n_components, 1)], "Input pi does not have required tensor dimensions (%i, %i, %i)" % (
+ 1, self.n_components, 1)
+
+ self.pi.data = pi
+
+ def get_kmeans_mu(self, x, n_centers, init_times=50, min_delta=1e-3):
+ """
+ Find an initial value for the mean. Requires a threshold min_delta for the k-means algorithm to stop iterating.
+ The algorithm is repeated init_times often, after which the best centerpoint is returned.
+ args:
+ x: torch.FloatTensor (n, d) or (n, 1, d)
+ init_times: init
+ min_delta: int
+ """
+ if len(x.size()) == 3:
+ x = x.squeeze(1)
+ x_min, x_max = x.min(), x.max()
+ x = (x - x_min) / (x_max - x_min)
+
+ min_cost = np.inf
+
+ for i in range(init_times):
+ center_idxs = torch.from_numpy(np.random.choice(np.arange(x.shape[0]), size=n_centers, replace=False)).to(x.device)
+ tmp_center = x[center_idxs, ...]
+ l2_dis = torch.norm((x.unsqueeze(1).repeat(1, n_centers, 1) - tmp_center), p=2, dim=2)
+ l2_cls = torch.argmin(l2_dis, dim=1)
+
+ cost = 0
+ for c in range(n_centers):
+ cost += torch.norm(x[l2_cls == c] - tmp_center[c], p=2, dim=1).mean()
+
+ if cost < min_cost:
+ min_cost = cost
+ center = tmp_center
+
+ delta = np.inf
+
+ while delta > min_delta:
+ l2_dis = torch.norm((x.unsqueeze(1).repeat(1, n_centers, 1) - center), p=2, dim=2)
+ l2_cls = torch.argmin(l2_dis, dim=1)
+ center_old = center.clone()
+
+ for c in range(n_centers):
+ center[c] = x[l2_cls == c].mean(dim=0)
+
+ delta = torch.norm((center_old - center), dim=1).max()
+
+ return (center.unsqueeze(0) * (x_max - x_min) + x_min)
diff --git a/models/star_prompt_utils/second_stage_model.py b/models/star_prompt_utils/second_stage_model.py
new file mode 100644
index 00000000..ceb92650
--- /dev/null
+++ b/models/star_prompt_utils/second_stage_model.py
@@ -0,0 +1,439 @@
+import os
+import sys
+import json
+import torch
+import torch.nn as nn
+from typing import List
+from kornia.augmentation import Normalize
+
+try:
+ import clip
+except ImportError:
+ raise ImportError("Please install the CLIP package by running: pip install git+https://github.com/openai/CLIP.git")
+
+from datasets.utils.continual_dataset import ContinualDataset
+from models.star_prompt_utils.vision_transformer import VisionTransformer
+
+
+class Prompter(torch.nn.Module):
+ keys: torch.Tensor
+
+ def __init__(self, args, dataset: ContinualDataset,
+ num_classes: int, target_embed_len: int,
+ target_embed_dim: int, prompt_layers: List[int],
+ clip_model: clip.model.CLIP = None, clip_preprocess=None,
+ device='cpu'):
+ super().__init__()
+ assert args.prompt_mode in ['residual', 'concat'], 'This prompter supports only STAR-Prompt residual-style prompts (`residual`) or Prefix tuning-style prompts (`concat`).'
+ self.args = args
+ self.prompt_layers = prompt_layers
+ self.target_embed_len = target_embed_len
+ self.target_embed_dim = target_embed_dim
+ self.device = device
+ self.num_classes = num_classes
+ self.prompt_mode = args.prompt_mode
+
+ if clip_model is not None:
+ assert clip_preprocess is not None, 'Preprocess must be provided if the model is provided'
+
+ print("Loading CLIP visual encoder and the pre-computed text features...")
+ clip_backbone = 'ViT-L/14' if not hasattr(args, 'clip_backbone') else args.clip_backbone
+ if hasattr(args, 'keys_ckpt_path') and args.keys_ckpt_path is not None:
+ if args.keys_ckpt_path.endswith('.json'):
+ try:
+ key_jobnum = json.load(open(args.keys_ckpt_path, 'r'))[args.dataset][str(args.seed)]
+ except BaseException:
+ print("key missing", args.dataset, args.seed, file=sys.stderr)
+ raise ValueError
+
+ t = dataset.N_TASKS - 1
+ self.keys_ckpt_path = f"coop_keys/coop_keys_{t}_{key_jobnum}.pt"
+ elif args.keys_ckpt_path.endswith('.pt'):
+ self.keys_ckpt_path = args.keys_ckpt_path
+ else:
+ t = dataset.N_TASKS - 1
+ self.keys_ckpt_path = f"coop_keys/coop_keys_{t}_{args.keys_ckpt_path}.pt"
+
+ if not os.path.exists(self.keys_ckpt_path):
+ raise ValueError(f'Keys checkpoint `{self.keys_ckpt_path}` does not exist')
+
+ self.keys, first_stage_args = self.load_keys()
+ if first_stage_args is not None:
+ print("Keys loaded. Loading CLIP version:", first_stage_args.clip_backbone)
+ clip_backbone = first_stage_args.clip_backbone
+ if clip_model is None:
+ self.clip_model, self.clip_preprocess = clip.load(clip_backbone, self.device)
+ self.clip_model = self.clip_model.float() # force fp32 when used for eval
+ else:
+ self.clip_model = clip_model
+ self.clip_preprocess = clip_preprocess
+ else: # use prompt templates
+ self.keys_ckpt_path = None
+ print("No keys loaded. Using default CLIP version:", clip_backbone)
+ if clip_model is None:
+ self.clip_model, self.clip_preprocess = clip.load(clip_backbone, self.device)
+ self.clip_model = self.clip_model.float() # force fp32 when used for eval
+ else:
+ self.clip_model = clip_model
+ self.clip_preprocess = clip_preprocess
+ self.keys = self.load_default_prompt_templates(dataset.get_prompt_templates(), dataset.get_class_names())
+
+ self.clip_normalization = Normalize(self.clip_preprocess.transforms[-1].mean,
+ self.clip_preprocess.transforms[-1].std).to(self.device)
+ self.denorm_transform = dataset.get_denormalization_transform()
+
+ for p in self.clip_model.parameters():
+ p.requires_grad = False
+
+ for l in self.prompt_layers:
+ if args.prompt_mode == 'residual':
+ # NOTE: this initialization follows that of CODA-Prompt.
+ # We originally initialize a prompt for key, query, and value of the MHA layer.
+ tmp = self.get_parameter((self.num_classes, 3, self.target_embed_dim))
+ # We only use value at the end, so we keep only a single tensor.
+ tmp.data = tmp.data[:, 0]
+ # HOWEVER: Since the orthogonal_ of pytorch flattens the tensor, the value prompt is not orthogonal anymore.
+ # orthogonal_ made (C, 3, D) -> (C, 3*D) -> orthogonal -> (C, 3, D), thus each 3*D is orthogonal, but not each D.
+ # This is intended and maked the orthogonalization loss being optimized at the beginning.
+ setattr(self, f'p_{l}', tmp)
+ else:
+ setattr(self, f'p_concat_{l}', self.get_parameter((self.num_classes, 2 * self.args.prefix_tuning_prompt_len,
+ self.target_embed_dim)))
+
+ setattr(self, f'a_{l}', self.get_parameter((self.num_classes, self.clip_model.visual.output_dim)))
+
+ def set_keys(self, keys: torch.Tensor, start_class: int, end_class: int):
+ """
+ Set the keys for the classes in the range `[start_class, end_class)`.
+ """
+ assert end_class - start_class == keys.shape[0], 'Number of classes in the keys tensor does not match the range'
+
+ self.keys[start_class:end_class] = keys
+
+ def get_parameter(self, shape, type_init: str = 'orto') -> torch.nn.Parameter:
+ """
+ Create and initialize a parameter tensor. Code courtesy from CODA-Prompt.
+ """
+ param = torch.nn.Parameter(torch.zeros(*shape, dtype=torch.float32, device=self.device))
+ if type_init == 'orto':
+ torch.nn.init.orthogonal_(param)
+ if type_init == 'gaussian':
+ torch.nn.init.normal_(param, mean=0.0, std=0.1)
+ return param
+
+ @torch.no_grad()
+ def load_default_prompt_templates(self, templates: List[str], dataset_classes: List[str]) -> torch.Tensor:
+ """
+ Pre-computes the CLIP's text-encoder features if the keys are not loaded from a checkpoint.
+ """
+ if hasattr(self.args, 'statc_keys_use_templates') and self.args.statc_keys_use_templates:
+ all_features = []
+ for t in templates:
+ text_inputs = torch.cat([clip.tokenize(t.format(c)) for c in dataset_classes]).to(self.device)
+ text_features = self.clip_model.encode_text(text_inputs)
+ all_features.append(text_features)
+ text_features = torch.stack(all_features).mean(dim=0)
+ else:
+ text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in dataset_classes]).to(self.device)
+ text_features = self.clip_model.encode_text(text_inputs)
+ text_features /= text_features.norm(dim=-1, keepdim=True)
+ return text_features.float()
+
+ @torch.no_grad()
+ def load_keys(self):
+ """
+ Load the keys from a `first_stage_starprompt` checkpoint file (run with `--save_first_stage_keys=1`).
+ The checkpoint file can be either:
+ - A path to a checkpoint file (.pt) containing ONLY THE FIRST STAGE KEYS.
+ The number of classes and the dataset must match the current run, but we cannot check if the seed was the same.
+ - A path to the checkpoint made by `first_stage_starprompt` or the job-id (`conf_jobnum`) of the `first_stage_starprompt` run that made the keys.
+ Checks will prevent loading keys with a different order of the classes or dataset.
+ - A JSON file containing the job-id (`conf_jobnum`) of the `first_stage_starprompt` run that made the keys.
+ The JSON is expected to contain an entry for each dataset and seed: `{dataset: {seed: job-id}}`.
+
+ Returns:
+ The keys tensor
+ The arguments used in the first stage
+ """
+ print(f'Loading keys from {self.keys_ckpt_path}', file=sys.stderr)
+ st = torch.load(self.keys_ckpt_path)
+ if isinstance(st, dict):
+ keys = st['keys'].to(self.device)
+ self.old_args = st['args']
+ assert self.num_classes == keys.shape[0]
+ assert self.args.dataset == self.old_args.dataset
+ assert self.args.permute_classes == self.old_args.permute_classes
+ if self.args.permute_classes:
+ assert self.args.seed == self.old_args.seed
+ else:
+ keys = st.to(self.device)
+ self.old_args = None
+ assert self.num_classes == keys.shape[0]
+ print('Keys loaded successfully', file=sys.stderr)
+ return keys.float(), self.old_args
+
+ @torch.no_grad()
+ def get_query(self, x, disable_renorm=True):
+ """
+ Compute the CLIP features for the input image `x`.
+
+ Args:
+ x: the input image tensor
+ disable_renorm: if False, the final normalization applied to `x` will be swapped with the CLIP's one.
+ """
+ if not disable_renorm:
+ x = self.denorm_transform(x)
+ x = self.clip_normalization(x)
+ clip_out = self.clip_model.encode_image(x)
+ return clip_out
+
+ def compute_maps(self, clip_query: torch.Tensor, modulation_coeffs: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the CLIP output given the `clip_query` and the `keys`. The queries are modulated by the `modulation_coeffs`.
+ """
+ filter_values = torch.softmax(modulation_coeffs, dim=-1)
+
+ clip_query = clip_query.unsqueeze(1).expand(clip_query.shape[0], modulation_coeffs.shape[0], clip_query.shape[-1])
+ clip_out_a = clip_query * filter_values[None, :, :]
+ clip_out_a_norm = torch.nn.functional.normalize(clip_out_a, dim=-1)
+
+ clip_query = torch.einsum('bcd,cd->bc', clip_out_a_norm, keys) * 5
+
+ return clip_query
+
+ def get_masked_clip_out(self, sim_act_map):
+ """
+ We only need the output of the CLIP model for the most similar class, so we mask the rest.
+ """
+ with torch.no_grad():
+ mask = torch.ones_like(sim_act_map, dtype=torch.bool)
+ mask.scatter_(1, sim_act_map.argmax(dim=1, keepdim=True), False)
+ sim_act_map[mask] = 0.0
+
+ return sim_act_map
+
+ def compute_super_prompts(self, class_prompts: torch.Tensor, masked_clip_out: torch.Tensor, start_idx: int, end_idx: int) -> torch.Tensor:
+ """
+ Compute the actual super-prompt by merging the individual prompts for the classes in the range `[start_idx, end_idx)`.
+ The merge is made according to the similarity map `sim_act_map` and scaled by it if `enable_confidence_modulation` is set.
+
+ Args:
+ class_prompts: the prompt parameters for each class in the range `[start_idx, end_idx)`
+ masked_clip_out: the masked CLIP output for the classes in the range `[start_idx, end_idx)`,
+containing the similarity value for the most similar class for each image.
+ start_idx: the start index of the classes to consider
+ end_idx: the end index of the classes to consider
+
+ Returns:
+ The super-prompt for the classes in the range `[start_idx, end_idx)`.
+ """
+ masked_clip_out = masked_clip_out[:, start_idx:end_idx]
+ class_prompts = class_prompts[start_idx:end_idx]
+
+ if self.args.enable_confidence_modulation == 0:
+ masked_clip_out = (masked_clip_out != 0).float() # make it binary if not using confidence modulation
+
+ if self.args.prompt_mode == 'residual':
+ sp = torch.einsum('bc,cd->bd', masked_clip_out, class_prompts)
+ else:
+ sp = torch.einsum('bc,cmd->bmd', masked_clip_out, class_prompts)
+ return sp
+
+ def get_prompts(self, layer_idx: int, clip_query: torch.Tensor, cur_classes: int, frozen_past_classes=0):
+ """
+ Compute the prompts for the `layer_idx`-th layer for `cur_classes` classes.
+ The prompts until `frozen_past_classes` are detached to prevent gradients from flowing back.
+ By default, all the layers require prompting. This can be changed by adjusting the `prompt_layers` attribute.
+
+ Returns:
+ The computed prompt, if the layer requires prompting. Else, returns None.
+ """
+
+ if layer_idx in self.prompt_layers:
+
+ a: torch.Tensor = getattr(self, f'a_{layer_idx}')
+ if self.prompt_mode == 'residual':
+ pv: torch.Tensor = getattr(self, f'p_{layer_idx}')
+ else:
+ p_concat: torch.Tensor = getattr(self, f'p_concat_{layer_idx}')
+ p_concat_k, p_concat_v = torch.split(p_concat, self.args.prefix_tuning_prompt_len, dim=1)
+
+ if frozen_past_classes > 0:
+ with torch.no_grad(): # detach the past prompts to prevent gradients from flowing back
+ clip_out_prev = self.compute_maps(clip_query, a[:frozen_past_classes].detach(), self.keys[:frozen_past_classes].detach())
+ clip_out_curr = self.compute_maps(clip_query, a[frozen_past_classes:cur_classes], self.keys[frozen_past_classes:cur_classes])
+ clip_out = torch.cat((clip_out_prev.detach(), clip_out_curr), dim=1)
+ clip_out = self.get_masked_clip_out(clip_out)
+
+ with torch.no_grad():
+ if self.prompt_mode == 'residual':
+ sp_past = self.compute_super_prompts(pv, clip_out, 0, frozen_past_classes)
+ else:
+ sp_concat_k_past = self.compute_super_prompts(p_concat_k, clip_out, 0, frozen_past_classes).squeeze(2)
+ sp_concat_v_past = self.compute_super_prompts(p_concat_v, clip_out, 0, frozen_past_classes).squeeze(2)
+
+ if self.prompt_mode == 'residual':
+ sp_curr = self.compute_super_prompts(pv, clip_out, frozen_past_classes, cur_classes)
+ super_prompt = sp_past.detach() + sp_curr
+ else:
+ sp_concat_k_curr = self.compute_super_prompts(p_concat_k, clip_out, frozen_past_classes, cur_classes).squeeze(2)
+ sp_concat_v_curr = self.compute_super_prompts(p_concat_v, clip_out, frozen_past_classes, cur_classes).squeeze(2)
+ super_prompt = (sp_concat_k_past.detach() + sp_concat_k_curr, sp_concat_v_past.detach() + sp_concat_v_curr)
+ else:
+ clip_out = self.compute_maps(clip_query, a[:cur_classes], self.keys[:cur_classes])
+ clip_out = self.get_masked_clip_out(clip_out)
+
+ if self.prompt_mode == 'residual':
+ super_prompt = self.compute_super_prompts(pv, clip_out, 0, cur_classes)
+ else:
+ sp_concat_k = self.compute_super_prompts(p_concat_k, clip_out, 0, cur_classes).squeeze(2)
+ sp_concat_v = self.compute_super_prompts(p_concat_v, clip_out, 0, cur_classes).squeeze(2)
+ super_prompt = (sp_concat_k, sp_concat_v)
+
+ return super_prompt
+ else:
+ return None
+
+ def compute_ortho_loss(self, cur_classes: int, frozen_past_classes=0) -> torch.Tensor:
+ """
+ Compute the orthogonality loss for the prompts of the layers in `prompt_layers`.
+ The loss is computed in two parts:
+ - The intra-orthogonality loss between the prompts of the current classes (between `frozen_past_classes` and `cur_classes`).
+ - The inter-orthogonality loss between the prompts of the past classes (before `frozen_past_classes`).
+ If `frozen_past_classes` is 0, the loss is skipped and not computed.
+
+ The argument `ortho_split_val` is used to manage the orthogonality loss computation between the layers.
+ The layers before `ortho_split_val` will have a weight of 0, thus the loss will not have any effect on them.
+ """
+
+ if frozen_past_classes == 0: # No ortho to compute between present and past
+ return 0.
+
+ ortho_loss_list = []
+ weight_loss_list = []
+
+ def _compute_loss(p: torch.Tensor, frozen_past_classes: int, cur_classes: int) -> torch.Tensor:
+ past_pv = p[:frozen_past_classes].detach()
+ cur_pv = p[frozen_past_classes:cur_classes]
+
+ eye_intra = torch.eye(cur_classes - frozen_past_classes).bool()
+
+ intra_ortho_loss = (torch.matmul(cur_pv, cur_pv.T)[eye_intra] - 1).pow(2).mean()
+ inter_ortho_loss = (torch.matmul(cur_pv, past_pv.T)).pow(2).mean()
+ return intra_ortho_loss + inter_ortho_loss
+
+ for layer_idx in self.prompt_layers:
+
+ if self.prompt_mode == 'residual':
+ p = getattr(self, f'p_{layer_idx}')
+ current_loss = _compute_loss(p, frozen_past_classes, cur_classes)
+ else:
+ p_concat = getattr(self, f'p_concat_{layer_idx}')
+ p_concat_k, p_concat_v = torch.split(p_concat, self.args.prefix_tuning_prompt_len, dim=1)
+
+ p_concat_k = p_concat_k.view(p_concat_k.shape[0], -1)
+ p_concat_v = p_concat_v.view(p_concat_v.shape[0], -1)
+
+ current_loss_k = _compute_loss(p_concat_k, frozen_past_classes, cur_classes)
+ current_loss_v = _compute_loss(p_concat_v, frozen_past_classes, cur_classes)
+
+ current_loss = current_loss_k + current_loss_v
+
+ current_weight = 1.
+ if layer_idx < self.args.ortho_split_val:
+ current_weight = 0.
+
+ current_loss = current_weight * current_loss
+
+ weight_loss_list.append(current_weight)
+ ortho_loss_list.append(current_loss)
+
+ total_ortho_loss = sum(ortho_loss_list) / sum(weight_loss_list)
+
+ return total_ortho_loss
+
+
+class Model(nn.Module):
+ prompter: Prompter
+
+ def __init__(self, args, backbone: nn.Module, dataset: ContinualDataset, num_classes, device='cpu',
+ clip_model: clip.model.CLIP = None, clip_preprocess=None):
+ super().__init__()
+
+ assert 'resnet' not in str(type(backbone)).lower(), "ResNet not supported"
+
+ self.args = args
+ self.num_classes = num_classes
+ self.device = device
+
+ # get feature encoder
+ vit_model = VisionTransformer(embed_dim=768,
+ depth=12,
+ num_heads=12,
+ drop_path_rate=0,
+ num_classes=num_classes,
+ prompt_mode=args.prompt_mode).to(device)
+
+ print("Loading the Vision Transformer backbone...")
+ load_dict = backbone.state_dict()
+ for k in list(load_dict.keys()):
+ if 'head' in k:
+ del load_dict[k]
+ missing, unexpected = vit_model.load_state_dict(load_dict, strict=False)
+ assert len([m for m in missing if 'head' not in m]) == 0, f"Missing keys: {missing}"
+ assert len(unexpected) == 0, f"Unexpected keys: {unexpected}"
+
+ self.vit = vit_model
+
+ self.prompt_layers = list(range(len(self.vit.blocks)))
+
+ print("Initializing the prompter and prompt parameters...")
+ self.prompter = Prompter(args,
+ dataset,
+ num_classes=num_classes,
+ target_embed_len=self.vit.patch_embed.num_patches,
+ target_embed_dim=self.vit.embed_dim,
+ prompt_layers=self.prompt_layers,
+ clip_model=clip_model,
+ clip_preprocess=clip_preprocess,
+ device=device)
+
+ # freeze the backbone
+ for n, p in self.vit.named_parameters():
+ if n != 'head.weight' and n != 'head.bias':
+ p.requires_grad = False
+
+ def train(self, mode=True):
+ super().train(False)
+ self.prompter.train(False)
+ self.vit.train(mode)
+
+ return self
+
+ def forward(self, x: torch.Tensor, cur_classes: int, frozen_past_classes=0, query_x=None, return_features=False, return_query=False) -> torch.Tensor:
+ """
+ Compute the forward of the second-stage of STAR-Prompt.
+ Classes from `frozen_past_classes` to `cur_classes` will have a gradient, while all those before `frozen_past_classes` will be detached.
+
+ If `query_x` is provided, it will be used as the query for the CLIP's visual encoder.
+ Otherwise, the input image `x` will be used as the query. Note that the CLIP's pre-processing will applied to `x` in this case.
+
+ Args:
+ x: the input image tensor
+ cur_classes: the number of classes up to the current task
+ frozen_past_classes: the number of classes from the past tasks that will be frozen
+ query_x: (optional) the query tensor for the CLIP's visual encoder
+ return_features: if True, the features from the Vision Transformer will be returned instead of the classification output
+ return_query: if True, the query tensor will be returned with the output
+ """
+ enable_renorm = query_x is None
+ query_x = x if query_x is None else query_x
+ clip_query = self.prompter.get_query(query_x, disable_renorm=not enable_renorm)
+ features = self.vit.forward_features(x, first_stage_query=clip_query, prompter=self.prompter, cur_classes=cur_classes, frozen_past_classes=frozen_past_classes)
+ if return_features:
+ return features
+
+ out = self.vit.forward_head(features)
+ if return_query:
+ return out, clip_query
+ return out
diff --git a/models/star_prompt_utils/vision_transformer.py b/models/star_prompt_utils/vision_transformer.py
new file mode 100644
index 00000000..45190d80
--- /dev/null
+++ b/models/star_prompt_utils/vision_transformer.py
@@ -0,0 +1,107 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from backbone.vit import VisionTransformer as MammothVP, Block as MammothViTBlock
+from models.coda_prompt_utils.vit import Attention as PrefixTuningAttention
+
+
+class ResidualPromptAttention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, prompts=None):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if prompts is not None:
+ prompts = prompts.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ v = v + prompts
+
+ if torch.__version__ >= '2.1.0':
+ x = F.scaled_dot_product_attention(q, k, v, scale=self.scale, dropout_p=self.attn_drop.p)
+ else:
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(MammothViTBlock):
+ def forward(self, x, prompts=None):
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), prompts)))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+
+class VisionTransformer(MammothVP):
+
+ def __init__(self, *args, prompt_mode='residual', **kwargs):
+ super().__init__(*args, **kwargs)
+ assert prompt_mode in ['residual', 'concat'], 'prompt_mode should be either residual or concat'
+
+ attn_layer = ResidualPromptAttention if prompt_mode == 'residual' else PrefixTuningAttention
+
+ self.blocks = nn.Sequential(*[
+ Block(
+ dim=self.embed_dim,
+ num_heads=self.num_heads,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias,
+ init_values=self.init_values,
+ drop=self.pos_drop.p,
+ attn_drop=self.attn_drop_rate,
+ attn_layer=attn_layer,
+ drop_path=self.dpr[i],
+ norm_layer=self.norm_layer,
+ act_layer=self.act_layer
+ )
+ for i in range(self.depth)])
+
+ self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
+
+ if self.weight_init != 'skip':
+ self.init_weights(self.weight_init)
+
+ def forward_features(self, x, first_stage_query, prompter, cur_classes: int, frozen_past_classes=0):
+ x = self.patch_embed(x)
+ x = self._pos_embed(x)
+ x = self.norm_pre(x)
+ for idx, blk in enumerate(self.blocks):
+ prompts = prompter.get_prompts(idx, first_stage_query, frozen_past_classes=frozen_past_classes, cur_classes=cur_classes)
+ if prompts is not None:
+ x = blk(x, prompts)
+ else:
+ x = blk(x)
+ x = self.norm(x)
+ return x
+
+ def forward(self, x: torch.Tensor, first_stage_query: torch.Tensor, prompter, cur_classes: int, frozen_past_classes=0) -> torch.Tensor:
+ """
+ Compute the forward of STAR-Prompt.
+
+ Args:
+ x: input image
+ query: the output of the visual encoder of CLIP, to be used as query for the second stage's prompter
+ prompter: the prompter of the second stage
+ train: whether the model is in training mode. If True, the prompts of the past tasks will be frozen and only the current task's prompts will be updated. Else, all prompts will be frozen.
+ """
+ x = self.forward_features(x, first_stage_query, prompter, cur_classes, frozen_past_classes)
+ x = self.forward_head(x)
+ return x
diff --git a/models/starprompt.py b/models/starprompt.py
new file mode 100644
index 00000000..3e075727
--- /dev/null
+++ b/models/starprompt.py
@@ -0,0 +1,175 @@
+import logging
+import torch
+from argparse import ArgumentParser
+
+import torch
+
+from models.star_prompt_utils.end_to_end_model import STARPromptModel
+from models.utils.continual_model import ContinualModel
+from utils.schedulers import CosineSchedule
+
+try:
+ import clip
+except ImportError:
+ raise ImportError("Please install the CLIP package by running: pip install git+https://github.com/openai/CLIP.git (requires also `huggingface-hub`)")
+
+
+class STARPrompt(ContinualModel):
+ NAME = 'starprompt'
+ COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
+ net: STARPromptModel
+
+ @staticmethod
+ def get_parser() -> ArgumentParser:
+ parser = ArgumentParser(description='Second-stage of StarPrompt. Requires the keys saved from the first stage.')
+
+ frozen_group = parser.add_argument_group('Frozen hyperparameters')
+ frozen_group.add_argument("--virtual_bs_n", type=int, default=1,
+ help="virtual batch size iterations")
+ frozen_group.add_argument("--ortho_split_val", type=int, default=0)
+ frozen_group.add_argument('--gr_mog_n_iters_second_stage', type=int, default=500,
+ help="Number of EM iterations during fit for GR with MOG on the second stage.")
+ frozen_group.add_argument('--gr_mog_n_iters_first_stage', type=int, default=200,
+ help="Number of EM iterations during fit for GR with MOG on the first stage.")
+ frozen_group.add_argument('--gr_mog_n_components', type=int, default=5,
+ help="Number of components for GR with MOG (both first and second stage).")
+ frozen_group.add_argument('--batch_size_gr', type=int, default=128,
+ help="Batch size for Generative Replay (both first and second stage).")
+ frozen_group.add_argument('--num_samples_gr', type=int, default=256,
+ help="Number of samples for Generative Replay (both first and second stage).")
+ frozen_group.add_argument('--prefix_tuning_prompt_len', type=int, default=5,
+ help="Prompt length for prefix tuning. Used only if `--prompt_mode==concat`.")
+
+ ablation_group = parser.add_argument_group('Ablations hyperparameters')
+ ablation_group.add_argument('--gr_model', type=str, default='mog', choices=['mog', 'gaussian'],
+ help="Type of distribution model for Generative Replay (both first and second stage). "
+ "- `mog`: Mixture of Gaussian. "
+ "- `gaussian`: Single Gaussian distribution.")
+ ablation_group.add_argument("--enable_gr", type=int, default=1, choices=[0, 1],
+ help="Enable Generative Replay (both first and second stage).")
+ ablation_group.add_argument('--prompt_mode', type=str, default='residual', choices=['residual', 'concat'],
+ help="Prompt type for the second stage. "
+ "- `residual`: STAR-Prompt style prompting. "
+ "- `concat`: Prefix-Tuning style prompting.")
+ ablation_group.add_argument("--enable_confidence_modulation", type=int, default=1, choices=[0, 1],
+ help="Enable confidence modulation with CLIP similarities (Eq. 5 of the main paper)?")
+
+ tunable_group = parser.add_argument_group('Tunable hyperparameters')
+ # second stage
+ tunable_group.add_argument("--lambda_ortho_second_stage", type=float, default=10,
+ help="orthogonality loss coefficient")
+ tunable_group.add_argument("--num_monte_carlo_gr_second_stage", type=int, default=1,
+ help="how many times to sample from the dataset for alignment")
+ tunable_group.add_argument("--num_epochs_gr_second_stage", type=int, default=10,
+ help="Num. of epochs for GR.")
+ tunable_group.add_argument("--learning_rate_gr_second_stage", type=float, default=0.001,
+ help="Learning rate for GR.")
+ # first stage
+ tunable_group.add_argument("--num_monte_carlo_gr_first_stage", type=int, default=1,
+ help="how many times to sample from the dataset for alignment")
+ tunable_group.add_argument("--learning_rate_gr_first_stage", type=float, default=0.05,
+ help="Learning rate for Generative Replay.")
+ tunable_group.add_argument("--lambda_ortho_first_stage", type=float, default=30,
+ help="Orthogonality loss coefficient for coop")
+ tunable_group.add_argument("--num_epochs_gr_first_stage", type=int, default=10,
+ help="Num. of epochs for Generative Replay.")
+
+ parser.add_argument("--clip_backbone", type=str, default='ViT-L/14', help="CLIP backbone architecture",
+ choices=clip.available_models())
+
+ first_stage_optim_group = parser.add_argument_group('First stage optimization hyperparameters')
+ first_stage_optim_group.add_argument("--first_stage_optim", type=str, default='sgd', choices=['sgd', 'adam'],
+ help="First stage optimizer")
+ first_stage_optim_group.add_argument("--first_stage_lr", type=float, default=0.002, help="First stage learning rate")
+ first_stage_optim_group.add_argument("--first_stage_momentum", type=float, default=0, help="First stage momentum")
+ first_stage_optim_group.add_argument("--first_stage_weight_decay", type=float, default=0, help="First stage weight decay")
+ first_stage_optim_group.add_argument("--first_stage_epochs", type=int, help="First stage epochs. If not set, it will be the same as `n_epochs`.")
+
+ return parser
+
+ def __init__(self, backbone, loss, args, transform):
+ if not hasattr(args, 'first_stage_epochs') or args.first_stage_epochs is None:
+ logging.info("`first_stage_epochs` not set. Setting it to `n_epochs`.")
+ args.first_stage_epochs = args.n_epochs
+
+ super().__init__(backbone, loss, args, transform)
+
+ self.net = STARPromptModel(args,
+ backbone=self.net,
+ dataset=self.dataset,
+ num_classes=self.num_classes,
+ device=self.device)
+
+ def end_task(self, dataset):
+ if hasattr(self, 'opt'):
+ del self.opt # free up some vram
+
+ if self.args.enable_gr:
+ self.net.update_statistics(dataset, self.n_past_classes, self.n_seen_classes)
+ self.net.backup(self.current_task, self.n_past_classes, self.n_seen_classes)
+
+ if self.current_task > 0:
+ if self.args.seed is not None:
+ torch.manual_seed(self.args.seed)
+ self.net.align(self.current_task, self.n_seen_classes, self.loss)
+
+ def get_parameters(self):
+ if not isinstance(self.net, STARPromptModel): # during initialization
+ return super().get_parameters()
+ return [p for p in self.net.second_stage.parameters() if p.requires_grad]
+
+ def get_scheduler(self):
+ return CosineSchedule(self.opt, K=self.args.n_epochs)
+
+ def begin_task(self, dataset):
+ # clean junk on GPU
+ if hasattr(self, 'opt'):
+ del self.opt
+
+ torch.cuda.empty_cache()
+
+ # adapt CLIP on current task
+ self.net.train_first_stage_on_task(dataset, self.current_task, self.n_past_classes, self.n_seen_classes, self.loss)
+ self.net.update_keys(self.n_past_classes, self.n_seen_classes)
+ self.net.second_stage.train()
+
+ # initialize second stage
+
+ # For later GR
+ self.net.recall_classifier_second_stage(self.current_task, self.n_past_classes, self.n_seen_classes)
+
+ self.opt = self.get_optimizer()
+ self.scheduler = self.get_scheduler()
+
+ def forward(self, x):
+ logits = self.net(x, cur_classes=self.n_seen_classes)
+ logits = logits[:, :self.n_seen_classes]
+ return logits
+
+ def observe(self, inputs, labels, not_aug_inputs, epoch=None): # second stage only
+ stream_inputs, stream_labels = inputs, labels
+ stream_logits = self.net(stream_inputs, cur_classes=self.n_seen_classes, frozen_past_classes=self.n_past_classes)
+
+ # Compute accuracy on current training batch for logging
+ with torch.no_grad():
+ stream_preds = stream_logits[:, :self.n_seen_classes].argmax(dim=1)
+ stream_acc = (stream_preds == stream_labels).sum().item() / stream_labels.shape[0]
+
+ # mask old classes
+ stream_logits[:, :self.n_past_classes] = -float('inf')
+ loss = self.loss(stream_logits[:, :self.n_seen_classes], stream_labels)
+
+ loss_ortho = self.net.second_stage.prompter.compute_ortho_loss(frozen_past_classes=self.n_past_classes, cur_classes=self.n_seen_classes)
+ loss += self.args.lambda_ortho_second_stage * loss_ortho
+
+ if self.epoch_iteration == 0:
+ self.opt.zero_grad()
+
+ (loss / self.args.virtual_bs_n).backward()
+ if (self.epoch_iteration > 0 or self.args.virtual_bs_n == 1) and \
+ self.epoch_iteration % self.args.virtual_bs_n == 0:
+ self.opt.step()
+ self.opt.zero_grad()
+
+ return {'loss': loss.item(),
+ 'stream_accuracy': stream_acc}
diff --git a/models/twf.py b/models/twf.py
index 14648535..2bdd55bb 100644
--- a/models/twf.py
+++ b/models/twf.py
@@ -1,3 +1,4 @@
+import logging
import torch
from models.twf_utils.utils import init_twf
from utils.augmentations import CustomRandomCrop, CustomRandomHorizontalFlip, DoubleCompose, DoubleTransform, apply_transform
@@ -56,13 +57,13 @@ def __init__(self, backbone, loss, args, transform):
self.buf_transform = self.get_custom_double_transform(self.original_transform.transforms)
if self.args.loadcheck is None:
- print("Warning: no checkpoint loaded!")
+ logging.warning("no checkpoint loaded!")
if self.args.lambda_fp_replay == 0:
- print('Warning: lambda_fp_replay is 0, so no replay of attention masks will be used')
+ logging.warning('lambda_fp_replay is 0, so no replay of attention masks will be used')
if self.args.lambda_diverse_loss == 0:
- print('Warning: lambda_diverse_loss is 0, so no diverse loss will be used')
+ logging.warning('lambda_diverse_loss is 0, so no diverse loss will be used')
def get_custom_double_transform(self, transform):
tfs = []
diff --git a/models/twf_utils/afd.py b/models/twf_utils/afd.py
index 4112d4f1..0bd928f6 100644
--- a/models/twf_utils/afd.py
+++ b/models/twf_utils/afd.py
@@ -7,6 +7,7 @@
from utils.conditional_bn import ConditionalBatchNorm1d
from utils.conditional_bn import ConditionalBatchNorm2d
+from utils.conf import warn_once
def get_rnd_weight(num_tasks, fin, fout=None, nonlinearity='relu'):
@@ -110,9 +111,7 @@ def forward(self, logits):
if self.training:
if str(logits.device) == 'cpu':
- if not hasattr(self, 'warned') or not self.warned:
- print('Warning: GumbelSoftmax may be unstable in CPU (see https://github.com/pytorch/pytorch/issues/101620)')
- self.warned = True
+ warn_once('GumbelSoftmax may be unstable in CPU (see https://github.com/pytorch/pytorch/issues/101620)')
h = nn.functional.gumbel_softmax(logits, tau=self.tau, hard=True)
h = h[..., 0]
return h
diff --git a/models/utils/continual_model.py b/models/utils/continual_model.py
index 97fcfff8..a1791bd2 100644
--- a/models/utils/continual_model.py
+++ b/models/utils/continual_model.py
@@ -24,11 +24,13 @@
# LICENSE file in the root directory of this source tree.
from abc import abstractmethod
+import logging
import sys
from argparse import ArgumentParser, Namespace
from contextlib import suppress
-from typing import List
+from typing import List, Tuple
+import kornia
import torch
import torch.nn as nn
import torch.optim as optim
@@ -52,8 +54,23 @@ class ContinualModel(nn.Module):
COMPATIBILITY: List[str]
AVAIL_OPTIMS = ['sgd', 'adam', 'adamw']
+ args: Namespace # The command line arguments
+ device: torch.device # The device to be used for training
+ net: nn.Module # The backbone of the model (defined by the `dataset`)
+ loss: nn.Module # The loss function to be used (defined by the `dataset`)
+ opt: optim.Optimizer # The optimizer to be used for training
+ scheduler: optim.lr_scheduler._LRScheduler # (optional) The scheduler for the optimizer. If defined, it will overwrite the one defined in the `dataset`
+ # The transformation to be applied to the input data. The model will try to convert it to a kornia transform to be applicable to a batch of samples at once
+ transform: transforms.Compose | kornia.augmentation.AugmentationSequential
+ original_transform: transforms.Compose # The original transformation to be applied to the input data. This is the one defined by the `dataset`
+ task_iteration: int # Number of iterations in the current task
+ epoch_iteration: int # Number of iterations in the current epoch. Updated if `epoch` is passed to observe
+ dataset: ContinualDataset # The instance of the dataset. Used to update the number of classes in the current task
+ num_classes: int # Total number of classes in the dataset
+ n_tasks: int # Total number of tasks in the dataset
+
@staticmethod
- def get_parser() -> Namespace:
+ def get_parser() -> ArgumentParser:
"""
Returns the parser of the model.
@@ -65,6 +82,20 @@ def get_parser() -> Namespace:
parser = ArgumentParser(description='Base CL model')
return parser
+ @property
+ def task_iteration(self):
+ """
+ Returns the number of iterations in the current task.
+ """
+ return self._task_iteration
+
+ @property
+ def epoch_iteration(self):
+ """
+ Returns the number of iterations in the current epoch.
+ """
+ return self._epoch_iteration
+
@property
def current_task(self):
"""
@@ -153,14 +184,14 @@ def __init__(self, backbone: nn.Module, loss: nn.Module,
self.transform = to_kornia_transform(transform.transforms[-1].transforms)
self.normalization_transform = to_kornia_transform(self.dataset.get_normalization_transform())
except BaseException:
- print("Warning: could not initialize kornia transforms.")
+ logging.error("could not initialize kornia transforms.")
self.normalization_transform = transforms.Compose([transforms.ToPILImage(), self.dataset.TEST_TRANSFORM]) if hasattr(
self.dataset, 'TEST_TRANSFORM') else transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), self.dataset.get_normalization_transform()])
if self.net is not None:
self.opt = self.get_optimizer()
else:
- print("Warning: no default model for this dataset. You will have to specify the optimizer yourself.")
+ logging.warning("no default model for this dataset. You will have to specify the optimizer yourself.")
self.opt = None
self.device = get_device()
@@ -168,7 +199,7 @@ def __init__(self, backbone: nn.Module, loss: nn.Module,
raise NotImplementedError('Please specify the name and the compatibility of the model.')
if self.args.label_perc != 1 and 'cssl' not in self.COMPATIBILITY:
- print('WARNING: label_perc is not explicitly supported by this model -> training may break')
+ logging.info('label_perc is not explicitly supported by this model -> training may break')
def to(self, device):
"""
@@ -191,7 +222,7 @@ def get_parameters(self):
"""
return self.net.parameters()
- def get_optimizer(self):
+ def get_optimizer(self) -> optim.Optimizer:
# check if optimizer is in torch.optim
supported_optims = {optim_name.lower(): optim_name for optim_name in dir(optim) if optim_name.lower() in self.AVAIL_OPTIMS}
opt = None
@@ -209,11 +240,17 @@ def get_optimizer(self):
raise ValueError('Unknown optimizer: {}'.format(self.args.optimizer))
return opt
- def _compute_offsets(self, task):
- cpt = self.N_CLASSES // self.N_TASKS
- offset1 = task * cpt
- offset2 = (task + 1) * cpt
- return offset1, offset2
+ def compute_offsets(self, task: int) -> Tuple[int, int]:
+ """
+ Compute the start and end offset given the task.
+
+ Args:
+ task: the task index
+
+ Returns:
+ the start and end offset
+ """
+ return self.dataset.get_offsets(task)
def get_debug_iters(self):
"""
@@ -264,19 +301,34 @@ def meta_observe(self, *args, **kwargs):
Returns:
the value of the loss function
"""
+ if 'epoch' in kwargs and kwargs['epoch'] is not None:
+ epoch = kwargs['epoch']
+ if self._past_epoch != epoch:
+ self._past_epoch = epoch
+ self._epoch_iteration = 0
if 'cssl' not in self.COMPATIBILITY: # drop unlabeled data if not supported
labeled_mask = args[1] != -1
- if labeled_mask.sum() == 0:
- return 0
- args = [arg[labeled_mask] if isinstance(arg, torch.Tensor) and arg.shape[0] == args[0].shape[0] else arg for arg in args]
+ if (~labeled_mask).any(): # if there are any unlabeled samples
+ if labeled_mask.sum() == 0: # if all samples are unlabeled
+ return 0
+ args = [arg[labeled_mask] if isinstance(arg, torch.Tensor) and arg.shape[0] == args[0].shape[0] else arg for arg in args]
if 'wandb' in sys.modules and not self.args.nowand:
pl = persistent_locals(self.observe)
ret = pl(*args, **kwargs)
- self.autolog_wandb(pl.locals)
+ extra = {}
+ if isinstance(ret, dict):
+ assert 'loss' in ret, "Loss not found in return dict"
+ extra = {k: v for k, v in ret.items() if k != 'loss'}
+ ret = ret['loss']
+ self.autolog_wandb(pl.locals, extra=extra)
else:
ret = self.observe(*args, **kwargs)
- self.task_iteration += 1
+ if isinstance(ret, dict):
+ assert 'loss' in ret, "Loss not found in return dict"
+ ret = ret['loss']
+ self._task_iteration += 1
+ self._epoch_iteration += 1
return ret
def meta_begin_task(self, dataset):
@@ -288,11 +340,12 @@ def meta_begin_task(self, dataset):
Args:
dataset: the current task's dataset
"""
- self.task_iteration = 0
+ self._task_iteration = 0
+ self._epoch_iteration = 0
+ self._past_epoch = 0
self._n_classes_current_task = self._cpt if isinstance(self._cpt, int) else self._cpt[self._current_task]
- self._n_seen_classes = self._cpt * (self._current_task + 1) if isinstance(self._cpt, int) else sum(self._cpt[:self._current_task + 1])
+ self._n_past_classes, self._n_seen_classes = self.compute_offsets(self._current_task)
self._n_remaining_classes = self.N_CLASSES - self._n_seen_classes
- self._n_past_classes = self._cpt * self._current_task if isinstance(self._cpt, int) else sum(self._cpt[:self._current_task])
self.begin_task(dataset)
def meta_end_task(self, dataset):
diff --git a/models/utils/future_model.py b/models/utils/future_model.py
new file mode 100644
index 00000000..395d17b0
--- /dev/null
+++ b/models/utils/future_model.py
@@ -0,0 +1,17 @@
+""" This is the base class for all models that support future prediction, i.e., zero-shot prediction.
+
+ It extends the ContinualModel class and adds the future_forward method, which should be implemented by all models that inherit from this class.
+ Such a method should take an input tensor and return a tensor representing the future prediction. This method is used by the future prediction evaluation protocol.
+
+ The change_transform method is used to update the transformation applied to the input data. This is useful when the model is trained on a dataset and then evaluated on a different dataset. In this case, the transformation should be updated to match the new dataset.
+"""
+
+from .continual_model import ContinualModel
+
+
+class FutureModel(ContinualModel):
+ def future_forward(self, x):
+ raise NotImplementedError
+
+ def change_transform(self, dataset):
+ pass
diff --git a/models/utils/lider_model.py b/models/utils/lider_model.py
index 77b55b49..0ee0ae79 100644
--- a/models/utils/lider_model.py
+++ b/models/utils/lider_model.py
@@ -2,6 +2,7 @@
Base class for all models that use the Lipschitz regularization in LiDER (https://arxiv.org/pdf/2210.06443.pdf).
"""
+import logging
import torch
import torch.nn.functional as F
from tqdm import tqdm
@@ -33,7 +34,7 @@ def __init__(self, backbone, loss, args, transform):
super().__init__(backbone, loss, args, transform)
if self.args.alpha_lip_lambda == 0 and self.args.beta_lip_lambda == 0:
- print("WARNING: LiDER is enabled but both `alpha_lip_lambda` and `beta_lip_lambda` are 0. LiDER will not be used.")
+ logging.error("LiDER is enabled but both `alpha_lip_lambda` and `beta_lip_lambda` are 0. LiDER will not be used.")
def transmitting_matrix(self, fm1: torch.Tensor, fm2: torch.Tensor):
if fm1.size(2) > fm2.size(2):
diff --git a/requirements-optional.txt b/requirements-optional.txt
index 0e9c2c08..66f9fd7c 100644
--- a/requirements-optional.txt
+++ b/requirements-optional.txt
@@ -3,4 +3,10 @@ onedrivedownloader==1.1.3
pytest==7.4.2
quadprog==0.1.11
setproctitle==1.3.2
-wandb
\ No newline at end of file
+wandb
+deeplake
+pandas
+timm==0.9.8
+clip @ git+https://github.com/openai/CLIP.git
+scikit-learn
+decorator
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 5b4f4683..51910471 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,10 @@
-torch
+torch>=2.1.0
numpy
torchvision
kornia>=0.7.0
Pillow
timm==0.9.8
tqdm
-onedrivedownloader
\ No newline at end of file
+onedrivedownloader
+ftfy
+regex
\ No newline at end of file
diff --git a/scripts/local_launcher.py b/scripts/local_launcher.py
index a858a9cd..cc4f9524 100644
--- a/scripts/local_launcher.py
+++ b/scripts/local_launcher.py
@@ -1,6 +1,14 @@
-import functools
import os
-import random
+import sys
+
+if 'scripts' in os.path.dirname(os.path.abspath(__file__)):
+ mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+else:
+ mammoth_path = os.getcwd()
+os.chdir(mammoth_path)
+sys.path.append(mammoth_path)
+
+import functools
import subprocess
import sys
import time
@@ -31,7 +39,7 @@ def parse_args():
assert args.redundancy >= 1, "redundancy must be at least 1"
assert args.start_from >= 0, "start_from must be at least 0"
- jobs_list = [l for l in open(args.file, "r").read().splitlines() if l.strip() != "" and not l.startswith("#")][args.start_from:] * args.redundancy
+ jobs_list = [l for l in open(args.file, "r").read().splitlines() if l.strip() != "" and not l.strip().startswith("#")][args.start_from:] * args.redundancy
if args.reverse:
jobs_list = list(reversed(jobs_list))
jobname = args.file.strip().split("/")[-1].split("\\")[-1].split(".")[0]
@@ -95,9 +103,11 @@ def main():
def signal_handler(sig, frame):
print('Killing all processes')
if os.name == 'nt':
- os.system("taskkill /F /T /PID {}".format(os.getpid()))
+ for job_index, (jobname, pid) in active_jobs.items():
+ os.system("taskkill /F /PID {}".format(pid))
else:
- os.system("kill -9 -1")
+ for job_index, (jobname, pid) in active_jobs.items():
+ os.system("kill -9 {}".format(pid))
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
diff --git a/scripts/prepare_grid.py b/scripts/prepare_grid.py
index 2707c85b..e8594489 100644
--- a/scripts/prepare_grid.py
+++ b/scripts/prepare_grid.py
@@ -1,6 +1,8 @@
import os
+if os.getcwd().split('/')[-1] == 'scripts':
+ os.chdir('..')
+
import itertools
-import numpy as np
import argparse
parser = argparse.ArgumentParser(description='Prepare grid')
@@ -41,11 +43,12 @@
for k, v in zip(combos.keys(), c):
if v is None:
continue
-if isinstance(k, if) for i in range(len(k)):
+ if isinstance(k, (list, tuple)):
+ for i in range(len(k)):
ll += f" --{k[i]}={v[i]}"
else:
ll += f" --{k}={v}"
- f.write(ll +'\n')
+ f.write(ll + '\n')
all_configs.append(ll)
clines += 1
diff --git a/scripts/slurm_sbatcher.py b/scripts/slurm_sbatcher.py
index 7d7c9a57..b7420282 100644
--- a/scripts/slurm_sbatcher.py
+++ b/scripts/slurm_sbatcher.py
@@ -1,7 +1,9 @@
-import argparse
+import logging
import os
-import socket
-import time
+if os.getcwd().split('/')[-1] == 'scripts':
+ os.chdir('..')
+
+import argparse
import math
if __name__ == '__main__':
@@ -32,7 +34,7 @@
args = parser.parse_args()
if args.ddp:
- print("Warning: distributed stuff not yet supported in mammoth (problems with buffer synchronization). Use at your own risk!")
+ logging.error("distributed stuff not yet supported in mammoth (problems with buffer synchronization). Use at your own risk!")
with open(args.file, 'r') as f:
all_com = f.read().splitlines()
diff --git a/scripts/wandb_sync.py b/scripts/wandb_sync.py
index e443416f..15eb06b0 100644
--- a/scripts/wandb_sync.py
+++ b/scripts/wandb_sync.py
@@ -1,9 +1,11 @@
import argparse
+import logging
import os
+from functools import partial
+from multiprocessing.pool import ThreadPool
from pathlib import Path
from tqdm import tqdm
-from multiprocessing.pool import ThreadPool
if 'scripts' in os.path.dirname(os.path.abspath(__file__)):
mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -13,9 +15,10 @@
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument("--n_workers", type=int, help="Number of workers to use. If not specified, will use all available cores. (Recommended: n_cpus*3)")
- parser.add_argument("--limit", type=int, help="Limit the number of runs to sync")
- parser.add_argument("--reverse", action="store_true", help="Reverse the order of runs to sync")
+ parser.add_argument("-w", "-n", "--n_workers", type=int, help="Number of workers to use. If not specified, will use all available cores. (Recommended: n_cpus*3)")
+ parser.add_argument("-l", "--limit", type=int, help="Limit the number of runs to sync")
+ parser.add_argument("-r", "--reverse", action="store_true", help="Reverse the order of runs to sync?")
+ parser.add_argument("-c", "--clean_after", action="store_true", help="Clean run after syncing?")
args = parser.parse_args()
if args.n_workers is None:
@@ -32,9 +35,11 @@ def check_offline():
return len([f for f in os.listdir() if 'offline' in f]) > 0
-def sync_run(run):
+def sync_run(run, clean_after=False):
"""Syncs a single run"""
- os.system(f"wandb sync {run} >>synced.log 2>>err.log")
+ ret_code = os.system(f"wandb sync {run} >>synced.log 2>>err.log")
+ if ret_code == 0 and clean_after:
+ os.system(f"rm -rf {run}")
if __name__ == "__main__":
@@ -53,6 +58,9 @@ def sync_run(run):
runlist = runlist[:args.limit]
print("Limiting to", args.limit, "runs")
+ if args.clean_after:
+ logging.info("Cleaning after syncing")
+
print(len(runlist), "runs to sync")
# delete file synced.log if exists
@@ -64,8 +72,9 @@ def sync_run(run):
Path("err.log").unlink()
# sync all runs in multiple threads and log tqdm
+ sync_fn = partial(sync_run, clean_after=args.clean_after)
with ThreadPool(args.n_workers) as p:
- r = list(tqdm(p.imap(sync_run, runlist), total=len(runlist)))
+ r = list(tqdm(p.imap(sync_fn, runlist), total=len(runlist)))
# check if there are any errors in err.log
if Path("err.log").exists():
diff --git a/tests/test_bic.py b/tests/test_bic.py
index 8a20cd00..91b929b3 100644
--- a/tests/test_bic.py
+++ b/tests/test_bic.py
@@ -2,9 +2,11 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('distill_after_bic', [0, 1])
def test_bic(distill_after_bic):
sys.argv = ['mammoth',
diff --git a/tests/test_ccic.py b/tests/test_ccic.py
index 9b6ef49a..07898bc7 100644
--- a/tests/test_ccic.py
+++ b/tests/test_ccic.py
@@ -1,10 +1,12 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-from utils.main import main, parse_args
+from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-cifar10'])
@pytest.mark.parametrize('label_perc', ['0.1', '0.08'])
def test_ccic(dataset, label_perc):
diff --git a/tests/test_cgil.py b/tests/test_cgil.py
new file mode 100644
index 00000000..47d30dc9
--- /dev/null
+++ b/tests/test_cgil.py
@@ -0,0 +1,37 @@
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from utils.main import main
+from utils.test_utils import init_test_environ
+import pytest
+
+
+@init_test_environ
+def test_dualprompt():
+ sys.argv = ['mammoth',
+ '--model',
+ 'cgil',
+ '--dataset',
+ 'seq-cifar100-224',
+ '--lr',
+ '1e-4',
+ '--n_epochs',
+ '1',
+ '--batch_size',
+ '2',
+ '--non_verbose',
+ '1',
+ '--num_workers',
+ '0',
+ '--seed',
+ '0',
+ '--debug_mode',
+ '1']
+
+ # log all outputs to file
+ if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')):
+ os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs'))
+ sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_cgil.log'), 'w', encoding='utf-8')
+ sys.stderr = sys.stdout
+
+ main()
diff --git a/tests/test_checkpointing.py b/tests/test_checkpointing.py
index 318e7a05..2524cc0f 100644
--- a/tests/test_checkpointing.py
+++ b/tests/test_checkpointing.py
@@ -2,11 +2,15 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('model', ['sgd', 'slca', 'l2p'])
-def test_checkpointing_bufferfree(model):
+@pytest.mark.parametrize('savecheck', ['last', 'task'])
+@pytest.mark.parametrize('joint', ['0', '1'])
+def test_checkpointing_bufferfree(model, savecheck, joint):
N_TASKS = 5 # cifar10
# TEST CHECKPOINT SAVE
@@ -19,8 +23,10 @@ def test_checkpointing_bufferfree(model):
'1e-4',
'--n_epochs',
'1',
+ '--joint',
+ joint,
'--savecheck',
- '1',
+ savecheck,
'--batch_size',
'4',
'--non_verbose',
@@ -35,17 +41,24 @@ def test_checkpointing_bufferfree(model):
# log all outputs to file
if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')):
os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs'))
- sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_save.{model}.log'), 'w', encoding='utf-8')
+ sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_save.{model}.{savecheck}.{"joint" if joint=="1" else "cl"}.log'), 'w', encoding='utf-8')
sys.stderr = sys.stdout
main()
# read output file and search for the string 'Saving checkpoint into'
- with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_save.{model}.log'), 'r', encoding='utf-8') as f:
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_save.{model}.{savecheck}.{"joint" if joint=="1" else "cl"}.log'), 'r', encoding='utf-8') as f:
lines = f.readlines()
ckpt_name = [line for line in lines if 'Saving checkpoint into' in line]
assert any(ckpt_name), f'Checkpoint not saved for model {model}'
- ckpt_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip() + f'_{N_TASKS-1}.pt'
+ if joint == '0':
+ if savecheck == 'last':
+ ckpt_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip() + f'_last.pt'
+ elif savecheck == 'task':
+ ckpt_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip() + f'_{N_TASKS-1}.pt'
+ elif joint == '1':
+ ckpt_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip() + f'_joint.pt'
+
ckpt_path = os.path.join('checkpoints', ckpt_name)
assert os.path.exists(ckpt_path), f'Checkpoint file {ckpt_path} not found'
@@ -60,6 +73,8 @@ def test_checkpointing_bufferfree(model):
'1e-4',
'--n_epochs',
'1',
+ '--joint',
+ joint,
'--loadcheck',
ckpt_path,
'--batch_size',
@@ -76,17 +91,26 @@ def test_checkpointing_bufferfree(model):
# log all outputs to file
if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')):
os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs'))
- sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_load.{model}.log'), 'w', encoding='utf-8')
+ sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_load.{model}.{savecheck}.{"joint" if joint=="1" else "cl"}.log'), 'w', encoding='utf-8')
sys.stderr = sys.stdout
main()
# REMOVE CHECKPOINT FILE
- for i in range(N_TASKS):
- c_path = ckpt_path.split(f'_{N_TASKS-1}.pt')[0] + f'_{i}.pt'
- os.remove(c_path)
-
-
-def test_checkpointing_replay():
+ if joint == '0':
+ if savecheck == 'task':
+ for i in range(N_TASKS):
+ c_path = ckpt_path.split(f'_{N_TASKS-1}.pt')[0] + f'_{i}.pt'
+ os.remove(c_path)
+ elif savecheck == 'last':
+ os.remove(ckpt_path)
+ elif joint == '1':
+ os.remove(ckpt_path)
+
+
+@init_test_environ
+@pytest.mark.parametrize('savecheck', ['last', 'task'])
+@pytest.mark.parametrize('joint', ['0', '1'])
+def test_checkpointing_replay(savecheck, joint):
N_TASKS = 5 # cifar10
# TEST CHECKPOINT SAVE
@@ -101,12 +125,14 @@ def test_checkpointing_replay():
'0.1',
'--lr',
'1e-4',
+ '--joint',
+ joint,
'--n_epochs',
'1',
'--buffer_size',
'50',
'--savecheck',
- '1',
+ savecheck,
'--batch_size',
'4',
'--non_verbose',
@@ -121,17 +147,24 @@ def test_checkpointing_replay():
# log all outputs to file
if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')):
os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs'))
- sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_save.derpp.log'), 'w', encoding='utf-8')
+ sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_save.derpp.{savecheck}.{"joint" if joint=="1" else "cl"}.log'), 'w', encoding='utf-8')
sys.stderr = sys.stdout
main()
# read output file and search for the string 'Saving checkpoint into'
- with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_save.derpp.log'), 'r', encoding='utf-8') as f:
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_save.derpp.{savecheck}.{"joint" if joint=="1" else "cl"}.log'), 'r', encoding='utf-8') as f:
lines = f.readlines()
ckpt_name = [line for line in lines if 'Saving checkpoint into' in line]
assert any(ckpt_name), f'Checkpoint not saved for derpp'
- ckpt_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip() + f'_{N_TASKS-1}.pt'
+ if joint == '0':
+ if savecheck == 'last':
+ ckpt_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip() + f'_last.pt'
+ elif savecheck == 'task':
+ ckpt_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip() + f'_{N_TASKS-1}.pt'
+ elif joint == '1':
+ ckpt_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip() + f'_joint.pt'
+
ckpt_path = os.path.join('checkpoints', ckpt_name)
assert os.path.exists(ckpt_path), f'Checkpoint file {ckpt_path} not found'
@@ -150,6 +183,8 @@ def test_checkpointing_replay():
'1e-4',
'--n_epochs',
'1',
+ '--joint',
+ joint,
'--buffer_size',
'50',
'--loadcheck',
@@ -168,11 +203,17 @@ def test_checkpointing_replay():
# log all outputs to file
if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')):
os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs'))
- sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_load.derpp.log'), 'w', encoding='utf-8')
+ sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_load.derpp.{savecheck}.{"joint" if joint=="1" else "cl"}.log'), 'w', encoding='utf-8')
sys.stderr = sys.stdout
main()
# REMOVE CHECKPOINT FILE
- for i in range(N_TASKS):
- c_path = ckpt_path.split(f'_{N_TASKS-1}.pt')[0] + f'_{i}.pt'
- os.remove(c_path)
+ if joint == '0':
+ if savecheck == 'task':
+ for i in range(N_TASKS):
+ c_path = ckpt_path.split(f'_{N_TASKS-1}.pt')[0] + f'_{i}.pt'
+ os.remove(c_path)
+ elif savecheck == 'last':
+ os.remove(ckpt_path)
+ elif joint == '1':
+ os.remove(ckpt_path)
diff --git a/tests/test_coda.py b/tests/test_coda.py
index 1066e869..7789ce93 100644
--- a/tests/test_coda.py
+++ b/tests/test_coda.py
@@ -1,10 +1,12 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-from utils.main import main, parse_args
+from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-cifar100-224', 'seq-imagenet-r'])
def test_coda(dataset):
sys.argv = ['mammoth',
diff --git a/tests/test_codaprompt.py b/tests/test_codaprompt.py
index fef858e4..fcc8b4a5 100644
--- a/tests/test_codaprompt.py
+++ b/tests/test_codaprompt.py
@@ -1,10 +1,12 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-from utils.main import main, parse_args
+from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-cifar10-224', 'seq-imagenet-r'])
@pytest.mark.parametrize('code_optimization', [0, 1])
def test_codaprompt(dataset, code_optimization):
diff --git a/tests/test_code_optimization.py b/tests/test_code_optimization.py
index eb03a552..42a0ba87 100644
--- a/tests/test_code_optimization.py
+++ b/tests/test_code_optimization.py
@@ -2,9 +2,11 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('code_optimization', [0, 1, 2, 3])
def test_code_optim_erace(code_optimization):
sys.argv = ['mammoth',
@@ -39,6 +41,7 @@ def test_code_optim_erace(code_optimization):
main()
+@init_test_environ
@pytest.mark.parametrize('code_optimization', [0, 1, 2, 3])
def test_code_optimization_slca(code_optimization):
sys.argv = ['mammoth',
diff --git a/tests/test_cssl_support.py b/tests/test_cssl_support.py
index 47ea91b6..b0a6eb9e 100644
--- a/tests/test_cssl_support.py
+++ b/tests/test_cssl_support.py
@@ -1,10 +1,12 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-from utils.main import main, parse_args
+from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-cifar10', 'seq-tinyimg'])
@pytest.mark.parametrize('label_perc', ['0.1', '0.08', '0.5', '1'])
def test_cssl_support(dataset, label_perc):
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index adc6bfb1..fa2db994 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -1,14 +1,18 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-from utils.main import main, parse_args
+from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-mnist', 'seq-cifar10', 'seq-cifar100', 'seq-tinyimg',
'rot-mnist', 'perm-mnist', 'mnist-360', 'seq-cifar100-224',
- 'seq-cifar10-224', 'seq-cifar100-224-rs',
- 'seq-cifar100-224-rs', 'seq-tinyimg-r', 'seq-cub200', 'seq-imagenet-r'])
+ 'seq-cifar10-224', 'seq-cifar100-224-rs', 'seq-cub200-rs',
+ 'seq-cifar100-224-rs', 'seq-tinyimg-r', 'seq-cub200', 'seq-imagenet-r',
+ 'seq-cars196', 'seq-chestx', 'seq-cropdisease', 'seq-eurosat-rgb',
+ 'seq-isic', 'seq-mit67', 'seq-resisc45'])
def test_datasets(dataset):
sys.argv = ['mammoth',
'--model',
@@ -31,7 +35,10 @@ def test_datasets(dataset):
'1']
# clean all downloaded datasets
- dataset_paths = ['CUB200', 'CIFAR10', 'CIFAR100', 'MNIST', 'TINYIMG', 'imagenet-r']
+ dataset_paths = ['CUB200', 'CIFAR10', 'CIFAR100', 'MNIST',
+ 'TINYIMG', 'imagenet-r', 'cars196', 'chestx',
+ 'cropdisease', 'eurosat', 'isic', 'MIT67',
+ 'NWPU-RESISC45']
basepath = os.path.dirname(os.path.abspath(__file__))
dt_dir = os.path.join(os.path.dirname(basepath), 'data')
for path in dataset_paths:
@@ -47,6 +54,7 @@ def test_datasets(dataset):
main()
+@init_test_environ
def test_dataset_workers():
sys.argv = ['mammoth',
'--model',
diff --git a/tests/test_der_example.py b/tests/test_der_example.py
index 048b3f0c..bbd60cb5 100644
--- a/tests/test_der_example.py
+++ b/tests/test_der_example.py
@@ -1,11 +1,13 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-from utils.main import main, parse_args
+from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
@pytest.mark.parametrize('dataset', ['seq-mnist', 'seq-cifar10', 'rot-mnist', 'perm-mnist', 'mnist-360', 'seq-cifar100-224'])
+@init_test_environ
def test_der(dataset):
sys.argv = ['mammoth',
'--model',
@@ -38,6 +40,7 @@ def test_der(dataset):
main()
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-mnist', 'seq-cifar10', 'rot-mnist', 'perm-mnist', 'mnist-360', 'seq-cifar100-224'])
def test_derpp(dataset):
sys.argv = ['mammoth',
diff --git a/tests/test_dualprompt.py b/tests/test_dualprompt.py
index a9963c70..e045d04f 100644
--- a/tests/test_dualprompt.py
+++ b/tests/test_dualprompt.py
@@ -1,10 +1,12 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-from utils.main import main, parse_args
+from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
def test_dualprompt():
sys.argv = ['mammoth',
'--model',
diff --git a/tests/test_er_example.py b/tests/test_er_example.py
index c09a8ab1..62db0a2a 100644
--- a/tests/test_er_example.py
+++ b/tests/test_er_example.py
@@ -1,10 +1,12 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-from utils.main import main, parse_args
+from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-mnist', 'seq-cifar10', 'rot-mnist', 'perm-mnist', 'mnist-360'])
def test_er(dataset):
sys.argv = ['mammoth',
diff --git a/tests/test_fdr.py b/tests/test_fdr.py
index 5950d36f..14a1bba6 100644
--- a/tests/test_fdr.py
+++ b/tests/test_fdr.py
@@ -2,9 +2,11 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
def test_fdr():
sys.argv = ['mammoth',
'--model',
diff --git a/tests/test_gdumb.py b/tests/test_gdumb.py
index fabbf495..0e6afabf 100644
--- a/tests/test_gdumb.py
+++ b/tests/test_gdumb.py
@@ -2,9 +2,11 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
def test_gdumb_cutmix():
sys.argv = ['mammoth',
'--model',
@@ -40,6 +42,7 @@ def test_gdumb_cutmix():
main()
+@init_test_environ
def test_gdumb():
sys.argv = ['mammoth',
'--model',
diff --git a/tests/test_gem.py b/tests/test_gem.py
index 78a7d076..f5ac47c3 100644
--- a/tests/test_gem.py
+++ b/tests/test_gem.py
@@ -2,6 +2,7 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
@@ -9,6 +10,7 @@ def unsupport_quadprog():
return os.name == 'nt'
+@init_test_environ
@pytest.mark.skipif(unsupport_quadprog(), reason='`quadprog` not supported on Windows. Good luck.')
@pytest.mark.parametrize('dataset', ['seq-cifar10', 'seq-mnist'])
@pytest.mark.parametrize('model', ['gem', 'agem', 'agem_r'])
diff --git a/tests/test_hal.py b/tests/test_hal.py
index c94185cc..ddc21d87 100644
--- a/tests/test_hal.py
+++ b/tests/test_hal.py
@@ -2,9 +2,11 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-cifar10', 'seq-mnist'])
def test_hal(dataset):
sys.argv = ['mammoth',
diff --git a/tests/test_icarl.py b/tests/test_icarl.py
index 051f3528..3f22dd54 100644
--- a/tests/test_icarl.py
+++ b/tests/test_icarl.py
@@ -2,9 +2,11 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-cifar10', 'seq-mnist'])
def test_icarl(dataset):
sys.argv = ['mammoth',
diff --git a/tests/test_l2p.py b/tests/test_l2p.py
index 706081aa..d15ca098 100644
--- a/tests/test_l2p.py
+++ b/tests/test_l2p.py
@@ -1,10 +1,12 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-from utils.main import main, parse_args
+from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-cifar100-224', 'seq-imagenet-r'])
def test_l2p(dataset):
sys.argv = ['mammoth',
diff --git a/tests/test_lider.py b/tests/test_lider.py
index 2e590d97..85f6fa64 100644
--- a/tests/test_lider.py
+++ b/tests/test_lider.py
@@ -2,9 +2,11 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
def test_gdumb_lider():
sys.argv = ['mammoth',
'--model',
@@ -40,6 +42,7 @@ def test_gdumb_lider():
main()
+@init_test_environ
def test_icarl_lider():
sys.argv = ['mammoth',
'--model',
@@ -75,6 +78,7 @@ def test_icarl_lider():
main()
+@init_test_environ
def test_erace_lider():
sys.argv = ['mammoth',
'--model',
@@ -110,6 +114,7 @@ def test_erace_lider():
main()
+@init_test_environ
def test_derpp_lider():
sys.argv = ['mammoth',
'--model',
diff --git a/tests/test_lucir.py b/tests/test_lucir.py
index d524ec49..791bf591 100644
--- a/tests/test_lucir.py
+++ b/tests/test_lucir.py
@@ -2,9 +2,11 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-cifar10', 'seq-mnist'])
@pytest.mark.parametrize('imprint_weights', [0, 1])
def test_lucir(dataset, imprint_weights):
diff --git a/tests/test_pnn.py b/tests/test_pnn.py
index e8d050b3..0a6b8b4d 100644
--- a/tests/test_pnn.py
+++ b/tests/test_pnn.py
@@ -2,9 +2,11 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-cifar10', 'seq-mnist'])
def test_pnn(dataset):
sys.argv = ['mammoth',
diff --git a/tests/test_regularization.py b/tests/test_regularization.py
index df2b5edd..7fe28836 100644
--- a/tests/test_regularization.py
+++ b/tests/test_regularization.py
@@ -2,9 +2,11 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-cifar10', 'seq-mnist'])
@pytest.mark.parametrize('model', ['ewc_on'])
def test_ewc(dataset, model):
@@ -40,6 +42,7 @@ def test_ewc(dataset, model):
main()
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-cifar10', 'seq-mnist'])
@pytest.mark.parametrize('model', ['si'])
def test_si(dataset, model):
@@ -75,6 +78,7 @@ def test_si(dataset, model):
main()
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-cifar10', 'seq-mnist'])
@pytest.mark.parametrize('model', ['lwf_mc', 'lwf'])
def test_lwf(dataset, model):
diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py
index 80af6160..8d25a266 100644
--- a/tests/test_scheduler.py
+++ b/tests/test_scheduler.py
@@ -29,10 +29,10 @@ def test_der_cifar100_defaultscheduler():
'--debug_mode',
'1',
'--savecheck',
- '1',
+ 'task',
'--seed',
'0']
-
+
log_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_der_cifar100_defaultscheduler.log')
# log all outputs to file
if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')):
@@ -50,17 +50,16 @@ def test_der_cifar100_defaultscheduler():
ckpt_base_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip()
ckpt_paths = [os.path.join('checkpoints', ckpt_base_name + f'_{i}.pt') for i in range(N_TASKS)]
-
for ckpt_path in ckpt_paths:
assert os.path.exists(ckpt_path), f'Checkpoint file {ckpt_path} not found'
ckpt = torch.load(ckpt_path)
opt, sched = ckpt['optimizer']['param_groups'][0], ckpt['scheduler']
assert opt['initial_lr'] == 0.03, f'Learning rate not updated correctly in {ckpt_path}'
- assert opt['lr']==opt['initial_lr']*0.1*0.1, f'Learning rate not updated correctly in {ckpt_path}'
+ assert opt['lr'] == opt['initial_lr'] * 0.1 * 0.1, f'Learning rate not updated correctly in {ckpt_path}'
assert list(sched['milestones'].keys()) == [35, 45], f'Milestones not updated correctly in {ckpt_path}'
- assert sched['base_lrs']==[0.03], f'Base learning rate not updated correctly in {ckpt_path}'
-
+ assert sched['base_lrs'] == [0.03], f'Base learning rate not updated correctly in {ckpt_path}'
+
def test_der_cifar100_customscheduler():
N_TASKS = 10
@@ -84,14 +83,14 @@ def test_der_cifar100_customscheduler():
'--debug_mode',
'1',
'--savecheck',
- '1',
+ 'task',
'--lr_scheduler',
'multisteplr',
'--lr_milestones',
- '2','4','6','8',
+ '2', '4', '6', '8',
'--seed',
'0']
-
+
log_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_der_cifar100_customscheduler.der.cifar100.log')
# log all outputs to file
if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')):
@@ -109,13 +108,12 @@ def test_der_cifar100_customscheduler():
ckpt_base_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip()
ckpt_paths = [os.path.join('checkpoints', ckpt_base_name + f'_{i}.pt') for i in range(N_TASKS)]
-
for ckpt_path in ckpt_paths:
assert os.path.exists(ckpt_path), f'Checkpoint file {ckpt_path} not found'
-
+
ckpt = torch.load(ckpt_path)
opt, sched = ckpt['optimizer']['param_groups'][0], ckpt['scheduler']
assert opt['initial_lr'] == 0.1, f'Learning rate not updated correctly in {ckpt_path}'
- assert opt['lr']==opt['initial_lr']*0.1*0.1*0.1*0.1, f'Learning rate not updated correctly in {ckpt_path}'
- assert list(sched['milestones'].keys()) == [2,4,6,8], f'Milestones not updated correctly in {ckpt_path}'
- assert sched['base_lrs']==[0.1], f'Base learning rate not updated correctly in {ckpt_path}'
\ No newline at end of file
+ assert opt['lr'] == opt['initial_lr'] * 0.1 * 0.1 * 0.1 * 0.1, f'Learning rate not updated correctly in {ckpt_path}'
+ assert list(sched['milestones'].keys()) == [2, 4, 6, 8], f'Milestones not updated correctly in {ckpt_path}'
+ assert sched['base_lrs'] == [0.1], f'Base learning rate not updated correctly in {ckpt_path}'
diff --git a/tests/test_slca.py b/tests/test_slca.py
index 456985dc..130f245e 100644
--- a/tests/test_slca.py
+++ b/tests/test_slca.py
@@ -2,9 +2,11 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
def test_slca():
sys.argv = ['mammoth',
'--model',
diff --git a/tests/test_starprompt.py b/tests/test_starprompt.py
new file mode 100644
index 00000000..de7e53cd
--- /dev/null
+++ b/tests/test_starprompt.py
@@ -0,0 +1,119 @@
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from utils.main import main
+from utils.test_utils import init_test_environ
+import pytest
+
+
+@init_test_environ
+def test_first_and_second_stage():
+ sys.argv = ['mammoth',
+ '--model',
+ 'first_stage_starprompt',
+ '--dataset',
+ 'seq-cifar10-224',
+ '--lr',
+ '0.002',
+ '--n_epochs',
+ '1',
+ '--batch_size',
+ '2',
+ '--non_verbose',
+ '1',
+ '--num_workers',
+ '0',
+ '--seed',
+ '1993',
+ '--debug_mode',
+ '1']
+ fn = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_first_stage_starprompt.log')
+
+ # log all outputs to file
+ if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')):
+ os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs'))
+ sys.stdout = open(fn, 'w', encoding='utf-8')
+ sys.stderr = sys.stdout
+
+ main()
+
+ # read output file and search for the string 'Saved text-encoder keys in:'
+ with open(fn, 'r', encoding='utf-8') as f:
+ lines = f.readlines()
+ ckpt_name = [line for line in lines if 'Saved text-encoder keys in:' in line]
+ assert any(ckpt_name), f'Keys not found in {fn}'
+
+ ckpt_path = ckpt_name[0].split('Saved text-encoder keys in:')[1].strip()
+
+ assert os.path.exists(ckpt_path), f'Checkpoint file {ckpt_path} not found'
+
+ # TEST CHECKPOINT LOAD
+ sys.argv = ['mammoth',
+ '--model',
+ 'second_stage_starprompt',
+ '--dataset',
+ 'seq-cifar10-224',
+ '--lr',
+ '1e-4',
+ '--optimizer',
+ 'adam',
+ '--n_epochs',
+ '1',
+ '--batch_size',
+ '4',
+ '--non_verbose',
+ '1',
+ '--num_workers',
+ '0',
+ '--seed',
+ '1993',
+ '--keys_ckpt_path',
+ ckpt_path,
+ '--debug_mode',
+ '1']
+
+ fn = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_second_stage_starprompt.log')
+
+ # log all outputs to file
+ if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')):
+ os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs'))
+ sys.stdout = open(fn, 'w', encoding='utf-8')
+ sys.stderr = sys.stdout
+ main()
+
+ # REMOVE CHECKPOINT FILE
+ os.remove(ckpt_path)
+
+
+@init_test_environ
+def test_full_starprompt():
+ sys.argv = ['mammoth',
+ '--model',
+ 'starprompt',
+ '--dataset',
+ 'seq-cifar10-224',
+ '--lr',
+ '1e-4',
+ '--optimizer',
+ 'adam',
+ '--n_epochs',
+ '1',
+ '--batch_size',
+ '4',
+ '--non_verbose',
+ '1',
+ '--num_workers',
+ '0',
+ '--seed',
+ '1993',
+ '--debug_mode',
+ '1']
+
+ fn = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_full_starprompt.log')
+
+ # log all outputs to file
+ if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')):
+ os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs'))
+ sys.stdout = open(fn, 'w', encoding='utf-8')
+ sys.stderr = sys.stdout
+ main()
diff --git a/tests/test_twf.py b/tests/test_twf.py
index 04df890c..c413d07c 100644
--- a/tests/test_twf.py
+++ b/tests/test_twf.py
@@ -2,9 +2,11 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('dataset', ['seq-cifar100', 'seq-tinyimg']) # , 'seq-cub200'
@pytest.mark.parametrize('resize_maps', ['0', '1'])
def test_twf_random_init(dataset, resize_maps):
@@ -54,9 +56,10 @@ def test_twf_random_init(dataset, resize_maps):
main()
+@init_test_environ
@pytest.mark.parametrize(('dataset', 'loadcheck'),
[('seq-cifar100', 'https://unimore365-my.sharepoint.com/:u:/g/personal/215580_unimore_it/EeWEOSls505AsMCTXAxWoLUBmeIjCiplFl40zDOCmB_lEw?e=Izv0jh'),
- ('seq-cub200', 'https://unimore365-my.sharepoint.com/:u:/g/personal/215580_unimore_it/EV7I5BpJvURIhMMk95r3x5YBAZKch-NPFEJ9hhPQghcWCw?e=dt8wp3'),
+ ('seq-cub200-rs', 'https://unimore365-my.sharepoint.com/:u:/g/personal/215580_unimore_it/EV7I5BpJvURIhMMk95r3x5YBAZKch-NPFEJ9hhPQghcWCw?e=dt8wp3'),
('seq-cifar10', 'https://unimore365-my.sharepoint.com/:u:/g/personal/215580_unimore_it/EWttSkmKfkNEpEWNiPoS3zUB6uzZydc0irOW0Xbu3jtr3Q?e=JQ6Fay')])
@pytest.mark.parametrize('resize_maps', ['0', '1'])
def test_twf_with_checkpoint(dataset, loadcheck, resize_maps):
diff --git a/tests/test_validation.py b/tests/test_validation.py
index 3e2b055c..87a2ba74 100644
--- a/tests/test_validation.py
+++ b/tests/test_validation.py
@@ -1,12 +1,15 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-from utils.main import main, parse_args
+from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
-@pytest.mark.parametrize('validation', ['0.2','0','20'])
-@pytest.mark.parametrize('validation_mode', ['complete','current'])
-def test_validation_classil( validation, validation_mode):
+
+@init_test_environ
+@pytest.mark.parametrize('validation', ['0.2', '0', '20'])
+@pytest.mark.parametrize('validation_mode', ['complete', 'current'])
+def test_validation_classil(validation, validation_mode):
sys.argv = ['mammoth',
'--model',
'sgd',
@@ -40,8 +43,9 @@ def test_validation_classil( validation, validation_mode):
main()
-@pytest.mark.parametrize('dataset', ['mnist-360','perm-mnist'])
-@pytest.mark.parametrize('validation', ['0.2','0','20'])
+@init_test_environ
+@pytest.mark.parametrize('dataset', ['mnist-360', 'perm-mnist'])
+@pytest.mark.parametrize('validation', ['0.2', '0', '20'])
@pytest.mark.parametrize('validation_mode', ['complete'])
def test_validation_domainil(dataset, validation, validation_mode):
sys.argv = ['mammoth',
diff --git a/tests/test_wandb.py b/tests/test_wandb.py
new file mode 100644
index 00000000..fdc2a6dd
--- /dev/null
+++ b/tests/test_wandb.py
@@ -0,0 +1,46 @@
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from utils.main import main
+from utils.test_utils import init_test_environ
+import pytest
+
+
+@init_test_environ
+def test_wandb_log_erace():
+ sys.argv = ['mammoth',
+ '--model',
+ 'er-ace',
+ '--buffer_size',
+ '50',
+ '--dataset',
+ 'seq-cifar10',
+ '--lr',
+ '1e-3',
+ '--n_epochs',
+ '1',
+ '--batch_size',
+ '4',
+ '--non_verbose',
+ '1',
+ '--num_workers',
+ '0',
+ '--seed',
+ '0',
+ '--debug_mode',
+ '1',
+ '--wandb_project',
+ 'mammoth-test',
+ '--wandb_entity',
+ 'mammoth-test']
+
+ os.environ['WANDB_MODE'] = 'disabled'
+
+ log_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_wandb_log_erace.log')
+
+ # log all outputs to file
+ if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')):
+ os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs'))
+ sys.stdout = open(log_path, 'w', encoding='utf-8')
+ sys.stderr = sys.stdout
+ main()
diff --git a/tests/test_xder.py b/tests/test_xder.py
index f01d06bb..bc120e42 100644
--- a/tests/test_xder.py
+++ b/tests/test_xder.py
@@ -2,9 +2,11 @@
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main
+from utils.test_utils import init_test_environ
import pytest
+@init_test_environ
@pytest.mark.parametrize('model', ['xder', 'xder_rpc', 'xder_ce'])
def test_xder(model):
sys.argv = ['mammoth',
diff --git a/utils/args.py b/utils/args.py
index c102f8eb..ce757da6 100644
--- a/utils/args.py
+++ b/utils/args.py
@@ -2,7 +2,6 @@
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-
if __name__ == '__main__':
import os
import sys
@@ -36,7 +35,8 @@ def add_experiment_args(parser: ArgumentParser) -> None:
exp_group.add_argument('--lr', type=float, required=True, help='Learning rate.')
exp_group.add_argument('--batch_size', type=int, help='Batch size.')
exp_group.add_argument('--label_perc', type=float, default=1, help='Percentage in (0-1] of labeled examples per task.')
- exp_group.add_argument('--joint', type=int, choices=[0, 1], default=0, help='Train model on Joint (single task)?')
+ exp_group.add_argument('--joint', type=int, choices=(0, 1), default=0, help='Train model on Joint (single task)?')
+ exp_group.add_argument('--eval_future', type=int, choices=(0, 1), default=0, help='Evaluate future tasks?')
validation_group = parser.add_argument_group('Validation and fitting arguments', 'Arguments used to define the validation strategy and the method used to fit the model.')
@@ -96,10 +96,16 @@ def add_management_args(parser: ArgumentParser) -> None:
mng_group.add_argument('--seed', type=int, default=None,
help='The random seed. If not provided, a random seed will be used.')
- mng_group.add_argument('--permute_classes', type=int, choices=[0, 1], default=0,
+ mng_group.add_argument('--permute_classes', type=int, choices=[0, 1], default=1,
help='Permute classes before splitting into tasks? This applies the seed before permuting if the `seed` argument is present.')
mng_group.add_argument('--base_path', type=str, default="./data/",
help='The base path where to save datasets, logs, results.')
+ mng_group.add_argument('--device', type=str,
+ help='The device (or devices) available to use for training. '
+ 'More than one device can be specified by separating them with a comma. '
+ 'If not provided, the code will use the least used GPU available (if there are any), otherwise the CPU. '
+ 'MPS is supported and is automatically used if no GPU is available and MPS is supported. '
+ 'If more than one GPU is available, Mammoth will use the least used one if `--distributed=no`.')
mng_group.add_argument('--notes', type=str, default=None,
help='Helper argument to include notes for this run. Example: distinguish between different versions of a model and allow separation of results')
mng_group.add_argument('--eval_epochs', type=int, default=None,
@@ -119,7 +125,7 @@ def add_management_args(parser: ArgumentParser) -> None:
'2: Use BF16, if available.'
'3: Use BF16 and `torch.compile`. BEWARE: torch.compile may break your code if you change the model after the first run! Use with caution.')
mng_group.add_argument('--distributed', type=str, default='no', choices=['no', 'dp', 'ddp'], help='Enable distributed training?')
- mng_group.add_argument('--savecheck', default=0, choices=[0, 1], type=int, help='Save checkpoint?')
+ mng_group.add_argument('--savecheck', choices=['last', 'task'], type=str, help='Save checkpoint every `task` or at the end of the training (`last`).')
mng_group.add_argument('--loadcheck', type=str, default=None, help='Path of the checkpoint to load (.pt file for the specific task)')
mng_group.add_argument('--ckpt_name', type=str, required=False, help='(optional) checkpoint save name.')
mng_group.add_argument('--start_from', type=int, default=None, help="Task to start from")
@@ -130,7 +136,7 @@ def add_management_args(parser: ArgumentParser) -> None:
wandb_group.add_argument('--wandb_name', type=str, default=None,
help='Wandb name for this run. Overrides the default name (`args.model`).')
wandb_group.add_argument('--wandb_entity', type=str, help='Wandb entity')
- wandb_group.add_argument('--wandb_project', type=str, default='mammoth', help='Wandb project name')
+ wandb_group.add_argument('--wandb_project', type=str, help='Wandb project name')
def add_rehearsal_args(parser: ArgumentParser) -> None:
@@ -169,11 +175,12 @@ def parse_choices(self) -> str:
return ', '.join([c.keys() if isinstance(c, dict) else str(c) for c in self.choices])
def __str__(self):
- tb = '\t'
- return f"""**\\-\\-{self.name}** : {self.type}
- *Help*: {self.help}\n
- - Default: {self.default}\n
- - Choices: {self.parse_choices() if self.choices is not None else ''}"""
+ tb = f"""**\\-\\-{self.name}** : {self.type.__name__ if self.type is not None else 'unknown'}
+\t*Help*: {self.help}\n
+\t- *Default*: ``{self.default}``"""
+ if self.choices is not None:
+ tb += f"\n\t- *Choices*: ``{self.parse_choices()}``"
+ return tb
class _DocArgsGroup:
@@ -188,7 +195,11 @@ def __init__(self, group_name: str, group_desc: str, doc_args: _DocsArgs):
def __str__(self):
args_str = '\n'.join([arg.__str__() for arg in self.doc_args])
- return f""".. rubric:: {self.group_name.capitalize()}\n\n*{self.group_desc}*\n\n{args_str}"""
+ s = f""".. rubric:: {self.group_name.capitalize()}\n\n"""
+ if self.group_desc:
+ s += f"*{self.group_desc}*\n\n"
+ s += args_str
+ return s
def _parse_actions(actions: list, group_name: str, group_desc: str) -> _DocArgsGroup:
@@ -218,7 +229,9 @@ def _parse_actions(actions: list, group_name: str, group_desc: str) -> _DocArgsG
add_experiment_args(parser)
docs_args = []
- for group in parser._action_groups[2:]: # first two groups are the positional and optional arguments
+ for group in parser._action_groups:
+ if len([a for a in group._group_actions if a.dest != 'help']) == 0:
+ continue
docs_args.append(_parse_actions(group._group_actions, group.title, group.description))
with open('docs/utils/args.rst', 'w') as f:
@@ -232,7 +245,9 @@ def _parse_actions(actions: list, group_name: str, group_desc: str) -> _DocArgsG
parser = ArgumentParser()
add_management_args(parser)
docs_args = []
- for group in parser._action_groups[2:]: # first two groups are the positional and optional arguments
+ for group in parser._action_groups:
+ if len([a for a in group._group_actions if a.dest != 'help']) == 0:
+ continue
docs_args.append(_parse_actions(group._group_actions, group.title, group.description))
with open('docs/utils/args.rst', 'a') as f:
@@ -255,3 +270,21 @@ def _parse_actions(actions: list, group_name: str, group_desc: str) -> _DocArgsG
print("Saving documentation in docs/utils/args.rst")
print("Done!")
+
+ from models import get_model_names
+
+ for model_name, model_class in get_model_names().items():
+ parser = model_class.get_parser()
+
+ model_args_groups = []
+ for group in parser._action_groups:
+ if len([a for a in group._group_actions if a.dest != 'help']) == 0:
+ continue
+ model_args_groups.append(_parse_actions(group._group_actions, group.title, group.description))
+ model_filename = model_name.replace("-", "_")
+ with open(f'docs/models/{model_filename}_args.rst', 'w') as f:
+ f.write(f'Arguments\n')
+ f.write(f'~~~~~~~~~~~\n\n')
+ for arg in model_args_groups:
+ f.write(str(arg) + '\n\n')
+ print(f"Saving documentation in docs/models/{model_filename}_args.rst")
diff --git a/utils/augmentations.py b/utils/augmentations.py
index efa0c763..0cd7231c 100644
--- a/utils/augmentations.py
+++ b/utils/augmentations.py
@@ -35,6 +35,8 @@ def apply_transform(x: torch.Tensor, transform) -> torch.Tensor:
x = torch.as_tensor(np.array(x, copy=True)).permute((2, 0, 1))
return transform(x)
else:
+ if isinstance(x, PIL.Image.Image):
+ return transform(x)
return torch.stack([transform(xi) for xi in x.cpu()], dim=0).to(x.device)
@@ -236,6 +238,24 @@ def __call__(self, x):
)), self.mean, self.std)
+class RepeatedTransform(object):
+ """
+ This class applies a series of transforms to the same input.
+
+ Args:
+ transform_list: The list of transformations to be applied.
+ """
+
+ def __init__(self, transform_list: list):
+ self.transform_list = transform_list
+
+ assert len(self.transform_list) > 0, 'The list of transformations must not be empty.'
+
+ @torch.no_grad()
+ def __call__(self, input):
+ return torch.stack([apply_transform(input, t) for t in self.transform_list])
+
+
class DoubleTransform(object):
"""
This class applies a given transformation to the first image and leaves the second input unchanged.
diff --git a/utils/buffer.py b/utils/buffer.py
index 211691bc..8e1f1b7c 100644
--- a/utils/buffer.py
+++ b/utils/buffer.py
@@ -13,7 +13,7 @@
from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
from utils.augmentations import apply_transform
-from utils.conf import get_device
+from utils.conf import create_seeded_dataloader, get_device
def icarl_replay(self: ContinualModel, dataset, val_set_split=0):
@@ -67,8 +67,7 @@ def refold_transform(x): return (x.cpu() * 255).squeeze(1).type(torch.uint8)
refold_transform((self.buffer.examples)[:len(self.buffer)][buff_val_mask])
])
- self.val_loader = torch.utils.data.DataLoader(self.val_dataset, batch_size=self.args.batch_size, shuffle=True,
- num_workers=self.args.num_workers)
+ self.val_loader = create_seeded_dataloader(self.args, self.val_dataset, batch_size=self.args.batch_size, shuffle=True)
def reservoir(num_seen_examples: int, buffer_size: int) -> int:
diff --git a/utils/checkpoints.py b/utils/checkpoints.py
index c4b8ce07..6bb1f9fb 100644
--- a/utils/checkpoints.py
+++ b/utils/checkpoints.py
@@ -1,6 +1,7 @@
import random
import string
+import numpy as np
import torch
from torch import distributed as dist
import os
@@ -170,10 +171,20 @@ def mammoth_load_checkpoint(args, model: torch.nn.Module, ignore_classifier=Fals
def _check_loaded_args(args, loaded_args):
+ def _check_arg(arg, loaded_arg):
+ if isinstance(arg, (list, tuple)):
+ return any([a != la for a, la in zip(arg, loaded_arg)])
+ elif isinstance(arg, dict):
+ return any([k not in loaded_arg or _check_arg(v, loaded_arg[k]) for k, v in arg.items()])
+ elif isinstance(arg, (torch.Tensor, np.ndarray)):
+ return (arg != loaded_arg).any()
+ return arg != loaded_arg
+
ignored_args = ['loadcheck', 'start_from', 'stop_after', 'conf_jobnum', 'conf_host', 'conf_timestamp', 'distributed', 'examples_log', 'examples_full_log',
- 'intensive_savecheck', 'job_number', 'conf_git_commit', 'loss_log', 'tensorboard', 'seed', 'savecheck', 'notes', 'non_verbose', 'autorelaunch', 'force_compat', 'conf_external_path']
+ 'intensive_savecheck', 'job_number', 'conf_git_commit', 'loss_log', 'tensorboard', 'seed', 'savecheck', 'notes', 'non_verbose', 'autorelaunch',
+ 'force_compat', 'conf_external_path', 'ckpt_name']
mismatched_args = [x for x in vars(args) if x not in ignored_args and (
- x not in vars(loaded_args) or getattr(args, x) != getattr(loaded_args, x))]
+ x not in vars(loaded_args) or _check_arg(getattr(args, x), getattr(loaded_args, x)))]
if len(mismatched_args):
if 'force_compat' not in vars(args) or args.force_compat:
diff --git a/utils/conf.py b/utils/conf.py
index 0f2fb681..45229084 100644
--- a/utils/conf.py
+++ b/utils/conf.py
@@ -7,11 +7,16 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+import logging
import os
-import sys
import random
-import torch
+import sys
+from functools import partial
+
+from typing import List
import numpy as np
+import torch
+from torch.utils.data import DataLoader
def warn_once(*msg):
@@ -26,49 +31,84 @@ def warn_once(*msg):
warn_once.warned = set()
if msg not in warn_once.warned:
warn_once.warned.add(msg)
- print(msg, file=sys.stderr)
+ logging.warning(msg)
+
+
+def _get_gpu_memory_pynvml_all_processes(device_id: int = 0) -> int:
+ """
+ Use pynvml to get the memory allocated on the GPU.
+ Returns the memory allocated on the GPU in Bytes.
+ """
+ if not hasattr(_get_gpu_memory_pynvml_all_processes, f'handle_{device_id}'):
+ torch.cuda.pynvml.nvmlInit() # only once
+ handle = torch.cuda.pynvml.nvmlDeviceGetHandleByIndex(device_id)
+ setattr(_get_gpu_memory_pynvml_all_processes, f'handle_{device_id}', handle)
+ handle = getattr(_get_gpu_memory_pynvml_all_processes, f'handle_{device_id}')
-def get_alloc_memory_all_devices() -> list[int]:
+ procs = torch.cuda.pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
+ return sum([proc.usedGpuMemory for proc in procs])
+
+
+def get_alloc_memory_all_devices(return_all=False) -> list[int]:
"""
Returns the memory allocated on all the available devices.
+ By default, tries to return the memory read from pynvml, if available.
+ Else, it returns the memory `reserved` by torch.
+
+ If `return_all` is set to True, it returns a tuple with the memory reserved, allocated and from pynvml.
+
+ Values are in Bytes.
"""
- gpu_memory = []
+ gpu_memory_reserved = []
+ gpu_memory_allocated = []
+ gpu_memory_nvidiasmi = []
for i in range(torch.cuda.device_count()):
- _ = torch.tensor([1]).to(i)
- gpu_memory.append(torch.cuda.memory_allocated(i))
- if all(memory == 0 for memory in gpu_memory):
- print("WARNING: some weird GPU memory issue. "
- "Using trick from https://discuss.pytorch.org/t/torch-cuda-memory-allocated-returns-0-if-pytorch-no-cuda-memory-caching-1/188796")
- for i in range(torch.cuda.device_count()):
- torch.zeros(1).to(i)
- free_memory, total_memory = torch.cuda.mem_get_info(i)
- gpu_memory[i] = total_memory - free_memory
- return gpu_memory
+ _ = torch.tensor([1]).to(i) # allocate memory to get more accurate reading from torch
+ gpu_memory_reserved.append(torch.cuda.max_memory_reserved(i))
+ gpu_memory_allocated.append(torch.cuda.max_memory_allocated(i))
+ try:
+ gpu_memory_nvidiasmi.append(_get_gpu_memory_pynvml_all_processes(i))
+ except BaseException as e:
+ warn_once('Could not get memory from pynvml. Maybe try `pip install --force-reinstall gpustat`.', str(e))
+ gpu_memory_nvidiasmi.append(-1)
-def get_device() -> torch.device:
+ if return_all:
+ return gpu_memory_reserved, gpu_memory_allocated, gpu_memory_nvidiasmi
+ else:
+ if any([g > 0 for g in gpu_memory_nvidiasmi]):
+ return gpu_memory_nvidiasmi
+ return gpu_memory_allocated
+
+
+def get_device(avail_devices: str = None) -> torch.device:
"""
Returns the least used GPU device if available else MPS or CPU.
"""
- def _get_device():
+ def _get_device(avail_devices: List[int] = None) -> torch.device:
# get least used gpu by used memory
- if torch.cuda.is_available() and torch.cuda.device_count() > 0:
+ if torch.cuda.is_available() and torch.cuda.device_count() > 0 and len(avail_devices) > 0:
gpu_memory = get_alloc_memory_all_devices()
- device = torch.device(f'cuda:{np.argmin(gpu_memory)}')
+ gpu_memory = [gpu_memory[i] for i in avail_devices]
+ device = torch.device(f'cuda:{avail_devices[np.argmin(gpu_memory)]}')
return device
try:
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
- print("WARNING: MSP support is still experimental. Use at your own risk!")
+ logging.warning("MSP support is still experimental. Use at your own risk!")
return torch.device("mps")
except BaseException:
- print("WARNING: Something went wrong with MPS. Using CPU.")
+ logging.error("Something went wrong with MPS. Using CPU.")
return torch.device("cpu")
# Permanently store the chosen device
if not hasattr(get_device, 'device'):
- get_device.device = _get_device()
+ if avail_devices is not None:
+ avail_devices = [int(d) for d in avail_devices.split(',')]
+ else:
+ avail_devices = list(range(torch.cuda.device_count())) if torch.cuda.is_available() else []
+ get_device.device = _get_device(avail_devices=avail_devices)
print(f'Using device {get_device.device}')
return get_device.device
@@ -112,22 +152,24 @@ def set_random_seed(seed: int) -> None:
print('Could not set cuda seed.')
-def set_random_seed_worker(worker_id) -> None:
+def worker_init_fn(worker_id, num_workers, seed, rank=1):
"""
Sets the seeds for a worker of a dataloader.
+ The seed of each worker is set to: `num_worker * rank + worker_id + seed`
"""
- worker_seed = torch.initial_seed() % 2**32
+ worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
-def create_seeded_dataloader(args, dataset, **dataloader_args) -> torch.utils.data.DataLoader:
+def create_seeded_dataloader(args, dataset, **dataloader_args) -> DataLoader:
"""
Creates a dataloader object from a dataset, setting the seeds for the workers (if `--seed` is set).
Args:
args: the arguments of the program
dataset: the dataset to be loaded
+ verbose: whether to print the number of workers
dataloader_args: external arguments of the dataloader
Returns:
@@ -137,11 +179,14 @@ def create_seeded_dataloader(args, dataset, **dataloader_args) -> torch.utils.da
n_cpus = 4 if not hasattr(os, 'sched_getaffinity') else len(os.sched_getaffinity(0))
num_workers = n_cpus if args.num_workers is None else args.num_workers
dataloader_args['num_workers'] = num_workers if 'num_workers' not in dataloader_args else dataloader_args['num_workers']
+ logging.info(f'Using {dataloader_args["num_workers"]} workers for the dataloader.')
if args.seed is not None:
worker_generator = torch.Generator()
worker_generator.manual_seed(args.seed)
else:
worker_generator = None
dataloader_args['generator'] = worker_generator if 'generator' not in dataloader_args else dataloader_args['generator']
- dataloader_args['worker_init_fn'] = set_random_seed_worker if 'worker_init_fn' not in dataloader_args else dataloader_args['worker_init_fn']
- return torch.utils.data.DataLoader(dataset, **dataloader_args)
+ init_fn = partial(worker_init_fn, num_workers=num_workers, seed=args.seed) if args.seed is not None else None
+ dataloader_args['worker_init_fn'] = init_fn if 'worker_init_fn' not in dataloader_args else dataloader_args['worker_init_fn']
+
+ return DataLoader(dataset, **dataloader_args)
diff --git a/utils/kornia_utils.py b/utils/kornia_utils.py
index 6c85a420..4e209cd6 100644
--- a/utils/kornia_utils.py
+++ b/utils/kornia_utils.py
@@ -4,6 +4,7 @@
import torch
from torchvision import transforms
from kornia.augmentation.container.params import ParamItem
+from kornia.constants import Resample
class KorniaMultiAug(kornia.augmentation.AugmentationSequential):
@@ -87,6 +88,14 @@ def forward(self, *args, **kwargs) -> torch.Tensor:
return self._do_transform(*args, **kwargs)
+def _convert_interpolation_to_resample(interpolation: int) -> int:
+ interpolation_name = transforms.InterpolationMode(interpolation).name
+ if hasattr(Resample, interpolation_name):
+ return getattr(Resample, interpolation_name)
+ else:
+ raise NotImplementedError(f"Interpolation mode {interpolation_name} not supported by Kornia.")
+
+
def to_kornia_transform(transform: transforms.Compose, apply: bool = True) -> Union[List[kornia.augmentation.AugmentationBase2D], KorniaAugNoGrad]:
"""
Converts PIL transforms to Kornia transforms.
@@ -144,6 +153,8 @@ def to_kornia_transform(transform: transforms.Compose, apply: bool = True) -> Un
pass
elif isinstance(t, transforms.Normalize):
ts.append(kornia.augmentation.Normalize(mean=t.mean, std=t.std, p=1))
+ elif isinstance(t, transforms.Resize):
+ ts.append(kornia.augmentation.Resize(size=t.size, antialias=t.antialias, resample=_convert_interpolation_to_resample(t.interpolation)))
else:
raise NotImplementedError
diff --git a/utils/loggers.py b/utils/loggers.py
index d048bd75..ac0e7723 100644
--- a/utils/loggers.py
+++ b/utils/loggers.py
@@ -20,7 +20,7 @@
import wandb
-def log_accs(args, logger, accs, t, setting, epoch=None, prefix="RESULT"):
+def log_accs(args, logger, accs, t, setting, epoch=None, prefix="RESULT", future=False):
"""
Logs the accuracy values and other metrics.
@@ -35,7 +35,10 @@ def log_accs(args, logger, accs, t, setting, epoch=None, prefix="RESULT"):
epoch: The epoch number (optional).
prefix: The prefix for the metrics (default="RESULT").
"""
- mean_acc = print_mean_accuracy(accs, t + 1 if isinstance(t, (float, int)) else t, setting, joint=args.joint, epoch=epoch)
+
+ mean_acc = print_mean_accuracy(accs, t + 1 if isinstance(t, (float, int)) else t,
+ setting, joint=args.joint,
+ epoch=epoch, future=future)
if not args.disable_log:
logger.log(mean_acc)
@@ -43,16 +46,23 @@ def log_accs(args, logger, accs, t, setting, epoch=None, prefix="RESULT"):
if not args.nowand:
postfix = "" if epoch is None else f"_epoch_{epoch}"
- d2 = {f'{prefix}_class_mean_accs{postfix}': mean_acc[0], f'{prefix}_task_mean_accs{postfix}': mean_acc[1],
- **{f'{prefix}_class_acc_{i}{postfix}': a for i, a in enumerate(accs[0])},
- **{f'{prefix}_task_acc_{i}{postfix}': a for i, a in enumerate(accs[1])},
- 'Task': t}
+ if future:
+ prefix += "_transf"
+ if isinstance(mean_acc, float): # domain or gcl
+ d2 = {f'{prefix}_domain_mean_accs{postfix}': mean_acc,
+ **{f'{prefix}_domain_acc_{i}{postfix}': a for i, a in enumerate(accs[0])},
+ 'Task': t}
+ else:
+ d2 = {f'{prefix}_class_mean_accs{postfix}': mean_acc[0], f'{prefix}_task_mean_accs{postfix}': mean_acc[1],
+ **{f'{prefix}_class_acc_{i}{postfix}': a for i, a in enumerate(accs[0])},
+ **{f'{prefix}_task_acc_{i}{postfix}': a for i, a in enumerate(accs[1])},
+ 'Task': t}
wandb.log(d2)
def print_mean_accuracy(accs: np.ndarray, task_number: int,
- setting: str, joint=False, epoch=None) -> None:
+ setting: str, joint=False, epoch=None, future=False) -> None:
"""
Prints the mean accuracy on stderr.
@@ -81,6 +91,7 @@ def print_mean_accuracy(accs: np.ndarray, task_number: int,
print('\tRaw accuracy values: Class-IL {} | Task-IL {}'.format(accs[0], accs[1]), file=sys.stderr)
else:
prefix = "Accuracy" if epoch is None else f"Accuracy (epoch {epoch})"
+ prefix = "Future " + prefix if future else prefix
if setting == 'domain-il' or setting == 'general-continual':
mean_acc, _ = mean_acc
print('{} for {} task(s): [Domain-IL]: {} %'.format(prefix,
@@ -91,7 +102,7 @@ def print_mean_accuracy(accs: np.ndarray, task_number: int,
print('{} for {} task(s): \t [Class-IL]: {} % \t [Task-IL]: {} %'.format(prefix, task_number, round(
mean_acc_class_il, 2), round(mean_acc_task_il, 2)), file=sys.stderr)
print('\tRaw accuracy values: Class-IL {} | Task-IL {}'.format(accs[0], accs[1]), file=sys.stderr)
-
+ print('\n', file=sys.stderr)
return mean_acc
@@ -265,9 +276,12 @@ def log_system_stats(self, cpu_res, gpu_res):
self.cpu_res.append(cpu_res)
if gpu_res is not None:
self.gpu_res.append(gpu_res)
+ gpu_res = {f'GPU_{i}_memory_usage': r for i, r in gpu_res.items()}
+ else:
+ gpu_res = {}
if not self.args.nowand:
- wandb.log({'CPU_memory_usage': cpu_res, **{f'GPU_{i}_memory_usage': r for i, r in gpu_res.items()}})
+ wandb.log({'CPU_memory_usage': cpu_res, **gpu_res})
def write(self, args: Dict[str, Any]) -> None:
"""
diff --git a/utils/main.py b/utils/main.py
index 33415d7b..29d018ff 100644
--- a/utils/main.py
+++ b/utils/main.py
@@ -16,12 +16,13 @@
# LICENSE file in the root directory of this source tree.
# needed (don't change it)
+import logging
import numpy # noqa
+import os
+import sys
import time
import importlib
-import os
import socket
-import sys
import datetime
import uuid
from argparse import ArgumentParser
@@ -34,6 +35,18 @@
sys.path.append(mammoth_path + '/models')
from utils import create_if_not_exists, custom_str_underscore
+from utils.conf import warn_once
+
+if __name__ == '__main__':
+ try:
+ if os.getenv('MAMMOTH_TEST', '0') == '0':
+ from dotenv import load_dotenv
+ load_dotenv()
+ else:
+ warn_once("Running in test mode. Ignoring .env file.")
+ except ImportError:
+ warn_once("Warning: python-dotenv not installed. Ignoring .env file.")
+
from utils.args import add_management_args, add_experiment_args
from utils.conf import base_path, get_device
from utils.distributed import make_dp
@@ -63,8 +76,8 @@ def parse_args():
parser = ArgumentParser(description='mammoth', allow_abbrev=False, add_help=False)
parser.add_argument('--model', type=custom_str_underscore, help='Model name.', choices=list(get_all_models().keys()))
parser.add_argument('--load_best_args', action='store_true',
- help='Loads the best arguments for each method, '
- 'dataset and memory buffer.')
+ help='(deprecated) Loads the best arguments for each method, dataset and memory buffer. '
+ 'NOTE: This option is deprecated and not up to date.')
args = parser.parse_known_args()[0]
models_dict = get_all_models()
@@ -110,20 +123,34 @@ def parse_args():
args.model = models_dict[args.model]
if args.lr_scheduler is not None:
- print('Warning: lr_scheduler set to {}, overrides default from dataset.'.format(args.lr_scheduler), file=sys.stderr)
+ logging.info('`lr_scheduler` set to {}, overrides default from dataset.'.format(args.lr_scheduler))
if args.seed is not None:
set_random_seed(args.seed)
+ # Add uuid, timestamp and hostname for logging
+ args.conf_jobnum = str(uuid.uuid4())
+ args.conf_timestamp = str(datetime.datetime.now())
+ args.conf_host = socket.gethostname()
+
+ # Add the current git commit hash to the arguments if available
+ try:
+ import git
+ repo = git.Repo(path=os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+ args.conf_git_hash = repo.head.object.hexsha
+ except Exception:
+ logging.error("Could not retrieve git hash.")
+ args.conf_git_hash = None
+
if args.savecheck:
assert args.inference_only == 0, "Should not save checkpoint in inference only mode"
if not os.path.isdir('checkpoints'):
create_if_not_exists("checkpoints")
now = time.strftime("%Y%m%d-%H%M%S")
+ uid = args.conf_jobnum.split('-')[0]
extra_ckpt_name = "" if args.ckpt_name is None else f"{args.ckpt_name}_"
- args.ckpt_name = f"{extra_ckpt_name}{args.model}_{args.dataset}_{args.buffer_size if hasattr(args, 'buffer_size') else 0}_{args.n_epochs}_{str(now)}"
- args.ckpt_name_replace = f"{extra_ckpt_name}{args.model}_{args.dataset}_{'{}'}_{args.buffer_size if hasattr(args, 'buffer_size') else 0}__{args.n_epochs}_{str(now)}"
+ args.ckpt_name = f"{extra_ckpt_name}{args.model}_{args.dataset}_{args.buffer_size if hasattr(args, 'buffer_size') else 0}_{args.n_epochs}_{str(now)}_{uid}"
print("Saving checkpoint into", args.ckpt_name, file=sys.stderr)
if args.joint:
@@ -133,8 +160,8 @@ def parse_args():
assert 0 < args.label_perc <= 1, "label_perc must be in (0, 1]"
if args.validation is not None:
- print(f"INFO: Using {args.validation}% of the training set as validation set.", file=sys.stderr)
- print(f"INFO: Validation will be computed with mode `{args.validation_mode}`.", file=sys.stderr)
+ logging.info(f"Using {args.validation}% of the training set as validation set.")
+ logging.info(f"Validation will be computed with mode `{args.validation_mode}`.")
return args
@@ -143,12 +170,13 @@ def main(args=None):
from models import get_model
from datasets import ContinualDataset, get_dataset
from utils.training import train
+ from models.utils.future_model import FutureModel
lecun_fix()
if args is None:
args = parse_args()
- device = get_device()
+ device = get_device(avail_devices=args.device)
args.device = device
# set base path
@@ -156,17 +184,13 @@ def main(args=None):
if args.code_optimization != 0:
torch.set_float32_matmul_precision('high' if args.code_optimization == 1 else 'medium')
- print("INFO: code_optimization is set to", args.code_optimization, file=sys.stderr)
- print(f"Using {torch.get_float32_matmul_precision()} precision for matmul.", file=sys.stderr)
+ logging.info("Code_optimization is set to", args.code_optimization)
+ logging.info(f"Using {torch.get_float32_matmul_precision()} precision for matmul.")
if args.code_optimization == 2:
if not torch.cuda.is_bf16_supported():
raise NotImplementedError('BF16 is not supported on this machine.')
- # Add uuid, timestamp and hostname for logging
- args.conf_jobnum = str(uuid.uuid4())
- args.conf_timestamp = str(datetime.datetime.now())
- args.conf_host = socket.gethostname()
dataset = get_dataset(args)
if args.fitting_mode == 'epochs' and args.n_epochs is None and isinstance(dataset, ContinualDataset):
@@ -191,7 +215,7 @@ def main(args=None):
# from https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
if torch.cuda.get_device_capability()[0] >= 7 and os.name != 'nt':
print("================ Compiling model with torch.compile ================")
- print("WARNING: `torch.compile` may break your code if you change the model after the first run!")
+ logging.warning("`torch.compile` may break your code if you change the model after the first run!")
print("This includes adding classifiers for new tasks, changing the backbone, etc.")
print("ALSO: some models CHANGE the backbone during initialization. Remember to call `torch.compile` again after that.")
print("====================================================================")
@@ -205,6 +229,7 @@ def main(args=None):
loss = dataset.get_loss()
model = get_model(args, backbone, loss, dataset.get_transform())
# model = torch.compile(model)
+ assert isinstance(model, FutureModel) or not args.eval_future, "Model does not support future_forward."
if args.distributed == 'dp':
if args.batch_size < torch.cuda.device_count():
@@ -221,8 +246,13 @@ def main(args=None):
print('Debug mode enabled: running only a few forward steps per epoch with W&B disabled.')
args.nowand = 1
+ if args.wandb_entity is None:
+ args.wandb_entity = os.getenv('WANDB_ENTITY', None)
+ if args.wandb_project is None:
+ args.wandb_project = os.getenv('WANDB_PROJECT', None)
+
if args.wandb_entity is None or args.wandb_project is None:
- print('Warning: wandb_entity and wandb_project not set. Disabling wandb.')
+ logging.warning('`wandb_entity` and `wandb_project` not set. Disabling wandb.')
args.nowand = 1
else:
print('Logging to wandb: {}/{}'.format(args.wandb_entity, args.wandb_project))
diff --git a/utils/simclrloss.py b/utils/simclrloss.py
index 5d74e178..148c6bd0 100644
--- a/utils/simclrloss.py
+++ b/utils/simclrloss.py
@@ -21,14 +21,15 @@ def __init__(self, temperature=0.07, contrast_mode='all',
self.reduction = reduction
def forward(self, features, labels=None, mask=None):
- """Compute loss for model. If both `labels` and `mask` are None,
- it degenerates to SimCLR unsupervised loss:
- https://arxiv.org/pdf/2002.05709.pdf
+ """
+ Compute loss for model. If both `labels` and `mask` are None,
+ it degenerates to SimCLR unsupervised loss: https://arxiv.org/pdf/2002.05709.pdf .
+
Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
- mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
- has the same class as sample i. Can be asymmetric.
+ mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j has the same class as sample i. Can be asymmetric.
+
Returns:
A loss scalar.
"""
diff --git a/utils/stats.py b/utils/stats.py
index 924e5629..9a91a460 100644
--- a/utils/stats.py
+++ b/utils/stats.py
@@ -8,8 +8,7 @@ def get_memory_mb():
Returns:
dict: A dictionary containing the memory usage of the current process and its children.
- The dictionary has the following
- keys:
+ The dictionary has the following keys:
- self: The memory usage of the current process.
- children: The memory usage of the children of the current process.
- total: The total memory usage of the current process and its children.
@@ -34,7 +33,7 @@ def get_memory_gpu_mb():
Get the memory usage of all GPUs in MB.
"""
- return [d / 1024 for d in get_alloc_memory_all_devices()]
+ return [d / 1024 / 1024 for d in get_alloc_memory_all_devices()]
else:
get_memory_gpu_mb = None
except BaseException:
@@ -49,12 +48,15 @@ class track_system_stats:
Tracks both CPU and GPU memory usage if available.
Usage:
- with track_system_stats() as t:
- for i in range(100):
- ... # Do something
- t()
- cpu_res, gpu_res = t.cpu_res, t.gpu_res
+ .. code-block:: python
+
+ with track_system_stats() as t:
+ for i in range(100):
+ ... # Do something
+ t()
+
+ cpu_res, gpu_res = t.cpu_res, t.gpu_res
Args:
logger (Logger): external logger.
@@ -87,19 +89,20 @@ def __enter__(self):
if self.disabled:
return self
self.initial_cpu_res, self.initial_gpu_res = self.get_stats()
- self.initial_gpu_res = {g: g_res for g, g_res in enumerate(self.initial_gpu_res)}
-
- self.avg_gpu_res = self.initial_gpu_res
- self.avg_cpu_res = self.initial_cpu_res
-
- self.max_cpu_res = self.initial_cpu_res
- self.max_gpu_res = self.initial_gpu_res
-
if self.initial_cpu_res is None and self.initial_gpu_res is None:
self.disabled = True
+ else:
+ if self.initial_gpu_res is not None:
+ self.initial_gpu_res = {g: g_res for g, g_res in enumerate(self.initial_gpu_res)}
- if self.logger is not None:
- self.logger.log_system_stats(self.initial_cpu_res, self.initial_gpu_res)
+ self.avg_gpu_res = self.initial_gpu_res
+ self.avg_cpu_res = self.initial_cpu_res
+
+ self.max_cpu_res = self.initial_cpu_res
+ self.max_gpu_res = self.initial_gpu_res
+
+ if self.logger is not None:
+ self.logger.log_system_stats(self.initial_cpu_res, self.initial_gpu_res)
return self
@@ -114,6 +117,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
if self.disabled:
return
+ torch.cuda.synchronize() # this allows to raise errors triggered previously by the GPU
+
cpu_res, gpu_res = self.get_stats()
self.update_stats(cpu_res, gpu_res)
@@ -138,6 +143,7 @@ def update_stats(self, cpu_res, gpu_res):
if self.initial_gpu_res is not None:
self.avg_gpu_res = {g: (g_res + alpha * (g_res - self.avg_gpu_res[g])) for g, g_res in enumerate(gpu_res)}
self.max_gpu_res = {g: max(self.max_gpu_res[g], g_res) for g, g_res in enumerate(gpu_res)}
+ gpu_res = {g: g_res for g, g_res in enumerate(gpu_res)}
if self.logger is not None:
self.logger.log_system_stats(cpu_res, gpu_res)
diff --git a/utils/test_utils.py b/utils/test_utils.py
new file mode 100644
index 00000000..02fa250d
--- /dev/null
+++ b/utils/test_utils.py
@@ -0,0 +1,9 @@
+import os
+import decorator
+
+
+def init_test_environ(func):
+ def wrapper(func, *args, **kwargs):
+ os.environ['MAMMOTH_TEST'] = '1'
+ return func(*args, **kwargs)
+ return decorator.decorator(wrapper, func)
diff --git a/utils/training.py b/utils/training.py
index 640efd6c..cd08d1ca 100644
--- a/utils/training.py
+++ b/utils/training.py
@@ -5,22 +5,22 @@
from copy import deepcopy
import math
+import os
import sys
from argparse import Namespace
+from time import time
from typing import Iterable, Tuple
import torch
+from tqdm import tqdm
from datasets import get_dataset
from datasets.utils.continual_dataset import ContinualDataset
from datasets.utils.gcl_dataset import GCLDataset
from models.utils.continual_model import ContinualModel
-from utils import random_id
from utils.checkpoints import mammoth_load_checkpoint
from utils.loggers import *
from utils.stats import track_system_stats
-from utils.status import ProgressBar
-import time
try:
import wandb
@@ -39,9 +39,10 @@ def mask_classes(outputs: torch.Tensor, dataset: ContinualDataset, k: int) -> No
dataset: the continual dataset
k: the task index
"""
- outputs[:, 0:k * dataset.N_CLASSES_PER_TASK] = -float('inf')
- outputs[:, (k + 1) * dataset.N_CLASSES_PER_TASK:
- dataset.N_TASKS * dataset.N_CLASSES_PER_TASK] = -float('inf')
+ num_classes = dataset.N_CLASSES
+ start_c, end_c = dataset.get_offsets(k)
+ outputs[:, :start_c] = -float('inf')
+ outputs[:, end_c:num_classes] = -float('inf')
@torch.no_grad()
@@ -66,6 +67,8 @@ def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False, retur
n_classes = dataset.get_offsets()[1]
loss_fn = dataset.get_loss()
avg_loss = 0
+ total_len = sum(len(x) for x in dataset.test_loaders) if hasattr(dataset.test_loaders[0], '__len__') else None
+ pbar = tqdm(dataset.test_loaders, total=total_len, desc='Evaluating')
for k, test_loader in enumerate(dataset.test_loaders):
if last and k < len(dataset.test_loaders) - 1:
continue
@@ -84,7 +87,10 @@ def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False, retur
if 'class-il' not in model.COMPATIBILITY and 'general-continual' not in model.COMPATIBILITY:
outputs = model(inputs, k)
else:
- outputs = model(inputs)
+ if model.args.eval_future and k >= model.current_task:
+ outputs = model.future_forward(inputs)
+ else:
+ outputs = model(inputs)
if return_loss:
loss = loss_fn(outputs, labels)
@@ -94,6 +100,9 @@ def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False, retur
correct += torch.sum(pred == labels).item()
total += labels.shape[0]
i += 1
+ pbar.set_postfix({f'acc_task_{k+1}': max(0, correct / total * 100)})
+ pbar.set_description(f"Evaluating Task {k+1}")
+ pbar.update(1)
if dataset.SETTING == 'class-il':
mask_classes(outputs, dataset, k)
@@ -103,6 +112,7 @@ def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False, retur
accs.append(correct / total * 100
if 'class-il' in model.COMPATIBILITY or 'general-continual' in model.COMPATIBILITY else 0)
accs_mask_classes.append(correct_mask_classes / total * 100)
+ pbar.close()
model.net.train(status)
if return_loss:
@@ -120,15 +130,15 @@ def initialize_wandb(args: Namespace) -> None:
assert wandb is not None, "Wandb not installed, please install it or run without wandb"
run_name = args.wandb_name if args.wandb_name is not None else args.model
- run_id = random_id(5)
+ run_id = args.conf_jobnum.split('-')[0]
name = f'{run_name}_{run_id}'
- wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), name=name)
+ mode = 'disabled' if os.getenv('MAMMOTH_TEST', '0') == '1' else os.getenv('WANDB_MODE', 'online')
+ wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), name=name, mode=mode)
args.wandb_url = wandb.run.get_url()
def train_single_epoch(model: ContinualModel,
train_loader: Iterable,
- progress_bar: ProgressBar,
args: Namespace,
epoch: int,
current_task: int,
@@ -141,7 +151,6 @@ def train_single_epoch(model: ContinualModel,
Args:
model: the model to be trained
train_loader: the data loader for the training set
- progress_bar: the progress bar for the current epoch
args: the arguments from the command line
epoch: the current epoch
current_task: the current task index
@@ -155,6 +164,9 @@ def train_single_epoch(model: ContinualModel,
train_iter = iter(train_loader)
i = 0
+ previous_time = time()
+
+ pbar = tqdm(train_iter, total=data_len, desc=f"Task {current_task + 1} - Epoch {epoch + 1}")
while True:
try:
data = next(train_iter)
@@ -162,7 +174,7 @@ def train_single_epoch(model: ContinualModel,
break
if args.debug_mode and i > model.get_debug_iters():
break
- if args.fitting_mode == 'iters' and progress_bar.current_task_iter >= model.args.n_iters:
+ if args.fitting_mode == 'iters' and model.task_iteration >= model.args.n_iters:
break
if hasattr(train_loader.dataset, 'logits'):
@@ -179,17 +191,20 @@ def train_single_epoch(model: ContinualModel,
loss = model.meta_observe(inputs, labels, not_aug_inputs, epoch=epoch)
assert not math.isnan(loss)
- if args.code_optimization == 0:
+ if args.code_optimization == 0 and 'cuda' in str(args.device):
torch.cuda.synchronize()
- progress_bar.prog(i, data_len, epoch, current_task, loss)
system_tracker()
i += 1
+ time_diff = time() - previous_time
+ previous_time = time()
+ ep_h = 3600 / (data_len * time_diff) if data_len else 'N/A'
+ pbar.set_postfix({'loss': loss} if ep_h == 'N/A' else {'loss': loss, 'ep/h': ep_h})
+ pbar.update()
+
if scheduler is not None:
scheduler.step()
- return i
-
def train(model: ContinualModel, dataset: ContinualDataset,
args: Namespace) -> None:
@@ -215,6 +230,9 @@ def train(model: ContinualModel, dataset: ContinualDataset,
with track_system_stats(logger) as system_tracker:
results, results_mask_classes = [], []
+ if args.eval_future:
+ results_transf, results_mask_classes_transf = [], []
+
if args.start_from is not None:
for i in range(args.start_from):
train_loader, _ = dataset.get_data_loaders()
@@ -230,8 +248,6 @@ def train(model: ContinualModel, dataset: ContinualDataset,
print('Checkpoint Loaded!')
- progress_bar = ProgressBar(joint=args.joint, verbose=not args.non_verbose)
-
if args.enable_other_metrics:
dataset_copy = get_dataset(args)
for t in range(dataset.N_TASKS):
@@ -244,15 +260,25 @@ def train(model: ContinualModel, dataset: ContinualDataset,
start_task = 0 if args.start_from is None else args.start_from
end_task = dataset.N_TASKS if args.stop_after is None else args.stop_after
+ if args.eval_future:
+ eval_dataset = get_dataset(args)
+ for _ in range(dataset.N_TASKS):
+ eval_dataset.get_data_loaders()
+ model.change_transform(eval_dataset)
+ del eval_dataset.train_loader
+ else:
+ eval_dataset = dataset
+
torch.cuda.empty_cache()
for t in range(start_task, end_task):
model.net.train()
- train_loader, test_loader = dataset.get_data_loaders()
+ train_loader, _ = dataset.get_data_loaders()
+
model.meta_begin_task(dataset)
- if not args.inference_only:
+ if not args.inference_only and args.n_epochs > 0:
if t and args.enable_other_metrics:
- accs = evaluate(model, dataset, last=True)
+ accs = evaluate(model, eval_dataset, last=True)
results[t - 1] = results[t - 1] + accs[0]
if dataset.SETTING == 'class-il':
results_mask_classes[t - 1] = results_mask_classes[t - 1] + accs[1]
@@ -268,16 +294,16 @@ def train(model: ContinualModel, dataset: ContinualDataset,
if not isinstance(dataset, GCLDataset):
data_len = len(train_loader)
- train_single_epoch(model, train_loader, progress_bar, args, current_task=t, epoch=epoch,
+ train_single_epoch(model, train_loader, args, current_task=t, epoch=epoch,
system_tracker=system_tracker, data_len=data_len, scheduler=scheduler)
epoch += 1
if args.fitting_mode == 'epochs' and epoch >= model.args.n_epochs:
break
- elif args.fitting_mode == 'iters' and progress_bar.current_task_iter >= model.args.n_iters:
+ elif args.fitting_mode == 'iters' and model.task_iteration >= model.args.n_iters:
break
elif args.fitting_mode == 'early_stopping' and epoch % args.early_stopping_freq == 0 and epoch > 0:
- epoch_accs, _, epoch_loss = evaluate(model, dataset, return_loss=True, last=True)
+ epoch_accs, _, epoch_loss = evaluate(model, eval_dataset, return_loss=True, last=True)
if args.early_stopping_metric == 'accuracy':
ea_metric = np.mean(epoch_accs) # Higher accuracy is better
@@ -303,20 +329,31 @@ def train(model: ContinualModel, dataset: ContinualDataset,
cur_stopping_patience = args.early_stopping_patience
if args.eval_epochs is not None and (epoch > 0 or args.eval_epochs == 1) and epoch % args.eval_epochs == 0 and epoch < model.args.n_epochs:
- epoch_accs = evaluate(model, dataset)
+ epoch_accs = evaluate(model, eval_dataset)
log_accs(args, logger, epoch_accs, t, dataset.SETTING, epoch=epoch)
- progress_bar.reset()
-
model.meta_end_task(dataset)
- accs = evaluate(model, dataset)
+ accs = evaluate(model, eval_dataset)
+
+ if args.eval_future and t < dataset.N_TASKS - 1:
+ transf_accs = accs[0][t + 1:], accs[1][t + 1:]
+ accs = accs[0][:t + 1], accs[1][:t + 1]
+ results_transf.append(transf_accs[0])
+ results_mask_classes_transf.append(transf_accs[1])
+
results.append(accs[0])
results_mask_classes.append(accs[1])
log_accs(args, logger, accs, t, dataset.SETTING)
+ if args.eval_future:
+ avg_transf = np.mean([np.mean(task_) for task_ in results_transf])
+ print(f"Transfer Metrics - AVG Transfer {avg_transf:.2f}")
+ if t < dataset.N_TASKS - 1:
+ log_accs(args, logger, transf_accs, t, dataset.SETTING, future=True)
+
if args.savecheck:
save_obj = {
'model': model.state_dict(),
@@ -328,11 +365,14 @@ def train(model: ContinualModel, dataset: ContinualDataset,
if 'buffer_size' in model.args:
save_obj['buffer'] = deepcopy(model.buffer).to('cpu')
- # Saving model checkpoint
- checkpoint_name = f'checkpoints/{args.ckpt_name}_joint.pt' if args.joint else f'checkpoints/{args.ckpt_name}_{t}.pt'
- torch.save(save_obj, checkpoint_name)
-
- del progress_bar
+ # Saving model checkpoint for the current task
+ checkpoint_name = None
+ if args.savecheck == 'task':
+ checkpoint_name = f'checkpoints/{args.ckpt_name}_joint.pt' if args.joint else f'checkpoints/{args.ckpt_name}_{t}.pt'
+ elif args.savecheck == 'last' and t == end_task - 1:
+ checkpoint_name = f'checkpoints/{args.ckpt_name}_joint.pt' if args.joint else f'checkpoints/{args.ckpt_name}_last.pt'
+ if checkpoint_name is not None:
+ torch.save(save_obj, checkpoint_name)
if args.validation:
# Final evaluation on the real test set